Commit 4bfd4b4f authored by Médéric Boquien's avatar Médéric Boquien

Rewrite the savefluxes module to make use of the new infrastructure. Though...

Rewrite the savefluxes module to make use of the new infrastructure. Though this implementation is not terribly scalable in case of some really large arrays. It should be enough for now anyway. Famous last words.
parent e79f941d
......@@ -14,55 +14,21 @@ parameters.
The data file is used only to get the list of fluxes to be computed.
"""
import multiprocessing as mp
import os
from itertools import product, repeat
from collections import OrderedDict
import ctypes
from datetime import datetime
from astropy.table import Table
from itertools import product, repeat
import os
import multiprocessing as mp
from multiprocessing.sharedctypes import RawArray
import time
from . import AnalysisModule
from ..warehouse import SedWarehouse
from ..data import Database
from .utils import find_changed_parameters
def _worker_sed(warehouse, filters, modules, parameters, changed):
"""Internal function to parallelize the computation of fluxes.
Parameters
----------
warehouse: SedWarehouse object
SedWarehosue instance that is used to generate models. This has to be
passed in argument to benefit from the cache in a multiprocessing
context. Because processes are forked, there is no issue with cache
consistency.
filters: list
List of filters
modules: list
List of modules
parameters: list of dictionaries
List of parameters for each module
"""
sed = warehouse.get_sed(modules, parameters)
warehouse.partial_clear_cache(changed)
row = []
# Add the parameter values to the row. Some parameters are array
# so we must join their content.
for module_param in parameters:
for value in module_param.values():
if type(value) == list:
value = ".".join(value)
row.append(value)
# Add the flux in each filter to the row
row += [sed.compute_fnu(filter_.trans_table,
filter_.effective_wavelength)
for filter_ in filters]
return row
from .utils import ParametersHandler, backup_dir, save_fluxes
from ..warehouse import SedWarehouse
from .workers import init_fluxes as init_worker_fluxes
from .workers import fluxes as worker_fluxes
class SaveFluxes(AnalysisModule):
......@@ -118,47 +84,58 @@ class SaveFluxes(AnalysisModule):
"""
# Rename the output directory if it exists
backup_dir()
out_file = parameters["output_file"]
out_format = parameters["output_format"]
# If the output file already exists make a copy.
if os.path.isfile(out_file):
new_name = datetime.now().strftime("%Y%m%d%H%M") + "_" + out_file
os.rename(out_file, new_name)
print("The existing {} file was renamed to {}".format(
out_file,
new_name
))
# Get the filters in the database
filter_names = [name for name in column_list
if not name.endswith('_err')]
# The parameters handler allows us to retrieve the models parameters
# from a 1D index. This is useful in that we do not have to create
# a list of parameters as they are computed on-the-fly. It also has
# nice goodies such as finding the index of the first parameter to
# have changed between two indices or the number of models.
params = ParametersHandler(creation_modules, creation_modules_params)
n_params = params.size
# Get the needed filters in the pcigale database. We use an ordered
# dictionary because we need the keys to always be returned in the
# same order. We also put the filters in the shared modules as they
# are needed to compute the fluxes during the models generation.
with Database() as base:
filter_list = [base.get_filter(name) for name in filter_names]
# Columns of the output table
out_columns = []
for module_param_list in zip(creation_modules,
creation_modules_params[0]):
for module_param in product([module_param_list[0]],
module_param_list[1].keys()):
out_columns.append(".".join(module_param))
out_columns += filter_names
# Parallel computation of the fluxes
with SedWarehouse(cache_type=parameters["storage_type"]) as warehouse,\
mp.Pool(processes=cores) as pool:
changed_pars = find_changed_parameters(creation_modules_params)
out_rows = pool.starmap(_worker_sed,
zip(repeat(warehouse),
repeat(filter_list),
repeat(creation_modules),
creation_modules_params,
changed_pars))
# The zip call is to convert the list of rows to a list of columns.
out_table = Table(list(zip(*out_rows)), names=out_columns)
out_table.write(out_file, format=out_format)
filters = OrderedDict([(name, base.get_filter(name))
for name in column_list
if not name.endswith('_err')])
n_filters = len(filters)
# Retrieve an arbitrary SED to obtain the list of output parameters
warehouse = SedWarehouse(cache_type=parameters["storage_type"])
sed = warehouse.get_sed(creation_modules, params.from_index(0))
info = sed.info
n_info = len(sed.info)
del warehouse, sed
model_fluxes = (RawArray(ctypes.c_double,
n_params * n_filters),
(n_params, n_filters))
model_parameters = (RawArray(ctypes.c_double,
n_params * n_info),
(n_params, n_info))
initargs = (params, filters, model_fluxes, model_parameters,
time.time(), mp.Value('i', 0))
if cores == 1: # Do not create a new process
init_worker_fluxes(*initargs)
for idx in range(n_params):
worker_fluxes(idx)
else: # Analyse observations in parallel
with mp.Pool(processes=cores, initializer=init_worker_fluxes,
initargs=initargs) as pool:
pool.map(worker_fluxes, range(n_params))
save_fluxes(model_fluxes, model_parameters, filters, info, out_file,
out_format=out_format)
# AnalysisModule to be returned by get_module
Module = SaveFluxes
......@@ -13,6 +13,7 @@ import itertools
import os
import numpy as np
from astropy.table import Table, Column
# Directory where the output files are stored
OUT_DIR = "out/"
......@@ -133,3 +134,37 @@ def backup_dir(directory=OUT_DIR):
new_name
))
os.mkdir(directory)
def save_fluxes(model_fluxes, model_parameters, filters, names, filename,
directory=OUT_DIR, out_format='ascii'):
"""Save fluxes and associated parameters into a table.
Parameters
----------
model_fluxes: RawArray
Contains the fluxes of each model.
model_parameters: RawArray
Contains the parameters associated to each model.
filters: OrderedDict
Contains the filters.
names: List
Contains the parameters names.
filename: str
Name under which the file should be saved.
directory: str
Directory under which the file should be saved.
out_format: str
Format of the output file
"""
out_fluxes = np.ctypeslib.as_array(model_fluxes[0])
out_fluxes = out_fluxes.reshape(model_fluxes[1])
out_params = np.ctypeslib.as_array(model_parameters[0])
out_params = out_params.reshape(model_parameters[1])
out_table = Table(np.hstack((out_fluxes, out_params)),
names=list(filters.keys()) + list(names))
out_table.write("{}/{}".format(directory, filename), format=out_format)
# -*- coding: utf-8 -*-
# Copyright (C) 2013 Centre de données Astrophysiques de Marseille
# Copyright (C) 2013-2014 Institute of Astronomy
# Copyright (C) 2014 Yannick Roehlly <yannick@iaora.eu>
# Licensed under the CeCILL-v2 licence - see Licence_CeCILL_V2-en.txt
# Author: Yannick Roehlly & Médéric Boquien
import time
import numpy as np
from ..warehouse import SedWarehouse
def init_fluxes(params, filters, fluxes, info, t_begin, n_computed):
"""Initializer of the pool of processes. It is mostly used to convert
RawArrays into numpy arrays. The latter are defined as global variables to
be accessible from the workers.
Parameters
----------
params: ParametersHandler
Handles the parameters from a 1D index.
filters: OrderedDict
Contains filters to compute the fluxes.
fluxes: RawArray and tuple containing the shape
Fluxes of individual models. Shared among workers.
n_computed: Value
Number of computed models. Shared among workers.
t_begin: float
Time of the beginning of the computation.
"""
global gbl_model_fluxes, gbl_model_info, gbl_n_computed, gbl_t_begin
global gbl_params, gbl_previous_idx, gbl_filters, gbl_warehouse
gbl_model_fluxes = np.ctypeslib.as_array(fluxes[0])
gbl_model_fluxes = gbl_model_fluxes.reshape(fluxes[1])
gbl_model_info = np.ctypeslib.as_array(info[0])
gbl_model_info = gbl_model_info.reshape(info[1])
gbl_n_computed = n_computed
gbl_t_begin = t_begin
gbl_params = params
gbl_previous_idx = -1
gbl_filters = filters
gbl_warehouse = SedWarehouse(cache_type="memory")
def fluxes(idx):
"""Worker process to retrieve a SED and affect the relevant data to shared
RawArrays.
Parameters
----------
idx: int
Index of the model to retrieve its parameters from the parameters
handler.
"""
global gbl_previous_idx
if gbl_previous_idx > -1:
gbl_warehouse.partial_clear_cache(
gbl_params.index_module_changed(gbl_previous_idx, idx))
gbl_previous_idx = idx
sed = gbl_warehouse.get_sed(gbl_params.modules,
gbl_params.from_index(idx))
if 'age' in sed.info and sed.info['age'] > sed.info['universe.age']:
model_fluxes = -99. * np.ones(len(gbl_filters))
else:
model_fluxes = np.array([sed.compute_fnu(filter_.trans_table,
filter_.effective_wavelength)
for filter_ in gbl_filters.values()])
gbl_model_fluxes[idx, :] = model_fluxes
gbl_model_info[idx, :] = list(sed.info.values())
with gbl_n_computed.get_lock():
gbl_n_computed.value += 1
n_computed = gbl_n_computed.value
if n_computed % 100 == 0 or n_computed == gbl_params.size:
t_elapsed = time.time() - gbl_t_begin
print("{}/{} models computed in {} seconds ({} models/s)".
format(n_computed, gbl_params.size,
np.around(t_elapsed, decimals=1),
np.around(n_computed/t_elapsed, decimals=1)),
end="\r")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment