__init__.py 5.66 KB
Newer Older
1
2
3
4
5
6
# -*- coding: utf-8 -*-
# Copyright (C) 2015 Institute of Astronomy
# Licensed under the CeCILL-v2 licence - see Licence_CeCILL_V2-en.txt
# Author: Médéric Boquien

import argparse
7
8
9
import multiprocessing as mp
import sys

10
11
12
13
from astropy.table import Table, Column
import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
14

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from pcigale.data import Database, Filter

__version__ = "0.1-alpha"


def list_filters():
    """Print the list of filters in the pcigale database.
    """
    with Database() as base:
        filters = {name: base.get_filter(name) for name in
                   base.get_filter_names()}

    name = Column(data=[filters[f].name for f in filters], name='Name')
    description = Column(data=[filters[f].description for f in filters],
                         name='Description')
30
31
    wl = Column(data=[filters[f].pivot_wavelength for f in filters],
                name='Pivot Wavelength', unit=u.nm, format='%d')
32
33
34
35
    samples = Column(data=[filters[f].trans_table[0].size for f in filters],
                     name="Points")

    t = Table()
36
    t.add_columns([name, description, wl, samples])
37
    t.sort(['Pivot Wavelength'])
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    t.pprint(max_lines=-1, max_width=-1)


def add_filters(fnames):
    """Add filters to the pcigale database.
    """
    with Database(writable=True) as base:
        for fname in fnames:
            with open(fname, 'r') as f_fname:
                filter_name = f_fname.readline().strip('# \n\t')
                filter_type = f_fname.readline().strip('# \n\t')
                filter_description = f_fname.readline().strip('# \n\t')
            filter_table = np.genfromtxt(fname)
            # The table is transposed to have table[0] containing the
            # wavelength and table[1] containing the transmission.
            filter_table = filter_table.transpose()
54

55
56
57
            # We convert the wavelength from Å to nm.
            filter_table[0] *= 0.1

58
59
60
61
62
63
64
            # We convert to energy if needed
            if filter_type == 'photon':
                filter_table[1] *= filter_table[0]
            elif filter_type != 'energy':
                raise ValueError("Filter transmission type can only be "
                                 "'energy' or 'photon'.")

65
66
67
            print("Importing {}... ({} points)".format(filter_name,
                                                       filter_table.shape[1]))

68
            new_filter = Filter(filter_name, filter_description, filter_table)
69

70
71
72
            # We normalise the filter and compute the pivot wavelength. If the
            # filter is a pseudo-filter used to compute line fluxes, it should
            # not be normalised.
73
74
            if not (filter_name.startswith('PSEUDO') or
                    filter_name.startswith('linefilter.')):
75
76
                new_filter.normalise()
            else:
77
                new_filter.pivot_wavelength = np.mean(
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
                    filter_table[0][filter_table[1] > 0]
                )

            base.add_filter(new_filter)


def del_filters(fnames):
    """Delete filters from the pcigale database
    """
    with Database(writable=True) as base:
        names = base.get_filter_names()
        for fname in fnames:
            if fname in names:
                base.del_filter(fname)
                print("Removing filter {}".format(fname))
            else:
                print("Filter {} not in the database".format(fname))


def worker_plot(fname):
    """Worker to plot filter transmission curves in parallel

    Parameters
    ----------
    fname: string
        Name of the filter to be plotted
    """
    with Database() as base:
        _filter = base.get_filter(fname)
    plt.clf()
    plt.plot(_filter.trans_table[0], _filter.trans_table[1], color='k')
    plt.xlim(_filter.trans_table[0][0], _filter.trans_table[0][-1])
    plt.minorticks_on()
    plt.xlabel('Wavelength [nm]')
    plt.ylabel('Relative transmission')
    plt.title("{} filter".format(fname))
    plt.tight_layout()
    plt.savefig("{}.pdf".format(fname))


def plot_filters(fnames):
    """Plot the filters provided as parameters. If not filter is given, then
    plot all the filters.
    """
    if len(fnames) == 0:
        with Database() as base:
            fnames = base.get_filter_names()
    with mp.Pool(processes=mp.cpu_count()) as pool:
        pool.map(worker_plot, fnames)


def main():

131
132
133
134
135
136
    if sys.version_info[:2] >= (3, 4):
        mp.set_start_method('spawn')
    else:
        print("Could not set the multiprocessing start method to spawn. If "
              "you encounter a deadlock, please upgrade to Python≥3.4.")

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    parser = argparse.ArgumentParser()

    subparsers = parser.add_subparsers(help="List of commands")

    list_parser = subparsers.add_parser('list', help=list_filters.__doc__)
    list_parser.set_defaults(parser='list')

    add_parser = subparsers.add_parser('add', help=add_filters.__doc__)
    add_parser.add_argument('names', nargs='+', help="List of file names")
    add_parser.set_defaults(parser='add')

    del_parser = subparsers.add_parser('del', help=del_filters.__doc__)
    del_parser.add_argument('names', nargs='+', help="List of filter names")
    del_parser.set_defaults(parser='del')

    plot_parser = subparsers.add_parser('plot', help=plot_filters.__doc__)
    plot_parser.add_argument('names', nargs='*', help="List of filter names")
    plot_parser.set_defaults(parser='plot')

    if len(sys.argv) == 1:
        parser.print_usage()
    else:
        args = parser.parse_args()
        if args.parser == 'list':
            list_filters()
        elif args.parser == 'add':
            add_filters(args.names)
        elif args.parser == 'del':
            del_filters(args.names)
        elif args.parser == 'plot':
            plot_filters(args.names)