Commit e0e09b8c authored by Médéric Boquien's avatar Médéric Boquien

Simplify the interface of the analysis modules as it is needlessly restrictive...

Simplify the interface of the analysis modules as it is needlessly restrictive by forcing a precise list of parameters. Different modules do not necessarily need the same information and the need for different input parameters may emerge in the future. To make the interface future-proof, we simply pass the file containing the configuration file. Then each analysis module picks the data it needs from it.
parent ac38dd23
...@@ -8,7 +8,7 @@ import multiprocessing as mp ...@@ -8,7 +8,7 @@ import multiprocessing as mp
import sys import sys
from .session.configuration import Configuration from .session.configuration import Configuration
from .analysis_modules import get_module as get_analysis_module from .analysis_modules import get_module
from .analysis_modules.utils import ParametersHandler from .analysis_modules.utils import ParametersHandler
__version__ = "0.1-alpha" __version__ = "0.1-alpha"
...@@ -46,18 +46,8 @@ def check(config): ...@@ -46,18 +46,8 @@ def check(config):
def run(config): def run(config):
"""Run the analysis. """Run the analysis.
""" """
data_file = config.configuration['data_file'] analysis_module = get_module(config.configuration['analysis_method'])
column_list = config.configuration['column_list'] analysis_module.process(config.configuration)
creation_modules = config.configuration['creation_modules']
creation_modules_params = config.configuration['creation_modules_params']
analysis_module = get_analysis_module(config.configuration[
'analysis_method'])
analysis_module_params = config.configuration['analysis_method_params']
cores = config.configuration['cores']
analysis_module.process(data_file, column_list, creation_modules,
creation_modules_params, analysis_module_params,
cores)
def main(): def main():
......
...@@ -31,8 +31,7 @@ class AnalysisModule(object): ...@@ -31,8 +31,7 @@ class AnalysisModule(object):
# module parameter. # module parameter.
self.parameters = kwargs self.parameters = kwargs
def _process(self, data_file, column_list, creation_modules, def _process(self, configuration):
creation_modules_params, parameters):
"""Do the actual analysis """Do the actual analysis
This method is responsible for the fitting / analysis process This method is responsible for the fitting / analysis process
...@@ -40,19 +39,8 @@ class AnalysisModule(object): ...@@ -40,19 +39,8 @@ class AnalysisModule(object):
Parameters Parameters
---------- ----------
data_file: string configuration: dictionary
Name of the file containing the observations to be fitted. Configuration file
column_list: array of strings
Names of the columns from the data file to use in the analysis.
creation_modules: array of strings
Names (in the right order) of the modules to use to build the SED.
creation_modules_params: array of array of dictionaries
Array containing all the possible combinations of configurations
for the creation_modules. Each 'inner' array has the same length as
the creation_modules array and contains the configuration
dictionary for the corresponding module.
parameters: dictionary
Configuration for the module.
Returns Returns
------- -------
...@@ -61,8 +49,7 @@ class AnalysisModule(object): ...@@ -61,8 +49,7 @@ class AnalysisModule(object):
""" """
raise NotImplementedError() raise NotImplementedError()
def process(self, data_file, column_list, creation_modules, def process(self, configuration):
creation_modules_params, parameters):
"""Process with the analysis """Process with the analysis
This method is responsible for checking the module parameters before This method is responsible for checking the module parameters before
...@@ -72,19 +59,8 @@ class AnalysisModule(object): ...@@ -72,19 +59,8 @@ class AnalysisModule(object):
Parameters Parameters
---------- ----------
data_file: string configuration: dictionary
Name of the file containing the observations to be fitted. Contents of pcigale.ini in the form of a dictionary
column_list: array of strings
Names of the columns from the data file to use in the analysis.
creation_modules: array of strings
Names (in the right order) of the modules to use to build the SED.
creation_modules_params: array of array of dictionaries
Array containing all the possible combinations of configurations
for the creation_modules. Each 'inner' array has the same length as
the creation_modules array and contains the configuration
dictionary for the corresponding module.
parameters: dictionary
Configuration for the module.
Returns Returns
------- -------
...@@ -95,6 +71,7 @@ class AnalysisModule(object): ...@@ -95,6 +71,7 @@ class AnalysisModule(object):
KeyError: when not all the needed parameters are given. KeyError: when not all the needed parameters are given.
""" """
parameters = configuration['analysis_method_params']
# For parameters that are present on the parameter_list with a default # For parameters that are present on the parameter_list with a default
# value and that are not in the parameters dictionary, we add them # value and that are not in the parameters dictionary, we add them
# with their default value. # with their default value.
...@@ -124,8 +101,7 @@ class AnalysisModule(object): ...@@ -124,8 +101,7 @@ class AnalysisModule(object):
"expected one." + message) "expected one." + message)
# We do the actual processing # We do the actual processing
self._process(data_file, column_list, creation_modules, self._process(configuration)
creation_modules_params, parameters)
def get_module(module_name): def get_module(module_name):
......
...@@ -90,8 +90,7 @@ class PdfAnalysis(AnalysisModule): ...@@ -90,8 +90,7 @@ class PdfAnalysis(AnalysisModule):
)) ))
]) ])
def process(self, data_file, column_list, creation_modules, def process(self, conf):
creation_modules_params, config, cores):
"""Process with the psum analysis. """Process with the psum analysis.
The analysis is done in two steps which can both run on multiple The analysis is done in two steps which can both run on multiple
...@@ -102,19 +101,8 @@ class PdfAnalysis(AnalysisModule): ...@@ -102,19 +101,8 @@ class PdfAnalysis(AnalysisModule):
Parameters Parameters
---------- ----------
data_file: string conf: dictionary
Name of the file containing the observations to fit. Contents of pcigale.ini in the form of a dictionary
column_list: list of strings
Name of the columns from the data file to use for the analysis.
creation_modules: list of strings
List of the module names (in the right order) to use for creating
the SEDs.
creation_modules_params: list of dictionaries
List of the parameter dictionaries for each module.
config: dictionary
Dictionary containing the configuration.
core: integer
Number of cores to run the analysis on
""" """
np.seterr(invalid='ignore') np.seterr(invalid='ignore')
...@@ -125,23 +113,27 @@ class PdfAnalysis(AnalysisModule): ...@@ -125,23 +113,27 @@ class PdfAnalysis(AnalysisModule):
backup_dir() backup_dir()
# Initalise variables from input arguments. # Initalise variables from input arguments.
analysed_variables = config["analysed_variables"] creation_modules = conf['creation_modules']
creation_modules_params = conf['creation_modules_params']
analysed_variables = conf['analysis_method_params']["analysed_variables"]
analysed_variables_nolog = [variable[:-4] if variable.endswith('_log') analysed_variables_nolog = [variable[:-4] if variable.endswith('_log')
else variable for variable in else variable for variable in
analysed_variables] analysed_variables]
n_variables = len(analysed_variables) n_variables = len(analysed_variables)
save = {key: config["save_{}".format(key)].lower() == "true" save = {key: conf['analysis_method_params']["save_{}".format(key)].lower() == "true"
for key in ["best_sed", "chi2", "pdf"]} for key in ["best_sed", "chi2", "pdf"]}
lim_flag = config["lim_flag"].lower() == "true" lim_flag = conf['analysis_method_params']["lim_flag"].lower() == "true"
mock_flag = config["mock_flag"].lower() == "true" mock_flag = conf['analysis_method_params']["mock_flag"].lower() == "true"
filters = [name for name in column_list if not name.endswith('_err')] filters = [name for name in conf['column_list'] if not
name.endswith('_err')]
n_filters = len(filters) n_filters = len(filters)
# Read the observation table and complete it by adding error where # Read the observation table and complete it by adding error where
# none is provided and by adding the systematic deviation. # none is provided and by adding the systematic deviation.
obs_table = complete_obs_table(read_table(data_file), column_list, obs_table = complete_obs_table(read_table(conf['data_file']),
filters, TOLERANCE, lim_flag) conf['column_list'], filters, TOLERANCE,
lim_flag)
n_obs = len(obs_table) n_obs = len(obs_table)
w_redshifting = creation_modules.index('redshifting') w_redshifting = creation_modules.index('redshifting')
...@@ -184,12 +176,12 @@ class PdfAnalysis(AnalysisModule): ...@@ -184,12 +176,12 @@ class PdfAnalysis(AnalysisModule):
initargs = (params, filters, analysed_variables_nolog, model_fluxes, initargs = (params, filters, analysed_variables_nolog, model_fluxes,
model_variables, time.time(), mp.Value('i', 0)) model_variables, time.time(), mp.Value('i', 0))
if cores == 1: # Do not create a new process if conf['cores'] == 1: # Do not create a new process
init_worker_sed(*initargs) init_worker_sed(*initargs)
for idx in range(n_params): for idx in range(n_params):
worker_sed(idx) worker_sed(idx)
else: # Compute the models in parallel else: # Compute the models in parallel
with mp.Pool(processes=cores, initializer=init_worker_sed, with mp.Pool(processes=conf['cores'], initializer=init_worker_sed,
initargs=initargs) as pool: initargs=initargs) as pool:
pool.map(worker_sed, range(n_params)) pool.map(worker_sed, range(n_params))
...@@ -212,12 +204,13 @@ class PdfAnalysis(AnalysisModule): ...@@ -212,12 +204,13 @@ class PdfAnalysis(AnalysisModule):
analysed_averages, analysed_std, best_fluxes, analysed_averages, analysed_std, best_fluxes,
best_parameters, best_chi2, best_chi2_red, save, lim_flag, best_parameters, best_chi2, best_chi2_red, save, lim_flag,
n_obs) n_obs)
if cores == 1: # Do not create a new process if conf['cores'] == 1: # Do not create a new process
init_worker_analysis(*initargs) init_worker_analysis(*initargs)
for idx, obs in enumerate(obs_table): for idx, obs in enumerate(obs_table):
worker_analysis(idx, obs) worker_analysis(idx, obs)
else: # Analyse observations in parallel else: # Analyse observations in parallel
with mp.Pool(processes=cores, initializer=init_worker_analysis, with mp.Pool(processes=conf['cores'],
initializer=init_worker_analysis,
initargs=initargs) as pool: initargs=initargs) as pool:
pool.starmap(worker_analysis, enumerate(obs_table)) pool.starmap(worker_analysis, enumerate(obs_table))
......
...@@ -67,8 +67,7 @@ class SaveFluxes(AnalysisModule): ...@@ -67,8 +67,7 @@ class SaveFluxes(AnalysisModule):
)) ))
]) ])
def process(self, data_file, column_list, creation_modules, def process(self, conf):
creation_modules_params, parameters, cores):
"""Process with the savedfluxes analysis. """Process with the savedfluxes analysis.
All the possible theoretical SED are created and the fluxes in the All the possible theoretical SED are created and the fluxes in the
...@@ -77,34 +76,25 @@ class SaveFluxes(AnalysisModule): ...@@ -77,34 +76,25 @@ class SaveFluxes(AnalysisModule):
Parameters Parameters
---------- ----------
data_file: string conf: dictionary
Name of the file containing the observations to fit. Contents of pcigale.ini in the form of a dictionary
column_list: list of strings
Name of the columns from the data file to use for the analysis.
creation_modules: list of strings
List of the module names (in the right order) to use for creating
the SEDs.
creation_modules_params: list of dictionaries
List of the parameter dictionaries for each module.
parameters: dictionary
Dictionary containing the parameters.
cores: integer
Number of cores to run the analysis on
""" """
# Rename the output directory if it exists # Rename the output directory if it exists
backup_dir() backup_dir()
out_file = parameters["output_file"] creation_modules = conf['creation_modules']
out_format = parameters["output_format"] creation_modules_params = conf['creation_modules_params']
save_sed = parameters["save_sed"].lower() == "true" out_file = conf['analysis_method_params']["output_file"]
out_format = conf['analysis_method_params']["output_format"]
filters = [name for name in column_list if not name.endswith('_err')] save_sed = conf['analysis_method_params']["save_sed"].lower() == "true"
filters = [name for name in conf['column_list'] if not
name.endswith('_err')]
n_filters = len(filters) n_filters = len(filters)
w_redshifting = creation_modules.index('redshifting') w_redshifting = creation_modules.index('redshifting')
if list(creation_modules_params[w_redshifting]['redshift']) == ['']: if list(creation_modules_params[w_redshifting]['redshift']) == ['']:
obs_table = read_table(data_file) obs_table = read_table(conf['data_file'])
z = np.unique(np.around(obs_table['redshift'], z = np.unique(np.around(obs_table['redshift'],
decimals=REDSHIFT_DECIMALS)) decimals=REDSHIFT_DECIMALS))
creation_modules_params[w_redshifting]['redshift'] = z creation_modules_params[w_redshifting]['redshift'] = z
...@@ -118,14 +108,14 @@ class SaveFluxes(AnalysisModule): ...@@ -118,14 +108,14 @@ class SaveFluxes(AnalysisModule):
params = ParametersHandler(creation_modules, creation_modules_params) params = ParametersHandler(creation_modules, creation_modules_params)
n_params = params.size n_params = params.size
if parameters["variables"] == '': if conf['analysis_method_params']["variables"] == '':
# Retrieve an arbitrary SED to obtain the list of output parameters # Retrieve an arbitrary SED to obtain the list of output parameters
warehouse = SedWarehouse() warehouse = SedWarehouse()
sed = warehouse.get_sed(creation_modules, params.from_index(0)) sed = warehouse.get_sed(creation_modules, params.from_index(0))
info = list(sed.info.keys()) info = list(sed.info.keys())
del warehouse, sed del warehouse, sed
else: else:
info = parameters["variables"] info = conf['analysis_method_params']["variables"]
n_info = len(info) n_info = len(info)
info.sort() info.sort()
n_info = len(info) n_info = len(info)
...@@ -137,12 +127,13 @@ class SaveFluxes(AnalysisModule): ...@@ -137,12 +127,13 @@ class SaveFluxes(AnalysisModule):
initargs = (params, filters, save_sed, info, model_fluxes, initargs = (params, filters, save_sed, info, model_fluxes,
model_parameters, time.time(), mp.Value('i', 0)) model_parameters, time.time(), mp.Value('i', 0))
if cores == 1: # Do not create a new process if conf['cores'] == 1: # Do not create a new process
init_worker_fluxes(*initargs) init_worker_fluxes(*initargs)
for idx in range(n_params): for idx in range(n_params):
worker_fluxes(idx) worker_fluxes(idx)
else: # Analyse observations in parallel else: # Analyse observations in parallel
with mp.Pool(processes=cores, initializer=init_worker_fluxes, with mp.Pool(processes=conf['cores'],
initializer=init_worker_fluxes,
initargs=initargs) as pool: initargs=initargs) as pool:
pool.map(worker_fluxes, range(n_params)) pool.map(worker_fluxes, range(n_params))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment