Commit f3f9b0df authored by Médéric Boquien's avatar Médéric Boquien
Browse files

Allow savefluxes to use multiple cores.

parent 077253d2
...@@ -15,17 +15,55 @@ The data file is used only to get the list of fluxes to be computed. ...@@ -15,17 +15,55 @@ The data file is used only to get the list of fluxes to be computed.
""" """
import multiprocessing as mp
import os import os
from itertools import product from itertools import product, repeat
from collections import OrderedDict from collections import OrderedDict
from datetime import datetime from datetime import datetime
from astropy.table import Table from astropy.table import Table
from progressbar import ProgressBar
from . import AnalysisModule from . import AnalysisModule
from ..warehouse import SedWarehouse from ..warehouse import SedWarehouse
from ..data import Database from ..data import Database
def _worker_sed(warehouse, filters, modules, parameters):
"""Internal function to parallelize the computation of fluxes.
Parameters
----------
warehouse: SedWarehouse object
SedWarehosue instance that is used to generate models. This has to be
passed in argument to benefit from the cache in a multiprocessing
context. Because processes are forked, there is no issue with cache
consistency.
filters: list
List of filters
modules: list
List of modules
parameters: list of dictionaries
List of parameters for each module
"""
sed = warehouse.get_sed(modules, parameters)
row = []
# Add the parameter values to the row. Some parameters are array
# so we must join their content.
for module_param in parameters:
for value in module_param.values():
if type(value) == list:
value = ".".join(value)
row.append(value)
# Add the flux in each filter to the row
row += [sed.compute_fnu(filter_.trans_table,
filter_.effective_wavelength)
for filter_ in filters]
return row
class SaveFluxes(AnalysisModule): class SaveFluxes(AnalysisModule):
"""Save fluxes analysis module """Save fluxes analysis module
...@@ -54,7 +92,7 @@ class SaveFluxes(AnalysisModule): ...@@ -54,7 +92,7 @@ class SaveFluxes(AnalysisModule):
]) ])
def process(self, data_file, column_list, creation_modules, def process(self, data_file, column_list, creation_modules,
creation_modules_params, parameters): creation_modules_params, parameters, cores):
"""Process with the savedfluxes analysis. """Process with the savedfluxes analysis.
All the possible theoretical SED are created and the fluxes in the All the possible theoretical SED are created and the fluxes in the
...@@ -74,6 +112,8 @@ class SaveFluxes(AnalysisModule): ...@@ -74,6 +112,8 @@ class SaveFluxes(AnalysisModule):
List of the parameter dictionaries for each module. List of the parameter dictionaries for each module.
parameters: dictionary parameters: dictionary
Dictionary containing the parameters. Dictionary containing the parameters.
cores: integer
Number of cores to run the analysis on
""" """
...@@ -95,40 +135,26 @@ class SaveFluxes(AnalysisModule): ...@@ -95,40 +135,26 @@ class SaveFluxes(AnalysisModule):
with Database() as base: with Database() as base:
filter_list = [base.get_filter(name) for name in filter_names] filter_list = [base.get_filter(name) for name in filter_names]
# Content of the output table. # Columns of the output table
# In the output table, we put the content of the sed.info dictionary out_columns = []
# plus the flux in all the filters. As all the SEDs are made with the for module_param_list in zip(creation_modules,
# same pipeline, they should have the same sed.info dictionary keys. creation_modules_params[0]):
output = [] for module_param in product([module_param_list[0]],
module_param_list[1].keys()):
# Open the warehouse out_columns.append(".".join(module_param))
sed_warehouse = SedWarehouse( out_columns += filter_names
cache_type=parameters["storage_type"])
# Parallel computation of the fluxes
# We loop over all the possible theoretical SEDs with SedWarehouse(cache_type=parameters["storage_type"]) as warehouse,\
progress_bar = ProgressBar(maxval=len(creation_modules_params)).start() mp.Pool(processes=cores) as pool:
for model_index, parameters in enumerate(creation_modules_params): out_rows = pool.starmap(_worker_sed,
sed = sed_warehouse.get_sed(creation_modules, parameters) zip(repeat(warehouse),
repeat(filter_list),
# Take the content of the sed info dictionary. repeat(creation_modules),
row = list(sed.info.values()) creation_modules_params))
# Add the flux in each filter to the row
row += [sed.compute_fnu(filter_.trans_table,
filter_.effective_wavelength)
for filter_ in filter_list]
output.append(row)
progress_bar.update(model_index + 1)
progress_bar.finish()
# We take the names of the columns from the last computed SED.
out_columns = list(sed.info.keys()) + filter_names
# The zip call is to convert the list of rows to a list of columns. # The zip call is to convert the list of rows to a list of columns.
out_table = Table(list(zip(*output)), names=out_columns) out_table = Table(list(zip(*out_rows)), names=out_columns)
out_table.write(out_file, format=out_format) out_table.write(out_file, format=out_format)
# AnalysisModule to be returned by get_module # AnalysisModule to be returned by get_module
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment