observations.py 12.3 KB
Newer Older
1 2 3 4 5
# -*- coding: utf-8 -*-
# Copyright (C) 2017 Universidad de Antofagasta
# Licensed under the CeCILL-v2 licence - see Licence_CeCILL_V2-en.txt
# Author: Médéric Boquien

6
from astropy.cosmology import WMAP7 as cosmo
7 8
from astropy.table import Column
import numpy as np
9
from scipy.constants import parsec
10 11

from ..utils import read_table
12 13
from .utils import get_info

14 15 16 17 18 19 20 21 22

class ObservationsManager(object):
    """Class to abstract the handling of the observations and provide a
    consistent interface for the rest of cigale to deal with observations.

    An ObservationsManager is in charge of reading the input file to memory,
    check the consistency of the data, replace invalid values with NaN, etc.

    """
23
    def __new__(cls, config, params=None, **kwargs):
24
        if config['data_file']:
25
            return ObservationsManagerPassbands(config, params, **kwargs)
26 27 28 29
        else:
            return ObservationsManagerVirtual(config, **kwargs)


30 31 32 33 34 35 36
class Observation(object):
    """Class to take one row of the observations table and extract the list of
    fluxes, intensive properties, extensive properties and their errors, that
    are going to be considered in the fit.
    """
    def __init__(self, row, cls):
        self.redshift = row['redshift']
37
        self.id = row['id']
38 39 40 41 42 43 44 45 46
        if 'distance' in row.colnames and np.isfinite(row['distance']):
            self.distance = row['distance'] * parsec * 1e6
        else:
            if self.redshift == 0.:
                self.distance = 10. * parsec
            elif self.redshift > 0:
                self.distance = cosmo.luminosity_distance(self.redshift).value
            else:
                self.distance = np.nan
47 48 49 50 51 52 53 54 55 56
        self.fluxes = np.array([row[band] for band in cls.bands])
        self.fluxes_err = np.array([row[band + '_err'] for band in cls.bands])
        self.intprops = np.array([row[prop] for prop in cls.intprops])
        self.intprops_err = np.array([row[prop + '_err'] for prop in
                                      cls.intprops])
        self.extprops = np.array([row[prop] for prop in cls.extprops])
        self.extprops_err = np.array([row[prop + '_err'] for prop in
                                      cls.extprops])


57 58 59 60 61 62 63 64
class ObservationsManagerPassbands(object):
    """Class to generate a manager for data files providing fluxes in
    passbands.

    A class instance can be used as a sequence. In that case a row is returned
    at each iteration.
    """

65
    def __init__(self, config, params, defaulterror=0.1, modelerror=0.1,
66 67
                 threshold=-9990.):

68 69 70
        self.conf = config
        self.params = params
        self.allpropertiesnames, self.massproportional = get_info(self)
71 72 73
        self.table = read_table(config['data_file'])
        self.bands = [band for band in config['bands'] if not
                      band.endswith('_err')]
74 75
        self.bands_err = [band for band in config['bands'] if
                          band.endswith('_err')]
76 77 78 79 80 81 82 83 84 85
        self.intprops = [prop for prop in config['properties'] if (prop not in
                         self.massproportional and not prop.endswith('_err'))]
        self.intprops_err = [prop for prop in config['properties'] if
                             (prop.endswith('_err') and prop[:-4] not in
                             self.massproportional)]
        self.extprops = [prop for prop in config['properties'] if (prop in
                         self.massproportional and not prop.endswith('_err'))]
        self.extprops_err = [prop for prop in config['properties'] if
                             (prop.endswith('_err') and prop[:-4] in
                             self.massproportional)]
86 87
        self.tofit = self.bands + self.intprops + self.extprops
        self.tofit_err = self.bands_err + self.intprops_err + self.extprops_err
88 89 90 91 92 93 94 95

        # Sanitise the input
        self._check_filters()
        self._check_errors(defaulterror)
        self._check_invalid(config['analysis_params']['lim_flag'],
                            threshold)
        self._add_model_error(modelerror)

96 97 98 99
        # Rebuild the quantities to fit after vetting them
        self.tofit = self.bands + self.intprops + self.extprops
        self.tofit_err = self.bands_err + self.intprops_err + self.extprops_err

100 101
        self.observations = list([Observation(row, self) for row in self.table])

102
    def __len__(self):
103
        return len(self.observations)
104 105 106

    def __iter__(self):
        self.idx = 0
107
        self.max = len(self.observations)
108 109 110 111 112

        return self

    def __next__(self):
        if self.idx < self.max:
113
            obs = self.observations[self.idx]
114 115 116 117 118
            self.idx += 1
            return obs
        raise StopIteration

    def _check_filters(self):
119
        """Check whether the list of filters and poperties makes sense.
120 121

        Two situations are checked:
122 123 124 125
        * If a filter or property to be included in the fit is missing from
        the data file, an exception is raised.
        * If a filter or property is given in the input file but is not to be
        included in the fit, a warning is displayed
126 127

        """
128 129
        for item in self.tofit + self.tofit_err:
            if item not in self.table.colnames:
130
                raise Exception("{} to be taken in the fit but not present "
131
                                "in the observation table.".format(item))
132

133
        for item in self.table.colnames:
134 135
            if (item != 'id' and item != 'redshift' and item != 'distance' and
                item not in self.tofit + self.tofit_err):
136
                self.table.remove_column(item)
137
                print("Warning: {} in the input file but not to be taken into"
138
                      " account in the fit.".format(item))
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161

    def _check_errors(self, defaulterror=0.1):
        """Check whether the error columns are present. If not, add them.

        This happens in two cases. Either when the error column is not in the
        list of bands to be analysed or when the error column is not present
        in the data file. Note that an error cannot be included explicitly if
        it is not present in the input file. Such a case would be ambiguous
        and will have been caught by self._check_filters().

        We take the error as defaulterror × flux, so by default 10% of the
        flux. The absolute value of the flux is taken in case it is negative.

        Parameters
        ----------
        defaulterror: float
            Fraction of the flux to take as the error when the latter is not
            given in input. By default: 10%.

        """
        if defaulterror < 0.:
            raise ValueError("The relative default error must be positive.")

162 163 164 165
        for item in self.tofit:
            error = item + '_err'
            if item in self.intprops:
                if error not in self.intprops_err or error not in self.table.colnames:
166
                    raise ValueError("Intensive properties errors must be in input file.")
167 168 169
            elif error not in self.tofit_err or error not in self.table.colnames:
                colerr = Column(data=np.fabs(self.table[item] * defaulterror),
                                name=error)
170
                self.table.add_column(colerr,
171
                                      index=self.table.colnames.index(item)+1)
172 173
                print("Warning: {}% of {} taken as errors.".format(defaulterror *
                                                                   100.,
174
                                                                   item))
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194

    def _check_invalid(self, upperlimits=False, threshold=-9990.):
        """Check whether invalid data are correctly marked as such.

        This happens in two cases:
        * Data are marked as upper limits (negative error) but the upper
        limit flag is set to False.
        * Data or errors are lower than -9990.

        We mark invalid data as np.nan. In case the entire column is made of
        invalid data, we remove them from the table

        Parameters
        ----------
        threshold: float
            Threshold under which the data are consisdered invalid.

        """
        allinvalid = []

195 196 197 198 199
        for item in self.bands + self.extprops:
            error = item + '_err'
            w = np.where((self.table[item] < threshold) |
                         (self.table[error] < threshold))
            self.table[item][w] = np.nan
200
            self.table[error][w] = np.nan
201
            if upperlimits is False:
202 203
                w = np.where(self.table[error] <= 0.)
                self.table[item][w] = np.nan
204
            else:
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
                w = np.where(self.table[error] == 0.)
                self.table[item][w] = np.nan
            if np.all(~np.isfinite(self.table[item])):
                allinvalid.append(item)

        for item in allinvalid:
            if item in self.bands:
                self.bands.remove(item)
                self.bands_err.remove(item + '_err')
            elif item in self.extprops:
                self.extprops.remove(item)
                self.extprops_err.remove(item + '_err')
            self.table.remove_columns([item, item + '_err'])
            print("Warning: {} removed as no valid data was found.".format(
                allinvalid))
220 221 222 223 224 225 226

    def _add_model_error(self, modelerror=0.1):
        """Add in quadrature the error of the model to the input error.

        Parameters
        ----------
        modelerror: float
227 228
            Relative error of the models relative to the flux (or property). By
            default 10%.
229 230 231 232 233

        """
        if modelerror < 0.:
            raise ValueError("The relative model error must be positive.")

234
        for item in self.bands + self.extprops:
235 236 237 238
            error = item + '_err'
            w = np.where(self.table[error] >= 0.)
            self.table[error][w] = np.sqrt(self.table[error][w]**2. + (
                self.table[item][w]*modelerror)**2.)
239 240 241 242 243 244 245 246 247 248 249 250 251 252

    def generate_mock(self, fits):
        """Replaces the actual observations with a mock catalogue. It is
        computed from the best fit fluxes of a previous run. For each object
        and each band, we randomly draw a new flux from a Gaussian distribution
        centered on the best fit flux and with a standard deviation identical
        to the observed one.

        Parameters
        ----------
        fits: ResultsManager
            Best fit fluxes of a previous run.

        """
253 254 255 256 257 258
        for idx, band in enumerate(self.bands):
            err = band + '_err'
            self.table[band] = np.random.normal(fits.best.fluxes[:, idx],
                                                np.fabs(self.table[err]))
        for idx, prop in enumerate(self.intprops):
            err = prop + '_err'
259 260
            index = fits.best.propertiesnames.index(prop)
            self.table[prop] = np.random.normal(fits.best.properties[:, index],
261 262 263
                                                np.fabs(self.table[err]))
        for idx, prop in enumerate(self.extprops):
            err = prop + '_err'
264 265
            index = fits.best.propertiesnames.index(prop)
            self.table[prop] = np.random.normal(fits.best.properties[:, index],
266
                                                np.fabs(self.table[err]))
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300

    def save(self, filename):
        """Saves the observations as seen internally by the code so it is easy
        to see what fluxes are actually used in the fit. Files are saved in
        FITS and ASCII formats.

        Parameters
        ----------
        filename: str
            Root of the filename where to save the observations.

        """
        self.table.write('out/{}.fits'.format(filename))
        self.table.write('out/{}.txt'.format(filename),
                         format='ascii.fixed_width', delimiter=None)


class ObservationsManagerVirtual(object):
    """Virtual observations manager when there is no observations file given
    as input. In that case we only use the list of bands given in the
    pcigale.ini file.
    """

    def __init__(self, config, **kwargs):
        self.bands = [band for band in config['bands'] if not
                      band.endswith('_err')]

        if len(self.bands) != len(config['bands']):
            print("Warning: error bars were given in the list of bands.")
        elif len(self.bands) == 0:
            print("Warning: no band was given.")

        # We set the other class members to None as they do not make sense in
        # this situation
301
        self.bands_err = None
302 303 304 305 306 307
        self.table = None

    def __len__(self):
        """As there is no observed object the length is naturally 0.
        """
        return 0