from __future__ import absolute_import, division, print_function from pathlib import Path import os import numpy as np import matplotlib.pyplot as plt from multiprocessing import Pool, cpu_count from functools import partial from astropy import units as u from astropy.io import ascii from astropy.wcs import WCS from astropy.utils.console import ProgressBar from astropy.table import vstack from scipy.optimize import curve_fit from nikamap import NikaMap, Jackknife from nikamap.utils import pos_uniform from astropy.io import fits from astropy.table import Table, MaskedColumn import sys from mpl_toolkits.axes_grid1 import make_axes_locatable import dill as pickle from matplotlib.ticker import FormatStrFormatter from collections import OrderedDict import warnings from astropy.modeling.models import Gaussian1D from astropy.modeling.fitting import LevMarLSQFitter def find_nearest(array, values): x, y = np.meshgrid(array, values) ev = np.abs(x - y) return np.argmin(ev, axis=1) def add_axis(name, range, bins, unit=None, i_axe=3, log=False): """Define a dictionnary for additionnal wcs axes (linear or log)""" header = {'CTYPE{}'.format(i_axe): name, 'CRPIX{}'.format(i_axe): 1, 'CUNIT{}'.format(i_axe): unit} if log: # Log scale (edges definition) log_step = (np.log(range[1]) - np.log(range[0])) / bins center_start = np.exp(np.log(range[0]) + log_step / 2) header['CTYPE{}'.format(i_axe)] += '-LOG' header['CRVAL{}'.format(i_axe)] = center_start header['CDELT{}'.format(i_axe)] = log_step * center_start # Log scale (center definition) # log_step = (np.log(flux_range[1]) - np.log(flux_range[0])) / (bins-1) # center_start = range[0] else: # Linear scale (edges definition) step = (range[1] - range[0]) / bins header['CRVAL{}'.format(i_axe)] = range[0] + step / 2 header['CDELT{}'.format(i_axe)] = step # Linear scale (center definition) # step = (range[1] - range[0]) / (bins-1) return header def completness_purity_wcs(shape, wcs, bins=30, threshold_range=(0, 1), threshold_bins=10, threshold_log=False): """Build a wcs for the completness_purity function""" slice_step = np.ceil(np.asarray(shape) / bins).astype(int) celestial_slice = slice(0, shape[0], slice_step[0]), slice(0, shape[1], slice_step[1]) # [WIP]: Shall we use a 4D WCS ? (ra/dec flux/threshold) # [WIP]: -TAB does not seems to be very easy to do with astropy # Basicaly Working... . header = wcs[celestial_slice[0], celestial_slice[1]].to_header() header['WCSAXES'] = 3 header.update(add_axis('THRESHOLD', threshold_range, threshold_bins, i_axe=3)) return (bins, bins, threshold_bins), WCS(header) def completness_worker(shape, wcs, sources, fake_sources=None, min_threshold=2, max_threshold=5): """Compute completness from the fake source catalog Parameters ---------- shape : tuple the shape of the resulting image sources : :class:`astropy.table.Table` the detected sources fake_sources : :class:`astropy.table.Table` the fake sources table, with corresponding mask min_threshold : float the minimum SNR threshold requested max_threshold : float the maximum SNR threshold requested Returns ------- _completness, _norm_comp corresponding 2D :class:`numpy.ndarray` """ # If one wanted to used a histogramdd, one would need a threshold axis # covering ALL possible SNR, otherwise loose flux, or cap the thresholds... if fake_sources is not None: fake_snr = np.ma.array(sources[fake_sources['find_peak'].filled(0)]['SNR'], mask=fake_sources['find_peak'].mask) # As we are interested by the cumulative numbers, keep all inside the # upper pixel fake_snr[fake_snr > max_threshold] = max_threshold # print(fake_snr) # TODO: Consider keeping all pixels information in fake_source and source... # This would imply to do only a simple wcs_threshold here... xx, yy, zz = wcs.wcs_world2pix(fake_sources['ra'], fake_sources['dec'], fake_snr.filled(min_threshold), 0) # Number of fake sources recovered _completness, _ = np.histogramdd(np.asarray([xx, yy, zz]).T + 0.5, bins=np.asarray(shape), range=list(zip([0]*len(shape), shape)), weights=~fake_sources['find_peak'].mask) # Reverse cumulative sum to get all sources at the given threshold _completness = np.cumsum(_completness[..., ::-1], axis=2)[..., ::-1] # Number of fake sources (independant of threshold) _norm_comp, _, _ = np.histogram2d(xx + 0.5, yy + 0.5, bins=np.asarray(shape[0:2]), range=list(zip([0]*2, shape[0:2]))) else: _completness, _norm_comp = None, None return _completness, _norm_comp def purity_worker(shape, wcs, sources, max_threshold=2): """Compute completness from the fake source catalog Parameters ---------- shape : tuple the shape of the resulting image sources : :class:`astropy.table.Table` the detected sources table, with corresponding match max_threshold : float the maximum threshold requested Returns ------- _completness, _norm_comp corresponding 2D :class:`numpy.ndarray` """ if sources is not None: sources_snr = sources['SNR'] # As we are interested by the cumulative numbers, keep all inside the # upper pixel sources_snr[sources_snr > max_threshold] = max_threshold xx, yy, zz = wcs.wcs_world2pix(sources['ra'], sources['dec'], sources_snr, 0) ''' print(zz.shape) plt.plot(zz) plt.show(block=True) sys.exit() ''' # Number of fake sources recovered if sources is not None and 'fake_sources' in sources.keys(): _purity, _ = np.histogramdd(np.asarray([xx, yy, zz]).T + 0.5, bins=np.asarray(shape), range=list(zip([0]*len(shape), shape)), weights=~sources['fake_sources'].mask) # Revese cumulative sum... _purity = np.cumsum(_purity[..., ::-1], axis=2)[..., ::-1] else: _purity = None if sources is not None: # Number of total detected sources at a given threshold _norm_pur, _ = np.histogramdd(np.asarray([xx, yy, zz]).T + 0.5, bins=np.asarray(shape), range=list(zip([0]*len(shape), shape))) _norm_pur = np.cumsum(_norm_pur[..., ::-1], axis=2)[..., ::-1] else: _norm_pur = None return _purity, _norm_pur def completness_purity(sources, fake_sources, wcs=None, shape=None): """Compute completness map for a given flux""" # print(flux) # wcs_celestial = wcs.celestial # Lower and upper edges ... take the center of the pixel for the upper edge min_threshold, max_threshold = wcs.sub([3]).all_pix2world([-0.5, shape[2]-1], 0)[0] completness = np.zeros(shape, dtype=np.float) norm_comp = np.zeros(shape[0:2], dtype=np.float) purity = np.zeros(shape, dtype=np.float) norm_pur = np.zeros(shape, dtype=np.float) # %load_ext snakeviz # %snakeviz the following line.... all is spend in the find_peaks / # fit_2d_gaussian # TODO: Change the find_peaks routine, or maybe just the # fit_2d_gaussian to be FAST ! (Maybe look into gcntrd.pro routine # or photutils.centroid.centroid_1dg maybe ?) _completness, _norm_comp = completness_worker(shape, wcs, sources, fake_sources, min_threshold, max_threshold) # print(_completness) completness += _completness norm_comp += _norm_comp _purity, _norm_pur = purity_worker(shape, wcs, sources, max_threshold) purity += _purity norm_pur += _norm_pur # norm can be 0, so to avoid warning on invalid values... with np.errstate(divide='ignore', invalid='ignore'): completness /= norm_comp[..., np.newaxis] purity /= norm_pur # TODO: One should probably return completness AND norm if one want to # combine several fluxes return completness, purity def Plot_CompPur(completness, purity, threshold, nsim=None, savename=None, flux=None): threshold_bins = completness.shape[-1] fig, axes = plt.subplots(nrows=2, ncols=threshold_bins, sharex=True, sharey=True) for i in range(threshold_bins): axes[0, i].imshow(completness[:, :, i], vmin=0, vmax=1) im = axes[1, i].imshow(purity[:, :, i], vmin=0, vmax=1) axes[1, i].set_xlabel("thresh={:.2f}".format(threshold[i])) if i == (threshold_bins-1): # print('-----------') divider = make_axes_locatable(axes[1, i]) cax = divider.append_axes('right', size='5%', pad=0.0) fig = plt.gcf() fig.colorbar(im, cax=cax, orientation='vertical') if nsim is not None: axes[0, 0].set_title("{} simulations".format(nsim)) if flux is not None: axes[0, 1].set_title("{}".format(flux)) axes[0, 0].set_ylabel("completness") axes[1, 0].set_ylabel("purity") if savename is not None: plt.savefig(savename) def Evaluate(flux_ds_fs_list): # _flux = u.Quantity(pr.header['flux{}'.format(isimu)]) _flux = flux_ds_fs_list[0] sources = flux_ds_fs_list[1] fake_sources = flux_ds_fs_list[2] fluxval = _flux.to_value(u.mJy) ''' _flux.to_value(u.mJy) sources = Table.read(hdul['DETECTED_SOURCES{}'.format(_flux)]) fake_sources = Table.read(hdul['FAKE_SOURCES{}'.format(_flux)]) ''' print('{} data loaded'.format(_flux)) # sources = df['DETECTED_SOURCES{}'.format(flux)] # print(fake_sources) # sys.exit() completness, purity = completness_purity(sources, fake_sources, wcs=wcs_4D.sub([1, 2, 3]), shape=shape_4D[0:3]) return fluxval, completness, purity def Evaluate2(_flux, hdul): print(fits.HDUList(hdul).info()) # _flux = u.Quantity(pr.header['flux{}'.format(isimu)]) fluxval = _flux.to_value(u.mJy) sources = Table.read(hdul['DETECTED_SOURCES{}'.format(_flux)]) fake_sources = Table.read(hdul['FAKE_SOURCES{}'.format(_flux)]) print('{} data loaded'.format(_flux)) completness, purity = completness_purity(sources, fake_sources, wcs=wcs_4D.sub([1, 2, 3]), shape=shape_4D[0:3]) print(fluxval, completness, purity) return fluxval, completness, purity def find_nearest(array, values): x, y = np.meshgrid(array, values) ev = np.abs(x - y) return np.argmin(ev, axis=1) def PlotEvaluation(data, title='', flux=np.array([]), thresh=[], nfluxlabels=None, nthreshlabels=None, **kwargs): tickfs = 20 labelfs = 25 if nfluxlabels is not None: _label_flux = np.geomspace(flux[0], flux[-1], nfluxlabels) _f_idx = find_nearest(flux, _label_flux) flblpos, _flbl = _f_idx, flux[_f_idx] else: flblpos, _flbl = np.arange(len(flux)), flux if nthreshlabels is not None: _label_thresh = np.linspace(thresh[0], thresh[-1], nthreshlabels) _t_idx = find_nearest(thresh, _label_thresh) tlblpos, _tlbl = _t_idx, thresh[_t_idx] else: tlblpos, _tlbl = np.arange(len(thresh)), thresh flbl = [] for i in range(len(_flbl)): flbl.append('{:.1f}'.format(_flbl[i])) tlbl = [] for i in range(len(_tlbl)): tlbl.append('{:.1f}'.format(_tlbl[i])) # print(i, tlbl[i]) # print(np.array([tlblpos, tlbl]).T) plt.figure() plt.title(title, fontsize=30) plt.xlabel('Detection Threshold [SNR]', fontsize=labelfs) plt.ylabel('Flux [mJy]', fontsize=labelfs) plt.xticks(tlblpos, tlbl, fontsize=tickfs) # ax = plt.gca() plt.yticks(flblpos, flbl, fontsize=tickfs) # ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f')) plt.imshow(data, origin='lower', **kwargs) cbar = plt.colorbar() cbar.ax.tick_params(labelsize=tickfs) def PlotFixedThreshold(thresholds, bin, completness, allthresholds, flux, nfluxlabels=None, hlines=None): linestyles = ['-', '--', '-.', ':'] real_thresholds = find_nearest(allthresholds, thresholds) for i in range(len(real_thresholds)): _x = flux _y = completness[:, bin[0], bin[1], real_thresholds[i]] plt.plot(_x, _y, linestyle=linestyles[i], label='{:.1f}'.format(allthresholds[real_thresholds[i]])) if hlines is not None: for i, val in enumerate(hlines): plt.axhline(val, color='r') plt.title('Fixed Threshold', fontsize=30, y=1.02) plt.xlabel('Source Flux [mJy]', fontsize=25) plt.ylabel('Completness', fontsize=25) plt.yticks(fontsize=20) plt.xticks(fontsize=20) plt.subplots_adjust(left=0.12) ax = plt.gca() ax.set_xscale("log", nonposx='clip') # legend = plt.legend(fontsize=25, title='SNR', loc='lower right') legend = plt.legend(fontsize=25, title='SNR', loc='upper left', framealpha=1) plt.setp(legend.get_title(), fontsize=25) plt.show(block=True) def CombineMeasurements(sourceslist, fakesourceslist): fake_sources = Table() sources = Table() for _fake, _detected in zip(fakesourceslist, sourceslist): n_fake = len(fake_sources) n_detected = len(sources) if _detected is not None: _detected['ID'] = _detected['ID'] + n_detected if 'fake_sources' in _detected.keys(): _detected['fake_sources'] = _detected['fake_sources'] + n_fake sources = vstack([sources, _detected]) if _fake is not None: _fake['ID'] = _fake['ID'] + n_fake if 'find_peak' in _fake.keys(): _fake['find_peak'] = _fake['find_peak'] + n_detected fake_sources = vstack([fake_sources, _fake]) return sources, fake_sources def pos_in_mask(pos, mask=None, nsources=1, extra=None, retindex=False): """Check if pos is in mask, issue warning with less than nsources STOLEN FROM NIKAMAP AND MODIFIED TO RETURN INDEX Parameters ---------- pos : array_like (N, 2) pixel indexes (y, x) to be checked in mask mask : 2D boolean array_like corresponding mask nsources : int the requested number of sources Returns ------- :class:`numpy.ndarray` the pixel indexes within the mask """ pos = np.asarray(pos) inside = np.ones(pos.shape[0], dtype=bool) if mask is not None: pos_idx = np.floor(pos + 0.5).astype(int) inside = ~mask[pos_idx[:, 0], pos_idx[:, 1]] if retindex: return ~inside pos = pos[inside] if pos.shape[0] < nsources: warnings.warn("Only {} positions".format(pos.shape[0]), UserWarning) if extra is None: return pos else: return pos, extra[inside] def Flux1D(detectedflux, catflux, bins=10, fitter=LevMarLSQFitter): labelsize = 25 ticksize = 20 textfs = 25 boxprops = dict(facecolor='white', edgecolor='black', boxstyle='round, pad=.3', alpha=.5) fluxrel = (detectedflux / catflux).decompose() - 1. ''' print(fluxrel) print(fluxrel.shape) print(type(fluxrel)) print('creating histogram') ''' title = 'Flux Resolution' try: iter(catflux) # if it works maybe add fluxrange to title except TypeError: title += ' {:.2f}'.format(catflux) plt.figure() hist, edges, patches = plt.hist(fluxrel, bins=bins, density=True) plt.title(title, fontsize=30, y=1.02)# , loc='left') plt.xlabel(r'$\frac{F_{\mathrm{det}}}{F_{\mathrm{in}}}-1$', fontsize=labelsize) plt.ylabel('Normalized Counts', fontsize=labelsize) plt.xticks(fontsize=ticksize) plt.yticks(fontsize=ticksize) center = edges[:-1] + (edges[1:] - edges[:-1]) / 2 finit = Gaussian1D(mean=0) fitf = fitter() f = fitf(finit, center, hist) x = np.linspace(edges[0], edges[-1], 200) ax = plt.gca() plt.plot(x, f(x), color='r', linewidth=3) fittext = ('Entries: {}\n'.format(len(detectedflux)) + r'$x_0={:.2f}$'.format(f.mean.value) + '\n' + r'$\sigma={:4.2f}$'.format(f.stddev.value)) # ax.text(1-0.05, 0.93, fittext, ax.text(.5, 1-0.93, fittext, fontsize=textfs, bbox=boxprops, horizontalalignment='center', verticalalignment='bottom', ma='right', transform=ax.transAxes) return f def DeletePercentiles(array, minclip=0, maxclip=100): print(type(array)) percentiles = np.percentile(array, (minclip, maxclip)) print(percentiles) mask = np.array((array < percentiles[0]) | (array > percentiles[1]), dtype=bool) return array[~mask]