Commit bb0b205a authored by LUSTIG Peter's avatar LUSTIG Peter

Interpolation implemented

parent 73fe7323
......@@ -102,7 +102,7 @@ class BasicEvaluation:
self.fake_sourceslocmask = fake_sourceslocmask
class PCAreaEvaluation(BasicEvaluation):
class CompletnessArea(BasicEvaluation):
def __init__(self, sources, fake_sources=None,
shape=None, wcs=None, threshold_edges=None, flux_edges=None,
flux_centers=None,
......@@ -117,7 +117,7 @@ class PCAreaEvaluation(BasicEvaluation):
self.ChangeMask(mask)
self.threshold_edges = threshold_edges
self.flux_edges = flux_edges
self.flux_centers = flux_centers,
self.flux_centers = flux_centers
self.Completness()
def ChangeMask(self, mask):
......@@ -149,42 +149,67 @@ class PCAreaEvaluation(BasicEvaluation):
fake_sourceslocmask)
self.completness = completness / norm_completness[:, np.newaxis]
def InterpolateThresholds(self, value):
completness = self.completness
thresholds = self.threshold_edges[:-1]
flux_centers = self.flux_centers.to_value(u.mJy)
shapeflux = len(flux_centers)
shapethresh = len(thresholds)
if not np.isscalar(value):
shapeflux = (len(flux_centers), len(value))
shapethresh = (len(thresholds), len(value))
constfluxinterp = np.zeros(shapeflux)
constthreshinterp = np.zeros(shapethresh)
for i, constfluxcomp in enumerate(completness):
tmp = np.interp(value, constfluxcomp[::-1], thresholds[::-1])
constfluxinterp[i] = tmp
for i, constthreshcomp in enumerate(completness.T):
tmp = np.interp(value, constthreshcomp, flux_centers)
constthreshinterp[i] = tmp
return constfluxinterp, constthreshinterp
def PlotCompletness1D(self, idx, constant='flux', linestyles=None,
labels=None, ax=None, **kwargs):
ax=None, **kwargs):
if type(idx) is int:
idx = np.array([idx])
nplots = len(idx)
if constant == 'thresh':
completnesses = self.completness[:, idx]
x = self.flux_centers
xlabel = 'Flux'
completnesses = self.completness[:, idx].T
x = self.flux_centers.to_value(u.mJy)
xlabel = 'Flux [mJy]'
labelval = self.threshold_edges[idx]
legendtitle = 'SNR'
elif constant == 'flux':
completnesses = self.completness[idx]
x = self.threshold_edges[:-1]
xlabel = 'Detection Threshold'
labelval = self.flux_centers[idx]
labelval = self.flux_centers[idx].to_value(u.mJy)
legendtitle = 'Flux [mJy]'
else:
raise UserWarning('you must chose \'flux\' or thresh for constant')
if linestyles is None:
linestyles = ['-'] * nplots
if labels is None:
labels = np.repeat(labels, nplots)
if ax is None:
fig = plt.figure(figsize=(7, 6))
fig.subplots_adjust(left=0.16, right=0.97, bottom=0.15, top=0.9)
ax = fig.add_subplot(111)
for i, (completness, linestyle, label) in enumerate(
zip(completnesses, linestyles,
labels)):
ax.plot(x, completness, linestyle=linestyle, label=label,
**kwargs)
for i, (completness, linestyle) in enumerate(
zip(completnesses, linestyles)):
ax.plot(x, completness, linestyle=linestyle,
label='{:.2f}'.format(labelval[i]), **kwargs)
legend = ax.legend(fontsize=22, title=legendtitle, loc='upper right',
framealpha=1)
plt.setp(legend.get_title(), fontsize=20)
ax.set_title('Completness', fontsize=30, y=1.02)
ax.set_xlabel(xlabel, fontsize=25)
ax.set_ylabel('Completness', fontsize=25)
......@@ -193,21 +218,41 @@ class PCAreaEvaluation(BasicEvaluation):
def Plot2D(self, ax=None, nfluxlabels=5, nthreshlabels=5):
completness = self.completness
flux = self.flux_centers
flux = self.flux_centers.to_value(u.mJy)
print(type(flux))
thresh = self.threshold_edges
if ax is None:
fig = plt.figure(figsize=(7,7))
fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(111)
nflux, nthresh = completness.shape
fluxlabelidx = numpy.rint(np.linspace(0, nflux, nfluxlabels))
threshlabelidx = numpy.rint(np.linspace(0, nthresh, nthreshlabels))
fluxlabelidx = (np.rint(np.linspace(0, nflux-1, nfluxlabels))
.astype(int))
threshlabelidx = (np.rint(np.linspace(0, nthresh-1, nthreshlabels))
.astype(int))
threshlabel = np.array2string(np.round(
thresh[threshlabelidx], decimals=2),
separator=',')[1:-1].split(',')
fluxlabel = np.array2string(np.round(
flux[fluxlabelidx], decimals=2),
separator=',')[1:-1].split(',')
ax.imshow(completness, origin='lower', aspect='auto')
ax.set_xticks(threshlabelidx)
ax.set_xticklabels(threshlabel, {'fontsize': 15})
ax.set_yticks(fluxlabelidx)
ax.set_yticklabels(fluxlabel, {'fontsize': 15})
ax.set_title('Completness', fontsize=25, y=1.02)
ax.set_xlabel('Detection Threshold [SNR]', {'fontsize': 20})
ax.set_ylabel('Source Flux [mJy]', {'fontsize': 20})
fig = ax.get_figure()
fig.subplots_adjust(right=0.95, left=0.15)
return ax
class CountsAreaEvaluation(BasicEvaluation):
def __init__(self, sources,
......@@ -359,7 +404,7 @@ class Purity:
ylabel=ylabel, ax=ax)
return ax
%matplotlib tk
# %matplotlib tk
showcompletness = True
showfalsecounts = False
showrealcounts = False
......@@ -385,18 +430,24 @@ if showcompletness:
fake_sources = Table.read(hdul, 'FAKE_SOURCES')
fluxedges = LogBinEdges(fake_sources['amplitude']) * u.Jy
fluxcenters = np.sort(np.unique(fake_sources['amplitude']))
fluxcenters = u.Quantity(np.squeeze(np.sort(np.unique(
fake_sources['amplitude']))))
# %%
plt.close('all')
# %matplotlib tk
obj = PCAreaEvaluation(sources, fake_sources, shape=(501, 501), wcs=wcs,
threshold_edges=threshbins, flux_edges=fluxedges,
mask=totmask, flux_centers=fluxcenters)
obj.completness
obj.sourceslocmask
obj.fake_sourceslocmask
obj = CompletnessArea(sources, fake_sources, shape=(501, 501), wcs=wcs,
threshold_edges=threshbins, flux_edges=fluxedges,
mask=totmask, flux_centers=fluxcenters)
# obj.PlotCompletness1D(np.array([1, 25, 49]), constant='flux')
obj.Plot2D()
obj.PlotCompletness1D(np.array([1, 25, 49]), constant='thresh')
obj.completness
ax = obj.Plot2D()
obj.PlotCompletness1D(np.array([1, 25, 49]), constant='thresh')
obj.PlotCompletness1D(np.array([10, 25, 30]), constant='flux')
obj.InterpolateThresholds(0.8)
# %%
if showfalsecounts:
emptycountsfile = ('/home/peter/Dokumente/Uni/Paris/Stage/FirstSteps/'
'Completness/montecarlo_results/Newrealisation/'
......@@ -410,7 +461,7 @@ if showfalsecounts:
shape=shape, wcs=wcs, threshold_edges=threshbins,
mask=totmask, no_sim=1000)
false_counts = fc.Counts()
#%matplotlib tk
# %matplotlib tk
nax = fc.PlotCounts(label='False\nDetections')
if showrealcounts:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment