sed.py 15.2 KB
Newer Older
1
2
3
4
5
6
7
# -*- coding: utf-8 -*-
# Copyright (C) 2013 Centre de données Astrophysiques de Marseille
# Copyright (C) 2013-2014 Yannick Roehlly
# Copyright (C) 2013 Institute of Astronomy
# Copyright (C) 2014 Laboratoire d'Astrophysique de Marseille
# Licensed under the CeCILL-v2 licence - see Licence_CeCILL_V2-en.txt
# Author: Yannick Roehlly, Médéric Boquien & Denis Burgarella
8
9
10
11
12
13
14
15
16
17
18

from itertools import repeat
from collections import OrderedDict

from astropy.table import Table
import matplotlib

matplotlib.use('Agg')
import matplotlib.pyplot as plt
import multiprocessing as mp
import numpy as np
19
from os import path
20
21
22
import pkg_resources
from scipy.constants import c
from pcigale.data import Database
23
from utils.io import read_table
24
import matplotlib.gridspec as gridspec
25
from utils.counter import Counter
26

27
28
29
30
# Name of the file containing the best models information
BEST_RESULTS = "results.fits"
MOCK_RESULTS = "results_mock.fits"

31
32
33
34
35
36
37
38
39
40
AVAILABLE_SERIES = [
    'stellar_attenuated',
    'stellar_unattenuated',
    'nebular',
    'dust',
    'agn',
    'radio',
    'model'
]

41

42
43
44
45
46
47
48
def pool_initializer(counter):
    """Initializer of the pool of processes to share variables between workers.
    Parameters
    ----------
    :param counter: Counter class object for the number of models plotted
    """
    global gbl_counter
49

50
51
52
    gbl_counter = counter


53
def sed(config, sed_type, nologo, xrange, yrange, series, format, outdir):
54
55
    """Plot the best SED with associated observed and modelled fluxes.
    """
56
57
    obs = read_table(path.join(path.dirname(outdir),
                               config.configuration['data_file']))
58
    mod = Table.read(path.join(outdir, BEST_RESULTS))
59
60
61
62

    with Database() as base:
        filters = OrderedDict([(name, base.get_filter(name))
                               for name in config.configuration['bands']
63
                               if not (name.endswith('_err') or name.startswith('line'))])
64

65
66
67
68
69
    if nologo is True:
        logo = False
    else:
        logo = plt.imread(pkg_resources.resource_filename(__name__,
                                                          "../resources/CIGALE.png"))
70

71
    counter = Counter(len(obs))
72
73
74
75
76
77
78
    with mp.Pool(processes=config.configuration['cores'],
                 initializer=pool_initializer, initargs=(counter,)) as pool:
        pool.starmap(_sed_worker, zip(obs, mod, repeat(filters),
                                      repeat(sed_type), repeat(logo),
                                      repeat(xrange), repeat(yrange),
                                      repeat(series), repeat(format),
                                      repeat(outdir)))
79
80
81
82
        pool.close()
        pool.join()


83
84
def _sed_worker(obs, mod, filters, sed_type, logo, xrange, yrange, series,
                format, outdir):
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    """Plot the best SED with the associated fluxes in bands

    Parameters
    ----------
    obs: Table row
        Data from the input file regarding one object.
    mod: Table row
        Data from the best model of one object.
    filters: ordered dictionary of Filter objects
        The observed fluxes in each filter.
    sed_type: string
        Type of SED to plot. It can either be "mJy" (flux in mJy and observed
        frame) or "lum" (luminosity in W and rest frame)
    logo: numpy.array | boolean
        The readed logo image data. Has shape
        (M, N) for grayscale images.
        (M, N, 3) for RGB images.
        (M, N, 4) for RGBA images.
        Do not add the logo when set to False.
104
105
    xrange: tuple(float|boolean, float|boolean)
    yrange: tuple(float|boolean, float|boolean)
106
    series: list
107
108
    format: string
        One of png, pdf, ps, eps or svg.
109
110
    outdir: string
        The absolute path to outdir
111
112

    """
113
114
    gbl_counter.inc()

115
    id_best_model_file = path.join(outdir, f"{obs['id']}_best_model.fits")
116
117
    if path.isfile(id_best_model_file):
        sed = Table.read(id_best_model_file)
118
119

        filters_wl = np.array([filt.pivot_wavelength
120
121
                               for filt in filters.values()]) * 1e-3
        wavelength_spec = sed['wavelength'] * 1e-3
122
123
124
125
126
127
128
129
        obs_fluxes = np.array([obs[filt] for filt in filters.keys()])
        obs_fluxes_err = np.array([obs[filt+'_err']
                                   for filt in filters.keys()])
        mod_fluxes = np.array([mod["best."+filt] for filt in filters.keys()])
        if obs['redshift'] >= 0:
            z = obs['redshift']
        else:  # Redshift mode
            z = mod['best.universe.redshift']
130
131
        zp1 = 1. + z
        surf = 4. * np.pi * mod['best.universe.luminosity_distance'] ** 2
132

133
134
        xmin = 0.9 * np.min(filters_wl) if xrange[0] is False else xrange[0]
        xmax = 1.1 * np.max(filters_wl) if xrange[1] is False else xrange[1]
135

136
        if sed_type == 'lum':
137
            k_corr_SED = 1e-29 * surf * c / (filters_wl * 1e-6)
138
139
140
141
142
            obs_fluxes *= k_corr_SED
            obs_fluxes_err *= k_corr_SED
            mod_fluxes *= k_corr_SED

            for cname in sed.colnames[1:]:
143
                sed[cname] *= wavelength_spec * 1e3
144

145
146
            filters_wl /= zp1
            wavelength_spec /= zp1
147
148
            xmin /= zp1
            xmax /= zp1
149
150
151
        elif sed_type == 'mJy':
            k_corr_SED = 1.

152
            fact = 1e29 * 1e-3 * wavelength_spec**2 / c / surf
153
            for cname in sed.colnames[1:]:
154
                sed[cname] *= fact
155
        else:
156
            print("Unknown plot type")
157
158
159
160
161
162
163
164
165

        wsed = np.where((wavelength_spec > xmin) & (wavelength_spec < xmax))
        figure = plt.figure()
        gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1])
        if (sed.columns[1][wsed] > 0.).any():
            ax1 = plt.subplot(gs[0])
            ax2 = plt.subplot(gs[1])

            # Stellar emission
166
167
168
169
170
171
172
173
174
            if 'stellar_attenuated' in series:
                if 'nebular.absorption_young' in sed.columns:
                    ax1.loglog(wavelength_spec[wsed],
                               (sed['stellar.young'][wsed] +
                                sed['attenuation.stellar.young'][wsed] +
                                sed['nebular.absorption_young'][wsed] +
                                sed['stellar.old'][wsed] +
                                sed['attenuation.stellar.old'][wsed] +
                                sed['nebular.absorption_old'][wsed]),
175
                               label="Stellar attenuated", color='gold',
176
                               marker=None, nonposy='clip', linestyle='-',
177
                               linewidth=1.0)
178
179
180
181
182
183
                else:
                    ax1.loglog(wavelength_spec[wsed],
                               (sed['stellar.young'][wsed] +
                                sed['attenuation.stellar.young'][wsed] +
                                sed['stellar.old'][wsed] +
                                sed['attenuation.stellar.old'][wsed]),
184
                               label="Stellar attenuated", color='gold',
185
                               marker=None, nonposy='clip', linestyle='-',
186
                               linewidth=1.0)
187
188

            if 'stellar_unattenuated' in series:
189
                ax1.loglog(wavelength_spec[wsed],
190
191
                           (sed['stellar.old'][wsed] +
                            sed['stellar.young'][wsed]),
192
193
                           label="Stellar unattenuated",
                           color='xkcd:deep sky blue', marker=None,
194
                           nonposy='clip', linestyle='--', linewidth=1.0)
195

196
            # Nebular emission
197
            if 'nebular' in series and 'nebular.lines_young' in sed.columns:
198
199
200
201
202
203
204
205
206
                ax1.loglog(wavelength_spec[wsed],
                           (sed['nebular.lines_young'][wsed] +
                            sed['nebular.lines_old'][wsed] +
                            sed['nebular.continuum_young'][wsed] +
                            sed['nebular.continuum_old'][wsed] +
                            sed['attenuation.nebular.lines_young'][wsed] +
                            sed['attenuation.nebular.lines_old'][wsed] +
                            sed['attenuation.nebular.continuum_young'][wsed] +
                            sed['attenuation.nebular.continuum_old'][wsed]),
207
208
                           label="Nebular emission", color='xkcd:true green',
                           marker=None, nonposy='clip', linewidth=1.0)
209

210
            # Dust emission Draine & Li
211
            if 'dust' in series and 'dust.Umin_Umin' in sed.columns:
212
213
214
                ax1.loglog(wavelength_spec[wsed],
                           (sed['dust.Umin_Umin'][wsed] +
                            sed['dust.Umin_Umax'][wsed]),
215
216
217
                           label="Dust emission", color='xkcd:bright red',
                           marker=None, nonposy='clip', linestyle='-',
                           linewidth=1.0)
218

219
            # Dust emission Dale
220
            if 'dust' in series and 'dust' in sed.columns:
221
                ax1.loglog(wavelength_spec[wsed], sed['dust'][wsed],
222
223
224
                           label="Dust emission", color='xkcd:bright red',
                           marker=None, nonposy='clip', linestyle='-',
                           linewidth=1.0)
225

226
            # AGN emission Fritz
227
            if 'agn' in series and 'agn.fritz2006_therm' in sed.columns:
228
229
230
231
                ax1.loglog(wavelength_spec[wsed],
                           (sed['agn.fritz2006_therm'][wsed] +
                            sed['agn.fritz2006_scatt'][wsed] +
                            sed['agn.fritz2006_agn'][wsed]),
232
233
234
                           label="AGN emission", color='xkcd:apricot',
                           marker=None, nonposy='clip', linestyle='-',
                           linewidth=1.0)
235

236
            # Radio emission
237
            if 'radio' in series and 'radio_nonthermal' in sed.columns:
238
239
240
241
                ax1.loglog(wavelength_spec[wsed],
                           sed['radio_nonthermal'][wsed],
                           label="Radio nonthermal", color='brown',
                           marker=None, nonposy='clip', linestyle='-',
242
                           linewidth=1.0)
243

244
245
246
247
            if 'model' in series:
                ax1.loglog(wavelength_spec[wsed], sed['L_lambda_total'][wsed],
                           label="Model spectrum", color='k', nonposy='clip',
                           linestyle='-', linewidth=1.5)
248
249
250
251
252
253
254

            ax1.set_autoscale_on(False)
            s = np.argsort(filters_wl)
            filters_wl = filters_wl[s]
            mod_fluxes = mod_fluxes[s]
            obs_fluxes = obs_fluxes[s]
            obs_fluxes_err = obs_fluxes_err[s]
255
256
257
            ax1.scatter(filters_wl, mod_fluxes, marker='o',
                        color='xkcd:strawberry', s=8, zorder=3,
                        label="Model fluxes")
258
259
            mask_ok = np.logical_and(obs_fluxes > 0., obs_fluxes_err > 0.)
            ax1.errorbar(filters_wl[mask_ok], obs_fluxes[mask_ok],
260
                         yerr=obs_fluxes_err[mask_ok], ls='', marker='o',
261
                         label='Observed fluxes', markerfacecolor='None',
262
                         markersize=5, markeredgecolor='xkcd:pastel purple',
263
                         color='xkcd:light indigo', capsize=0., zorder=3, lw=1)
264
265
266
267
268
            mask_uplim = np.logical_and(np.logical_and(obs_fluxes > 0.,
                                                       obs_fluxes_err < 0.),
                                        obs_fluxes_err > -9990. * k_corr_SED)
            if not mask_uplim.any() == False:
                ax1.errorbar(filters_wl[mask_uplim], obs_fluxes[mask_uplim],
269
270
                             yerr=obs_fluxes_err[mask_uplim], ls='', marker='v',
                             label='Observed upper limits',
271
272
273
274
275
276
                             markerfacecolor='None', markersize=6,
                             markeredgecolor='g', capsize=0.)
            mask_noerr = np.logical_and(obs_fluxes > 0.,
                                        obs_fluxes_err < -9990. * k_corr_SED)
            if not mask_noerr.any() == False:
                ax1.errorbar(filters_wl[mask_noerr], obs_fluxes[mask_noerr],
277
278
                             ls='', marker='p', markerfacecolor='None',
                             markersize=5, markeredgecolor='r',
279
280
281
282
                             label='Observed fluxes, no errors', capsize=0.)
            mask = np.where(obs_fluxes > 0.)
            ax2.errorbar(filters_wl[mask],
                         (obs_fluxes[mask]-mod_fluxes[mask])/obs_fluxes[mask],
283
                         yerr=obs_fluxes_err[mask]/obs_fluxes[mask],
284
                         marker='_', label="(Obs-Mod)/Obs", color='k',
285
                         capsize=0., ls='None', lw=1)
286
287
288
289
            ax2.plot([xmin, xmax], [0., 0.], ls='--', color='k')
            ax2.set_xscale('log')
            ax2.minorticks_on()

290
291
292
293
            ax1.tick_params(direction='in', axis='both', which='both', top=True,
                            left=True, right=True, bottom=False)
            ax2.tick_params(direction='in', axis='both', which='both',
                            right=True)
294

295
296
297
            figure.subplots_adjust(hspace=0., wspace=0.)

            ax1.set_xlim(xmin, xmax)
298
299
300
301
302
303
304
305
306
307
308
309
310

            if yrange[0] is not False:
                ymin = yrange[0]
            else:
                ymin = min(np.min(obs_fluxes[mask_ok]),
                           np.min(mod_fluxes[mask_ok]))
                ymin *= 1e-1

            if yrange[1] is not False:
                ymax = yrange[1]
            else:
                if not mask_uplim.any() == False:
                    ymax = max(max(np.max(obs_fluxes[mask_ok]),
311
312
313
                                   np.max(obs_fluxes[mask_uplim])),
                               max(np.max(mod_fluxes[mask_ok]),
                                   np.max(mod_fluxes[mask_uplim])))
314
315
                else:
                    ymax = max(np.max(obs_fluxes[mask_ok]),
316
                               np.max(mod_fluxes[mask_ok]))
317
318
319
320
321
322
                ymax *= 1e1

            xmin = xmin if xmin < xmax else xmax - 1e1
            ymin = ymin if ymin < ymax else ymax - 1e1

            ax1.set_ylim(ymin, ymax)
323
324
325
            ax2.set_xlim(xmin, xmax)
            ax2.set_ylim(-1.0, 1.0)
            if sed_type == 'lum':
326
                ax2.set_xlabel(r"Rest-frame wavelength [$\mu$m]")
327
328
                ax1.set_ylabel("Luminosity [W]")
            else:
329
330
                ax2.set_xlabel(r"Observed $\lambda$ ($\mu$m)")
                ax1.set_ylabel(r"S$_\nu$ (mJy)")
331
            ax2.set_ylabel("Relative\nresidual")
332
333
            ax1.legend(fontsize=6, loc='best', frameon=False)
            ax2.legend(fontsize=6, loc='best', frameon=False)
334
335
            plt.setp(ax1.get_xticklabels(), visible=False)
            plt.setp(ax1.get_yticklabels()[1], visible=False)
336
337
            figure.suptitle(f"Best model for {obs['id']}\n (z={z:.3}, "
                            f"reduced χ²={mod['best.reduced_chi_square']:.2})")
338
            if logo is not False:
339
340
341
342
343
344
                # Multiplying the dpi by 2 is a hack so the figure is small
                # and not too pixelated
                figwidth = figure.get_figwidth() * figure.dpi * 2.
                figure.figimage(logo, figwidth-logo.shape[0], 0,
                                origin='upper', zorder=0, alpha=1)

345
346
            figure.savefig(path.join(outdir,
                                     f"{obs['id']}_best_model.{format}"),
347
                           dpi=figure.dpi * 2.)
348
349
            plt.close(figure)
        else:
350
            print(f"No valid best SED found for {obs['id']}. No plot created.")
351
    else:
352
        print(f"No SED found for {obs['id']}. No plot created.")