Commit faee374a authored by LUSTIG Peter's avatar LUSTIG Peter

improved plots

parent e501b9ac
...@@ -22,6 +22,11 @@ from astropy.io import fits ...@@ -22,6 +22,11 @@ from astropy.io import fits
from astropy.table import Table, MaskedColumn from astropy.table import Table, MaskedColumn
import sys import sys
from mpl_toolkits.axes_grid1 import make_axes_locatable from mpl_toolkits.axes_grid1 import make_axes_locatable
import dill as pickle
from matplotlib.ticker import FormatStrFormatter
from collections import OrderedDict
import os import os
os.getcwd() os.getcwd()
...@@ -278,6 +283,23 @@ def Evaluate(flux_ds_fs_list): ...@@ -278,6 +283,23 @@ def Evaluate(flux_ds_fs_list):
return fluxval, completness, purity return fluxval, completness, purity
def Evaluate2(_flux, hdul):
print(fits.HDUList(hdul).info())
# _flux = u.Quantity(pr.header['flux{}'.format(isimu)])
fluxval = _flux.to_value(u.mJy)
sources = Table.read(hdul['DETECTED_SOURCES{}'.format(_flux)])
fake_sources = Table.read(hdul['FAKE_SOURCES{}'.format(_flux)])
print('{} data loaded'.format(_flux))
completness, purity = completness_purity(sources, fake_sources,
wcs=wcs_4D.sub([1, 2, 3]),
shape=shape_4D[0:3])
print(fluxval, completness, purity)
return fluxval, completness, purity
plt.close('all') plt.close('all')
# _data = next(Jackknife(filenames, n=None)) # _data = next(Jackknife(filenames, n=None))
...@@ -332,15 +354,15 @@ assert np.all(np.abs(wcs_flux.all_pix2world(np.arange(flux_bins+1)-0.5, 0) - flu ...@@ -332,15 +354,15 @@ assert np.all(np.abs(wcs_flux.all_pix2world(np.arange(flux_bins+1)-0.5, 0) - flu
# This is a single run check for a single flux # This is a single run check for a single flux
hdul = fits.open('/home/peter/Dokumente/Uni/Paris/Stage/FirstSteps/' hdul = fits.HDUList(fits.open('/home/peter/Dokumente/Uni/Paris/Stage/FirstSteps/'
'Completness/combined_tables_long.fits') 'Completness/combined_tables_long.fits'))
# sys.exit()
nfluxes = hdul[0].header['NFLUXES'] nfluxes = hdul[0].header['NFLUXES']
print('{} different fluxes found'.format(nfluxes)) print('{} different fluxes found'.format(nfluxes))
# Get fluxlist: # Get fluxlist:
indata = []
indata = []
for isimu in range(nfluxes): for isimu in range(nfluxes):
_FLUX = u.Quantity(hdul[0].header['flux{}'.format(isimu)]) _FLUX = u.Quantity(hdul[0].header['flux{}'.format(isimu)])
...@@ -349,10 +371,20 @@ for isimu in range(nfluxes): ...@@ -349,10 +371,20 @@ for isimu in range(nfluxes):
_FAKE_SOURCES = Table.read(hdul['FAKE_SOURCES{}' _FAKE_SOURCES = Table.read(hdul['FAKE_SOURCES{}'
.format(_FLUX)]) .format(_FLUX)])
indata.append([_FLUX, _SOURCES, _FAKE_SOURCES]) indata.append([_FLUX, _SOURCES, _FAKE_SOURCES])
'''
flux = []
for isimu in range(nfluxes):
flux.append(u.Quantity(hdul[0].header['flux{}'.format(isimu)]))
'''
# helpfunc = partial(Evaluate, **{'hdul': hdul}) # helpfunc = partial(Evaluate, **{'hdul': hdul})
p = Pool(cpu_count()) p = Pool(cpu_count())
# print(hdul)
#hdul = []
# f = partial(Evaluate2, hdul=hdul)
# print(flux)
# res = p.map(f, flux)
# print(flux)
# sys.exit()
res = p.map(Evaluate, indata) res = p.map(Evaluate, indata)
res = list(zip(*res)) res = list(zip(*res))
FLUX = np.array(res[0]) FLUX = np.array(res[0])
...@@ -365,20 +397,58 @@ COMPLETNESS = COMPLETNESS[idxsort] ...@@ -365,20 +397,58 @@ COMPLETNESS = COMPLETNESS[idxsort]
PURITY = PURITY[idxsort] PURITY = PURITY[idxsort]
midbin = int(bins/2) midbin = int(bins/2)
print(midbin) print(midbin)
# sys.exit() # sys.exit()
# %% PlotFigure # %% PlotFigure
def PlotEvaluation(data, title='', flux=[], thresh=[], **kwargs): def find_nearest(array, values):
x, y = np.meshgrid(array, values)
ev = np.abs(x - y)
return np.argmin(ev, axis=1)
def PlotEvaluation(data, title='', flux=np.array([]), thresh=[],
nfluxlabels=None, nthreshlabels=None, **kwargs):
tickfs = 20 tickfs = 20
labelfs = 25 labelfs = 25
if nfluxlabels is not None:
_label_flux = np.geomspace(flux[0], flux[-1], nfluxlabels)
_f_idx = find_nearest(flux, _label_flux)
flblpos, _flbl = _f_idx, flux[_f_idx]
else:
flblpos, _flbl = np.arange(len(flux)), flux
if nthreshlabels is not None:
_label_thresh = np.linspace(thresh[0], thresh[-1], nthreshlabels)
_t_idx = find_nearest(thresh, _label_thresh)
print(_t_idx)
tlblpos, _tlbl = _t_idx, thresh[_t_idx]
else:
tlblpos, _tlbl = np.arange(len(thresh)), thresh
flbl = []
for i in range(len(_flbl)):
flbl.append('{:.1f}'.format(_flbl[i]))
tlbl = []
for i in range(len(_tlbl)):
tlbl.append('{:.1f}'.format(_tlbl[i]))
# print(i, tlbl[i])
# print(np.array([tlblpos, tlbl]).T)
plt.figure() plt.figure()
plt.title(title, fontsize=30) plt.title(title, fontsize=30)
plt.xlabel('Detection Threshold [SNR]', fontsize=labelfs) plt.xlabel('Detection Threshold [SNR]', fontsize=labelfs)
plt.ylabel('Flux [mJy]', fontsize=labelfs) plt.ylabel('Flux [mJy]', fontsize=labelfs)
plt.xticks(np.arange(len(thresh)), thresh, fontsize=tickfs)
plt.yticks(np.arange(len(flux)), flux, fontsize=tickfs) plt.xticks(tlblpos, tlbl, fontsize=tickfs)
# ax = plt.gca()
plt.yticks(flblpos, flbl, fontsize=tickfs)
# ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
plt.imshow(data, origin='lower', **kwargs) plt.imshow(data, origin='lower', **kwargs)
cbar = plt.colorbar() cbar = plt.colorbar()
cbar.ax.tick_params(labelsize=tickfs) cbar.ax.tick_params(labelsize=tickfs)
...@@ -386,6 +456,45 @@ def PlotEvaluation(data, title='', flux=[], thresh=[], **kwargs): ...@@ -386,6 +456,45 @@ def PlotEvaluation(data, title='', flux=[], thresh=[], **kwargs):
PlotEvaluation(COMPLETNESS[:, midbin, midbin, :], title='Completness', PlotEvaluation(COMPLETNESS[:, midbin, midbin, :], title='Completness',
flux=list(FLUX), thresh=threshold, cmap='bone') flux=np.array(FLUX), thresh=threshold, cmap='bone',
nfluxlabels=10, nthreshlabels=5, aspect='auto')
# %% 2D plot
def PlotFixedThreshold(thresholds, bin, completness, allthresholds, flux,
nfluxlabels=None, hlines=None):
linestyles = ['-', '--', '-.', ':']
real_thresholds = find_nearest(allthresholds, thresholds)
for i in range(len(real_thresholds)):
_x = flux
_y = completness[:, bin[0], bin[1], real_thresholds[i]]
plt.plot(_x, _y, linestyle=linestyles[i],
label='{:.1f}'.format(allthresholds[real_thresholds[i]]))
if hlines is not None:
for i, val in enumerate(hlines):
plt.axhline(val, color='r')
plt.title('Fixed Threshold', fontsize=30, y=1.02)
plt.xlabel('Source Flux [mJy]', fontsize=25)
plt.ylabel('Completness', fontsize=25)
plt.yticks(fontsize=20)
plt.xticks(fontsize=20)
plt.subplots_adjust(left=0.12)
ax = plt.gca()
ax.set_xscale("log", nonposx='clip')
# legend = plt.legend(fontsize=25, title='SNR', loc='lower right')
legend = plt.legend(fontsize=25, title='SNR', loc='upper left',
framealpha=1)
plt.setp(legend.get_title(), fontsize=25)
plt.show(block=True)
# cmap bone hot # cmap bone hot
plt.show(block=True) # print(np.array(FLUX))
PlotFixedThreshold(np.array((3, 5, 7)), (midbin, midbin), COMPLETNESS,
threshold, np.array(FLUX), nfluxlabels=None,
hlines=[.5, .8, .9])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment