Commit f964a39d authored by Médéric Boquien's avatar Médéric Boquien

Implement an observations manager. For now it is mainly used as a simple...

Implement an observations manager. For now it is mainly used as a simple container for the observations table but upcoming rearchitecturing will make full use of it.
parent b0f0f4d7
......@@ -4,6 +4,7 @@
### Added
- Provide the possibility not to store a given module in cache. This can be useful on computers with a limited amount of memory. The downside is that when not caching the model generation will be slower. (Médéric Boquien)
- An option `redshift\_decimals` is now provided in `pdf\_analysis` to indicate the number of decimals to round the observed redshifts to compute the grid of models. By default the model redshifts are rounded to two decimals but this can be insufficient at low z and/or when using narrow-band filters for instance. This only applies to the grid. The physical properties are still computed for the redshift at full precision. (Médéric Boquien)
- Bands with negative fluxes are now considered valid and are fitted as any other band. (Médéric Boquien)
### Changed
- Make the timestamp more readable when moving the out/ directory. (Médéric Boquien)
......
......@@ -136,125 +136,3 @@ def get_module(module_name):
except ImportError:
print('Module ' + module_name + ' does not exists!')
raise
def adjust_data(fluxes, errors, tolerance, lim_flag, default_error=0.1,
systematic_deviation=0.1):
"""Adjust the fluxes and errors replacing the invalid values by NaN, and
adding the systematic deviation. The systematic deviation changes the
errors to: sqrt(errors² + (fluxes*deviation)²)
Parameters
----------
fluxes: array of floats
Observed fluxes.
errors: array of floats
Observational errors in the same unit as the fluxes.
tolerance: float
Tolerance threshold under flux error is considered as 0.
lim_flag: boolean
Do we process upper limits (True) or treat them as no-data (False)?
default_error: float
Default error factor used when the provided error in under the
tolerance threshold.
systematic_deviation: float
Systematic deviation added to the error.
Returns
-------
error: array of floats
The corrected errors.
"""
# The arrays must have the same lengths.
if len(fluxes) != len(errors):
raise ValueError("The flux and error arrays must have the same "
"length.")
# We copy the arrays not to modify the original ones.
fluxes = fluxes.copy()
errors = errors.copy()
# We set invalid data to NaN
mask_invalid = np.where((fluxes <= tolerance) | (errors < -9990.))
fluxes[mask_invalid] = np.nan
errors[mask_invalid] = np.nan
# Replace missing errors by the default ones.
mask_noerror = np.where((fluxes > tolerance) & ~np.isfinite(errors))
errors[mask_noerror] = (default_error * fluxes[mask_noerror])
# Replace upper limits by no data if lim_flag==False
if not lim_flag:
mask_limflag = np.where((fluxes > tolerance) & (errors < tolerance))
fluxes[mask_limflag] = np.nan
errors[mask_limflag] = np.nan
# Add the systematic error.
mask_ok = np.where((fluxes > tolerance) & (errors > tolerance))
errors[mask_ok] = np.sqrt(errors[mask_ok]**2 +
(fluxes[mask_ok]*systematic_deviation)**2)
return fluxes, errors
def complete_obs_table(obs_table, used_columns, filter_list, tolerance,
lim_flag, default_error=0.1, systematic_deviation=0.1):
"""Complete the observation table
For each filter:
* If the corresponding error is not present in the used column list or in
the table columns, add (or replace) an error column with the default
error.
* Adjust the error value.
Parameters
----------
obs_table: astropy.table.Table
The observation table.
used_columns: list of strings
The list of columns to use in the observation table.
filter_list: list of strings
The list of filters used in the analysis.
tolerance: float
Tolerance threshold under flux error is considered as 0.
lim_flag: boolean
Do we process upper limits (True) or treat them as no-data (False)?
default_error: float
Default error factor used when the provided error in under the
tolerance threshold.
systematic_deviation: float
Systematic deviation added to the error.
Returns
-------
obs_table = astropy.table.Table
The completed observation table
Raises
------
Exception: When a filter is not present in the observation table.
"""
# TODO Print or log a warning when an error column is in the used column
# list but is not present in the observation table.
for name in filter_list:
if name not in obs_table.columns:
raise Exception("The filter <{}> (at least) is not present in "
"the observation table.".format(name))
name_err = name + "_err"
if name_err not in obs_table.columns:
obs_table.add_column(Column(name=name_err,
data=np.full(len(obs_table), np.nan)),
index=obs_table.colnames.index(name)+1)
elif name_err not in used_columns:
obs_table[name_err] = np.full(len(obs_table), np.nan)
obs_table[name], obs_table[name_err] = adjust_data(obs_table[name],
obs_table[name_err],
tolerance,
lim_flag,
default_error,
systematic_deviation)
return obs_table
......@@ -34,13 +34,14 @@ import time
import numpy as np
from ...utils import read_table
from .. import AnalysisModule, complete_obs_table
from .. import AnalysisModule
from .utils import save_results, analyse_chi2
from ...warehouse import SedWarehouse
from .workers import sed as worker_sed
from .workers import init_sed as init_worker_sed
from .workers import init_analysis as init_worker_analysis
from .workers import analysis as worker_analysis
from ...managers.observations import ObservationsManager
from ...managers.parameters import ParametersManager
......@@ -131,9 +132,8 @@ class PdfAnalysis(AnalysisModule):
# Read the observation table and complete it by adding error where
# none is provided and by adding the systematic deviation.
obs_table = complete_obs_table(read_table(conf['data_file']),
conf['bands'], filters, 0., lim_flag)
n_obs = len(obs_table)
obs = ObservationsManager(conf)
n_obs = len(obs.table)
z = np.array(conf['sed_modules_params']['redshifting']['redshift'])
......@@ -199,19 +199,19 @@ class PdfAnalysis(AnalysisModule):
n_obs)
if conf['cores'] == 1: # Do not create a new process
init_worker_analysis(*initargs)
for idx, obs in enumerate(obs_table):
for idx, obs in enumerate(obs.table):
worker_analysis(idx, obs)
else: # Analyse observations in parallel
with mp.Pool(processes=conf['cores'],
initializer=init_worker_analysis,
initargs=initargs) as pool:
pool.starmap(worker_analysis, enumerate(obs_table))
pool.starmap(worker_analysis, enumerate(obs.table))
analyse_chi2(best_chi2_red)
print("\nSaving results...")
save_results("results", obs_table['id'], variables, analysed_averages,
save_results("results", obs.table['id'], variables, analysed_averages,
analysed_std, best_chi2, best_chi2_red, best_parameters,
best_fluxes, filters, info)
......@@ -223,8 +223,8 @@ class PdfAnalysis(AnalysisModule):
for k in save:
save[k] = False
obs_fluxes = np.array([obs_table[name] for name in filters]).T
obs_errors = np.array([obs_table[name + "_err"] for name in
obs_fluxes = np.array([obs.table[name] for name in filters]).T
obs_errors = np.array([obs.table[name + "_err"] for name in
filters]).T
mock_fluxes = obs_fluxes.copy()
bestmod_fluxes = np.ctypeslib.as_array(best_fluxes[0])
......@@ -233,9 +233,8 @@ class PdfAnalysis(AnalysisModule):
mock_fluxes[wdata] = np.random.normal(bestmod_fluxes[wdata],
obs_errors[wdata])
mock_table = obs_table.copy()
for idx, name in enumerate(filters):
mock_table[name] = mock_fluxes[:, idx]
obs.table[name] = mock_fluxes[:, idx]
initargs = (params, filters, variables, z, model_fluxes,
model_variables, time.time(), mp.Value('i', 0),
......@@ -244,17 +243,17 @@ class PdfAnalysis(AnalysisModule):
lim_flag, n_obs)
if conf['cores'] == 1: # Do not create a new process
init_worker_analysis(*initargs)
for idx, mock in enumerate(mock_table):
for idx, mock in enumerate(obs.table):
worker_analysis(idx, mock)
else: # Analyse observations in parallel
with mp.Pool(processes=conf['cores'],
initializer=init_worker_analysis,
initargs=initargs) as pool:
pool.starmap(worker_analysis, enumerate(mock_table))
pool.starmap(worker_analysis, enumerate(obs.table))
print("\nSaving results...")
save_results("results_mock", mock_table['id'], variables,
save_results("results_mock", obs.table['id'], variables,
analysed_averages, analysed_std, best_chi2,
best_chi2_red, best_parameters, best_fluxes, filters,
info)
......
# -*- coding: utf-8 -*-
# Copyright (C) 2017 Universidad de Antofagasta
# Licensed under the CeCILL-v2 licence - see Licence_CeCILL_V2-en.txt
# Author: Médéric Boquien
from astropy.table import Column
import numpy as np
from ..utils import read_table
class ObservationsManager(object):
"""Class to abstract the handling of the observations and provide a
consistent interface for the rest of cigale to deal with observations.
An ObservationsManager is in charge of reading the input file to memory,
check the consistency of the data, replace invalid values with NaN, etc.
"""
def __new__(cls, config, **kwargs):
if config['data_file']:
return ObservationsManagerPassbands(config, **kwargs)
else:
return ObservationsManagerVirtual(config, **kwargs)
class ObservationsManagerPassbands(object):
"""Class to generate a manager for data files providing fluxes in
passbands.
A class instance can be used as a sequence. In that case a row is returned
at each iteration.
"""
def __init__(self, config, defaulterror=0.1, modelerror=0.1,
threshold=-9990.):
self.table = read_table(config['data_file'])
self.bands = [band for band in config['bands'] if not
band.endswith('_err')]
self.errors = [band for band in config['bands'] if
band.endswith('_err')]
# Sanitise the input
self._check_filters()
self._check_errors(defaulterror)
self._check_invalid(config['analysis_params']['lim_flag'],
threshold)
self._add_model_error(modelerror)
def __len__(self):
return len(self.table)
def __iter__(self):
self.idx = 0
self.max = len(self.table)
return self
def __next__(self):
if self.idx < self.max:
obs = self.table[self.idx]
self.idx += 1
return obs
raise StopIteration
def _check_filters(self):
"""Check whether the list of filters makes sense.
Two situations are checked:
* If a filter to be included in the fit is missing from the data file,
an exception is raised.
* If a filter is given in the input file but is not to be included in
the fit, a warning is displayed
"""
for band in self.bands + self.errors:
if band not in self.table.colnames:
raise Exception("{} to be taken in the fit but not present "
"in the observation table.".format(band))
for band in self.table.colnames:
if (band != 'id' and band != 'redshift' and
band not in self.bands + self.errors):
self.table.remove_column(band)
print("Warning: {} in the input file but not to be taken into "
"account in the fit.")
def _check_errors(self, defaulterror=0.1):
"""Check whether the error columns are present. If not, add them.
This happens in two cases. Either when the error column is not in the
list of bands to be analysed or when the error column is not present
in the data file. Note that an error cannot be included explicitly if
it is not present in the input file. Such a case would be ambiguous
and will have been caught by self._check_filters().
We take the error as defaulterror × flux, so by default 10% of the
flux. The absolute value of the flux is taken in case it is negative.
Parameters
----------
defaulterror: float
Fraction of the flux to take as the error when the latter is not
given in input. By default: 10%.
"""
if defaulterror < 0.:
raise ValueError("The relative default error must be positive.")
for band in self.bands:
banderr = band + '_err'
if banderr not in self.errors or banderr not in self.table.colnames:
colerr = Column(data=np.fabs(self.table[band] * defaulterror),
name=banderr)
self.table.add_column(colerr,
index=self.table.colnames.index(band)+1)
print("Warning: {}% of {} taken as errors.".format(defaulterror *
100.,
band))
def _check_invalid(self, upperlimits=False, threshold=-9990.):
"""Check whether invalid data are correctly marked as such.
This happens in two cases:
* Data are marked as upper limits (negative error) but the upper
limit flag is set to False.
* Data or errors are lower than -9990.
We mark invalid data as np.nan. In case the entire column is made of
invalid data, we remove them from the table
Parameters
----------
threshold: float
Threshold under which the data are consisdered invalid.
"""
allinvalid = []
for band in self.bands:
banderr = band + '_err'
w = np.where((self.table[band] < threshold) |
(self.table[banderr] < threshold))
self.table[band][w] = np.nan
if upperlimits is False:
w = np.where(self.table[banderr] < 0.)
self.table[band][w] = np.nan
if np.all(~np.isfinite(self.table[band])):
allinvalid.append(band)
for band in allinvalid:
self.bands.remove(band)
self.errors.remove(band + '_err')
self.table.remove_columns([band, band + '_err'])
print("Warning: {} removed as no valid data was found.".format(allinvalid))
def _add_model_error(self, modelerror=0.1):
"""Add in quadrature the error of the model to the input error.
Parameters
----------
modelerror: float
Relative error of the models relative to the flux. By default 10%.
"""
if modelerror < 0.:
raise ValueError("The relative model error must be positive.")
for band in self.bands:
banderr = band + '_err'
w = np.where(self.table[banderr] >= 0.)
self.table[banderr][w] = np.sqrt(self.table[banderr][w]**2. +
(self.table[band][w]*modelerror)**2.)
def generate_mock(self, fits):
"""Replaces the actual observations with a mock catalogue. It is
computed from the best fit fluxes of a previous run. For each object
and each band, we randomly draw a new flux from a Gaussian distribution
centered on the best fit flux and with a standard deviation identical
to the observed one.
Parameters
----------
fits: ResultsManager
Best fit fluxes of a previous run.
"""
for idx, band in enumerate(self.bands):
banderr = band + '_err'
self.table[band] = np.random.normal(fits.best.fluxes[:, idx],
np.fabs(self.table[banderr]))
def save(self, filename):
"""Saves the observations as seen internally by the code so it is easy
to see what fluxes are actually used in the fit. Files are saved in
FITS and ASCII formats.
Parameters
----------
filename: str
Root of the filename where to save the observations.
"""
self.table.write('out/{}.fits'.format(filename))
self.table.write('out/{}.txt'.format(filename),
format='ascii.fixed_width', delimiter=None)
class ObservationsManagerVirtual(object):
"""Virtual observations manager when there is no observations file given
as input. In that case we only use the list of bands given in the
pcigale.ini file.
"""
def __init__(self, config, **kwargs):
self.bands = [band for band in config['bands'] if not
band.endswith('_err')]
if len(self.bands) != len(config['bands']):
print("Warning: error bars were given in the list of bands.")
elif len(self.bands) == 0:
print("Warning: no band was given.")
# We set the other class members to None as they do not make sense in
# this situation
self.errors = None
self.table = None
def __len__(self):
"""As there is no observed object the length is naturally 0.
"""
return 0
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment