sed.py 15.1 KB
Newer Older
1 2 3 4 5 6 7
# -*- coding: utf-8 -*-
# Copyright (C) 2013 Centre de données Astrophysiques de Marseille
# Copyright (C) 2013-2014 Yannick Roehlly
# Copyright (C) 2013 Institute of Astronomy
# Copyright (C) 2014 Laboratoire d'Astrophysique de Marseille
# Licensed under the CeCILL-v2 licence - see Licence_CeCILL_V2-en.txt
# Author: Yannick Roehlly, Médéric Boquien & Denis Burgarella
8 9 10 11 12 13 14 15 16 17 18

from itertools import repeat
from collections import OrderedDict

from astropy.table import Table
import matplotlib

matplotlib.use('Agg')
import matplotlib.pyplot as plt
import multiprocessing as mp
import numpy as np
19
from os import path
20 21 22 23 24
import pkg_resources
from scipy.constants import c
from pcigale.data import Database
from pcigale.utils import read_table
import matplotlib.gridspec as gridspec
25
from pcigale.analysis_modules.utils import Counter, nothread
26

27 28 29 30
# Name of the file containing the best models information
BEST_RESULTS = "results.fits"
MOCK_RESULTS = "results_mock.fits"

31 32 33 34 35 36 37 38 39 40
AVAILABLE_SERIES = [
    'stellar_attenuated',
    'stellar_unattenuated',
    'nebular',
    'dust',
    'agn',
    'radio',
    'model'
]

41

42 43 44 45 46 47 48 49 50 51 52 53 54
def pool_initializer(counter):
    """Initializer of the pool of processes to share variables between workers.
    Parameters
    ----------
    :param counter: Counter class object for the number of models plotted
    """
    global gbl_counter
    # Limit the number of threads to 1 if we use MKL in order to limit the
    # oversubscription of the CPU/RAM.
    nothread()
    gbl_counter = counter


55
def sed(config, sed_type, nologo, xrange, yrange, series, format, outdir):
56 57
    """Plot the best SED with associated observed and modelled fluxes.
    """
58 59
    obs = read_table(path.join(path.dirname(outdir),
                               config.configuration['data_file']))
60
    mod = Table.read(path.join(outdir, BEST_RESULTS))
61 62 63 64

    with Database() as base:
        filters = OrderedDict([(name, base.get_filter(name))
                               for name in config.configuration['bands']
65
                               if not (name.endswith('_err') or name.startswith('line'))])
66

67 68 69 70 71
    if nologo is True:
        logo = False
    else:
        logo = plt.imread(pkg_resources.resource_filename(__name__,
                                                          "../resources/CIGALE.png"))
72

73
    counter = Counter(len(obs))
74 75 76 77 78 79 80
    with mp.Pool(processes=config.configuration['cores'],
                 initializer=pool_initializer, initargs=(counter,)) as pool:
        pool.starmap(_sed_worker, zip(obs, mod, repeat(filters),
                                      repeat(sed_type), repeat(logo),
                                      repeat(xrange), repeat(yrange),
                                      repeat(series), repeat(format),
                                      repeat(outdir)))
81 82 83 84
        pool.close()
        pool.join()


85 86
def _sed_worker(obs, mod, filters, sed_type, logo, xrange, yrange, series,
                format, outdir):
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
    """Plot the best SED with the associated fluxes in bands

    Parameters
    ----------
    obs: Table row
        Data from the input file regarding one object.
    mod: Table row
        Data from the best model of one object.
    filters: ordered dictionary of Filter objects
        The observed fluxes in each filter.
    sed_type: string
        Type of SED to plot. It can either be "mJy" (flux in mJy and observed
        frame) or "lum" (luminosity in W and rest frame)
    logo: numpy.array | boolean
        The readed logo image data. Has shape
        (M, N) for grayscale images.
        (M, N, 3) for RGB images.
        (M, N, 4) for RGBA images.
        Do not add the logo when set to False.
106 107
    xrange: tuple(float|boolean, float|boolean)
    yrange: tuple(float|boolean, float|boolean)
108
    series: list
109 110
    format: string
        One of png, pdf, ps, eps or svg.
111 112
    outdir: string
        The absolute path to outdir
113 114

    """
115 116
    gbl_counter.inc()

117
    id_best_model_file = path.join(outdir, f"{obs['id']}_best_model.fits")
118 119
    if path.isfile(id_best_model_file):
        sed = Table.read(id_best_model_file)
120 121

        filters_wl = np.array([filt.pivot_wavelength
122 123
                               for filt in filters.values()]) * 1e-3
        wavelength_spec = sed['wavelength'] * 1e-3
124 125 126 127 128 129 130 131
        obs_fluxes = np.array([obs[filt] for filt in filters.keys()])
        obs_fluxes_err = np.array([obs[filt+'_err']
                                   for filt in filters.keys()])
        mod_fluxes = np.array([mod["best."+filt] for filt in filters.keys()])
        if obs['redshift'] >= 0:
            z = obs['redshift']
        else:  # Redshift mode
            z = mod['best.universe.redshift']
132 133
        zp1 = 1. + z
        surf = 4. * np.pi * mod['best.universe.luminosity_distance'] ** 2
134

135 136
        xmin = 0.9 * np.min(filters_wl) if xrange[0] is False else xrange[0]
        xmax = 1.1 * np.max(filters_wl) if xrange[1] is False else xrange[1]
137

138
        if sed_type == 'lum':
139
            k_corr_SED = 1e-29 * surf * c / (filters_wl * 1e-6)
140 141 142 143 144
            obs_fluxes *= k_corr_SED
            obs_fluxes_err *= k_corr_SED
            mod_fluxes *= k_corr_SED

            for cname in sed.colnames[1:]:
145
                sed[cname] *= wavelength_spec * 1e3
146

147 148
            filters_wl /= zp1
            wavelength_spec /= zp1
149 150
            xmin /= zp1
            xmax /= zp1
151 152 153
        elif sed_type == 'mJy':
            k_corr_SED = 1.

154
            fact = 1e29 * 1e-3 * wavelength_spec**2 / c / surf
155
            for cname in sed.colnames[1:]:
156
                sed[cname] *= fact
157
        else:
158
            print("Unknown plot type")
159 160 161 162 163 164 165 166 167

        wsed = np.where((wavelength_spec > xmin) & (wavelength_spec < xmax))
        figure = plt.figure()
        gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1])
        if (sed.columns[1][wsed] > 0.).any():
            ax1 = plt.subplot(gs[0])
            ax2 = plt.subplot(gs[1])

            # Stellar emission
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
            if 'stellar_attenuated' in series:
                if 'nebular.absorption_young' in sed.columns:
                    ax1.loglog(wavelength_spec[wsed],
                               (sed['stellar.young'][wsed] +
                                sed['attenuation.stellar.young'][wsed] +
                                sed['nebular.absorption_young'][wsed] +
                                sed['stellar.old'][wsed] +
                                sed['attenuation.stellar.old'][wsed] +
                                sed['nebular.absorption_old'][wsed]),
                               label="Stellar attenuated ", color='orange',
                               marker=None, nonposy='clip', linestyle='-',
                               linewidth=0.5)
                else:
                    ax1.loglog(wavelength_spec[wsed],
                               (sed['stellar.young'][wsed] +
                                sed['attenuation.stellar.young'][wsed] +
                                sed['stellar.old'][wsed] +
                                sed['attenuation.stellar.old'][wsed]),
                               label="Stellar attenuated ", color='orange',
                               marker=None, nonposy='clip', linestyle='-',
                               linewidth=0.5)

            if 'stellar_unattenuated' in series:
191
                ax1.loglog(wavelength_spec[wsed],
192 193 194 195 196
                           (sed['stellar.old'][wsed] +
                            sed['stellar.young'][wsed]),
                           label="Stellar unattenuated", color='b', marker=None,
                           nonposy='clip', linestyle='--', linewidth=0.5)

197
            # Nebular emission
198
            if 'nebular' in series and 'nebular.lines_young' in sed.columns:
199 200 201 202 203 204 205 206 207 208 209
                ax1.loglog(wavelength_spec[wsed],
                           (sed['nebular.lines_young'][wsed] +
                            sed['nebular.lines_old'][wsed] +
                            sed['nebular.continuum_young'][wsed] +
                            sed['nebular.continuum_old'][wsed] +
                            sed['attenuation.nebular.lines_young'][wsed] +
                            sed['attenuation.nebular.lines_old'][wsed] +
                            sed['attenuation.nebular.continuum_young'][wsed] +
                            sed['attenuation.nebular.continuum_old'][wsed]),
                           label="Nebular emission", color='y', marker=None,
                           nonposy='clip', linewidth=.5)
210

211
            # Dust emission Draine & Li
212
            if 'dust' in series and 'dust.Umin_Umin' in sed.columns:
213 214 215 216 217
                ax1.loglog(wavelength_spec[wsed],
                           (sed['dust.Umin_Umin'][wsed] +
                            sed['dust.Umin_Umax'][wsed]),
                           label="Dust emission", color='r', marker=None,
                           nonposy='clip', linestyle='-', linewidth=0.5)
218

219
            # Dust emission Dale
220
            if 'dust' in series and 'dust' in sed.columns:
221 222 223
                ax1.loglog(wavelength_spec[wsed], sed['dust'][wsed],
                           label="Dust emission", color='r', marker=None,
                           nonposy='clip', linestyle='-', linewidth=0.5)
224

225
            # AGN emission Fritz
226
            if 'agn' in series and 'agn.fritz2006_therm' in sed.columns:
227 228 229 230 231 232
                ax1.loglog(wavelength_spec[wsed],
                           (sed['agn.fritz2006_therm'][wsed] +
                            sed['agn.fritz2006_scatt'][wsed] +
                            sed['agn.fritz2006_agn'][wsed]),
                           label="AGN emission", color='g', marker=None,
                           nonposy='clip', linestyle='-', linewidth=0.5)
233

234
            # Radio emission
235
            if 'radio' in series and 'radio_nonthermal' in sed.columns:
236 237 238 239 240 241
                ax1.loglog(wavelength_spec[wsed],
                           sed['radio_nonthermal'][wsed],
                           label="Radio nonthermal", color='brown',
                           marker=None, nonposy='clip', linestyle='-',
                           linewidth=0.5)

242 243 244 245
            if 'model' in series:
                ax1.loglog(wavelength_spec[wsed], sed['L_lambda_total'][wsed],
                           label="Model spectrum", color='k', nonposy='clip',
                           linestyle='-', linewidth=1.5)
246 247 248 249 250 251 252 253 254 255 256

            ax1.set_autoscale_on(False)
            s = np.argsort(filters_wl)
            filters_wl = filters_wl[s]
            mod_fluxes = mod_fluxes[s]
            obs_fluxes = obs_fluxes[s]
            obs_fluxes_err = obs_fluxes_err[s]
            ax1.scatter(filters_wl, mod_fluxes, marker='o', color='r', s=8,
                        zorder=3, label="Model fluxes")
            mask_ok = np.logical_and(obs_fluxes > 0., obs_fluxes_err > 0.)
            ax1.errorbar(filters_wl[mask_ok], obs_fluxes[mask_ok],
257
                         yerr=obs_fluxes_err[mask_ok], ls='', marker='s',
258 259 260 261 262 263 264
                         label='Observed fluxes', markerfacecolor='None',
                         markersize=6, markeredgecolor='b', capsize=0.)
            mask_uplim = np.logical_and(np.logical_and(obs_fluxes > 0.,
                                                       obs_fluxes_err < 0.),
                                        obs_fluxes_err > -9990. * k_corr_SED)
            if not mask_uplim.any() == False:
                ax1.errorbar(filters_wl[mask_uplim], obs_fluxes[mask_uplim],
265 266
                             yerr=obs_fluxes_err[mask_uplim], ls='', marker='v',
                             label='Observed upper limits',
267 268 269 270 271 272 273 274 275 276 277 278
                             markerfacecolor='None', markersize=6,
                             markeredgecolor='g', capsize=0.)
            mask_noerr = np.logical_and(obs_fluxes > 0.,
                                        obs_fluxes_err < -9990. * k_corr_SED)
            if not mask_noerr.any() == False:
                ax1.errorbar(filters_wl[mask_noerr], obs_fluxes[mask_noerr],
                             ls='', marker='s', markerfacecolor='None',
                             markersize=6, markeredgecolor='r',
                             label='Observed fluxes, no errors', capsize=0.)
            mask = np.where(obs_fluxes > 0.)
            ax2.errorbar(filters_wl[mask],
                         (obs_fluxes[mask]-mod_fluxes[mask])/obs_fluxes[mask],
279
                         yerr=obs_fluxes_err[mask]/obs_fluxes[mask],
280
                         marker='_', label="(Obs-Mod)/Obs", color='k',
281
                         capsize=0., ls='None')
282 283 284 285
            ax2.plot([xmin, xmax], [0., 0.], ls='--', color='k')
            ax2.set_xscale('log')
            ax2.minorticks_on()

286 287 288 289
            ax1.tick_params(direction='in', axis='both', which='both', top=True,
                            left=True, right=True, bottom=False)
            ax2.tick_params(direction='in', axis='both', which='both',
                            right=True)
290

291 292 293
            figure.subplots_adjust(hspace=0., wspace=0.)

            ax1.set_xlim(xmin, xmax)
294 295 296 297 298 299 300 301 302 303 304 305 306

            if yrange[0] is not False:
                ymin = yrange[0]
            else:
                ymin = min(np.min(obs_fluxes[mask_ok]),
                           np.min(mod_fluxes[mask_ok]))
                ymin *= 1e-1

            if yrange[1] is not False:
                ymax = yrange[1]
            else:
                if not mask_uplim.any() == False:
                    ymax = max(max(np.max(obs_fluxes[mask_ok]),
307 308 309
                                   np.max(obs_fluxes[mask_uplim])),
                               max(np.max(mod_fluxes[mask_ok]),
                                   np.max(mod_fluxes[mask_uplim])))
310 311
                else:
                    ymax = max(np.max(obs_fluxes[mask_ok]),
312
                               np.max(mod_fluxes[mask_ok]))
313 314 315 316 317 318
                ymax *= 1e1

            xmin = xmin if xmin < xmax else xmax - 1e1
            ymin = ymin if ymin < ymax else ymax - 1e1

            ax1.set_ylim(ymin, ymax)
319 320 321
            ax2.set_xlim(xmin, xmax)
            ax2.set_ylim(-1.0, 1.0)
            if sed_type == 'lum':
322
                ax2.set_xlabel(r"Rest-frame wavelength [$\mu$m]")
323 324
                ax1.set_ylabel("Luminosity [W]")
            else:
325 326
                ax2.set_xlabel(r"Observed $\lambda$ ($\mu$m)")
                ax1.set_ylabel(r"S$_\nu$ (mJy)")
327
            ax2.set_ylabel("Relative\nresidual")
328 329
            ax1.legend(fontsize=6, loc='best', frameon=False)
            ax2.legend(fontsize=6, loc='best', frameon=False)
330 331
            plt.setp(ax1.get_xticklabels(), visible=False)
            plt.setp(ax1.get_yticklabels()[1], visible=False)
332 333
            figure.suptitle(f"Best model for {obs['id']}\n (z={z:.3}, "
                            f"reduced χ²={mod['best.reduced_chi_square']:.2})")
334
            if logo is not False:
335 336 337 338 339 340
                # Multiplying the dpi by 2 is a hack so the figure is small
                # and not too pixelated
                figwidth = figure.get_figwidth() * figure.dpi * 2.
                figure.figimage(logo, figwidth-logo.shape[0], 0,
                                origin='upper', zorder=0, alpha=1)

341 342
            figure.savefig(path.join(outdir,
                                     f"{obs['id']}_best_model.{format}"),
343
                           dpi=figure.dpi * 2.)
344 345
            plt.close(figure)
        else:
346
            print(f"No valid best SED found for {obs['id']}. No plot created.")
347
    else:
348
        print(f"No SED found for {obs['id']}. No plot created.")