Commit 30cfdcf7 authored by Peter Lustig's avatar Peter Lustig

Added MPI support.

parent 3dd130e8
......@@ -5,7 +5,7 @@ import os
import numpy as np
import matplotlib.pyplot as plt
from multiprocessing import Pool, cpu_count
from multiprocess import Pool, cpu_count
from functools import partial
from astropy import units as u
......@@ -18,10 +18,14 @@ from scipy.optimize import curve_fit
from nikamap import NikaMap, Jackknife
from nikamap.utils import pos_uniform
from mpi4py import MPI
from IPython import get_ipython
ipython = get_ipython()
import myfunctions as mfct
from time import clock
if '__IPYTHON__' in globals():
ipython.magic('load_ext autoreload')
ipython.magic('autoreload 2')
......@@ -31,6 +35,10 @@ if '__IPYTHON__' in globals():
#%autoreload 2
#%matplotlib tk
comm = MPI.COMM_WORLD
rank, size = comm.Get_rank(), comm.Get_size()
plt.ion()
def Wait():
......@@ -285,7 +293,7 @@ def completness_purity(flux, wcs=None, shape=None, nsources=8**2, within=(1 / 4,
return completness, purity
def completness_purity_2(flux, wcs=None, shape=None, nsources=8**2, within=(1 / 4, 3 / 4), nsim=5, jk_map=None):
def completness_purity_2(flux, nsim=2, wcs=None, shape=None, nsources=8**2, within=(1 / 4, 3 / 4), jk_map=None):
"""Compute completness map for a given flux"""
print(flux)
......@@ -332,6 +340,20 @@ def completness_purity_2(flux, wcs=None, shape=None, nsources=8**2, within=(1 /
def Plot_CompPur(completness, purity, threshold,nsim=None):
threshold_bins = completness.shape[-1]
fig, axes = plt.subplots(nrows=2, ncols=threshold_bins, sharex=True, sharey=True)
for i in range(threshold_bins):
axes[0, i].imshow(completness[:, :, i], vmin=0, vmax=1)
axes[1, i].imshow(purity[:, :, i], vmin=0, vmax=1)
axes[1, i].set_xlabel("thresh={:.2f}".format(threshold[i]))
if nsim is not None:
axes[0, 0].set_title("{} simulations".format(nsim))
axes[0, 0].set_ylabel("completness")
axes[1, 0].set_ylabel("purity")
plt.show(block=True)
plt.close('all')
filenames = list(Path(DATA_DIR).glob('../*/map.fits'))
......@@ -364,7 +386,7 @@ shape_4D, wcs_4D = completness_purity_wcs(data.shape, data.wcs, bins=20,
threshold_range=threshold_range, threshold_bins=threshold_bins)
print(wcs_4D)
#print(wcs_4D)
# Testing the lower edges
wcs_threshold = wcs_4D.sub([3])
assert np.all(np.abs(wcs_threshold.all_pix2world(np.arange(threshold_bins+1)-0.5, 0) - threshold_edges) < 1e-15)
......@@ -380,13 +402,66 @@ DETECTED_SOURCES = []
# This is a single run check for a single flux
nsim = 2
pool = Pool(cpu_count())
nsim = 5
#ncores = 3
allflux = [10*u.mJy]
sim_per_core = mfct.Distribute_Njobs(size, nsim)
print(sim_per_core[rank])
#completness_purity_2 = partial(completness_purity_2, **{'nsources':8**2, 'within':(0, 1),
# 'wcs':wcs_4D.sub([1, 2, 3]),
# 'shape':shape_4D[0:3],
# 'jk_map':data,
# 'nsim':2})
comm.Barrier()
if rank == 0:
T0 = clock()
for irun, flux in enumerate(allflux):
comm.Barrier()
if rank ==0:
t0 = clock()
comp, comp_n, pur, pur_n = completness_purity_2(10*u.mJy, nsim=sim_per_core[rank],
nsources=8**2, within=(0, 1),
shape=shape_4D[0:3],
wcs=wcs_4D.sub([1, 2, 3]), jk_map=data)
comm.Barrier()
all_comp = comm.gather(comp, root=0)
all_comp_n = comm.gather(comp_n, root=0)
all_pur = comm.gather(pur, root=0)
all_pur_n = comm.gather(pur_n, root=0)
if rank == 0:
all_comp = np.sum(all_comp, axis=0)
all_comp_n = np.sum(all_comp_n, axis=0)
all_pur = np.sum(all_pur, axis=0)
all_pur_n = np.sum(all_pur_n, axis=0)
print(all_comp_n.shape)
print('------------')
with np.errstate(divide='ignore', invalid='ignore'):
completness = all_comp / all_comp_n[..., np.newaxis]
purity = all_pur / all_pur_n
Plot_CompPur(completness, purity, threshold, nsim=nsim)
# Save result !!
p_completness_purity_2 = partial(completness_purity_2, **{'nsources'=8**2, 'within'=(0, 1),
'wcs'=wcs_4D.sub([1, 2, 3]),
'shape'=shape_4D[0:3], 'nsim'=nsim,
'jk_map'=data})
#completness, purity = completness_purity_2(flux=10*u.mJy)
......@@ -403,16 +478,6 @@ p_completness_purity_2 = partial(completness_purity_2, **{'nsources'=8**2, 'with
#shape=shape_4D[0:3], nsim=np.multiply(*shape_4D[0:2]) * 100,
fig, axes = plt.subplots(nrows=2, ncols=threshold_bins, sharex=True, sharey=True)
for i in range(threshold_bins):
axes[0, i].imshow(completness[:, :, i], vmin=0, vmax=1)
axes[1, i].imshow(purity[:, :, i], vmin=0, vmax=1)
axes[1, i].set_xlabel("thresh={:.2f}".format(threshold[i]))
axes[0, 0].set_title("{} simulations".format(nsim))
axes[0, 0].set_ylabel("completness")
axes[1, 0].set_ylabel("purity")
plt.show()
Wait()
'''
# To merge all the fake_sources and sources catalogs
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment