evaluation.py 10.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
from __future__ import absolute_import, division, print_function

from pathlib import Path
import os
import numpy as np
import matplotlib.pyplot as plt

from multiprocessing import Pool, cpu_count
from functools import partial

from astropy import units as u
from astropy.io import ascii
from astropy.wcs import WCS
from astropy.utils.console import ProgressBar
from astropy.table import vstack

from scipy.optimize import curve_fit

from nikamap import NikaMap, Jackknife
from nikamap.utils import pos_uniform
from astropy.io import fits
from astropy.table import Table, MaskedColumn
import sys
from mpl_toolkits.axes_grid1 import make_axes_locatable
LUSTIG Peter's avatar
LUSTIG Peter committed
25 26 27
import dill as pickle
from matplotlib.ticker import FormatStrFormatter
from collections import OrderedDict
LUSTIG Peter's avatar
LUSTIG Peter committed
28
from utils import completness_purity_wcs, completness_worker, purity_worker
LUSTIG Peter's avatar
LUSTIG Peter committed
29
from utils import find_nearest
LUSTIG Peter's avatar
LUSTIG Peter committed
30 31


32 33 34 35 36 37 38 39 40 41 42 43

import os
os.getcwd()
'''
%load_ext autoreload
%autoreload 2
%matplotlib tk
'''

plt.ion()


LUSTIG Peter's avatar
LUSTIG Peter committed
44 45 46 47 48 49
class PCEvaluation:
    def __init__(self, sources, fake_sources, shape, wcs, flux=None,
                 mapbins=19, threshold_bins=5, threshold_range=(3, 5)):

        idxsort = np.argsort(flux.to_value(u.mJy))
        self.flux = flux[idxsort]
LUSTIG Peter's avatar
LUSTIG Peter committed
50 51
        self.sources = [sources[i] for i in idxsort]
        self.fake_sources = [fake_sources[i] for i in idxsort]
LUSTIG Peter's avatar
LUSTIG Peter committed
52 53 54 55 56 57 58 59 60 61 62

        self.completness = None
        self.purity = None
        assert len(sources) == len(fake_sources), ("Number of results for "
                                                   "sources and fake "
                                                   "sources is not the same.")
        assert len(sources) == len(flux), ("Number of provided fluxes differs "
                                           "from number of simulation results")
        assert type(mapbins) is int, "number of bins must be an integer"
        self.bins = mapbins

63 64
        self.thresholds = np.linspace(threshold_range[0], threshold_range[1],
                                      threshold_bins)
LUSTIG Peter's avatar
LUSTIG Peter committed
65 66 67 68 69 70 71 72 73 74 75 76 77 78
        threshold_edges = np.linspace(threshold_range[0], threshold_range[1],
                                      threshold_bins+1)

        self.shape_3D, self.wcs_3D = completness_purity_wcs(
                                        shape, wcs,
                                        bins=self.bins,
                                        threshold_range=threshold_range,
                                        threshold_bins=threshold_bins)

        # Testing the lower edges
        wcs_threshold = self.wcs_3D.sub([3])
        assert np.all(np.abs(wcs_threshold.all_pix2world(
                                np.arange(threshold_bins+1)-0.5, 0)
                             - threshold_edges) < 1e-15)
LUSTIG Peter's avatar
LUSTIG Peter committed
79
        self.completness, self.purity, self.hitmap = self.GetCP()
LUSTIG Peter's avatar
LUSTIG Peter committed
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99

    def GetCP(self, sources=None, fake_sources=None, wcs=None, shape=None,
                pool=None):
        if sources is None:
            sources = self.sources
        if fake_sources is None:
            fake_sources = self.fake_sources
        if wcs is None:
            wcs = self.wcs_3D
        if shape is None:
            shape = self.shape_3D

        if pool is not None:
            f = partial(self.completness_purity, wcs=wcs, shape=shape)
            res = pool.starmap(f, (sources, fake_sources))
            res = list(zip(*res))
            return res[0], res[1], res[2]
        else:
            comp, pur, hitm = [], [], []
            for i in range(len(sources)):
LUSTIG Peter's avatar
LUSTIG Peter committed
100
                tmpres = self.completness_purity(sources[i], fake_sources[i],
LUSTIG Peter's avatar
LUSTIG Peter committed
101 102 103 104
                                                 wcs, shape)
                comp.append(tmpres[0])
                pur.append(tmpres[1])
                hitm.append(tmpres[2])
LUSTIG Peter's avatar
LUSTIG Peter committed
105
            return np.array(comp), np.array(pur), np.array(hitm)
LUSTIG Peter's avatar
LUSTIG Peter committed
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136

    def completness_purity(self, sources, fake_sources, wcs=None,
                           shape=None):
        """Compute completness map for a given flux"""

        min_threshold, max_threshold = wcs.sub([3]).all_pix2world(
                                                        [-0.5, shape[2]-1],
                                                        0)[0]

        # %load_ext snakeviz
        # %snakeviz the following line.... all is spend in the find_peaks /
        # fit_2d_gaussian
        # TODO: Change the find_peaks routine, or maybe just the
        # fit_2d_gaussian to be FAST ! (Maybe look into gcntrd.pro routine
        # or photutils.centroid.centroid_1dg maybe ?)

        completness, norm_comp = completness_worker(shape, wcs, sources,
                                                    fake_sources,
                                                    min_threshold,
                                                    max_threshold)

        purity, norm_pur = purity_worker(shape, wcs, sources, max_threshold)

        # norm can be 0, so to avoid warning on invalid values...
        with np.errstate(divide='ignore', invalid='ignore'):
            completness /= norm_comp[..., np.newaxis]
            purity /= norm_pur

        # TODO: One should probably return completness AND norm if one want to
        # combine several fluxes
        return completness, purity, norm_comp
LUSTIG Peter's avatar
LUSTIG Peter committed
137

138 139
    def PlotBin(self, data, title='', flux=None, thresh=None,
                nfluxlabels=None, nthreshlabels=None, **kwargs):
LUSTIG Peter's avatar
LUSTIG Peter committed
140 141 142
        tickfs = 20
        labelfs = 25

143 144 145 146 147
        if flux is None:
            flux = self.flux.to_value(u.mJy)
        if thresh is None:
            thresh = self.threshold

LUSTIG Peter's avatar
LUSTIG Peter committed
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
        if nfluxlabels is not None:
            _label_flux = np.geomspace(flux[0], flux[-1], nfluxlabels)
            _f_idx = find_nearest(flux, _label_flux)
            flblpos, _flbl = _f_idx, flux[_f_idx]
        else:
            flblpos, _flbl = np.arange(len(flux)), flux

        if nthreshlabels is not None:
            _label_thresh = np.linspace(thresh[0], thresh[-1], nthreshlabels)
            _t_idx = find_nearest(thresh, _label_thresh)
            print(_t_idx)
            tlblpos, _tlbl = _t_idx, thresh[_t_idx]
        else:
            tlblpos, _tlbl = np.arange(len(thresh)), thresh

        flbl = []
        for i in range(len(_flbl)):
            flbl.append('{:.1f}'.format(_flbl[i]))

        tlbl = []
        for i in range(len(_tlbl)):
            tlbl.append('{:.1f}'.format(_tlbl[i]))

        plt.figure()
        plt.title(title, fontsize=30)
        plt.xlabel('Detection Threshold [SNR]', fontsize=labelfs)
        plt.ylabel('Flux [mJy]', fontsize=labelfs)
        plt.xticks(tlblpos, tlbl, fontsize=tickfs)
        plt.yticks(flblpos, flbl, fontsize=tickfs)
177
        plt.imshow(data, origin='lower', aspect='auto', **kwargs)
LUSTIG Peter's avatar
LUSTIG Peter committed
178 179 180
        cbar = plt.colorbar()
        cbar.ax.tick_params(labelsize=tickfs)

181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
    def PlotFixedThreshold(self, data, thresholds, nfluxlabels=None,
                           hlines=None):

        linestyles = ['-', '--', '-.', ':']
        real_thresholds = find_nearest(self.thresholds, thresholds)
        _x = self.flux.to_value(u.mJy)
        for i in range(len(real_thresholds)):
            _y = data[:, real_thresholds[i]]
            plt.plot(_x, _y, linestyle=linestyles[i],
                     label='{:.1f}'.format(
                                self.thresholds[real_thresholds[i]]))
        if hlines is not None:
            for i, val in enumerate(hlines):
                plt.axhline(val, color='r')
        plt.title('Fixed Threshold', fontsize=30, y=1.02)
        plt.xlabel('Source Flux [mJy]', fontsize=25)
        plt.ylabel('Completness', fontsize=25)
        plt.yticks(fontsize=20)
        plt.xticks(fontsize=20)
        plt.subplots_adjust(left=0.12)
        ax = plt.gca()
        ax.set_xscale("log", nonposx='clip')
        # legend = plt.legend(fontsize=25, title='SNR', loc='lower right')
        legend = plt.legend(fontsize=25, title='SNR', loc='upper left',
                            framealpha=1)
        plt.setp(legend.get_title(), fontsize=25)
        plt.show(block=True)

    def PlotOverview(self, completness=None, purity=None, threshold=None,
                     flux=None):
        if completness is None:
            completness = self.completness
        if purity is None:
            purity = self.purity
        if threshold is None:
            threshold = self.thresholds
        threshold_bins = completness.shape[-1]
        fluxidx = find_nearest(self.flux.to_value(u.mJy), flux.to_value(u.mJy))
        realflux = self.flux[fluxidx]
        for iimg in range(len(fluxidx)):
            _completness = completness[fluxidx[iimg]]
            _purity = purity[fluxidx[iimg]]
            fig, axes = plt.subplots(nrows=2, ncols=threshold_bins,
                                     sharex=True, sharey=True)

            for i in range(threshold_bins):
                print(i)
                axes[0, i].imshow(_completness[:, :, i], vmin=0, vmax=1)
                im = axes[1, i].imshow(_purity[:, :, i], vmin=0, vmax=1)
                axes[1, i].set_xlabel("thresh={:.2f}".format(threshold[i]))
                if i == (threshold_bins-1):
                    # print('-----------')
                    divider = make_axes_locatable(axes[1, i])
                    cax = divider.append_axes('right', size='5%', pad=0.0)
                    fig = plt.gcf()
                    fig.colorbar(im, cax=cax, orientation='vertical')

            if flux is not None:
                axes[0, 1].set_title("{:.1f}".format(realflux[iimg]))
            axes[0, 0].set_ylabel("completness")
            axes[1, 0].set_ylabel("purity")

LUSTIG Peter's avatar
LUSTIG Peter committed
243

244 245 246

DATA_DIR = "/home/peter/Dokumente/Uni/Paris/Stage/data/v_1"
data = NikaMap.read(Path(DATA_DIR) / '..' / 'map.fits')
LUSTIG Peter's avatar
LUSTIG Peter committed
247 248
sh = data.data.shape
wcs = data.wcs
249

LUSTIG Peter's avatar
LUSTIG Peter committed
250 251 252
hdul = fits.HDUList(fits.open('/home/peter/Dokumente/Uni/Paris/Stage/'
                              'FirstSteps/Completness/'
                              'combined_tables_long.fits'))
253 254 255
nfluxes = hdul[0].header['NFLUXES']
print('{} different fluxes found'.format(nfluxes))

LUSTIG Peter's avatar
LUSTIG Peter committed
256 257 258 259
FLUX = []
SOURCE = []
FSOURCE = []

LUSTIG Peter's avatar
LUSTIG Peter committed
260
# for isimu in range(nfluxes):
261
for isimu in range(nfluxes):
LUSTIG Peter's avatar
LUSTIG Peter committed
262 263 264 265 266 267 268 269
    FLUX.append(u.Quantity(hdul[0].header['flux{}'.format(isimu)]))

    SOURCE.append(Table.read(hdul['DETECTED_SOURCES{}'
                                  .format(FLUX[isimu])]))
    FSOURCE.append(Table.read(hdul['FAKE_SOURCES{}'
                                   .format(FLUX[isimu])]))

xx = PCEvaluation(SOURCE, FSOURCE, sh, wcs, u.Quantity(FLUX))
LUSTIG Peter's avatar
LUSTIG Peter committed
270
# %% testfunctions
LUSTIG Peter's avatar
LUSTIG Peter committed
271

272

LUSTIG Peter's avatar
LUSTIG Peter committed
273 274
dd = xx.completness
print('comp shape', dd.shape)
275 276 277
# xx.PlotBin(xx.completness[:, 9, 9, :])
# xx.PlotFixedThreshold(xx.completness[:, 9, 9, :], np.array([3, 5]))
xx.PlotOverview(flux=5*u.mJy)
LUSTIG Peter's avatar
LUSTIG Peter committed
278
plt.show(block=True)
279
sys.exit('Done')