Commit e0e09b8c authored by Médéric Boquien's avatar Médéric Boquien

Simplify the interface of the analysis modules as it is needlessly restrictive...

Simplify the interface of the analysis modules as it is needlessly restrictive by forcing a precise list of parameters. Different modules do not necessarily need the same information and the need for different input parameters may emerge in the future. To make the interface future-proof, we simply pass the file containing the configuration file. Then each analysis module picks the data it needs from it.
parent ac38dd23
......@@ -8,7 +8,7 @@ import multiprocessing as mp
import sys
from .session.configuration import Configuration
from .analysis_modules import get_module as get_analysis_module
from .analysis_modules import get_module
from .analysis_modules.utils import ParametersHandler
__version__ = "0.1-alpha"
......@@ -46,18 +46,8 @@ def check(config):
def run(config):
"""Run the analysis.
"""
data_file = config.configuration['data_file']
column_list = config.configuration['column_list']
creation_modules = config.configuration['creation_modules']
creation_modules_params = config.configuration['creation_modules_params']
analysis_module = get_analysis_module(config.configuration[
'analysis_method'])
analysis_module_params = config.configuration['analysis_method_params']
cores = config.configuration['cores']
analysis_module.process(data_file, column_list, creation_modules,
creation_modules_params, analysis_module_params,
cores)
analysis_module = get_module(config.configuration['analysis_method'])
analysis_module.process(config.configuration)
def main():
......
......@@ -31,8 +31,7 @@ class AnalysisModule(object):
# module parameter.
self.parameters = kwargs
def _process(self, data_file, column_list, creation_modules,
creation_modules_params, parameters):
def _process(self, configuration):
"""Do the actual analysis
This method is responsible for the fitting / analysis process
......@@ -40,19 +39,8 @@ class AnalysisModule(object):
Parameters
----------
data_file: string
Name of the file containing the observations to be fitted.
column_list: array of strings
Names of the columns from the data file to use in the analysis.
creation_modules: array of strings
Names (in the right order) of the modules to use to build the SED.
creation_modules_params: array of array of dictionaries
Array containing all the possible combinations of configurations
for the creation_modules. Each 'inner' array has the same length as
the creation_modules array and contains the configuration
dictionary for the corresponding module.
parameters: dictionary
Configuration for the module.
configuration: dictionary
Configuration file
Returns
-------
......@@ -61,8 +49,7 @@ class AnalysisModule(object):
"""
raise NotImplementedError()
def process(self, data_file, column_list, creation_modules,
creation_modules_params, parameters):
def process(self, configuration):
"""Process with the analysis
This method is responsible for checking the module parameters before
......@@ -72,19 +59,8 @@ class AnalysisModule(object):
Parameters
----------
data_file: string
Name of the file containing the observations to be fitted.
column_list: array of strings
Names of the columns from the data file to use in the analysis.
creation_modules: array of strings
Names (in the right order) of the modules to use to build the SED.
creation_modules_params: array of array of dictionaries
Array containing all the possible combinations of configurations
for the creation_modules. Each 'inner' array has the same length as
the creation_modules array and contains the configuration
dictionary for the corresponding module.
parameters: dictionary
Configuration for the module.
configuration: dictionary
Contents of pcigale.ini in the form of a dictionary
Returns
-------
......@@ -95,6 +71,7 @@ class AnalysisModule(object):
KeyError: when not all the needed parameters are given.
"""
parameters = configuration['analysis_method_params']
# For parameters that are present on the parameter_list with a default
# value and that are not in the parameters dictionary, we add them
# with their default value.
......@@ -124,8 +101,7 @@ class AnalysisModule(object):
"expected one." + message)
# We do the actual processing
self._process(data_file, column_list, creation_modules,
creation_modules_params, parameters)
self._process(configuration)
def get_module(module_name):
......
......@@ -90,8 +90,7 @@ class PdfAnalysis(AnalysisModule):
))
])
def process(self, data_file, column_list, creation_modules,
creation_modules_params, config, cores):
def process(self, conf):
"""Process with the psum analysis.
The analysis is done in two steps which can both run on multiple
......@@ -102,19 +101,8 @@ class PdfAnalysis(AnalysisModule):
Parameters
----------
data_file: string
Name of the file containing the observations to fit.
column_list: list of strings
Name of the columns from the data file to use for the analysis.
creation_modules: list of strings
List of the module names (in the right order) to use for creating
the SEDs.
creation_modules_params: list of dictionaries
List of the parameter dictionaries for each module.
config: dictionary
Dictionary containing the configuration.
core: integer
Number of cores to run the analysis on
conf: dictionary
Contents of pcigale.ini in the form of a dictionary
"""
np.seterr(invalid='ignore')
......@@ -125,23 +113,27 @@ class PdfAnalysis(AnalysisModule):
backup_dir()
# Initalise variables from input arguments.
analysed_variables = config["analysed_variables"]
creation_modules = conf['creation_modules']
creation_modules_params = conf['creation_modules_params']
analysed_variables = conf['analysis_method_params']["analysed_variables"]
analysed_variables_nolog = [variable[:-4] if variable.endswith('_log')
else variable for variable in
analysed_variables]
n_variables = len(analysed_variables)
save = {key: config["save_{}".format(key)].lower() == "true"
save = {key: conf['analysis_method_params']["save_{}".format(key)].lower() == "true"
for key in ["best_sed", "chi2", "pdf"]}
lim_flag = config["lim_flag"].lower() == "true"
mock_flag = config["mock_flag"].lower() == "true"
lim_flag = conf['analysis_method_params']["lim_flag"].lower() == "true"
mock_flag = conf['analysis_method_params']["mock_flag"].lower() == "true"
filters = [name for name in column_list if not name.endswith('_err')]
filters = [name for name in conf['column_list'] if not
name.endswith('_err')]
n_filters = len(filters)
# Read the observation table and complete it by adding error where
# none is provided and by adding the systematic deviation.
obs_table = complete_obs_table(read_table(data_file), column_list,
filters, TOLERANCE, lim_flag)
obs_table = complete_obs_table(read_table(conf['data_file']),
conf['column_list'], filters, TOLERANCE,
lim_flag)
n_obs = len(obs_table)
w_redshifting = creation_modules.index('redshifting')
......@@ -184,12 +176,12 @@ class PdfAnalysis(AnalysisModule):
initargs = (params, filters, analysed_variables_nolog, model_fluxes,
model_variables, time.time(), mp.Value('i', 0))
if cores == 1: # Do not create a new process
if conf['cores'] == 1: # Do not create a new process
init_worker_sed(*initargs)
for idx in range(n_params):
worker_sed(idx)
else: # Compute the models in parallel
with mp.Pool(processes=cores, initializer=init_worker_sed,
with mp.Pool(processes=conf['cores'], initializer=init_worker_sed,
initargs=initargs) as pool:
pool.map(worker_sed, range(n_params))
......@@ -212,12 +204,13 @@ class PdfAnalysis(AnalysisModule):
analysed_averages, analysed_std, best_fluxes,
best_parameters, best_chi2, best_chi2_red, save, lim_flag,
n_obs)
if cores == 1: # Do not create a new process
if conf['cores'] == 1: # Do not create a new process
init_worker_analysis(*initargs)
for idx, obs in enumerate(obs_table):
worker_analysis(idx, obs)
else: # Analyse observations in parallel
with mp.Pool(processes=cores, initializer=init_worker_analysis,
with mp.Pool(processes=conf['cores'],
initializer=init_worker_analysis,
initargs=initargs) as pool:
pool.starmap(worker_analysis, enumerate(obs_table))
......
......@@ -67,8 +67,7 @@ class SaveFluxes(AnalysisModule):
))
])
def process(self, data_file, column_list, creation_modules,
creation_modules_params, parameters, cores):
def process(self, conf):
"""Process with the savedfluxes analysis.
All the possible theoretical SED are created and the fluxes in the
......@@ -77,34 +76,25 @@ class SaveFluxes(AnalysisModule):
Parameters
----------
data_file: string
Name of the file containing the observations to fit.
column_list: list of strings
Name of the columns from the data file to use for the analysis.
creation_modules: list of strings
List of the module names (in the right order) to use for creating
the SEDs.
creation_modules_params: list of dictionaries
List of the parameter dictionaries for each module.
parameters: dictionary
Dictionary containing the parameters.
cores: integer
Number of cores to run the analysis on
conf: dictionary
Contents of pcigale.ini in the form of a dictionary
"""
# Rename the output directory if it exists
backup_dir()
out_file = parameters["output_file"]
out_format = parameters["output_format"]
save_sed = parameters["save_sed"].lower() == "true"
filters = [name for name in column_list if not name.endswith('_err')]
creation_modules = conf['creation_modules']
creation_modules_params = conf['creation_modules_params']
out_file = conf['analysis_method_params']["output_file"]
out_format = conf['analysis_method_params']["output_format"]
save_sed = conf['analysis_method_params']["save_sed"].lower() == "true"
filters = [name for name in conf['column_list'] if not
name.endswith('_err')]
n_filters = len(filters)
w_redshifting = creation_modules.index('redshifting')
if list(creation_modules_params[w_redshifting]['redshift']) == ['']:
obs_table = read_table(data_file)
obs_table = read_table(conf['data_file'])
z = np.unique(np.around(obs_table['redshift'],
decimals=REDSHIFT_DECIMALS))
creation_modules_params[w_redshifting]['redshift'] = z
......@@ -118,14 +108,14 @@ class SaveFluxes(AnalysisModule):
params = ParametersHandler(creation_modules, creation_modules_params)
n_params = params.size
if parameters["variables"] == '':
if conf['analysis_method_params']["variables"] == '':
# Retrieve an arbitrary SED to obtain the list of output parameters
warehouse = SedWarehouse()
sed = warehouse.get_sed(creation_modules, params.from_index(0))
info = list(sed.info.keys())
del warehouse, sed
else:
info = parameters["variables"]
info = conf['analysis_method_params']["variables"]
n_info = len(info)
info.sort()
n_info = len(info)
......@@ -137,12 +127,13 @@ class SaveFluxes(AnalysisModule):
initargs = (params, filters, save_sed, info, model_fluxes,
model_parameters, time.time(), mp.Value('i', 0))
if cores == 1: # Do not create a new process
if conf['cores'] == 1: # Do not create a new process
init_worker_fluxes(*initargs)
for idx in range(n_params):
worker_fluxes(idx)
else: # Analyse observations in parallel
with mp.Pool(processes=cores, initializer=init_worker_fluxes,
with mp.Pool(processes=conf['cores'],
initializer=init_worker_fluxes,
initargs=initargs) as pool:
pool.map(worker_fluxes, range(n_params))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment