monte_carlo.py 7.89 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
from __future__ import absolute_import, division, print_function

from pathlib import Path
import os
import numpy as np
import matplotlib.pyplot as plt

from multiprocessing import Pool, cpu_count
from functools import partial

from astropy import units as u
from astropy.io import ascii
from astropy.wcs import WCS
from astropy.utils.console import ProgressBar
from astropy.table import vstack, Table

from scipy.optimize import curve_fit

from nikamap import NikaMap, Jackknife
from nikamap.utils import pos_uniform
from astropy.io import fits
22
from time import clock
23
import argparse
24 25
import sys
import datetime
26 27 28 29 30 31 32 33 34 35 36 37
'''
%load_ext autoreload
%autoreload 2
%matplotlib tk
'''

plt.ion()




def fake_worker(jkiter, min_threshold=2, nsources=8**2, flux=1*u.Jy,
LUSTIG Peter's avatar
LUSTIG Peter committed
38 39
                within=(0, 1), cat_gen=pos_uniform, parity_threshold=1,
                **kwargs):
40 41 42 43 44 45 46 47 48 49 50 51
    """The completness purity worker, create a fake dataset from an jackknifed
       image and return catalogs

    Parameters
    ----------
    img : :class:`nikamap.NikaMap`
        Jackknifed dataset
    min_threshold : float
        minimum threshold for the detection
    **kwargs :
        arguments for source injection
    """
LUSTIG Peter's avatar
LUSTIG Peter committed
52
    img = jkiter(parity_threshold)
53 54 55 56 57 58
    # img = IMAGE
    # Renormalize the stddev
    std = img.check_SNR()
    img.uncertainty.array *= std

    # Actually rather slow... maybe check the code ?
LUSTIG Peter's avatar
LUSTIG Peter committed
59
    # print(flux)
60 61 62
    if flux is not None:
        img.add_gaussian_sources(nsources=nsources, within=within,
                                 peak_flux=flux, cat_gen=cat_gen, **kwargs)
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79

    # ... match filter it ...
    mf_img = img.match_filter(img.beam)
    std = mf_img.check_SNR()
    # print(std)
    mf_img.uncertainty.array *= std
    # print(mf_img.wcs)

    # ... and detect sources with the lowest threshold...
    # The gaussian fit from subpixel=True is very slow here...
    mf_img.detect_sources(threshold=min_threshold)

    return mf_img.sources, mf_img.fake_sources


plt.close('all')

80 81 82 83 84
DATA_DIR_SERVER = Path("/data/NIKA/Reduced/"
                       "HLS091828_common_mode_one_block/v_1")
DATA_DIR_MYPC = Path("/home/peter/Dokumente/Uni/Paris/Stage/data/v_1")

flux = np.geomspace(1, 10, 3) * u.mJy
85
flux = [0]
86
nsim = 2
87
min_detection_threshold = 3
88 89 90 91 92 93 94 95 96 97 98 99
nsources = 5
outdir = Path("montecarlo_results/")
outdir = Path("testdir")
ncores = 2

timeprefix = '{date:%y%m%d_%H%M%S}_'.format(date=datetime.datetime.now())

if not outdir.exists():
    print('creating directory {}'.format(outdir))
    outdir.mkdir()


100
''' load maps for jackkifing'''
101 102 103 104 105 106 107 108 109 110 111 112
if DATA_DIR_SERVER.exists():
    DATA_DIR = DATA_DIR_SERVER
elif DATA_DIR_MYPC.exists():
    DATA_DIR = DATA_DIR_MYPC
else:
    sys.exit("Raw data path not found. Exit.")


jk_filenames = list(Path(DATA_DIR).glob('*/map.fits'))
for i in range(len(jk_filenames)):
    jk_filenames[i] = str(jk_filenames[i])

113 114 115 116
message = '''\
###############################################################
#                  Running Monte Carlo Tests                  #
#                                                             #
117
#          No Simulations per flux:    {:5d}                  #
118 119
#          Number of CPUs used:        {:5d}                  #
#          Minimum Detection Threshold:{:5.1f}                  #
120
#          No Different Fluxes:        {:5d}                  #
121 122
#          Number of Injected Sources: {:5d}                  #
###############################################################
123
'''.format(nsim, ncores, min_detection_threshold, len(flux), nsources)
124 125


126 127 128
print(message, "\n")
print("Creating Jackkife Object")
t0 = clock()
129
jk_iter = Jackknife(jk_filenames, n=nsim)
130 131
nm = jk_iter()
min_source_dist = 2 * nm.beam.fwhm_pix.value
132

133

134
print('Done in {:.2f}s'.format(clock()-t0))
135
print('Begin Monte Carlo')
136
jk_iter_list = [jk_iter] * nsim
LUSTIG Peter's avatar
LUSTIG Peter committed
137
p = Pool(ncores)
138

139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
if flux != [0]:
    for _flux in flux:

        helpfunc = partial(fake_worker, **{'min_threshold':
                                           min_detection_threshold,
                                           'nsources': nsources, 'flux': _flux,
                                           'within': (0, 1),
                                           'cat_gen': pos_uniform,
                                           'dist_threshold': min_source_dist,
                                           'parity_threshold': 0.5})

        print('Simulation with {:.2f}...'.format(_flux))

        if(1):
            res = p.map(helpfunc, jk_iter_list)
            res = list(zip(*res))

            DETECTED_SOURCES, FAKE_SOURCES = res[:]
        else:
            DETECTED_SOURCES = []
            FAKE_SOURCES = []
            with ProgressBar(nsim) as bar:
                for iloop in range(nsim):
                    tmpsource, tmpfakesource = helpfunc(jk_iter)
                    DETECTED_SOURCES.append(tmpsource)
                    FAKE_SOURCES.append(tmpfakesource)
                    bar.update()

        # To merge all the fake_sources and sources catalogs
        fake_sources = Table()
        sources = Table()
        for _fake, _detected in zip(FAKE_SOURCES[:], DETECTED_SOURCES[:]):
            n_fake = len(fake_sources)
            n_detected = len(sources)

            if _detected is not None:
                _detected['ID'] = _detected['ID'] + n_detected
                _detected['fake_sources'] = _detected['fake_sources'] + n_fake
                sources = vstack([sources, _detected])

            _fake['ID'] = _fake['ID'] + n_fake
            _fake['find_peak'] = _fake['find_peak'] + n_detected

            fake_sources = vstack([fake_sources, _fake])

        fname = ('flux{:.2f}mJy_thresh{}_nsim{}.fits'
                 .format(_flux.to_value(unit=u.mJy),
                         min_detection_threshold, nsim))

        fname = timeprefix + fname
        outfile = outdir / fname

        phdu = fits.PrimaryHDU()
        phdu.header['influx'] = '{}'.format(_flux)
        phdu.header['nsim'] = nsim
        phdu.header['sourcespersim'] = nsources
        phdu.header['dthresh'] = min_detection_threshold
        hdul = [phdu]
        if len(sources) > 0:
            hdul.append(fits.BinTableHDU(data=sources,
                                         name='Detected_Sources'))
        hdul.append(fits.BinTableHDU(data=fake_sources, name='Fake_Sources'))
        hdul = fits.HDUList(hdul)
        hdul.writeto(outfile, overwrite=False)
        print('results written to {}'.format(outfile))

if flux == [0]:
206 207
    helpfunc = partial(fake_worker, **{'min_threshold':
                                       min_detection_threshold,
208
                                       'flux': None,
209 210 211
                                       'within': (0, 1),
                                       'parity_threshold': 0.5})

212
    print('Simulation without sources...')
213 214 215 216 217

    if(1):
        res = p.map(helpfunc, jk_iter_list)
        res = list(zip(*res))

218
        DETECTED_SOURCES = res[0]
219 220 221 222
    else:
        DETECTED_SOURCES = []
        with ProgressBar(nsim) as bar:
            for iloop in range(nsim):
223
                tmpsource, _ = helpfunc(jk_iter)
224 225 226 227 228
                DETECTED_SOURCES.append(tmpsource)
                bar.update()

    # To merge all the fake_sources and sources catalogs
    sources = Table()
229
    for _detected in DETECTED_SOURCES:
230 231 232 233 234 235
        n_detected = len(sources)

        if _detected is not None:
            _detected['ID'] = _detected['ID'] + n_detected
            sources = vstack([sources, _detected])

236 237
    fname = ('nosources_thresh{}_nsim{}.fits'
             .format(min_detection_threshold, nsim))
238 239 240 241 242

    fname = timeprefix + fname
    outfile = outdir / fname

    phdu = fits.PrimaryHDU()
243
    phdu.header['influx'] = 0
244
    phdu.header['nsim'] = nsim
245
    phdu.header['sourcespersim'] = 0
246 247 248
    phdu.header['dthresh'] = min_detection_threshold
    hdul = [phdu]
    if len(sources) > 0:
249 250
        hdul.append(fits.BinTableHDU(data=sources,
                                     name='Detected_Sources'))
251 252 253
    hdul = fits.HDUList(hdul)
    hdul.writeto(outfile, overwrite=False)
    print('results written to {}'.format(outfile))