From ab75b24c052b7c88e6f5b55fd6464b602984a30b Mon Sep 17 00:00:00 2001 From: LUSTIG Peter Date: Mon, 16 Apr 2018 18:03:02 +0200 Subject: [PATCH] some changes --- evaluation.py | 78 ++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 61 insertions(+), 17 deletions(-) diff --git a/evaluation.py b/evaluation.py index e0b3eea..2878fb1 100644 --- a/evaluation.py +++ b/evaluation.py @@ -26,6 +26,7 @@ import dill as pickle from matplotlib.ticker import FormatStrFormatter from collections import OrderedDict from utils import completness_purity_wcs, completness_worker, purity_worker +from utils import find_nearest @@ -70,6 +71,8 @@ class PCEvaluation: bins=self.bins, threshold_range=threshold_range, threshold_bins=threshold_bins) + #print(self.wcs_3D) + #sys.exit() print('wcs created') # Testing the lower edges @@ -77,6 +80,8 @@ class PCEvaluation: assert np.all(np.abs(wcs_threshold.all_pix2world( np.arange(threshold_bins+1)-0.5, 0) - threshold_edges) < 1e-15) + self.completness, self.purity, self.hitmap = self.GetCP() + print(self.completness.shape) def GetCP(self, sources=None, fake_sources=None, wcs=None, shape=None, pool=None): @@ -102,7 +107,7 @@ class PCEvaluation: comp.append(tmpres[0]) pur.append(tmpres[1]) hitm.append(tmpres[2]) - return comp, pur, hitm + return np.array(comp), np.array(pur), np.array(hitm) def completness_purity(self, sources, fake_sources, wcs=None, shape=None): @@ -136,6 +141,52 @@ class PCEvaluation: return completness, purity, norm_comp + def GetBinResults(self, ix, iy): + #print(self.completness) + return (self.completness[:, iy, ix, :], + self.purity[:, iy, ix, :], + self.hitmap[:, iy, ix]) + + + def PlotBin(self, data, title='', flux=np.array([]), thresh=[], + nfluxlabels=None, nthreshlabels=None, **kwargs): + tickfs = 20 + labelfs = 25 + + if nfluxlabels is not None: + _label_flux = np.geomspace(flux[0], flux[-1], nfluxlabels) + _f_idx = find_nearest(flux, _label_flux) + flblpos, _flbl = _f_idx, flux[_f_idx] + else: + flblpos, _flbl = np.arange(len(flux)), flux + + if nthreshlabels is not None: + _label_thresh = np.linspace(thresh[0], thresh[-1], nthreshlabels) + _t_idx = find_nearest(thresh, _label_thresh) + print(_t_idx) + tlblpos, _tlbl = _t_idx, thresh[_t_idx] + else: + tlblpos, _tlbl = np.arange(len(thresh)), thresh + + flbl = [] + for i in range(len(_flbl)): + flbl.append('{:.1f}'.format(_flbl[i])) + + tlbl = [] + for i in range(len(_tlbl)): + tlbl.append('{:.1f}'.format(_tlbl[i])) + + plt.figure() + plt.title(title, fontsize=30) + plt.xlabel('Detection Threshold [SNR]', fontsize=labelfs) + plt.ylabel('Flux [mJy]', fontsize=labelfs) + plt.xticks(tlblpos, tlbl, fontsize=tickfs) + plt.yticks(flblpos, flbl, fontsize=tickfs) + plt.imshow(data, origin='lower', **kwargs) + cbar = plt.colorbar() + cbar.ax.tick_params(labelsize=tickfs) + + plt.close('all') DATA_DIR = "/home/peter/Dokumente/Uni/Paris/Stage/data/v_1" @@ -156,7 +207,7 @@ SOURCE = [] FSOURCE = [] # for isimu in range(nfluxes): -for isimu in range(3): +for isimu in range(6): FLUX.append(u.Quantity(hdul[0].header['flux{}'.format(isimu)])) SOURCE.append(Table.read(hdul['DETECTED_SOURCES{}' @@ -168,20 +219,19 @@ xx = PCEvaluation(SOURCE, FSOURCE, sh, wcs, u.Quantity(FLUX)) print('done') # %% testfunctions -res = xx.GetCP() +# p = Pool(2) +# res = xx.GetCP() +dd = xx.GetBinResults(9, 9) +dd = xx.completness +print('comp shape', dd.shape) +xx.PlotBin(xx.completness[:, 9, 9, :]) +plt.show(block=True) +sys.exit() # %% # midbin = int(bins / 2) - - -def find_nearest(array, values): - x, y = np.meshgrid(array, values) - ev = np.abs(x - y) - return np.argmin(ev, axis=1) - - def PlotEvaluation(data, title='', flux=np.array([]), thresh=[], nfluxlabels=None, nthreshlabels=None, **kwargs): tickfs = 20 @@ -209,19 +259,13 @@ def PlotEvaluation(data, title='', flux=np.array([]), thresh=[], tlbl = [] for i in range(len(_tlbl)): tlbl.append('{:.1f}'.format(_tlbl[i])) - # print(i, tlbl[i]) - # print(np.array([tlblpos, tlbl]).T) plt.figure() plt.title(title, fontsize=30) plt.xlabel('Detection Threshold [SNR]', fontsize=labelfs) plt.ylabel('Flux [mJy]', fontsize=labelfs) - plt.xticks(tlblpos, tlbl, fontsize=tickfs) - # ax = plt.gca() plt.yticks(flblpos, flbl, fontsize=tickfs) - # ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f')) - plt.imshow(data, origin='lower', **kwargs) cbar = plt.colorbar() cbar.ax.tick_params(labelsize=tickfs) -- GitLab