__init__.py 9.64 KB
Newer Older
1 2
# -*- coding: utf-8 -*-
# Copyright (C) 2013 Centre de données Astrophysiques de Marseille
3 4
# Copyright (C) 2013-2014 Institute of Astronomy
# Copyright (C) 2013-2014 Yannick Roehlly <yannick@iaora.eu>
5
# Licensed under the CeCILL-v2 licence - see Licence_CeCILL_V2-en.txt
6
# Author: Yannick Roehlly & Médéric Boquien
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27

"""
Probability Density Function analysis module
============================================

This module builds the probability density functions (PDF) of the SED
parameters to compute their moments.

The models corresponding to all possible combinations of parameters are
computed and their fluxes in the same filters as the observations are
integrated. These fluxes are compared to the observed ones to compute the
χ² value of the fitting. This χ² give a probability that is associated with
the model values for the parameters.

At the end, for each parameter, the probability-weighted mean and standard
deviation are computed and the best fitting model (the one with the least
reduced χ²) is given for each observation.

"""

from collections import OrderedDict
28
import ctypes
29
import multiprocessing as mp
30 31 32 33 34
from multiprocessing.sharedctypes import RawArray
import time

import numpy as np

35 36
from ...utils import read_table
from .. import AnalysisModule, complete_obs_table
37
from .utils import save_table_analysis, save_table_best
38 39
from ...warehouse import SedWarehouse
from ...data import Database
40
from .workers import sed as worker_sed
41 42
from .workers import init_sed as init_worker_sed
from .workers import init_analysis as init_worker_analysis
43
from .workers import analysis as worker_analysis
44
from ..utils import ParametersHandler, backup_dir
45 46

# Tolerance threshold under which any flux or error is considered as 0.
47
TOLERANCE = 1e-12
48
# Limit the redshift to this number of decimals
49
REDSHIFT_DECIMALS = 2
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66


class PdfAnalysis(AnalysisModule):
    """PDF analysis module"""

    parameter_list = OrderedDict([
        ("analysed_variables", (
            "array of strings",
            "List of the variables (in the SEDs info dictionaries) for which "
            "the statistical analysis will be done.",
            ["sfr", "average_sfr"]
        )),
        ("save_best_sed", (
            "boolean",
            "If true, save the best SED for each observation to a file.",
            False
        )),
67
        ("save_chi2", (
68
            "boolean",
69
            "If true, for each observation and each analysed variable save "
70
            "the reduced chi2.",
71 72 73 74
            False
        )),
        ("save_pdf", (
            "boolean",
75 76
            "If true, for each observation and each analysed variable save "
            "the probability density function.",
77 78 79 80 81 82 83 84 85 86
            False
        )),
        ("storage_type", (
            "string",
            "Type of storage used to cache the generate SED.",
            "memory"
        ))
    ])

    def process(self, data_file, column_list, creation_modules,
87
                creation_modules_params, config, cores):
88 89
        """Process with the psum analysis.

90 91 92 93 94
        The analysis is done in two steps which can both run on multiple
        processors to run faster. The first step is to compute all the fluxes
        associated with each model as well as ancillary data such as the SED
        information. The second step is to carry out the analysis of each
        object, considering all models at once.
95 96 97 98 99 100 101 102 103 104 105 106

        Parameters
        ----------
        data_file: string
            Name of the file containing the observations to fit.
        column_list: list of strings
            Name of the columns from the data file to use for the analysis.
        creation_modules: list of strings
            List of the module names (in the right order) to use for creating
            the SEDs.
        creation_modules_params: list of dictionaries
            List of the parameter dictionaries for each module.
107 108
        config: dictionary
            Dictionary containing the configuration.
109 110
        core: integer
            Number of cores to run the analysis on
111 112 113

        """

114 115
        print("Initialising the analysis module... ")

116
        # Rename the output directory if it exists
117
        backup_dir()
118

119 120 121
        # Initalise variables from input arguments.
        analysed_variables = config["analysed_variables"]
        n_variables = len(analysed_variables)
122
        save = {key:config["save_{}".format(key)].lower() == "true"
123
                for key in ["best_sed", "chi2", "pdf"]}
124

125 126
        # Get the needed filters in the pcigale database. We use an ordered
        # dictionary because we need the keys to always be returned in the
127 128
        # same order. We also put the filters in the shared modules as they
        # are needed to compute the fluxes during the models generation.
129
        with Database() as base:
130 131 132 133
            filters = OrderedDict([(name, base.get_filter(name))
                                   for name in column_list
                                   if not name.endswith('_err')])
        n_filters = len(filters)
134 135 136

        # Read the observation table and complete it by adding error where
        # none is provided and by adding the systematic deviation.
137
        obs_table = complete_obs_table(read_table(data_file), column_list,
138 139
                                       filters, TOLERANCE)
        n_obs = len(obs_table)
140

141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
        w_redshifting = creation_modules.index('redshifting')
        if creation_modules_params[w_redshifting]['redshift'] == ['']:
            z = np.unique(np.around(obs_table['redshift'],
                                    decimals=REDSHIFT_DECIMALS))
            creation_modules_params[w_redshifting]['redshift'] = z
            del z

        # The parameters handler allows us to retrieve the models parameters
        # from a 1D index. This is useful in that we do not have to create
        # a list of parameters as they are computed on-the-fly. It also has
        # nice goodies such as finding the index of the first parameter to
        # have changed between two indices or the number of models.
        params = ParametersHandler(creation_modules, creation_modules_params)
        n_params = params.size

156 157 158 159 160 161 162
        # Retrieve an arbitrary SED to obtain the list of output parameters
        warehouse = SedWarehouse(cache_type=config["storage_type"])
        sed = warehouse.get_sed(creation_modules, params.from_index(0))
        info = sed.info
        n_info = len(sed.info)
        del warehouse, sed

163 164
        print("Computing the models fluxes...")

165 166 167 168 169 170 171
        # Arrays where we store the data related to the models. For memory
        # efficiency reasons, we use RawArrays that will be passed in argument
        # to the pool. Each worker will fill a part of the RawArrays. It is
        # important that there is no conflict and that two different workers do
        # not write on the same section.
        # We put the shape in a tuple along with the RawArray because workers
        # need to know the shape to create the numpy array from the RawArray.
172 173
        model_redshifts = (RawArray(ctypes.c_double, n_params),
                           (n_params))
174
        model_fluxes = (RawArray(ctypes.c_double,
175 176
                                 n_params * n_filters),
                        (n_params, n_filters))
177
        model_variables = (RawArray(ctypes.c_double,
178 179
                                    n_params * n_variables),
                           (n_params, n_variables))
180

181 182 183
        initargs = (params, filters, analysed_variables, model_redshifts,
                    model_fluxes, model_variables, time.time(),
                    mp.Value('i', 0))
184 185
        if cores == 1:  # Do not create a new process
            init_worker_sed(*initargs)
186 187 188
            for idx in range(n_params):
                worker_sed(idx)
        else:  # Analyse observations in parallel
189 190
            with mp.Pool(processes=cores, initializer=init_worker_sed,
                         initargs=initargs) as pool:
191
                pool.map(worker_sed, range(n_params))
192

193
        print('\nAnalysing models...')
194

195 196 197 198 199 200 201 202 203 204 205 206
        # We use RawArrays for the same reason as previously
        analysed_averages = (RawArray(ctypes.c_double, n_obs * n_variables),
                           (n_obs, n_variables))
        analysed_std = (RawArray(ctypes.c_double, n_obs * n_variables),
                           (n_obs, n_variables))
        best_fluxes = (RawArray(ctypes.c_double, n_obs * n_filters),
                           (n_obs, n_filters))
        best_parameters = (RawArray(ctypes.c_double, n_obs * n_info),
                           (n_obs, n_info))
        best_chi2 = (RawArray(ctypes.c_double, n_obs), (n_obs))
        best_chi2_red = (RawArray(ctypes.c_double, n_obs), (n_obs))

207 208
        initargs = (params, filters, analysed_variables, model_redshifts,
                    model_fluxes, model_variables, time.time(),
209 210 211
                    mp.Value('i', 0), analysed_averages, analysed_std,
                    best_fluxes, best_parameters, best_chi2, best_chi2_red,
                    save, n_obs)
212 213
        if cores == 1:  # Do not create a new process
            init_worker_analysis(*initargs)
214 215
            for idx, obs in enumerate(obs_table):
                worker_analysis(idx, obs)
216
        else:  # Analyse observations in parallel
217 218
            with mp.Pool(processes=cores, initializer=init_worker_analysis,
                         initargs=initargs) as pool:
219
                pool.starmap(worker_analysis, enumerate(obs_table))
220

221
        print("\nSaving results...")
222

223
        save_table_analysis(obs_table['id'], analysed_variables,
224
                            analysed_averages, analysed_std)
225
        save_table_best(obs_table['id'], best_chi2, best_chi2_red,
226
                        best_parameters, best_fluxes, filters, info)
227 228

        print("Run completed!")
229

230 231
# AnalysisModule to be returned by get_module
Module = PdfAnalysis