Commit 1be38e55 authored by Yannick Roehlly's avatar Yannick Roehlly

Initialise module parameters at instantiation

In order to speed-up module re-use, get the module needs (ssp,
filters...) out of pcigale database at module instantiation instead of
doing this everytime "process" is called.

To do that:
- we must give the module parameters (and proceed with their checking)
at
  module instantiation.
- we must add the possibility for each module to have its own
  initialisation code.

In addition, to create the configuration file, we must be able to load
the modules without any parameter to query them for their parameter
list.

The bc03, dh2002 and m2005 modules were changed to use this new
behavior.
parent 10165bcc
......@@ -58,6 +58,14 @@ class Module(common.SEDCreationModule):
"b_912": "Amplitude of Lyman discontinuity"
}
def _init_code(self):
"""Read the SSP from the database."""
imf = self.parameters["imf"]
metallicity = float(self.parameters["metallicity"])
database = Database()
self.ssp = database.get_ssp_bc03(imf, metallicity)
database.session.close_all()
def _process(self, sed, parameters):
"""Add the convolution of a Bruzual and Charlot SSP to the SED
......@@ -74,15 +82,11 @@ class Module(common.SEDCreationModule):
metallicity = float(self.parameters["metallicity"])
separation_age = int(self.parameters["separation_age"])
sfh_time, sfh_sfr = sed.sfh
ssp = self.ssp
# Age of the galaxy at each time of the SFH
sfh_age = np.max(sfh_time) - sfh_time
# First, we take the SSP out of the database.
database = Database()
ssp = database.get_ssp_bc03(imf, metallicity)
database.session.close_all()
# First, we process the young population (age lower than the
# separation age.)
young_sfh = np.copy(sfh_sfr)
......
......@@ -8,6 +8,64 @@ Licensed under the CeCILL-v2 licence - see Licence_CeCILL_V2-en.txt
"""
def complete_parameters(given_parameters, parameter_list):
"""Complete the given parameter list with the default values
Complete the given_parameters dictionary with missing parameters that have
a default value in the parameter_list. If a parameter from parameter_list
have no default value and is not present in given_parameters, raises an
error. If a parameter is present in given_parameters and not in
parameter_list, an exception is also raised.
Parameters
----------
given_parameters : dictionary
Parameter dictionary used to configure the module.
parameter_list : dictionary
Parameter list from the module.
Returns
-------
parameters : dictionary
Dictionary combining the given parameters with the default values for
the missing ones.
Raises
------
KeyError when the given parameters are different from the expected ones.
"""
# For parameters that are present on the parameter_list with a default
# value and that are not in the giver_parameters dictionary, we add them
# with their default value.
for key in parameter_list:
if (not key in given_parameters) and (
parameter_list[key][2] is not None):
given_parameters[key] = parameter_list[key][2]
# If the keys of the parameters dictionary are different from the one
# of the parameter_list dictionary, we raises a KeyError. That means
# that a parameter is missing (and has no default value) or that an
# unexpected one was given.
if not set(given_parameters.keys()) == set(parameter_list.keys()):
missing_parameters = (set(parameter_list.keys())
- set(given_parameters.keys()))
unexpected_parameters = (set(given_parameters.keys())
- set(parameter_list.keys()))
message = ""
if missing_parameters:
message += ("Missing parameters: " +
", ".join(missing_parameters) +
". ")
if unexpected_parameters:
message += ("Unexpected parameters: " +
", ".join(unexpected_parameters) +
".")
raise KeyError("The parameters passed are different from the "
"expected one." + message)
return given_parameters
class SEDCreationModule(object):
"""Abstract class, the pCigale SED creation modules are based on.
"""
......@@ -34,20 +92,55 @@ class SEDCreationModule(object):
# instructions for the configuration.
comments = ""
def __init__(self, name=None, **kwargs):
def __init__(self, name=None, blank=False, **kwargs):
"""Instantiate a SED creation module
A name can be given to the module. This can be useful when a same
module is used several times with different parameters in the SED
creation process.
The module parameters values can be passed as keyworded paramatres.
The module parameters must be passed as keyworded parameters. If a
parameter is not given but exists in the parameter_list with a default
value, this value is used. If a parameter is missing or if an
unexpected parameter is given, an error will be raised.
Parameters
----------
name : string
Name of the module.
blank : boolean
If true, return a non-parameterised module that will be used only
to query the module parameter list.
The module parameters must be given as keyworded parameters.
Raises
------
KeyError : when not all the needed parameters are given or when an
unexpected parameter is given.
"""
self.name = name
# parameters is a dictionary containing the actual values for each
# module parameter.
self.parameters = kwargs
if not blank:
# Parameters given in constructor.
parameters = kwargs
# Complete the parameter dictionary and "export" it to the module
self.parameters = complete_parameters(parameters,
self.parameter_list)
# Run the initialisation code specific to the module.
self._init_code()
def _init_code(self):
"""Initialisation code specific to the module.
For instance, a module taking data in the database can use this method
to do so, only one time when the module instantiates.
"""
pass
def _process(self, sed, parameters):
"""Do the actual processing of the module on a SED object
......@@ -64,72 +157,21 @@ class SEDCreationModule(object):
"""
raise NotImplementedError()
def process(self, sed, parameters=None):
def process(self, sed):
"""Process a SED object with the module
This method is responsible for checking the module parameters (whether
they are given in the method call or are taken from parameters class
attribute) before doing the actual processing (_process method). If a
parameter is not given but exists in the parameter_list with a default
value, this value is used.
The SED object is updated during the process, one must take care of
copying it before, if needed.
Parameters
----------
sed : pcigale.sed.SED object
parameters : dictionary
Dictionary containing the module parameter values, if it is not
given, the module parameter values are used
Raises
------
KeyError : when not all the needed parameters are given.
"""
self._process(sed, self.parameters)
# If the parameter dictionary is not passed, use the module one
if not parameters:
parameters = self.parameters
# For parameters that are present on the parameter_list with a default
# value and that are not in the parameters dictionary, we add them
# with their default value.
for key in self.parameter_list:
if (not key in parameters) and (
self.parameter_list[key][2] is not None):
parameters[key] = self.parameter_list[key][2]
# If the keys of the parameters dictionary are different from the one
# of the parameter_list dictionary, we raises a KeyError. That means
# that a parameter is missing (and has no default value) or that an
# unexpected one was given.
if not set(parameters.keys()) == set(self.parameter_list.keys()):
missing_parameters = (set(self.parameter_list.keys())
- set(parameters.keys()))
unexpected_parameters = (set(parameters.keys())
- set(self.parameter_list.keys()))
message = ""
if missing_parameters:
message += ("Missing parameters: " +
", ".join(missing_parameters) +
".")
if unexpected_parameters:
message += ("Unexpected parameters: " +
", ".join(unexpected_parameters) +
".")
raise KeyError("The parameters passed are different from the "
"expected one." + message)
# TODO: We should also check that all parameters is from the right
# type.
# We do the actual processing.
self._process(sed, parameters)
def get_module(name):
def get_module(name, **kwargs):
"""Get a SED creation module from its name
Parameters
......@@ -144,7 +186,6 @@ def get_module(name):
-------
a pcigale.sed.modules.Module instance
"""
# Determine the real module name by removing the dotted prefix.
module_name = name.split('.')[0]
......@@ -152,7 +193,7 @@ def get_module(name):
# TODO Find a better way to do dynamic import
import_string = 'from . import ' + module_name + ' as module'
exec import_string
return module.Module(name=name)
return module.Module(name=name, **kwargs)
except ImportError:
print('Module ' + module_name + ' does not exists!')
raise
......@@ -42,6 +42,12 @@ class Module(common.SEDCreationModule):
out_parameter_list = {'alpha': 'Alpha slope.'}
def _init_code(self):
"""Get the template set out of the database"""
database = Database()
self.dh2002 = database.get_dh2002_infrared_templates()
database.session.close_all()
def _process(self, sed, parameters):
"""Add the IR re-emission contributions
......@@ -54,11 +60,7 @@ class Module(common.SEDCreationModule):
alpha = float(parameters["alpha"])
attenuation_value_names = parameters["attenuation_value_names"]
# Get the template set out of the database
database = Database()
dh2002 = database.get_dh2002_infrared_templates()
database.session.close_all()
dh2002 = self.dh2002
ir_template = dh2002.get_template(alpha)
# Base name for adding information to the SED.
......
......@@ -105,6 +105,14 @@ class Module(common.SEDCreationModule):
'(young population).'
}
def _init_code(self):
"""Read the SSP from the database."""
imf = self.parameters["imf"]
metallicity = float(self.parameters["metallicity"])
database = Database()
self.ssp = database.get_ssp_m2005(imf, metallicity)
database.session.close_all()
def _process(self, sed, parameters):
"""Add the convolution of a Maraston 2005 SSP to the SED
......@@ -121,15 +129,11 @@ class Module(common.SEDCreationModule):
metallicity = float(self.parameters["metallicity"])
separation_age = int(self.parameters["separation_age"])
sfh_time, sfh_sfr = sed.sfh
ssp = self.ssp
# Age of the galaxy at each time of the SFH
sfh_age = np.max(sfh_time) - sfh_time
# First, we take the SSP out of the database.
database = Database()
ssp = database.get_ssp_m2005(imf, metallicity)
database.session.close_all()
# First, we process the young population (age lower than the
# separation age.)
young_sfh = np.copy(sfh_sfr)
......
......@@ -188,14 +188,15 @@ class Configuration(object):
sub_config = self.config["sed_creation_modules"][module_name]
for name, (typ, description, default) in \
modules.get_module(module_name).parameter_list.items():
modules.get_module(module_name,
blank=True).parameter_list.items():
if default is None:
default = ''
sub_config[name] = default
sub_config.comments[name] = wrap(description)
self.config['sed_creation_modules'].comments[module_name] = [
modules.get_module(module_name).comments]
modules.get_module(module_name, blank=True).comments]
# Configuration for the redshift module
self.config['redshift_configuration'] = {}
......@@ -205,7 +206,8 @@ class Configuration(object):
"each.")
module_name = self.config['redshift_module']
for name, (typ, desc, default) in \
modules.get_module(module_name).parameter_list.items():
modules.get_module(module_name,
blank=True).parameter_list.items():
if default is None:
default = ''
self.config['redshift_configuration'][name] = default
......@@ -217,7 +219,8 @@ class Configuration(object):
"Configuration of the statistical analysis method.")
module_name = self.config['analysis_method']
for name, (typ, desc, default) in \
analysis.get_module(module_name).parameter_list.items():
analysis.get_module(module_name,
blank=True).parameter_list.items():
if default is None:
default = ''
self.config['analysis_configuration'][name] = default
......
......@@ -171,8 +171,8 @@ class Module(common.AnalysisModule):
base.close()
# We get the redshift module.
redshift_module = get_module(redshift_module_name)
redshift_module.parameters = redshift_configuration
redshift_module = get_module(redshift_module_name,
**redshift_configuration)
# Read the observation table and complete it by adding error where
# none is provided and by adding the systematic deviation.
......
......@@ -66,8 +66,8 @@ class SedWarehouse(object):
sed = self.storage.get(sed_key)
if not sed:
mod = sed_modules.get_module(module_list.pop())
mod.parameters = parameter_list.pop()
mod = sed_modules.get_module(module_list.pop(),
**parameter_list.pop())
if (len(module_list) == 0):
sed = SED()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment