Commit 46c41966 authored by Yannick Roehlly's avatar Yannick Roehlly
Browse files

Add best model plotting/saving to pdf_analysis

There is maybe a bug triggering the display of the best model figure in
a matplotlib interactive window.
parent fb159928
......@@ -34,7 +34,7 @@ from astropy.table import Table, Column
from ...utils import read_table
from .. import AnalysisModule, complete_obs_table
from ...creation_modules import get_module as get_creation_module
from .utils import gen_compute_fluxes_at_redshift, gen_pdf
from .utils import gen_compute_fluxes_at_redshift, gen_pdf, gen_best_sed_fig
from ...warehouse import SedWarehouse
from ...data import Database
......@@ -222,9 +222,10 @@ class PdfAnalysis(AnalysisModule):
with SedWarehouse(cache_type=parameters["storage_type"]) as \
sed_warehouse:
for model_index, parameters in enumerate(creation_modules_params):
for model_index, model_params in enumerate(
creation_modules_params):
sed = sed_warehouse.get_sed(creation_modules, parameters)
sed = sed_warehouse.get_sed(creation_modules, model_params)
# Cached function to compute the SED fluxes at a redshift
gen_fluxes = gen_compute_fluxes_at_redshift(
......@@ -339,6 +340,8 @@ class PdfAnalysis(AnalysisModule):
# Variable analysis #
##################################################################
print("Analysing the variables...")
# We compute the weighted average and standard deviation using the
# likelihood as weight. We first build the weight array by expanding
# the likelihood along a new axis corresponding to the analysed
......@@ -381,6 +384,8 @@ class PdfAnalysis(AnalysisModule):
# Best models #
##################################################################
print("Analysing the best models...")
# We define the best fitting model for each observation as the one
# with the least χ².
best_model_index = list(chi_squares.argmin(axis=0))
......@@ -418,6 +423,72 @@ class PdfAnalysis(AnalysisModule):
best_model_table.write(OUT_DIR + BEST_MODEL_FILE)
if plot_best_sed or save_best_sed:
print("Plotting/saving the best models...")
with SedWarehouse(cache_type=parameters["storage_type"]) as \
sed_warehouse:
for obs_index, obs_name in enumerate(obs_table["id"]):
obs_redshift = obs_table["redshift"][obs_index]
best_index = best_model_index[obs_index]
sed = sed_warehouse.get_sed(
creation_modules,
creation_modules_params[best_index]
)
if use_observation_redshift:
redshifting_module.parameters["redshift"] = \
obs_redshift
redshifting_module.process(sed)
igm_module.process(sed)
best_lambda = sed.wavelength_grid
best_fnu = sed.fnu * normalisation_factors[best_index,
obs_index]
if save_best_sed:
table = Table((
Column(best_lambda,
name="Wavelength",
unit="nm"),
Column(best_fnu,
name="Fnu density",
unit="mJy")
))
table.write(OUT_DIR + "{}_best_model.fits".format(
obs_name))
if plot_best_sed:
plot_mask = (
(best_lambda >= PLOT_L_MIN * (1 + obs_redshift)) &
(best_lambda <= PLOT_L_MAX * (1 + obs_redshift))
)
figure = gen_best_sed_fig(
best_lambda[plot_mask],
best_fnu[plot_mask],
[f.effective_wavelength for f in filters.values()],
norm_model_fluxes[best_index, obs_index, :],
[obs_table[f][obs_index] for f in filters]
)
if figure is None:
print("Can not plot best model for observation "
"{}!".format(obs_name))
else:
figure.suptitle(
u"Best model for {} - red-chi² = {}".format(
obs_name,
reduced_chi_squares[best_index, obs_index]
)
)
figure.savefig(OUT_DIR + "{}_best_model.pdf".format(
obs_name))
plt.close(figure)
##################################################################
# Probability Density Functions #
##################################################################
......
......@@ -7,6 +7,7 @@
import numpy as np
from scipy.stats import gaussian_kde
from scipy.linalg import LinAlgError
from matplotlib import pyplot as plt
from copy import deepcopy
from ...sed.cosmology import cosmology
from ...creation_modules import get_module as get_creation_module
......@@ -152,3 +153,43 @@ def gen_pdf(values, probabilities, grid):
return result
def gen_best_sed_fig(wave, fnu, filters_wave, filters_model, filters_obs):
"""Generate a figure for plotting the best models
Parameters
----------
wave : array-like of floats
The wavelength grid of the model spectrum.
fnu : array-like of floats
The Fnu spectrum of the model at each wavelength.
filters_wave : array-like of floats
The effective wavelengths of the various filters.
filters_model : array-like of floats
The model fluxes in each filter.
filters_obs : array-like of floats
The observed fluxes in each filter.
Returns
-------
A matplotlib.plt.figure.
"""
try:
figure = plt.figure()
ax = figure.add_subplot(111)
ax.loglog(wave, fnu, "-b", label="Model spectrum")
ax.loglog(filters_wave, filters_model, "ob", label="Model fluxes")
ax.loglog(filters_wave, filters_obs, "or", label="Observation fluxes")
ax.set_xlabel("Wavelength [nm]")
ax.set_ylabel("Flux [mJy]")
ax.legend(loc=0)
return figure
except ValueError:
# If the SED can't be plot in x and y logarithm scaled, that means
# that we have either negative wavelength or flux and that something
# has gone wrong.
return None
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment