Commit 3ea2300b authored by Yannick Roehlly's avatar Yannick Roehlly

Implement memory caching in the SED warehouse

parent bd120196
# -*- coding: utf-8 -*-
"""
Copyright (C) 2012 Centre de données Astrophysiques de Marseille
Copyright (C) 2012, 2013 Centre de données Astrophysiques de Marseille
Licensed under the CeCILL-v2 licence - see Licence_CeCILL_V2-en.txt
@author: Yannick Roehlly <yannick.roehlly@oamp.fr>
"""
from copy import deepcopy
from json import JSONEncoder
from . import SED
from .modules import common as sed_modules
def create_sed(module_list, parameter_list):
"""Create a new SED using the given modules and parameters
class SedWarehouse(object):
"""Create, cache and store SED
Parameters
----------
module_list : list
List of module names in the order they have to be used to create
the SED.
parameter_list : list of dictionaries
List of the parameter dictionaries corresponding to each module of
the module_list list.
This object is responsible for creating SED and storing them in a memory
cache or a database.
"""
Returns
-------
sed : pcigale.sed
The SED made from the given modules with the given parameters.
def __init__(self, cache_type="memory"):
"""Instantiate a SED warehouse
"""
Parameters
----------
cache_type : string
Type of cache used. For now, only in memory caching.
"""
if cache_type == "memory":
dict_cache = {}
def is_cached(key):
"""Return true is the key is in the cache.
Parameters
----------
key : any immutable
Returns
-------
boolean
"""
return (key in dict_cache)
def get_from_cache(key):
"""Return the value corresponding to the key in the cache.
If the key is not in the cache, returns None.
Parameters
----------
key : any immutable
Returns
-------
object
"""
# We return a copy not to modify the stored object.
return deepcopy(dict_cache.get(key))
def add_to_cache(key, value):
"""Add a new key, value pair to the cache.
Parameters
----------
key : any immutable
value : object
"""
# We store a copy not to modify the stored object.
dict_cache[key] = deepcopy(value)
self.is_cached = is_cached
self.get_from_cache = get_from_cache
self.add_to_cache = add_to_cache
def get_sed(self, module_list, parameter_list):
"""Get the SED corresponding to the module and parameter lists
If the SED was cached, get it from the cache. If it is not, create it
and add it the the cache. The method is recursive to permit caching
partial SED.
Parameters
----------
module_list : iterable
List of module names in the order they have to be used to
create the SED.
parameter_list : iterable
List of the parameter dictionaries corresponding to each
module of the module_list list.
Returns
-------
sed : pcigale.sed
The SED made from the given modules with the given parameters.
"""
module_list = list(module_list)
parameter_list = list(parameter_list)
# JSon representation of the tuple (module_list, parameter_list)
# used as a key for storing the SED in the cache.
encoder = JSONEncoder()
sed_key = encoder.encode((module_list, parameter_list))
sed = self.get_from_cache(sed_key)
if not sed:
mod = sed_modules.get_module(module_list.pop())
mod.parameters = parameter_list.pop()
# We start from an empty SED.
sed = SED()
if (len(module_list) == 0):
sed = SED()
else:
sed = self.get_sed(module_list, parameter_list)
for (module, parameters) in zip(module_list, parameter_list):
mod = sed_modules.get_module(module)
mod.parameters = parameters
mod.process(sed)
mod.process(sed)
self.add_to_cache(sed_key, sed)
return sed
return sed
......@@ -26,7 +26,7 @@ from scipy import stats
from progressbar import ProgressBar
from matplotlib import pyplot as plt
from . import common
from ..sed.warehouse import create_sed
from ..sed.warehouse import SedWarehouse
from ..sed.modules.common import get_module
from ..data import Database
......@@ -128,6 +128,9 @@ class Module(common.AnalysisModule):
"it yet exists.".format(OUT_DIR))
sys.exit()
# Open the warehouse
sed_warehouse = SedWarehouse()
# Get the parameters
analysed_variables = parameters["analysed_variables"]
save_best_sed = parameters["save_best_sed"]
......@@ -191,7 +194,7 @@ class Module(common.AnalysisModule):
# We loop over all the possible theoretical SEDs
progress_bar = ProgressBar(maxval=len(sed_modules_params)).start()
for model_index, parameters in enumerate(sed_modules_params):
sed = create_sed(sed_modules, parameters)
sed = sed_warehouse.get_sed(sed_modules, parameters)
# Compute the reduced Chi-square, the galaxy mass (normalisation
# factor) and probability for each observed SEDs. Add these and
......@@ -254,7 +257,7 @@ class Module(common.AnalysisModule):
best_chi2 = comp_table[best_index, obs_index, 0]
best_norm_factor = comp_table[best_index, obs_index, 2]
best_params = sed_modules_params[best_index]
best_sed = create_sed(sed_modules, best_params)
best_sed = sed_warehouse.get_sed(sed_modules, best_params)
# Save best SED
# TODO: For now, we only save the lambda vs fnu table. Once
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment