monte_carlo.py 10.6 KB
Newer Older
1 2 3 4 5 6 7
from __future__ import absolute_import, division, print_function

from pathlib import Path
import os
import numpy as np
import matplotlib.pyplot as plt

8
from multiprocess import Pool, cpu_count
9 10 11 12 13 14 15 16 17 18 19 20 21
from functools import partial

from astropy import units as u
from astropy.io import ascii
from astropy.wcs import WCS
from astropy.utils.console import ProgressBar
from astropy.table import vstack, Table

from scipy.optimize import curve_fit

from nikamap import NikaMap, Jackknife
from nikamap.utils import pos_uniform
from astropy.io import fits
22
from time import clock
23
import argparse
24 25
import sys
import datetime
26
from astropy.io.fits import ImageHDU
27 28 29
from copy import deepcopy, copy
# import dill

30 31 32 33 34 35 36 37 38
'''
%load_ext autoreload
%autoreload 2
%matplotlib tk
'''

plt.ion()


39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
def get_size(obj, seen=None):
    """Recursively finds size of objects"""
    size = sys.getsizeof(obj)
    if seen is None:
        seen = set()
    obj_id = id(obj)
    if obj_id in seen:
        return 0
    # Important mark as seen *before* entering recursion to gracefully handle
    # self-referential objects
    seen.add(obj_id)
    if isinstance(obj, dict):
        size += sum([get_size(v, seen) for v in obj.values()])
        size += sum([get_size(k, seen) for k in obj.keys()])
    elif hasattr(obj, '__dict__'):
        size += get_size(obj.__dict__, seen)
    elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)):
        size += sum([get_size(i, seen) for i in obj])
    return size


def fake_worker(img, min_threshold=2, nsources=8**2, flux=1*u.Jy,
LUSTIG Peter's avatar
LUSTIG Peter committed
61
                within=(0, 1), cat_gen=pos_uniform, parity_threshold=1,
62
                retmask=True, **kwargs):
63 64 65 66 67 68 69 70 71 72 73 74
    """The completness purity worker, create a fake dataset from an jackknifed
       image and return catalogs

    Parameters
    ----------
    img : :class:`nikamap.NikaMap`
        Jackknifed dataset
    min_threshold : float
        minimum threshold for the detection
    **kwargs :
        arguments for source injection
    """
75 76 77 78 79 80
    # print('using Jackknife object', jkiter)
    '''
    _jkiter = copy(jkiter)
    print('using Jackknife object', _jkiter)
    img = _jkiter(parity_threshold)
    '''
81 82 83 84 85 86
    # img = IMAGE
    # Renormalize the stddev
    std = img.check_SNR()
    img.uncertainty.array *= std

    # Actually rather slow... maybe check the code ?
LUSTIG Peter's avatar
LUSTIG Peter committed
87
    # print(flux)
88
    if flux is not None and (nsources != 0):
89 90
        img.add_gaussian_sources(nsources=nsources, within=within,
                                 peak_flux=flux, cat_gen=cat_gen, **kwargs)
91 92 93 94 95 96 97 98 99 100 101 102

    # ... match filter it ...
    mf_img = img.match_filter(img.beam)
    std = mf_img.check_SNR()
    # print(std)
    mf_img.uncertainty.array *= std
    # print(mf_img.wcs)

    # ... and detect sources with the lowest threshold...
    # The gaussian fit from subpixel=True is very slow here...
    mf_img.detect_sources(threshold=min_threshold)

103 104 105 106
    if retmask:
        return mf_img.sources, mf_img.fake_sources, img.mask, mf_img.mask
    else:
        return mf_img.sources, mf_img.fake_sources
107 108 109 110


plt.close('all')

111 112 113 114 115
DATA_DIR_SERVER = Path("/data/NIKA/Reduced/"
                       "HLS091828_common_mode_one_block/v_1")
DATA_DIR_MYPC = Path("/home/peter/Dokumente/Uni/Paris/Stage/data/v_1")

flux = np.geomspace(1, 10, 3) * u.mJy
116
flux = [0]
117 118
flux = np.linspace(1, 3, 2)*u.mJy
flux = np.array([10]) * u.mJy
119
nsim = 4
120
min_detection_threshold = 3
121
nsources = 2
122
nsources = 0
123 124
outdir = Path("montecarlo_results/")
outdir = Path("testdir")
125 126
ncores = 2
parity_threshold = 1
127 128 129 130 131 132 133 134

timeprefix = '{date:%y%m%d_%H%M%S}_'.format(date=datetime.datetime.now())

if not outdir.exists():
    print('creating directory {}'.format(outdir))
    outdir.mkdir()


135
''' load maps for jackkifing'''
136 137 138 139 140 141 142 143 144 145 146 147
if DATA_DIR_SERVER.exists():
    DATA_DIR = DATA_DIR_SERVER
elif DATA_DIR_MYPC.exists():
    DATA_DIR = DATA_DIR_MYPC
else:
    sys.exit("Raw data path not found. Exit.")


jk_filenames = list(Path(DATA_DIR).glob('*/map.fits'))
for i in range(len(jk_filenames)):
    jk_filenames[i] = str(jk_filenames[i])

148 149 150 151
message = '''\
###############################################################
#                  Running Monte Carlo Tests                  #
#                                                             #
152
#          No Simulations per flux:    {:5d}                  #
153 154
#          Number of CPUs used:        {:5d}                  #
#          Minimum Detection Threshold:{:5.1f}                  #
155
#          No Different Fluxes:        {:5d}                  #
156 157
#          Number of Injected Sources: {:5d}                  #
###############################################################
158
'''.format(nsim, ncores, min_detection_threshold, len(flux), nsources)
159 160


161 162 163
print(message, "\n")
print("Creating Jackkife Object")
t0 = clock()
164
jk_iter = Jackknife(jk_filenames, n=nsim)
165
print('Done in {:.2f}s'.format(clock()-t0))
166

167

168
'''
169
jk_iter_list = [jk_iter] * nsim
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
jk_iter_list = [jk_iter]

for i in range(1, ncores):
    jk_iter_list.append(Jackknife(jk_filenames, n=nsim))
jk_iter_list = jk_iter_list * int(np.ceil(nsim / ncores))
jk_iter_list = jk_iter_list[:nsim]
print(jk_iter_list)
'''
print("Creating {} Jackkife Maps".format(nsim))
jackknifes = []
for i in range(nsim):
    jackknifes.append(jk_iter(parity_threshold=parity_threshold))
print('Done in {:.2f}s'.format(clock()-t0))
min_source_dist = 2 * jackknifes[0].beam.fwhm_pix.value

LUSTIG Peter's avatar
LUSTIG Peter committed
185
p = Pool(ncores)
186

187
print('Begin Monte Carlo')
188
if nsources != 0:
189 190 191 192 193 194 195 196 197 198 199 200 201
    for _flux in flux:

        helpfunc = partial(fake_worker, **{'min_threshold':
                                           min_detection_threshold,
                                           'nsources': nsources, 'flux': _flux,
                                           'within': (0, 1),
                                           'cat_gen': pos_uniform,
                                           'dist_threshold': min_source_dist,
                                           'parity_threshold': 0.5})

        print('Simulation with {:.2f}...'.format(_flux))

        if(1):
202
            res = p.map(helpfunc, jackknifes)
203 204
            res = list(zip(*res))

205 206 207 208
            (DETECTED_SOURCES,
             FAKE_SOURCES,
             immask,
             matchmask) = res
209 210 211
        else:
            DETECTED_SOURCES = []
            FAKE_SOURCES = []
212 213
            immask = []
            matchmask = []
214 215
            with ProgressBar(nsim) as bar:
                for iloop in range(nsim):
216 217 218 219
                    (tmpsource,
                     tmpfakesource,
                     _immask,
                     _matchmask) = helpfunc(jk_iter)
220 221
                    DETECTED_SOURCES.append(tmpsource)
                    FAKE_SOURCES.append(tmpfakesource)
222 223
                    immask.append(_immask)
                    matchmask.append(_matchmask)
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
                    bar.update()

        # To merge all the fake_sources and sources catalogs
        fake_sources = Table()
        sources = Table()
        for _fake, _detected in zip(FAKE_SOURCES[:], DETECTED_SOURCES[:]):
            n_fake = len(fake_sources)
            n_detected = len(sources)

            if _detected is not None:
                _detected['ID'] = _detected['ID'] + n_detected
                _detected['fake_sources'] = _detected['fake_sources'] + n_fake
                sources = vstack([sources, _detected])

            _fake['ID'] = _fake['ID'] + n_fake
            _fake['find_peak'] = _fake['find_peak'] + n_detected

            fake_sources = vstack([fake_sources, _fake])
242 243 244
        immask = np.sum(immask, axis=0)
        print(np.unique(immask))
        matchmask = np.sum(matchmask, axis=0)
245 246 247 248 249 250 251 252

        fname = ('flux{:.2f}mJy_thresh{}_nsim{}.fits'
                 .format(_flux.to_value(unit=u.mJy),
                         min_detection_threshold, nsim))

        fname = timeprefix + fname
        outfile = outdir / fname

253
        phdu = fits.PrimaryHDU(header=jackknifes[0].wcs.to_header())
254 255 256 257 258 259 260 261 262
        phdu.header['influx'] = '{}'.format(_flux)
        phdu.header['nsim'] = nsim
        phdu.header['sourcespersim'] = nsources
        phdu.header['dthresh'] = min_detection_threshold
        hdul = [phdu]
        if len(sources) > 0:
            hdul.append(fits.BinTableHDU(data=sources,
                                         name='Detected_Sources'))
        hdul.append(fits.BinTableHDU(data=fake_sources, name='Fake_Sources'))
263 264
        hdul.append(ImageHDU(data=immask, name='imagemask'))
        hdul.append(ImageHDU(data=matchmask, name='matchmask'))
265 266 267 268
        hdul = fits.HDUList(hdul)
        hdul.writeto(outfile, overwrite=False)
        print('results written to {}'.format(outfile))

269
if nsources == 0:
270 271
    helpfunc = partial(fake_worker, **{'min_threshold':
                                       min_detection_threshold,
272
                                       'flux': None,
273
                                       'nsources': 0,
274 275 276
                                       'within': (0, 1),
                                       'parity_threshold': 0.5})

277
    print('Simulation without sources...')
278 279

    if(1):
280
        res = p.map(helpfunc, jackknifes)
281 282
        res = list(zip(*res))

283
        DETECTED_SOURCES = res[0]
284 285
        immask = res[2]
        matchmask = res[3]
286 287
    else:
        DETECTED_SOURCES = []
288 289
        immask = []
        matchmask = []
290 291
        with ProgressBar(nsim) as bar:
            for iloop in range(nsim):
292 293 294 295
                (tmpsource,
                 tmpfakesource,
                 _immask,
                 _matchmask) = helpfunc(jk_iter)
296
                DETECTED_SOURCES.append(tmpsource)
297 298
                immask.append(_immask)
                matchmask.append(_matchmask)
299 300
                bar.update()

301 302 303
    immask = np.sum(immask, axis=0)
    print(np.unique(immask))
    matchmask = np.sum(matchmask, axis=0)
304 305
    # To merge all the fake_sources and sources catalogs
    sources = Table()
306
    for _detected in DETECTED_SOURCES:
307 308 309 310 311 312
        n_detected = len(sources)

        if _detected is not None:
            _detected['ID'] = _detected['ID'] + n_detected
            sources = vstack([sources, _detected])

313 314
    fname = ('nosources_thresh{}_nsim{}.fits'
             .format(min_detection_threshold, nsim))
315 316 317 318

    fname = timeprefix + fname
    outfile = outdir / fname

319
    phdu = fits.PrimaryHDU(header=jackknifes[0].wcs.to_header())
320
    phdu.header['influx'] = 0
321
    phdu.header['nsim'] = nsim
322
    phdu.header['sourcespersim'] = 0
323 324 325
    phdu.header['dthresh'] = min_detection_threshold
    hdul = [phdu]
    if len(sources) > 0:
326 327
        hdul.append(fits.BinTableHDU(data=sources,
                                     name='Detected_Sources'))
328 329 330 331

    hdul.append(ImageHDU(data=immask, name='imagemask'))
    hdul.append(ImageHDU(data=matchmask, name='matchmask'))

332 333 334
    hdul = fits.HDUList(hdul)
    hdul.writeto(outfile, overwrite=False)
    print('results written to {}'.format(outfile))