evaluation.py 11.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
from __future__ import absolute_import, division, print_function

from pathlib import Path
import os
import numpy as np
import matplotlib.pyplot as plt

from multiprocessing import Pool, cpu_count
from functools import partial

from astropy import units as u
from astropy.io import ascii
from astropy.wcs import WCS
from astropy.utils.console import ProgressBar
from astropy.table import vstack

from scipy.optimize import curve_fit

from nikamap import NikaMap, Jackknife
from nikamap.utils import pos_uniform
from astropy.io import fits
from astropy.table import Table, MaskedColumn
import sys
from mpl_toolkits.axes_grid1 import make_axes_locatable
LUSTIG Peter's avatar
LUSTIG Peter committed
25 26 27
import dill as pickle
from matplotlib.ticker import FormatStrFormatter
from collections import OrderedDict
LUSTIG Peter's avatar
LUSTIG Peter committed
28
from utils import completness_purity_wcs, completness_worker, purity_worker
LUSTIG Peter's avatar
LUSTIG Peter committed
29
from utils import find_nearest
LUSTIG Peter's avatar
LUSTIG Peter committed
30 31


LUSTIG Peter's avatar
LUSTIG Peter committed
32 33 34 35 36 37
class PCEvaluation:
    def __init__(self, sources, fake_sources, shape, wcs, flux=None,
                 mapbins=19, threshold_bins=5, threshold_range=(3, 5)):

        idxsort = np.argsort(flux.to_value(u.mJy))
        self.flux = flux[idxsort]
LUSTIG Peter's avatar
LUSTIG Peter committed
38 39
        self.sources = [sources[i] for i in idxsort]
        self.fake_sources = [fake_sources[i] for i in idxsort]
LUSTIG Peter's avatar
LUSTIG Peter committed
40 41 42 43 44 45 46 47 48 49 50

        self.completness = None
        self.purity = None
        assert len(sources) == len(fake_sources), ("Number of results for "
                                                   "sources and fake "
                                                   "sources is not the same.")
        assert len(sources) == len(flux), ("Number of provided fluxes differs "
                                           "from number of simulation results")
        assert type(mapbins) is int, "number of bins must be an integer"
        self.bins = mapbins

51 52
        self.thresholds = np.linspace(threshold_range[0], threshold_range[1],
                                      threshold_bins)
LUSTIG Peter's avatar
LUSTIG Peter committed
53 54 55 56 57 58 59 60 61 62 63 64 65 66
        threshold_edges = np.linspace(threshold_range[0], threshold_range[1],
                                      threshold_bins+1)

        self.shape_3D, self.wcs_3D = completness_purity_wcs(
                                        shape, wcs,
                                        bins=self.bins,
                                        threshold_range=threshold_range,
                                        threshold_bins=threshold_bins)

        # Testing the lower edges
        wcs_threshold = self.wcs_3D.sub([3])
        assert np.all(np.abs(wcs_threshold.all_pix2world(
                                np.arange(threshold_bins+1)-0.5, 0)
                             - threshold_edges) < 1e-15)
LUSTIG Peter's avatar
LUSTIG Peter committed
67
        self.completness, self.purity, self.hitmap = self.GetCP()
LUSTIG Peter's avatar
LUSTIG Peter committed
68

69 70 71 72 73 74 75 76 77
    def GetCompletnessBin(self, xbin, ybin):
        return self.completness[:, ybin, xbin, :]

    def GetPurityBin(self, xbin, ybin):
        return self.purity[:, ybin, xbin, :]

    def GetHitsBin(self, xbin, ybin):
        return self.purity[:, ybin, xbin]

LUSTIG Peter's avatar
LUSTIG Peter committed
78
    def GetCP(self, sources=None, fake_sources=None, wcs=None, shape=None,
79
              pool=None):
LUSTIG Peter's avatar
LUSTIG Peter committed
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
        if sources is None:
            sources = self.sources
        if fake_sources is None:
            fake_sources = self.fake_sources
        if wcs is None:
            wcs = self.wcs_3D
        if shape is None:
            shape = self.shape_3D

        if pool is not None:
            f = partial(self.completness_purity, wcs=wcs, shape=shape)
            res = pool.starmap(f, (sources, fake_sources))
            res = list(zip(*res))
            return res[0], res[1], res[2]
        else:
            comp, pur, hitm = [], [], []
            for i in range(len(sources)):
LUSTIG Peter's avatar
LUSTIG Peter committed
97
                tmpres = self.completness_purity(sources[i], fake_sources[i],
LUSTIG Peter's avatar
LUSTIG Peter committed
98 99 100 101
                                                 wcs, shape)
                comp.append(tmpres[0])
                pur.append(tmpres[1])
                hitm.append(tmpres[2])
LUSTIG Peter's avatar
LUSTIG Peter committed
102
            return np.array(comp), np.array(pur), np.array(hitm)
LUSTIG Peter's avatar
LUSTIG Peter committed
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133

    def completness_purity(self, sources, fake_sources, wcs=None,
                           shape=None):
        """Compute completness map for a given flux"""

        min_threshold, max_threshold = wcs.sub([3]).all_pix2world(
                                                        [-0.5, shape[2]-1],
                                                        0)[0]

        # %load_ext snakeviz
        # %snakeviz the following line.... all is spend in the find_peaks /
        # fit_2d_gaussian
        # TODO: Change the find_peaks routine, or maybe just the
        # fit_2d_gaussian to be FAST ! (Maybe look into gcntrd.pro routine
        # or photutils.centroid.centroid_1dg maybe ?)

        completness, norm_comp = completness_worker(shape, wcs, sources,
                                                    fake_sources,
                                                    min_threshold,
                                                    max_threshold)

        purity, norm_pur = purity_worker(shape, wcs, sources, max_threshold)

        # norm can be 0, so to avoid warning on invalid values...
        with np.errstate(divide='ignore', invalid='ignore'):
            completness /= norm_comp[..., np.newaxis]
            purity /= norm_pur

        # TODO: One should probably return completness AND norm if one want to
        # combine several fluxes
        return completness, purity, norm_comp
LUSTIG Peter's avatar
LUSTIG Peter committed
134

135 136
    def PlotBin(self, data, title='', flux=None, thresh=None,
                nfluxlabels=None, nthreshlabels=None, **kwargs):
LUSTIG Peter's avatar
LUSTIG Peter committed
137 138 139
        tickfs = 20
        labelfs = 25

140 141 142
        if flux is None:
            flux = self.flux.to_value(u.mJy)
        if thresh is None:
143
            thresh = self.thresholds
144

LUSTIG Peter's avatar
LUSTIG Peter committed
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
        if nfluxlabels is not None:
            _label_flux = np.geomspace(flux[0], flux[-1], nfluxlabels)
            _f_idx = find_nearest(flux, _label_flux)
            flblpos, _flbl = _f_idx, flux[_f_idx]
        else:
            flblpos, _flbl = np.arange(len(flux)), flux

        if nthreshlabels is not None:
            _label_thresh = np.linspace(thresh[0], thresh[-1], nthreshlabels)
            _t_idx = find_nearest(thresh, _label_thresh)
            print(_t_idx)
            tlblpos, _tlbl = _t_idx, thresh[_t_idx]
        else:
            tlblpos, _tlbl = np.arange(len(thresh)), thresh

        flbl = []
        for i in range(len(_flbl)):
            flbl.append('{:.1f}'.format(_flbl[i]))

        tlbl = []
        for i in range(len(_tlbl)):
            tlbl.append('{:.1f}'.format(_tlbl[i]))

        plt.figure()
        plt.title(title, fontsize=30)
        plt.xlabel('Detection Threshold [SNR]', fontsize=labelfs)
        plt.ylabel('Flux [mJy]', fontsize=labelfs)
        plt.xticks(tlblpos, tlbl, fontsize=tickfs)
        plt.yticks(flblpos, flbl, fontsize=tickfs)
174
        plt.imshow(data, origin='lower', aspect='auto', **kwargs)
LUSTIG Peter's avatar
LUSTIG Peter committed
175 176 177
        cbar = plt.colorbar()
        cbar.ax.tick_params(labelsize=tickfs)

178
    def PlotFixedThreshold(self, data, thresholds, nfluxlabels=None,
179
                           hlines=None, ylabel=''):
180 181 182 183

        linestyles = ['-', '--', '-.', ':']
        real_thresholds = find_nearest(self.thresholds, thresholds)
        _x = self.flux.to_value(u.mJy)
184 185

        plt.figure()
186 187 188 189 190 191 192 193 194 195
        for i in range(len(real_thresholds)):
            _y = data[:, real_thresholds[i]]
            plt.plot(_x, _y, linestyle=linestyles[i],
                     label='{:.1f}'.format(
                                self.thresholds[real_thresholds[i]]))
        if hlines is not None:
            for i, val in enumerate(hlines):
                plt.axhline(val, color='r')
        plt.title('Fixed Threshold', fontsize=30, y=1.02)
        plt.xlabel('Source Flux [mJy]', fontsize=25)
196
        plt.ylabel(ylabel, fontsize=25)
197 198 199 200 201 202 203 204 205 206
        plt.yticks(fontsize=20)
        plt.xticks(fontsize=20)
        plt.subplots_adjust(left=0.12)
        ax = plt.gca()
        ax.set_xscale("log", nonposx='clip')
        # legend = plt.legend(fontsize=25, title='SNR', loc='lower right')
        legend = plt.legend(fontsize=25, title='SNR', loc='upper left',
                            framealpha=1)
        plt.setp(legend.get_title(), fontsize=25)

207 208
    def PlotOverview(self, flux=None, completness=None, purity=None,
                     threshold=None):
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
        if completness is None:
            completness = self.completness
        if purity is None:
            purity = self.purity
        if threshold is None:
            threshold = self.thresholds
        threshold_bins = completness.shape[-1]
        fluxidx = find_nearest(self.flux.to_value(u.mJy), flux.to_value(u.mJy))
        realflux = self.flux[fluxidx]
        for iimg in range(len(fluxidx)):
            _completness = completness[fluxidx[iimg]]
            _purity = purity[fluxidx[iimg]]
            fig, axes = plt.subplots(nrows=2, ncols=threshold_bins,
                                     sharex=True, sharey=True)

            for i in range(threshold_bins):
                axes[0, i].imshow(_completness[:, :, i], vmin=0, vmax=1)
                im = axes[1, i].imshow(_purity[:, :, i], vmin=0, vmax=1)
                axes[1, i].set_xlabel("thresh={:.2f}".format(threshold[i]))
                if i == (threshold_bins-1):
                    # print('-----------')
                    divider = make_axes_locatable(axes[1, i])
                    cax = divider.append_axes('right', size='5%', pad=0.0)
                    fig = plt.gcf()
                    fig.colorbar(im, cax=cax, orientation='vertical')

            if flux is not None:
                axes[0, 1].set_title("{:.1f}".format(realflux[iimg]))
            axes[0, 0].set_ylabel("completness")
            axes[1, 0].set_ylabel("purity")

240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
    def PlotHitmap(self, flux=None, **kwargs):
        fluxidx = find_nearest(self.flux.to_value(u.mJy), flux.to_value(u.mJy))
        for iidx in fluxidx:
            plt.figure()
            plt.imshow(self.hitmap[iidx], origin='lower', **kwargs)
            plt.title('Hitmap {:.1f}'.format(flux), fontsize=30, y=1.02)


def UglyLoader(filename):
    hdul = fits.HDUList(fits.open(filename))
    nfluxes = hdul[0].header['NFLUXES']
    print('{} different fluxes found'.format(nfluxes))

    FLUX = []
    SOURCE = []
    FSOURCE = []

    # for isimu in range(nfluxes):
    for isimu in range(nfluxes):
        FLUX.append(u.Quantity(hdul[0].header['flux{}'.format(isimu)]))

        SOURCE.append(Table.read(hdul['DETECTED_SOURCES{}'
262
                                      .format(isimu)]))
263
        FSOURCE.append(Table.read(hdul['FAKE_SOURCES{}'
264
                                       .format(isimu)]))
265 266 267 268 269 270 271 272 273 274 275
    return u.Quantity(FLUX), SOURCE, FSOURCE


if __name__ == '__main__':

    DATA_DIR = "/home/peter/Dokumente/Uni/Paris/Stage/data/v_1"
    data = NikaMap.read(Path(DATA_DIR) / '..' / 'map.fits')
    sh = data.data.shape
    wcs = data.wcs

    fname = ('/home/peter/Dokumente/Uni/Paris/Stage/'
276
             'FirstSteps/Completness/NEWcombined_tables_long.fits')
277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297

    FLUX, SOURCE, FSOURCE = UglyLoader(fname)

    xx = PCEvaluation(SOURCE, FSOURCE, sh, wcs, FLUX, mapbins=19,
                      threshold_bins=6, threshold_range=(2.5, 5))
    xx.PlotBin(xx.GetCompletnessBin(9, 9), nfluxlabels=10, title='Completness')
    xx.PlotBin(xx.GetPurityBin(9, 9), nfluxlabels=10, title='Purity')
    xx.PlotFixedThreshold(xx.GetCompletnessBin(9, 9), np.array([3, 5]),
                          ylabel='Completness')
    xx.PlotFixedThreshold(xx.GetPurityBin(9, 9), np.array([3, 5]),
                          ylabel='Purity')
    xx.PlotOverview(flux=5*u.mJy)
    xx.PlotHitmap(flux=5*u.mJy)
    yy = PCEvaluation(SOURCE, FSOURCE, sh, wcs, FLUX, mapbins=9,
                      threshold_bins=6, threshold_range=(2.5, 5))
    yy.PlotFixedThreshold(yy.GetCompletnessBin(4, 4), np.array([3, 5]),
                          ylabel='Completness')
    yy.PlotFixedThreshold(yy.GetPurityBin(4, 4), np.array([3, 5]),
                          ylabel='Purity')

    plt.show(block=True)