Commit 9068535f authored by LUSTIG Peter's avatar LUSTIG Peter

works now without fake sources

parent edca8b56
......@@ -30,21 +30,32 @@ from utils import find_nearest
class PCEvaluation:
def __init__(self, sources, fake_sources, shape, wcs, flux=None,
mapbins=19, threshold_bins=5, threshold_range=(3, 5)):
def __init__(self, sources, fake_sources=None, shape=None, wcs=None,
flux=None, mapbins=19, threshold_bins=5,
threshold_range=(3, 5)):
idxsort = np.argsort(flux.to_value(u.mJy))
self.flux = flux[idxsort]
self.sources = [sources[i] for i in idxsort]
self.fake_sources = [fake_sources[i] for i in idxsort]
self.flux = flux
self.sources = sources
self.fake_sources = fake_sources
if self.flux is not None:
idxsort = np.argsort(flux.to_value(u.mJy))
self.flux = flux[idxsort]
self.sources = [sources[i] for i in idxsort]
if fake_sources is not None:
self.fake_sources = [fake_sources[i] for i in idxsort]
self.completness = None
self.purity = None
assert len(sources) == len(fake_sources), ("Number of results for "
"sources and fake "
"sources is not the same.")
assert len(sources) == len(flux), ("Number of provided fluxes differs "
"from number of simulation results")
if fake_sources is not None:
assert len(sources) == len(fake_sources), ("Number of results for "
"sources and fake "
"sources is not the "
"same.")
if flux is not None:
assert len(sources) == len(flux), ("Number of provided fluxes "
"differs from number of "
"simulation results")
assert type(mapbins) is int, "number of bins must be an integer"
self.bins = mapbins
......@@ -64,7 +75,11 @@ class PCEvaluation:
assert np.all(np.abs(wcs_threshold.all_pix2world(
np.arange(threshold_bins+1)-0.5, 0)
- threshold_edges) < 1e-15)
self.completness, self.purity, self.hitmap = self.GetCP()
(self.completness,
self.purity,
self.hitmap,
self.detectmap) = self.GetCP()
def GetCompletnessBin(self, xbin, ybin):
return self.completness[:, ybin, xbin, :]
......@@ -86,20 +101,40 @@ class PCEvaluation:
if shape is None:
shape = self.shape_3D
if type(sources) is Table:
sources = [sources]
if fake_sources is None or type(fake_sources) is Table:
fake_sources = [fake_sources]
if pool is not None:
f = partial(self.completness_purity, wcs=wcs, shape=shape)
res = pool.starmap(f, (sources, fake_sources))
res = list(zip(*res))
return res[0], res[1], res[2]
else:
comp, pur, hitm = [], [], []
comp, pur, hitm, absdet = [], [], [], []
for i in range(len(sources)):
tmpres = self.completness_purity(sources[i], fake_sources[i],
wcs, shape)
comp.append(tmpres[0])
pur.append(tmpres[1])
hitm.append(tmpres[2])
return np.array(comp), np.array(pur), np.array(hitm)
absdet.append(tmpres[3])
return (np.array(comp), np.array(pur), np.array(hitm),
np.array(absdet))
def false_counts(self, sources=None, wcs=None, shape=None):
if sources is None:
sources = self.sources
if wcs is None:
wcs = self.wcs_3D
if shape is None:
shape = self.shape_3D
min_threshold, max_threshold = wcs.sub([3]).all_pix2world(
[-0.5, shape[2]-1],
0)[0]
return purity_worker(shape, wcs, sources, max_threshold)
def completness_purity(self, sources, fake_sources, wcs=None,
shape=None):
......@@ -115,7 +150,6 @@ class PCEvaluation:
# TODO: Change the find_peaks routine, or maybe just the
# fit_2d_gaussian to be FAST ! (Maybe look into gcntrd.pro routine
# or photutils.centroid.centroid_1dg maybe ?)
completness, norm_comp = completness_worker(shape, wcs, sources,
fake_sources,
min_threshold,
......@@ -125,12 +159,12 @@ class PCEvaluation:
# norm can be 0, so to avoid warning on invalid values...
with np.errstate(divide='ignore', invalid='ignore'):
completness /= norm_comp[..., np.newaxis]
purity /= norm_pur
if completness is not None:
completness /= norm_comp[..., np.newaxis]
if purity is not None:
purity /= norm_pur
# TODO: One should probably return completness AND norm if one want to
# combine several fluxes
return completness, purity, norm_comp
return completness, purity, norm_comp, norm_pur
def PlotBin(self, data, title='', flux=None, thresh=None,
nfluxlabels=None, nthreshlabels=None, **kwargs):
......@@ -267,6 +301,7 @@ def UglyLoader(filename):
if __name__ == '__main__':
DATA_DIR = "/home/peter/Dokumente/Uni/Paris/Stage/data/v_1"
data = NikaMap.read(Path(DATA_DIR) / '..' / 'map.fits')
sh = data.data.shape
......@@ -276,7 +311,7 @@ if __name__ == '__main__':
'FirstSteps/Completness/NEWcombined_tables_long.fits')
fname = ('/home/peter/Dokumente/Uni/Paris/Stage/'
'FirstSteps/Completness/NEW_combine_fct_result.fits')
'''
FLUX, SOURCE, FSOURCE = UglyLoader(fname)
xx = PCEvaluation(SOURCE, FSOURCE, sh, wcs, FLUX, mapbins=19,
......@@ -297,3 +332,11 @@ if __name__ == '__main__':
ylabel='Purity')
plt.show(block=True)
'''
fname = ('/home/peter/Dokumente/Uni/Paris/Stage/'
'FirstSteps/Completness/montecarlo_results/nocources/'
'180417_174539_nosources_thresh2.5_nsim5000.fits')
sources = Table.read(fname, 1)
PCEvaluation(sources, None, sh, wcs, None, mapbins=19,
threshold_bins=6, threshold_range=(2.5, 5))
......@@ -85,7 +85,7 @@ def completness_purity_wcs(shape, wcs, bins=30,
return (bins, bins, threshold_bins), WCS(header)
def completness_worker(shape, wcs, sources, fake_sources, min_threshold=2,
def completness_worker(shape, wcs, sources, fake_sources=None, min_threshold=2,
max_threshold=5):
"""Compute completness from the fake source catalog
......@@ -109,31 +109,34 @@ def completness_worker(shape, wcs, sources, fake_sources, min_threshold=2,
"""
# If one wanted to used a histogramdd, one would need a threshold axis
# covering ALL possible SNR, otherwise loose flux, or cap the thresholds...
fake_snr = np.ma.array(sources[fake_sources['find_peak'].filled(0)]['SNR'],
mask=fake_sources['find_peak'].mask)
# As we are interested by the cumulative numbers, keep all inside the
# upper pixel
fake_snr[fake_snr > max_threshold] = max_threshold
# print(fake_snr)
# TODO: Consider keeping all pixels information in fake_source and source...
# This would imply to do only a simple wcs_threshold here...
xx, yy, zz = wcs.wcs_world2pix(fake_sources['ra'], fake_sources['dec'],
fake_snr.filled(min_threshold), 0)
# Number of fake sources recovered
_completness, _ = np.histogramdd(np.asarray([xx, yy, zz]).T + 0.5,
bins=np.asarray(shape),
range=list(zip([0]*len(shape), shape)),
weights=~fake_sources['find_peak'].mask)
# Reverse cumulative sum to get all sources at the given threshold
_completness = np.cumsum(_completness[..., ::-1], axis=2)[..., ::-1]
# Number of fake sources (independant of threshold)
_norm_comp, _, _ = np.histogram2d(xx + 0.5, yy + 0.5,
bins=np.asarray(shape[0:2]),
range=list(zip([0]*2, shape[0:2])))
if fake_sources is not None:
fake_snr = np.ma.array(sources[fake_sources['find_peak'].filled(0)]['SNR'],
mask=fake_sources['find_peak'].mask)
# As we are interested by the cumulative numbers, keep all inside the
# upper pixel
fake_snr[fake_snr > max_threshold] = max_threshold
# print(fake_snr)
# TODO: Consider keeping all pixels information in fake_source and source...
# This would imply to do only a simple wcs_threshold here...
xx, yy, zz = wcs.wcs_world2pix(fake_sources['ra'], fake_sources['dec'],
fake_snr.filled(min_threshold), 0)
# Number of fake sources recovered
_completness, _ = np.histogramdd(np.asarray([xx, yy, zz]).T + 0.5,
bins=np.asarray(shape),
range=list(zip([0]*len(shape), shape)),
weights=~fake_sources['find_peak'].mask)
# Reverse cumulative sum to get all sources at the given threshold
_completness = np.cumsum(_completness[..., ::-1], axis=2)[..., ::-1]
# Number of fake sources (independant of threshold)
_norm_comp, _, _ = np.histogram2d(xx + 0.5, yy + 0.5,
bins=np.asarray(shape[0:2]),
range=list(zip([0]*2, shape[0:2])))
else:
_completness, _norm_comp = None, None
return _completness, _norm_comp
......@@ -163,14 +166,16 @@ def purity_worker(shape, wcs, sources, max_threshold=2):
sources_snr, 0)
# Number of fake sources recovered
_purity, _ = np.histogramdd(np.asarray([xx, yy, zz]).T + 0.5,
bins=np.asarray(shape),
range=list(zip([0]*len(shape), shape)),
weights=~sources['fake_sources'].mask)
# Revese cumulative sum...
_purity = np.cumsum(_purity[..., ::-1], axis=2)[..., ::-1]
if 'fake_sources' in sources.keys():
_purity, _ = np.histogramdd(np.asarray([xx, yy, zz]).T + 0.5,
bins=np.asarray(shape),
range=list(zip([0]*len(shape), shape)),
weights=~sources['fake_sources'].mask)
# Revese cumulative sum...
_purity = np.cumsum(_purity[..., ::-1], axis=2)[..., ::-1]
else:
_purity = None
# Number of total detected sources at a given threshold
_norm_pur, _ = np.histogramdd(np.asarray([xx, yy, zz]).T + 0.5,
bins=np.asarray(shape),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment