From faee374af4799eefa0c358399d5343d56e7596d0 Mon Sep 17 00:00:00 2001 From: LUSTIG Peter Date: Mon, 16 Apr 2018 12:37:42 +0200 Subject: [PATCH] improved plots --- evaluation.py | 131 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 120 insertions(+), 11 deletions(-) diff --git a/evaluation.py b/evaluation.py index 0164098..095fd47 100644 --- a/evaluation.py +++ b/evaluation.py @@ -22,6 +22,11 @@ from astropy.io import fits from astropy.table import Table, MaskedColumn import sys from mpl_toolkits.axes_grid1 import make_axes_locatable +import dill as pickle +from matplotlib.ticker import FormatStrFormatter +from collections import OrderedDict + + import os os.getcwd() @@ -278,6 +283,23 @@ def Evaluate(flux_ds_fs_list): return fluxval, completness, purity +def Evaluate2(_flux, hdul): + print(fits.HDUList(hdul).info()) + # _flux = u.Quantity(pr.header['flux{}'.format(isimu)]) + fluxval = _flux.to_value(u.mJy) + sources = Table.read(hdul['DETECTED_SOURCES{}'.format(_flux)]) + fake_sources = Table.read(hdul['FAKE_SOURCES{}'.format(_flux)]) + + print('{} data loaded'.format(_flux)) + + completness, purity = completness_purity(sources, fake_sources, + wcs=wcs_4D.sub([1, 2, 3]), + shape=shape_4D[0:3]) + + print(fluxval, completness, purity) + return fluxval, completness, purity + + plt.close('all') # _data = next(Jackknife(filenames, n=None)) @@ -332,15 +354,15 @@ assert np.all(np.abs(wcs_flux.all_pix2world(np.arange(flux_bins+1)-0.5, 0) - flu # This is a single run check for a single flux -hdul = fits.open('/home/peter/Dokumente/Uni/Paris/Stage/FirstSteps/' - 'Completness/combined_tables_long.fits') +hdul = fits.HDUList(fits.open('/home/peter/Dokumente/Uni/Paris/Stage/FirstSteps/' + 'Completness/combined_tables_long.fits')) +# sys.exit() nfluxes = hdul[0].header['NFLUXES'] print('{} different fluxes found'.format(nfluxes)) # Get fluxlist: -indata = [] - +indata = [] for isimu in range(nfluxes): _FLUX = u.Quantity(hdul[0].header['flux{}'.format(isimu)]) @@ -349,10 +371,20 @@ for isimu in range(nfluxes): _FAKE_SOURCES = Table.read(hdul['FAKE_SOURCES{}' .format(_FLUX)]) indata.append([_FLUX, _SOURCES, _FAKE_SOURCES]) - - +''' +flux = [] +for isimu in range(nfluxes): + flux.append(u.Quantity(hdul[0].header['flux{}'.format(isimu)])) +''' # helpfunc = partial(Evaluate, **{'hdul': hdul}) p = Pool(cpu_count()) +# print(hdul) +#hdul = [] +# f = partial(Evaluate2, hdul=hdul) +# print(flux) +# res = p.map(f, flux) +# print(flux) +# sys.exit() res = p.map(Evaluate, indata) res = list(zip(*res)) FLUX = np.array(res[0]) @@ -365,20 +397,58 @@ COMPLETNESS = COMPLETNESS[idxsort] PURITY = PURITY[idxsort] midbin = int(bins/2) + print(midbin) # sys.exit() # %% PlotFigure -def PlotEvaluation(data, title='', flux=[], thresh=[], **kwargs): +def find_nearest(array, values): + x, y = np.meshgrid(array, values) + ev = np.abs(x - y) + return np.argmin(ev, axis=1) + + +def PlotEvaluation(data, title='', flux=np.array([]), thresh=[], + nfluxlabels=None, nthreshlabels=None, **kwargs): tickfs = 20 labelfs = 25 + + if nfluxlabels is not None: + _label_flux = np.geomspace(flux[0], flux[-1], nfluxlabels) + _f_idx = find_nearest(flux, _label_flux) + flblpos, _flbl = _f_idx, flux[_f_idx] + else: + flblpos, _flbl = np.arange(len(flux)), flux + + if nthreshlabels is not None: + _label_thresh = np.linspace(thresh[0], thresh[-1], nthreshlabels) + _t_idx = find_nearest(thresh, _label_thresh) + print(_t_idx) + tlblpos, _tlbl = _t_idx, thresh[_t_idx] + else: + tlblpos, _tlbl = np.arange(len(thresh)), thresh + + flbl = [] + for i in range(len(_flbl)): + flbl.append('{:.1f}'.format(_flbl[i])) + + tlbl = [] + for i in range(len(_tlbl)): + tlbl.append('{:.1f}'.format(_tlbl[i])) + # print(i, tlbl[i]) + + # print(np.array([tlblpos, tlbl]).T) plt.figure() plt.title(title, fontsize=30) plt.xlabel('Detection Threshold [SNR]', fontsize=labelfs) plt.ylabel('Flux [mJy]', fontsize=labelfs) - plt.xticks(np.arange(len(thresh)), thresh, fontsize=tickfs) - plt.yticks(np.arange(len(flux)), flux, fontsize=tickfs) + + plt.xticks(tlblpos, tlbl, fontsize=tickfs) + # ax = plt.gca() + plt.yticks(flblpos, flbl, fontsize=tickfs) + # ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f')) + plt.imshow(data, origin='lower', **kwargs) cbar = plt.colorbar() cbar.ax.tick_params(labelsize=tickfs) @@ -386,6 +456,45 @@ def PlotEvaluation(data, title='', flux=[], thresh=[], **kwargs): PlotEvaluation(COMPLETNESS[:, midbin, midbin, :], title='Completness', - flux=list(FLUX), thresh=threshold, cmap='bone') + flux=np.array(FLUX), thresh=threshold, cmap='bone', + nfluxlabels=10, nthreshlabels=5, aspect='auto') + + +# %% 2D plot + + +def PlotFixedThreshold(thresholds, bin, completness, allthresholds, flux, + nfluxlabels=None, hlines=None): + + linestyles = ['-', '--', '-.', ':'] + real_thresholds = find_nearest(allthresholds, thresholds) + for i in range(len(real_thresholds)): + _x = flux + _y = completness[:, bin[0], bin[1], real_thresholds[i]] + plt.plot(_x, _y, linestyle=linestyles[i], + label='{:.1f}'.format(allthresholds[real_thresholds[i]])) + + if hlines is not None: + for i, val in enumerate(hlines): + plt.axhline(val, color='r') + plt.title('Fixed Threshold', fontsize=30, y=1.02) + plt.xlabel('Source Flux [mJy]', fontsize=25) + plt.ylabel('Completness', fontsize=25) + plt.yticks(fontsize=20) + plt.xticks(fontsize=20) + plt.subplots_adjust(left=0.12) + ax = plt.gca() + ax.set_xscale("log", nonposx='clip') + # legend = plt.legend(fontsize=25, title='SNR', loc='lower right') + legend = plt.legend(fontsize=25, title='SNR', loc='upper left', + framealpha=1) + plt.setp(legend.get_title(), fontsize=25) + plt.show(block=True) + # cmap bone hot -plt.show(block=True) +# print(np.array(FLUX)) + + +PlotFixedThreshold(np.array((3, 5, 7)), (midbin, midbin), COMPLETNESS, + threshold, np.array(FLUX), nfluxlabels=None, + hlines=[.5, .8, .9]) -- GitLab