Commit 8550ae20 authored by LUSTIG Peter's avatar LUSTIG Peter

added utils for evaluation to branch

parent b3fa9b4b
from __future__ import absolute_import, division, print_function
from pathlib import Path
import os
import numpy as np
import matplotlib.pyplot as plt
from multiprocessing import Pool, cpu_count
from functools import partial
from astropy import units as u
from astropy.io import ascii
from astropy.wcs import WCS
from astropy.utils.console import ProgressBar
from astropy.table import vstack
from scipy.optimize import curve_fit
from nikamap import NikaMap, Jackknife
from nikamap.utils import pos_uniform
from astropy.io import fits
from astropy.table import Table, MaskedColumn
import sys
from mpl_toolkits.axes_grid1 import make_axes_locatable
import dill as pickle
from matplotlib.ticker import FormatStrFormatter
from collections import OrderedDict
def find_nearest(array, values):
x, y = np.meshgrid(array, values)
ev = np.abs(x - y)
return np.argmin(ev, axis=1)
def add_axis(name, range, bins, unit=None, i_axe=3, log=False):
"""Define a dictionnary for additionnal wcs axes (linear or log)"""
header = {'CTYPE{}'.format(i_axe): name,
'CRPIX{}'.format(i_axe): 1,
'CUNIT{}'.format(i_axe): unit}
if log:
# Log scale (edges definition)
log_step = (np.log(range[1]) - np.log(range[0])) / bins
center_start = np.exp(np.log(range[0]) + log_step / 2)
header['CTYPE{}'.format(i_axe)] += '-LOG'
header['CRVAL{}'.format(i_axe)] = center_start
header['CDELT{}'.format(i_axe)] = log_step * center_start
# Log scale (center definition)
# log_step = (np.log(flux_range[1]) - np.log(flux_range[0])) / (bins-1)
# center_start = range[0]
else:
# Linear scale (edges definition)
step = (range[1] - range[0]) / bins
header['CRVAL{}'.format(i_axe)] = range[0] + step / 2
header['CDELT{}'.format(i_axe)] = step
# Linear scale (center definition)
# step = (range[1] - range[0]) / (bins-1)
return header
def completness_purity_wcs(shape, wcs, bins=30,
threshold_range=(0, 1), threshold_bins=10,
threshold_log=False):
"""Build a wcs for the completness_purity function"""
slice_step = np.ceil(np.asarray(shape) / bins).astype(int)
celestial_slice = slice(0, shape[0], slice_step[0]), slice(0, shape[1],
slice_step[1])
# [WIP]: Shall we use a 4D WCS ? (ra/dec flux/threshold)
# [WIP]: -TAB does not seems to be very easy to do with astropy
# Basicaly Working... .
header = wcs[celestial_slice[0], celestial_slice[1]].to_header()
header['WCSAXES'] = 3
header.update(add_axis('THRESHOLD', threshold_range, threshold_bins,
i_axe=3))
return (bins, bins, threshold_bins), WCS(header)
def completness_worker(shape, wcs, sources, fake_sources, min_threshold=2,
max_threshold=5):
"""Compute completness from the fake source catalog
Parameters
----------
shape : tuple
the shape of the resulting image
sources : :class:`astropy.table.Table`
the detected sources
fake_sources : :class:`astropy.table.Table`
the fake sources table, with corresponding mask
min_threshold : float
the minimum SNR threshold requested
max_threshold : float
the maximum SNR threshold requested
Returns
-------
_completness, _norm_comp
corresponding 2D :class:`numpy.ndarray`
"""
# If one wanted to used a histogramdd, one would need a threshold axis
# covering ALL possible SNR, otherwise loose flux, or cap the thresholds...
fake_snr = np.ma.array(sources[fake_sources['find_peak'].filled(0)]['SNR'],
mask=fake_sources['find_peak'].mask)
# As we are interested by the cumulative numbers, keep all inside the
# upper pixel
fake_snr[fake_snr > max_threshold] = max_threshold
# print(fake_snr)
# TODO: Consider keeping all pixels information in fake_source and source...
# This would imply to do only a simple wcs_threshold here...
xx, yy, zz = wcs.wcs_world2pix(fake_sources['ra'], fake_sources['dec'],
fake_snr.filled(min_threshold), 0)
# Number of fake sources recovered
_completness, _ = np.histogramdd(np.asarray([xx, yy, zz]).T + 0.5,
bins=np.asarray(shape),
range=list(zip([0]*len(shape), shape)),
weights=~fake_sources['find_peak'].mask)
# Reverse cumulative sum to get all sources at the given threshold
_completness = np.cumsum(_completness[..., ::-1], axis=2)[..., ::-1]
# Number of fake sources (independant of threshold)
_norm_comp, _, _ = np.histogram2d(xx + 0.5, yy + 0.5,
bins=np.asarray(shape[0:2]),
range=list(zip([0]*2, shape[0:2])))
return _completness, _norm_comp
def purity_worker(shape, wcs, sources, max_threshold=2):
"""Compute completness from the fake source catalog
Parameters
----------
shape : tuple
the shape of the resulting image
sources : :class:`astropy.table.Table`
the detected sources table, with corresponding match
max_threshold : float
the maximum threshold requested
Returns
-------
_completness, _norm_comp
corresponding 2D :class:`numpy.ndarray`
"""
sources_snr = sources['SNR']
# As we are interested by the cumulative numbers, keep all inside the
# upper pixel
sources_snr[sources_snr > max_threshold] = max_threshold
xx, yy, zz = wcs.wcs_world2pix(sources['ra'], sources['dec'],
sources_snr, 0)
# Number of fake sources recovered
_purity, _ = np.histogramdd(np.asarray([xx, yy, zz]).T + 0.5,
bins=np.asarray(shape),
range=list(zip([0]*len(shape), shape)),
weights=~sources['fake_sources'].mask)
# Revese cumulative sum...
_purity = np.cumsum(_purity[..., ::-1], axis=2)[..., ::-1]
# Number of total detected sources at a given threshold
_norm_pur, _ = np.histogramdd(np.asarray([xx, yy, zz]).T + 0.5,
bins=np.asarray(shape),
range=list(zip([0]*len(shape), shape)))
_norm_pur = np.cumsum(_norm_pur[..., ::-1], axis=2)[..., ::-1]
return _purity, _norm_pur
def completness_purity(sources, fake_sources, wcs=None,
shape=None):
"""Compute completness map for a given flux"""
# print(flux)
# wcs_celestial = wcs.celestial
# Lower and upper edges ... take the center of the pixel for the upper edge
min_threshold, max_threshold = wcs.sub([3]).all_pix2world([-0.5, shape[2]-1], 0)[0]
completness = np.zeros(shape, dtype=np.float)
norm_comp = np.zeros(shape[0:2], dtype=np.float)
purity = np.zeros(shape, dtype=np.float)
norm_pur = np.zeros(shape, dtype=np.float)
# %load_ext snakeviz
# %snakeviz the following line.... all is spend in the find_peaks /
# fit_2d_gaussian
# TODO: Change the find_peaks routine, or maybe just the
# fit_2d_gaussian to be FAST ! (Maybe look into gcntrd.pro routine
# or photutils.centroid.centroid_1dg maybe ?)
_completness, _norm_comp = completness_worker(shape, wcs, sources,
fake_sources,
min_threshold,
max_threshold)
# print(_completness)
completness += _completness
norm_comp += _norm_comp
_purity, _norm_pur = purity_worker(shape, wcs, sources, max_threshold)
purity += _purity
norm_pur += _norm_pur
# norm can be 0, so to avoid warning on invalid values...
with np.errstate(divide='ignore', invalid='ignore'):
completness /= norm_comp[..., np.newaxis]
purity /= norm_pur
# TODO: One should probably return completness AND norm if one want to
# combine several fluxes
return completness, purity
def Plot_CompPur(completness, purity, threshold, nsim=None, savename=None,
flux=None):
threshold_bins = completness.shape[-1]
fig, axes = plt.subplots(nrows=2, ncols=threshold_bins, sharex=True,
sharey=True)
for i in range(threshold_bins):
axes[0, i].imshow(completness[:, :, i], vmin=0, vmax=1)
im = axes[1, i].imshow(purity[:, :, i], vmin=0, vmax=1)
axes[1, i].set_xlabel("thresh={:.2f}".format(threshold[i]))
if i == (threshold_bins-1):
# print('-----------')
divider = make_axes_locatable(axes[1, i])
cax = divider.append_axes('right', size='5%', pad=0.0)
fig = plt.gcf()
fig.colorbar(im, cax=cax, orientation='vertical')
if nsim is not None:
axes[0, 0].set_title("{} simulations".format(nsim))
if flux is not None:
axes[0, 1].set_title("{}".format(flux))
axes[0, 0].set_ylabel("completness")
axes[1, 0].set_ylabel("purity")
if savename is not None:
plt.savefig(savename)
def Evaluate(flux_ds_fs_list):
# _flux = u.Quantity(pr.header['flux{}'.format(isimu)])
_flux = flux_ds_fs_list[0]
sources = flux_ds_fs_list[1]
fake_sources = flux_ds_fs_list[2]
fluxval = _flux.to_value(u.mJy)
'''
_flux.to_value(u.mJy)
sources = Table.read(hdul['DETECTED_SOURCES{}'.format(_flux)])
fake_sources = Table.read(hdul['FAKE_SOURCES{}'.format(_flux)])
'''
print('{} data loaded'.format(_flux))
# sources = df['DETECTED_SOURCES{}'.format(flux)]
# print(fake_sources)
# sys.exit()
completness, purity = completness_purity(sources, fake_sources,
wcs=wcs_4D.sub([1, 2, 3]),
shape=shape_4D[0:3])
return fluxval, completness, purity
def Evaluate2(_flux, hdul):
print(fits.HDUList(hdul).info())
# _flux = u.Quantity(pr.header['flux{}'.format(isimu)])
fluxval = _flux.to_value(u.mJy)
sources = Table.read(hdul['DETECTED_SOURCES{}'.format(_flux)])
fake_sources = Table.read(hdul['FAKE_SOURCES{}'.format(_flux)])
print('{} data loaded'.format(_flux))
completness, purity = completness_purity(sources, fake_sources,
wcs=wcs_4D.sub([1, 2, 3]),
shape=shape_4D[0:3])
print(fluxval, completness, purity)
return fluxval, completness, purity
def find_nearest(array, values):
x, y = np.meshgrid(array, values)
ev = np.abs(x - y)
return np.argmin(ev, axis=1)
def PlotEvaluation(data, title='', flux=np.array([]), thresh=[],
nfluxlabels=None, nthreshlabels=None, **kwargs):
tickfs = 20
labelfs = 25
if nfluxlabels is not None:
_label_flux = np.geomspace(flux[0], flux[-1], nfluxlabels)
_f_idx = find_nearest(flux, _label_flux)
flblpos, _flbl = _f_idx, flux[_f_idx]
else:
flblpos, _flbl = np.arange(len(flux)), flux
if nthreshlabels is not None:
_label_thresh = np.linspace(thresh[0], thresh[-1], nthreshlabels)
_t_idx = find_nearest(thresh, _label_thresh)
tlblpos, _tlbl = _t_idx, thresh[_t_idx]
else:
tlblpos, _tlbl = np.arange(len(thresh)), thresh
flbl = []
for i in range(len(_flbl)):
flbl.append('{:.1f}'.format(_flbl[i]))
tlbl = []
for i in range(len(_tlbl)):
tlbl.append('{:.1f}'.format(_tlbl[i]))
# print(i, tlbl[i])
# print(np.array([tlblpos, tlbl]).T)
plt.figure()
plt.title(title, fontsize=30)
plt.xlabel('Detection Threshold [SNR]', fontsize=labelfs)
plt.ylabel('Flux [mJy]', fontsize=labelfs)
plt.xticks(tlblpos, tlbl, fontsize=tickfs)
# ax = plt.gca()
plt.yticks(flblpos, flbl, fontsize=tickfs)
# ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
plt.imshow(data, origin='lower', **kwargs)
cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=tickfs)
def PlotFixedThreshold(thresholds, bin, completness, allthresholds, flux,
nfluxlabels=None, hlines=None):
linestyles = ['-', '--', '-.', ':']
real_thresholds = find_nearest(allthresholds, thresholds)
for i in range(len(real_thresholds)):
_x = flux
_y = completness[:, bin[0], bin[1], real_thresholds[i]]
plt.plot(_x, _y, linestyle=linestyles[i],
label='{:.1f}'.format(allthresholds[real_thresholds[i]]))
if hlines is not None:
for i, val in enumerate(hlines):
plt.axhline(val, color='r')
plt.title('Fixed Threshold', fontsize=30, y=1.02)
plt.xlabel('Source Flux [mJy]', fontsize=25)
plt.ylabel('Completness', fontsize=25)
plt.yticks(fontsize=20)
plt.xticks(fontsize=20)
plt.subplots_adjust(left=0.12)
ax = plt.gca()
ax.set_xscale("log", nonposx='clip')
# legend = plt.legend(fontsize=25, title='SNR', loc='lower right')
legend = plt.legend(fontsize=25, title='SNR', loc='upper left',
framealpha=1)
plt.setp(legend.get_title(), fontsize=25)
plt.show(block=True)
# plt.close('all')
# _data = next(Jackknife(filenames, n=None))
# # TODO: Should in principle be the same, but is not... check....
# _ = plt.hist((data.data - _data.data)[~data.mask & ~_data.mask],
# bins=1000, range=[-0.1, 0.1], log=True)
# Create the flux and threshold axes...
if __name__ == "__main__":
DATA_DIR = "/home/peter/Dokumente/Uni/Paris/Stage/data/v_1"
data = NikaMap.read(Path(DATA_DIR) / '..' / 'map.fits')
bins = 19
flux_bins = 2
flux_range = [0.1, 10]
threshold_bins = 5
threshold_range = [2, 5]
threshold_range = [3, 7.5]
# Does not really make sense... better define edges
fluxes = np.logspace(np.log10(flux_range[0]), np.log10(flux_range[1]),
flux_bins)*u.mJy
fluxes_edges = np.logspace(np.log10(flux_range[0]), np.log10(flux_range[1]),
flux_bins + 1)*u.mJy
# Does not really make sense... better define edges
threshold = np.linspace(threshold_range[0], threshold_range[1], threshold_bins)
threshold_edges = np.linspace(threshold_range[0], threshold_range[1],
threshold_bins+1)
shape_4D, wcs_4D = completness_purity_wcs(data.shape, data.wcs, bins=bins,
flux_range=flux_range,
flux_bins=flux_bins, flux_log=True,
threshold_range=threshold_range,
threshold_bins=threshold_bins)
# Testing the lower edges
wcs_threshold = wcs_4D.sub([3])
assert np.all(np.abs(wcs_threshold.all_pix2world(np.arange(threshold_bins+1)-0.5, 0) - threshold_edges) < 1e-15)
wcs_flux = wcs_4D.sub([4])
assert np.all(np.abs(wcs_flux.all_pix2world(np.arange(flux_bins+1)-0.5, 0) - fluxes_edges.value) < 1e-13)
# DEBUG :
# flux, nsources, within, wcs, shape, nsim, jk_filenames = 10*u.mJy, 8**2,
# (0, 1), wcs_4D.sub([1, 2, 3]), shape_4D[0:3], np.multiply(*shape_4D[0:2])
# * 100, filenames
# This is a single run check for a single flux
hdul = fits.HDUList(fits.open('/home/peter/Dokumente/Uni/Paris/Stage/FirstSteps/'
'Completness/combined_tables_long.fits'))
# sys.exit()
nfluxes = hdul[0].header['NFLUXES']
print('{} different fluxes found'.format(nfluxes))
# Get fluxlist:
indata = []
for isimu in range(nfluxes):
_FLUX = u.Quantity(hdul[0].header['flux{}'.format(isimu)])
_SOURCES = Table.read(hdul['DETECTED_SOURCES{}'
.format(_FLUX)])
_FAKE_SOURCES = Table.read(hdul['FAKE_SOURCES{}'
.format(_FLUX)])
indata.append([_FLUX, _SOURCES, _FAKE_SOURCES])
'''
flux = []
for isimu in range(nfluxes):
flux.append(u.Quantity(hdul[0].header['flux{}'.format(isimu)]))
'''
# helpfunc = partial(Evaluate, **{'hdul': hdul})
p = Pool(cpu_count())
res = p.map(Evaluate, indata)
res = list(zip(*res))
FLUX = np.array(res[0])
COMPLETNESS = np.array(res[1])
PURITY = np.array(res[2])
idxsort = np.argsort(FLUX)
FLUX = FLUX[idxsort]
COMPLETNESS = COMPLETNESS[idxsort]
PURITY = PURITY[idxsort]
midbin = int(bins/2)
# sys.exit()
# %% PlotFigure
PlotEvaluation(COMPLETNESS[:, midbin, midbin, :], title='Completness',
flux=np.array(FLUX), thresh=threshold, cmap='bone',
nfluxlabels=10, nthreshlabels=5, aspect='auto')
PlotEvaluation(PURITY[:, midbin, midbin, :], title='Purity',
flux=np.array(FLUX), thresh=threshold, cmap='bone',
nfluxlabels=10, nthreshlabels=5, aspect='auto')
plt.show(block=True)
# %% 2D plot
PlotFixedThreshold(np.array((3, 5, 7)), (midbin, midbin), COMPLETNESS,
threshold, np.array(FLUX), nfluxlabels=None,
hlines=[.5, .8, .9])
# %% Plot map
# Plot_CompPur(completness, purity, threshold, nsim=None, savename=None,
# flux=None):
plotidx = np.array([24, 30, 40, 49])
# plotthreshidx = np.array([0, 4, 8])
plotcomp = COMPLETNESS[plotidx]
plotpur = PURITY[plotidx]
# plotthresh = threshold[plotthreshidx]
plotflux = np.array(FLUX[plotidx])
for i in range(len(plotidx)):
_comp = plotcomp[i]
_pur = plotpur[i]
Plot_CompPur(_comp, _pur, threshold, flux=plotflux[i])
plt.show(block=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment