Commit e501b9ac authored by LUSTIG Peter's avatar LUSTIG Peter

Evaluation is now parallised. Loading data shouldbe modified

parent 8132caa4
from __future__ import absolute_import, division, print_function
from pathlib import Path
import os
import numpy as np
import matplotlib.pyplot as plt
from multiprocessing import Pool, cpu_count
from functools import partial
from astropy import units as u
from astropy.io import ascii
from astropy.wcs import WCS
from astropy.utils.console import ProgressBar
from astropy.table import vstack
from scipy.optimize import curve_fit
from nikamap import NikaMap, Jackknife
from nikamap.utils import pos_uniform
from astropy.io import fits
from astropy.table import Table, MaskedColumn
import sys
from mpl_toolkits.axes_grid1 import make_axes_locatable
import os
os.getcwd()
'''
%load_ext autoreload
%autoreload 2
%matplotlib tk
'''
plt.ion()
def Plot_CompPur(completness, purity, threshold, nsim=None, savename=None,
flux=None):
threshold_bins = completness.shape[-1]
fig, axes = plt.subplots(nrows=2, ncols=threshold_bins, sharex=True,
sharey=True)
for i in range(threshold_bins):
axes[0, i].imshow(completness[:, :, i], vmin=0, vmax=1)
im = axes[1, i].imshow(purity[:, :, i], vmin=0, vmax=1)
axes[1, i].set_xlabel("thresh={:.2f}".format(threshold[i]))
if i == (threshold_bins-1):
# print('-----------')
divider = make_axes_locatable(axes[1, i])
cax = divider.append_axes('right', size='5%', pad=0.0)
fig = plt.gcf()
fig.colorbar(im, cax=cax, orientation='vertical')
if nsim is not None:
axes[0, 0].set_title("{} simulations".format(nsim))
if flux is not None:
axes[0, 1].set_title("{}".format(flux))
axes[0, 0].set_ylabel("completness")
axes[1, 0].set_ylabel("purity")
if savename is not None:
plt.savefig(savename)
def add_axis(name, range, bins, unit=None, i_axe=3, log=False):
"""Define a dictionnary for additionnal wcs axes (linear or log)"""
header = {'CTYPE{}'.format(i_axe): name,
'CRPIX{}'.format(i_axe): 1,
'CUNIT{}'.format(i_axe): unit}
if log:
# Log scale (edges definition)
log_step = (np.log(range[1]) - np.log(range[0])) / bins
center_start = np.exp(np.log(range[0]) + log_step / 2)
header['CTYPE{}'.format(i_axe)] += '-LOG'
header['CRVAL{}'.format(i_axe)] = center_start
header['CDELT{}'.format(i_axe)] = log_step * center_start
# Log scale (center definition)
# log_step = (np.log(flux_range[1]) - np.log(flux_range[0])) / (bins-1)
# center_start = range[0]
else:
# Linear scale (edges definition)
step = (range[1] - range[0]) / bins
header['CRVAL{}'.format(i_axe)] = range[0] + step / 2
header['CDELT{}'.format(i_axe)] = step
# Linear scale (center definition)
# step = (range[1] - range[0]) / (bins-1)
return header
def completness_purity_wcs(shape, wcs, bins=30,
flux_range=(0, 1), flux_bins=10, flux_log=False,
threshold_range=(0, 1), threshold_bins=10, threshold_log=False):
"""Build a wcs for the completness_purity function"""
slice_step = np.ceil(np.asarray(shape) / bins).astype(int)
celestial_slice = slice(0, shape[0], slice_step[0]), slice(0, shape[1], slice_step[1])
# [WIP]: Shall we use a 4D WCS ? (ra/dec flux/threshold)
# [WIP]: -TAB does not seems to be very easy to do with astropy
# Basicaly Working... .
header = wcs[celestial_slice[0], celestial_slice[1]].to_header()
header['WCSAXES'] = 4
header.update(add_axis('THRESHOLD', threshold_range, threshold_bins, i_axe=3))
header.update(add_axis('FLUX', flux_range, flux_bins, i_axe=4, log=True))
return (bins, bins, threshold_bins, flux_bins), WCS(header)
def completness_worker(shape, wcs, sources, fake_sources, min_threshold=2,
max_threshold=5):
"""Compute completness from the fake source catalog
Parameters
----------
shape : tuple
the shape of the resulting image
sources : :class:`astropy.table.Table`
the detected sources
fake_sources : :class:`astropy.table.Table`
the fake sources table, with corresponding mask
min_threshold : float
the minimum SNR threshold requested
max_threshold : float
the maximum SNR threshold requested
Returns
-------
_completness, _norm_comp
corresponding 2D :class:`numpy.ndarray`
"""
# If one wanted to used a histogramdd, one would need a threshold axis
# covering ALL possible SNR, otherwise loose flux, or cap the thresholds...
fake_snr = np.ma.array(sources[fake_sources['find_peak'].filled(0)]['SNR'],
mask=fake_sources['find_peak'].mask)
# As we are interested by the cumulative numbers, keep all inside the
# upper pixel
fake_snr[fake_snr > max_threshold] = max_threshold
# print(fake_snr)
# TODO: Consider keeping all pixels information in fake_source and source...
# This would imply to do only a simple wcs_threshold here...
xx, yy, zz = wcs.wcs_world2pix(fake_sources['ra'], fake_sources['dec'],
fake_snr.filled(min_threshold), 0)
# Number of fake sources recovered
_completness, _ = np.histogramdd(np.asarray([xx, yy, zz]).T + 0.5,
bins=np.asarray(shape),
range=list(zip([0]*len(shape), shape)),
weights=~fake_sources['find_peak'].mask)
# Reverse cumulative sum to get all sources at the given threshold
_completness = np.cumsum(_completness[..., ::-1], axis=2)[..., ::-1]
# Number of fake sources (independant of threshold)
_norm_comp, _, _ = np.histogram2d(xx + 0.5, yy + 0.5,
bins=np.asarray(shape[0:2]),
range=list(zip([0]*2, shape[0:2])))
return _completness, _norm_comp
def purity_worker(shape, wcs, sources, max_threshold=2):
"""Compute completness from the fake source catalog
Parameters
----------
shape : tuple
the shape of the resulting image
sources : :class:`astropy.table.Table`
the detected sources table, with corresponding match
max_threshold : float
the maximum threshold requested
Returns
-------
_completness, _norm_comp
corresponding 2D :class:`numpy.ndarray`
"""
sources_snr = sources['SNR']
# As we are interested by the cumulative numbers, keep all inside the
# upper pixel
sources_snr[sources_snr > max_threshold] = max_threshold
xx, yy, zz = wcs.wcs_world2pix(sources['ra'], sources['dec'],
sources_snr, 0)
# Number of fake sources recovered
_purity, _ = np.histogramdd(np.asarray([xx, yy, zz]).T + 0.5,
bins=np.asarray(shape),
range=list(zip([0]*len(shape), shape)),
weights=~sources['fake_sources'].mask)
# Revese cumulative sum...
_purity = np.cumsum(_purity[..., ::-1], axis=2)[..., ::-1]
# Number of total detected sources at a given threshold
_norm_pur, _ = np.histogramdd(np.asarray([xx, yy, zz]).T + 0.5,
bins=np.asarray(shape),
range=list(zip([0]*len(shape), shape)))
_norm_pur = np.cumsum(_norm_pur[..., ::-1], axis=2)[..., ::-1]
return _purity, _norm_pur
def completness_purity(sources, fake_sources, wcs=None,
shape=None):
"""Compute completness map for a given flux"""
# print(flux)
# wcs_celestial = wcs.celestial
# Lower and upper edges ... take the center of the pixel for the upper edge
min_threshold, max_threshold = wcs.sub([3]).all_pix2world([-0.5, shape[2]-1], 0)[0]
completness = np.zeros(shape, dtype=np.float)
norm_comp = np.zeros(shape[0:2], dtype=np.float)
purity = np.zeros(shape, dtype=np.float)
norm_pur = np.zeros(shape, dtype=np.float)
# %load_ext snakeviz
# %snakeviz the following line.... all is spend in the find_peaks /
# fit_2d_gaussian
# TODO: Change the find_peaks routine, or maybe just the
# fit_2d_gaussian to be FAST ! (Maybe look into gcntrd.pro routine
# or photutils.centroid.centroid_1dg maybe ?)
_completness, _norm_comp = completness_worker(shape, wcs, sources,
fake_sources,
min_threshold,
max_threshold)
# print(_completness)
completness += _completness
norm_comp += _norm_comp
_purity, _norm_pur = purity_worker(shape, wcs, sources, max_threshold)
purity += _purity
norm_pur += _norm_pur
# norm can be 0, so to avoid warning on invalid values...
with np.errstate(divide='ignore', invalid='ignore'):
completness /= norm_comp[..., np.newaxis]
purity /= norm_pur
# TODO: One should probably return completness AND norm if one want to
# combine several fluxes
return completness, purity
def Evaluate(flux_ds_fs_list):
# _flux = u.Quantity(pr.header['flux{}'.format(isimu)])
_flux = flux_ds_fs_list[0]
sources = flux_ds_fs_list[1]
fake_sources = flux_ds_fs_list[2]
fluxval = _flux.to_value(u.mJy)
'''
_flux.to_value(u.mJy)
sources = Table.read(hdul['DETECTED_SOURCES{}'.format(_flux)])
fake_sources = Table.read(hdul['FAKE_SOURCES{}'.format(_flux)])
'''
print('{} data loaded'.format(_flux))
# sources = df['DETECTED_SOURCES{}'.format(flux)]
# print(fake_sources)
# sys.exit()
completness, purity = completness_purity(sources, fake_sources,
wcs=wcs_4D.sub([1, 2, 3]),
shape=shape_4D[0:3])
return fluxval, completness, purity
plt.close('all')
# _data = next(Jackknife(filenames, n=None))
# # TODO: Should in principle be the same, but is not... check....
# _ = plt.hist((data.data - _data.data)[~data.mask & ~_data.mask],
# bins=1000, range=[-0.1, 0.1], log=True)
# Create the flux and threshold axes...
DATA_DIR = "/home/peter/Dokumente/Uni/Paris/Stage/data/v_1"
data = NikaMap.read(Path(DATA_DIR) / '..' / 'map.fits')
bins = 9
flux_bins = 2
flux_range = [0.1, 10]
threshold_bins = 10
threshold_range = [2, 5]
threshold_range = [3, 7.5]
# Does not really make sense... better define edges
fluxes = np.logspace(np.log10(flux_range[0]), np.log10(flux_range[1]),
flux_bins)*u.mJy
fluxes_edges = np.logspace(np.log10(flux_range[0]), np.log10(flux_range[1]),
flux_bins + 1)*u.mJy
# Does not really make sense... better define edges
threshold = np.linspace(threshold_range[0], threshold_range[1], threshold_bins)
threshold_edges = np.linspace(threshold_range[0], threshold_range[1],
threshold_bins+1)
shape_4D, wcs_4D = completness_purity_wcs(data.shape, data.wcs, bins=bins,
flux_range=flux_range,
flux_bins=flux_bins, flux_log=True,
threshold_range=threshold_range,
threshold_bins=threshold_bins)
# Testing the lower edges
wcs_threshold = wcs_4D.sub([3])
assert np.all(np.abs(wcs_threshold.all_pix2world(np.arange(threshold_bins+1)-0.5, 0) - threshold_edges) < 1e-15)
wcs_flux = wcs_4D.sub([4])
assert np.all(np.abs(wcs_flux.all_pix2world(np.arange(flux_bins+1)-0.5, 0) - fluxes_edges.value) < 1e-13)
# DEBUG :
# flux, nsources, within, wcs, shape, nsim, jk_filenames = 10*u.mJy, 8**2,
# (0, 1), wcs_4D.sub([1, 2, 3]), shape_4D[0:3], np.multiply(*shape_4D[0:2])
# * 100, filenames
# This is a single run check for a single flux
hdul = fits.open('/home/peter/Dokumente/Uni/Paris/Stage/FirstSteps/'
'Completness/combined_tables_long.fits')
nfluxes = hdul[0].header['NFLUXES']
print('{} different fluxes found'.format(nfluxes))
# Get fluxlist:
indata = []
for isimu in range(nfluxes):
_FLUX = u.Quantity(hdul[0].header['flux{}'.format(isimu)])
_SOURCES = Table.read(hdul['DETECTED_SOURCES{}'
.format(_FLUX)])
_FAKE_SOURCES = Table.read(hdul['FAKE_SOURCES{}'
.format(_FLUX)])
indata.append([_FLUX, _SOURCES, _FAKE_SOURCES])
# helpfunc = partial(Evaluate, **{'hdul': hdul})
p = Pool(cpu_count())
res = p.map(Evaluate, indata)
res = list(zip(*res))
FLUX = np.array(res[0])
COMPLETNESS = np.array(res[1])
PURITY = np.array(res[2])
idxsort = np.argsort(FLUX)
FLUX = FLUX[idxsort]
COMPLETNESS = COMPLETNESS[idxsort]
PURITY = PURITY[idxsort]
midbin = int(bins/2)
print(midbin)
# sys.exit()
# %% PlotFigure
def PlotEvaluation(data, title='', flux=[], thresh=[], **kwargs):
tickfs = 20
labelfs = 25
plt.figure()
plt.title(title, fontsize=30)
plt.xlabel('Detection Threshold [SNR]', fontsize=labelfs)
plt.ylabel('Flux [mJy]', fontsize=labelfs)
plt.xticks(np.arange(len(thresh)), thresh, fontsize=tickfs)
plt.yticks(np.arange(len(flux)), flux, fontsize=tickfs)
plt.imshow(data, origin='lower', **kwargs)
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=tickfs)
plt.show(block=True)
PlotEvaluation(COMPLETNESS[:, midbin, midbin, :], title='Completness',
flux=list(FLUX), thresh=threshold, cmap='bone')
# cmap bone hot
plt.show(block=True)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment