Commit 8e9816af authored by LUSTIG Peter's avatar LUSTIG Peter

fixed map error in parallel execution

parent fe06153c
...@@ -5,7 +5,7 @@ import os ...@@ -5,7 +5,7 @@ import os
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from multiprocessing import Pool, cpu_count from multiprocess import Pool, cpu_count
from functools import partial from functools import partial
from astropy import units as u from astropy import units as u
...@@ -24,6 +24,9 @@ import argparse ...@@ -24,6 +24,9 @@ import argparse
import sys import sys
import datetime import datetime
from astropy.io.fits import ImageHDU from astropy.io.fits import ImageHDU
from copy import deepcopy, copy
# import dill
''' '''
%load_ext autoreload %load_ext autoreload
%autoreload 2 %autoreload 2
...@@ -33,7 +36,28 @@ from astropy.io.fits import ImageHDU ...@@ -33,7 +36,28 @@ from astropy.io.fits import ImageHDU
plt.ion() plt.ion()
def fake_worker(jkiter, min_threshold=2, nsources=8**2, flux=1*u.Jy, def get_size(obj, seen=None):
"""Recursively finds size of objects"""
size = sys.getsizeof(obj)
if seen is None:
seen = set()
obj_id = id(obj)
if obj_id in seen:
return 0
# Important mark as seen *before* entering recursion to gracefully handle
# self-referential objects
seen.add(obj_id)
if isinstance(obj, dict):
size += sum([get_size(v, seen) for v in obj.values()])
size += sum([get_size(k, seen) for k in obj.keys()])
elif hasattr(obj, '__dict__'):
size += get_size(obj.__dict__, seen)
elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)):
size += sum([get_size(i, seen) for i in obj])
return size
def fake_worker(img, min_threshold=2, nsources=8**2, flux=1*u.Jy,
within=(0, 1), cat_gen=pos_uniform, parity_threshold=1, within=(0, 1), cat_gen=pos_uniform, parity_threshold=1,
retmask=True, **kwargs): retmask=True, **kwargs):
"""The completness purity worker, create a fake dataset from an jackknifed """The completness purity worker, create a fake dataset from an jackknifed
...@@ -48,7 +72,12 @@ def fake_worker(jkiter, min_threshold=2, nsources=8**2, flux=1*u.Jy, ...@@ -48,7 +72,12 @@ def fake_worker(jkiter, min_threshold=2, nsources=8**2, flux=1*u.Jy,
**kwargs : **kwargs :
arguments for source injection arguments for source injection
""" """
img = jkiter(parity_threshold) # print('using Jackknife object', jkiter)
'''
_jkiter = copy(jkiter)
print('using Jackknife object', _jkiter)
img = _jkiter(parity_threshold)
'''
# img = IMAGE # img = IMAGE
# Renormalize the stddev # Renormalize the stddev
std = img.check_SNR() std = img.check_SNR()
...@@ -87,13 +116,14 @@ flux = np.geomspace(1, 10, 3) * u.mJy ...@@ -87,13 +116,14 @@ flux = np.geomspace(1, 10, 3) * u.mJy
flux = [0] flux = [0]
flux = np.linspace(1, 3, 2)*u.mJy flux = np.linspace(1, 3, 2)*u.mJy
flux = np.array([10]) * u.mJy flux = np.array([10]) * u.mJy
nsim = 4 nsim = 6
min_detection_threshold = 3 min_detection_threshold = 3
nsources = 2 nsources = 2
nsources = 0 nsources = 3
outdir = Path("montecarlo_results/") outdir = Path("montecarlo_results/")
outdir = Path("testdir") outdir = Path("testdir")
ncores = 4 ncores = 2
parity_threshold = 1
timeprefix = '{date:%y%m%d_%H%M%S}_'.format(date=datetime.datetime.now()) timeprefix = '{date:%y%m%d_%H%M%S}_'.format(date=datetime.datetime.now())
...@@ -132,16 +162,29 @@ print(message, "\n") ...@@ -132,16 +162,29 @@ print(message, "\n")
print("Creating Jackkife Object") print("Creating Jackkife Object")
t0 = clock() t0 = clock()
jk_iter = Jackknife(jk_filenames, n=nsim) jk_iter = Jackknife(jk_filenames, n=nsim)
nm = jk_iter() print('Done in {:.2f}s'.format(clock()-t0))
min_source_dist = 2 * nm.beam.fwhm_pix.value
print('Done in {:.2f}s'.format(clock()-t0)) '''
print('Begin Monte Carlo')
jk_iter_list = [jk_iter] * nsim jk_iter_list = [jk_iter] * nsim
jk_iter_list = [jk_iter]
for i in range(1, ncores):
jk_iter_list.append(Jackknife(jk_filenames, n=nsim))
jk_iter_list = jk_iter_list * int(np.ceil(nsim / ncores))
jk_iter_list = jk_iter_list[:nsim]
print(jk_iter_list)
'''
print("Creating {} Jackkife Maps".format(nsim))
jackknifes = []
for i in range(nsim):
jackknifes.append(jk_iter(parity_threshold=parity_threshold))
print('Done in {:.2f}s'.format(clock()-t0))
min_source_dist = 2 * jackknifes[0].beam.fwhm_pix.value
p = Pool(ncores) p = Pool(ncores)
globmask = None
print('Begin Monte Carlo')
if nsources != 0: if nsources != 0:
for _flux in flux: for _flux in flux:
...@@ -156,7 +199,7 @@ if nsources != 0: ...@@ -156,7 +199,7 @@ if nsources != 0:
print('Simulation with {:.2f}...'.format(_flux)) print('Simulation with {:.2f}...'.format(_flux))
if(1): if(1):
res = p.map(helpfunc, jk_iter_list) res = p.map(helpfunc, jackknifes)
res = list(zip(*res)) res = list(zip(*res))
(DETECTED_SOURCES, (DETECTED_SOURCES,
...@@ -234,7 +277,7 @@ if nsources == 0: ...@@ -234,7 +277,7 @@ if nsources == 0:
print('Simulation without sources...') print('Simulation without sources...')
if(1): if(1):
res = p.map(helpfunc, jk_iter_list) res = p.map(helpfunc, jackknifes)
res = list(zip(*res)) res = list(zip(*res))
DETECTED_SOURCES = res[0] DETECTED_SOURCES = res[0]
...@@ -255,7 +298,6 @@ if nsources == 0: ...@@ -255,7 +298,6 @@ if nsources == 0:
matchmask.append(_matchmask) matchmask.append(_matchmask)
bar.update() bar.update()
immask = np.sum(immask, axis=0) immask = np.sum(immask, axis=0)
print(np.unique(immask)) print(np.unique(immask))
matchmask = np.sum(matchmask, axis=0) matchmask = np.sum(matchmask, axis=0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment