observations.py 12.3 KB
Newer Older
1
2
3
4
5
# -*- coding: utf-8 -*-
# Copyright (C) 2017 Universidad de Antofagasta
# Licensed under the CeCILL-v2 licence - see Licence_CeCILL_V2-en.txt
# Author: Médéric Boquien

6
from astropy.cosmology import WMAP7 as cosmo
7
8
from astropy.table import Column
import numpy as np
9
from scipy.constants import parsec
10
11

from ..utils import read_table
12
13
from .utils import get_info

14
15
16
17
18
19
20
21
22

class ObservationsManager(object):
    """Class to abstract the handling of the observations and provide a
    consistent interface for the rest of cigale to deal with observations.

    An ObservationsManager is in charge of reading the input file to memory,
    check the consistency of the data, replace invalid values with NaN, etc.

    """
23
    def __new__(cls, config, params=None, **kwargs):
24
        if config['data_file']:
25
            return ObservationsManagerPassbands(config, params, **kwargs)
26
27
28
29
        else:
            return ObservationsManagerVirtual(config, **kwargs)


30
31
32
33
34
35
36
class Observation(object):
    """Class to take one row of the observations table and extract the list of
    fluxes, intensive properties, extensive properties and their errors, that
    are going to be considered in the fit.
    """
    def __init__(self, row, cls):
        self.redshift = row['redshift']
37
        self.id = row['id']
38
39
40
41
42
43
44
45
46
        if 'distance' in row.colnames and np.isfinite(row['distance']):
            self.distance = row['distance'] * parsec * 1e6
        else:
            if self.redshift == 0.:
                self.distance = 10. * parsec
            elif self.redshift > 0:
                self.distance = cosmo.luminosity_distance(self.redshift).value
            else:
                self.distance = np.nan
47
48
49
50
51
52
53
54
55
56
        self.fluxes = np.array([row[band] for band in cls.bands])
        self.fluxes_err = np.array([row[band + '_err'] for band in cls.bands])
        self.intprops = np.array([row[prop] for prop in cls.intprops])
        self.intprops_err = np.array([row[prop + '_err'] for prop in
                                      cls.intprops])
        self.extprops = np.array([row[prop] for prop in cls.extprops])
        self.extprops_err = np.array([row[prop + '_err'] for prop in
                                      cls.extprops])


57
58
59
60
61
62
63
64
class ObservationsManagerPassbands(object):
    """Class to generate a manager for data files providing fluxes in
    passbands.

    A class instance can be used as a sequence. In that case a row is returned
    at each iteration.
    """

65
    def __init__(self, config, params, defaulterror=0.1, modelerror=0.1,
66
67
                 threshold=-9990.):

68
69
70
        self.conf = config
        self.params = params
        self.allpropertiesnames, self.massproportional = get_info(self)
71
72
73
        self.table = read_table(config['data_file'])
        self.bands = [band for band in config['bands'] if not
                      band.endswith('_err')]
74
75
        self.bands_err = [band for band in config['bands'] if
                          band.endswith('_err')]
76
77
78
79
80
81
82
83
84
85
        self.intprops = [prop for prop in config['properties'] if (prop not in
                         self.massproportional and not prop.endswith('_err'))]
        self.intprops_err = [prop for prop in config['properties'] if
                             (prop.endswith('_err') and prop[:-4] not in
                             self.massproportional)]
        self.extprops = [prop for prop in config['properties'] if (prop in
                         self.massproportional and not prop.endswith('_err'))]
        self.extprops_err = [prop for prop in config['properties'] if
                             (prop.endswith('_err') and prop[:-4] in
                             self.massproportional)]
86
87
        self.tofit = self.bands + self.intprops + self.extprops
        self.tofit_err = self.bands_err + self.intprops_err + self.extprops_err
88
89
90
91
92
93
94
95

        # Sanitise the input
        self._check_filters()
        self._check_errors(defaulterror)
        self._check_invalid(config['analysis_params']['lim_flag'],
                            threshold)
        self._add_model_error(modelerror)

96
97
98
99
        # Rebuild the quantities to fit after vetting them
        self.tofit = self.bands + self.intprops + self.extprops
        self.tofit_err = self.bands_err + self.intprops_err + self.extprops_err

100
101
        self.observations = list([Observation(row, self) for row in self.table])

102
    def __len__(self):
103
        return len(self.observations)
104
105
106

    def __iter__(self):
        self.idx = 0
107
        self.max = len(self.observations)
108
109
110
111
112

        return self

    def __next__(self):
        if self.idx < self.max:
113
            obs = self.observations[self.idx]
114
115
116
117
118
            self.idx += 1
            return obs
        raise StopIteration

    def _check_filters(self):
119
        """Check whether the list of filters and poperties makes sense.
120
121

        Two situations are checked:
122
123
124
125
        * If a filter or property to be included in the fit is missing from
        the data file, an exception is raised.
        * If a filter or property is given in the input file but is not to be
        included in the fit, a warning is displayed
126
127

        """
128
129
        for item in self.tofit + self.tofit_err:
            if item not in self.table.colnames:
130
                raise Exception("{} to be taken in the fit but not present "
131
                                "in the observation table.".format(item))
132

133
        for item in self.table.colnames:
134
135
            if (item != 'id' and item != 'redshift' and item != 'distance' and
                item not in self.tofit + self.tofit_err):
136
                self.table.remove_column(item)
137
                print("Warning: {} in the input file but not to be taken into"
138
                      " account in the fit.".format(item))
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

    def _check_errors(self, defaulterror=0.1):
        """Check whether the error columns are present. If not, add them.

        This happens in two cases. Either when the error column is not in the
        list of bands to be analysed or when the error column is not present
        in the data file. Note that an error cannot be included explicitly if
        it is not present in the input file. Such a case would be ambiguous
        and will have been caught by self._check_filters().

        We take the error as defaulterror × flux, so by default 10% of the
        flux. The absolute value of the flux is taken in case it is negative.

        Parameters
        ----------
        defaulterror: float
            Fraction of the flux to take as the error when the latter is not
            given in input. By default: 10%.

        """
        if defaulterror < 0.:
            raise ValueError("The relative default error must be positive.")

162
163
164
165
        for item in self.tofit:
            error = item + '_err'
            if item in self.intprops:
                if error not in self.intprops_err or error not in self.table.colnames:
166
                    raise ValueError("Intensive properties errors must be in input file.")
167
168
169
            elif error not in self.tofit_err or error not in self.table.colnames:
                colerr = Column(data=np.fabs(self.table[item] * defaulterror),
                                name=error)
170
                self.table.add_column(colerr,
171
                                      index=self.table.colnames.index(item)+1)
172
173
                print("Warning: {}% of {} taken as errors.".format(defaulterror *
                                                                   100.,
174
                                                                   item))
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194

    def _check_invalid(self, upperlimits=False, threshold=-9990.):
        """Check whether invalid data are correctly marked as such.

        This happens in two cases:
        * Data are marked as upper limits (negative error) but the upper
        limit flag is set to False.
        * Data or errors are lower than -9990.

        We mark invalid data as np.nan. In case the entire column is made of
        invalid data, we remove them from the table

        Parameters
        ----------
        threshold: float
            Threshold under which the data are consisdered invalid.

        """
        allinvalid = []

195
196
197
198
199
        for item in self.bands + self.extprops:
            error = item + '_err'
            w = np.where((self.table[item] < threshold) |
                         (self.table[error] < threshold))
            self.table[item][w] = np.nan
200
            self.table[error][w] = np.nan
201
            if upperlimits is False:
202
203
                w = np.where(self.table[error] <= 0.)
                self.table[item][w] = np.nan
204
            else:
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
                w = np.where(self.table[error] == 0.)
                self.table[item][w] = np.nan
            if np.all(~np.isfinite(self.table[item])):
                allinvalid.append(item)

        for item in allinvalid:
            if item in self.bands:
                self.bands.remove(item)
                self.bands_err.remove(item + '_err')
            elif item in self.extprops:
                self.extprops.remove(item)
                self.extprops_err.remove(item + '_err')
            self.table.remove_columns([item, item + '_err'])
            print("Warning: {} removed as no valid data was found.".format(
                allinvalid))
220
221
222
223
224
225
226

    def _add_model_error(self, modelerror=0.1):
        """Add in quadrature the error of the model to the input error.

        Parameters
        ----------
        modelerror: float
227
228
            Relative error of the models relative to the flux (or property). By
            default 10%.
229
230
231
232
233

        """
        if modelerror < 0.:
            raise ValueError("The relative model error must be positive.")

234
        for item in self.bands + self.extprops:
235
236
237
238
            error = item + '_err'
            w = np.where(self.table[error] >= 0.)
            self.table[error][w] = np.sqrt(self.table[error][w]**2. + (
                self.table[item][w]*modelerror)**2.)
239
240
241
242
243
244
245
246
247
248
249
250
251
252

    def generate_mock(self, fits):
        """Replaces the actual observations with a mock catalogue. It is
        computed from the best fit fluxes of a previous run. For each object
        and each band, we randomly draw a new flux from a Gaussian distribution
        centered on the best fit flux and with a standard deviation identical
        to the observed one.

        Parameters
        ----------
        fits: ResultsManager
            Best fit fluxes of a previous run.

        """
253
254
255
256
257
258
        for idx, band in enumerate(self.bands):
            err = band + '_err'
            self.table[band] = np.random.normal(fits.best.fluxes[:, idx],
                                                np.fabs(self.table[err]))
        for idx, prop in enumerate(self.intprops):
            err = prop + '_err'
259
260
            index = fits.best.propertiesnames.index(prop)
            self.table[prop] = np.random.normal(fits.best.properties[:, index],
261
262
263
                                                np.fabs(self.table[err]))
        for idx, prop in enumerate(self.extprops):
            err = prop + '_err'
264
265
            index = fits.best.propertiesnames.index(prop)
            self.table[prop] = np.random.normal(fits.best.properties[:, index],
266
                                                np.fabs(self.table[err]))
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300

    def save(self, filename):
        """Saves the observations as seen internally by the code so it is easy
        to see what fluxes are actually used in the fit. Files are saved in
        FITS and ASCII formats.

        Parameters
        ----------
        filename: str
            Root of the filename where to save the observations.

        """
        self.table.write('out/{}.fits'.format(filename))
        self.table.write('out/{}.txt'.format(filename),
                         format='ascii.fixed_width', delimiter=None)


class ObservationsManagerVirtual(object):
    """Virtual observations manager when there is no observations file given
    as input. In that case we only use the list of bands given in the
    pcigale.ini file.
    """

    def __init__(self, config, **kwargs):
        self.bands = [band for band in config['bands'] if not
                      band.endswith('_err')]

        if len(self.bands) != len(config['bands']):
            print("Warning: error bars were given in the list of bands.")
        elif len(self.bands) == 0:
            print("Warning: no band was given.")

        # We set the other class members to None as they do not make sense in
        # this situation
301
        self.bands_err = None
302
303
304
305
306
307
        self.table = None

    def __len__(self):
        """As there is no observed object the length is naturally 0.
        """
        return 0