#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon May 27 17:31:18 2019

@author: rfetick
"""

import numpy as np
from scipy.optimize import least_squares
from astropy.io import fits
import time
from numpy.fft import fft2, fftshift, ifft2
from functools import lru_cache
from paompy.config import _EPSILON
from paompy.utils import binning

#%% FITTING FUNCTION
def lsq_flux_bck(model, data, weights, background=True, positive_bck=False):
    """Compute the analytical least-square solution for flux and background
    LS = SUM_pixels { weights*(flux*model + bck - data)² }
    
    Parameters
    ----------
    model: numpy.ndarray
    data: numpy.ndarray
    weights: numpy.ndarray
    
    Keywords
    --------
    background: bool
        Activate/inactivate background (activated by default:True)
    positive_bck : bool
        Makes background positive (default:False)
    """
    ws = np.sum(weights)
    mws = np.sum(model * weights)
    mwds = np.sum(model * weights * data)
    m2ws = np.sum(weights * (model ** 2))
    wds = np.sum(weights * data)

    if background:
        delta = mws ** 2 - ws * m2ws
        amp = 1. / delta * (mws * wds - ws * mwds)
        bck = 1. / delta * (-m2ws * wds + mws * mwds)
    else:
        amp = mwds / m2ws
        bck = 0.0
        
    if bck<0 and positive_bck: #re-implement above equation
        amp = mwds / m2ws
        bck = 0.0

    return amp, bck

def psffit(psf,Model,x0,weights=None,dxdy=(0.,0.),flux_bck=(True,True),
           positive_bck=False,fixed=None,**kwargs):
    """Fit a PSF with a parametric model solving the least-square problem
       epsilon(x) = SUM_pixel { weights * (amp * Model(x) + bck - psf)² }
    
    Parameters
    ----------
    psf : numpy.ndarray
        The experimental image to be fitted
    Model : class
        The class representing the fitting model
    x0 : tuple, list, numpy.ndarray
        Initial guess for parameters
    weights : numpy.ndarray
        Least-square weighting matrix (same size as `psf`)
        Default: uniform weighting
    dxdy : tuple of two floats
        Eventual guess on PSF shifting
    flux_bck : tuple of two bool
        Only background can be activate/inactivated
        Flux is always activated (sorry!)
    positive_bck : bool
        Force background to be positive or null
    fixed : numpy.ndarray
        Fix some parameters to their initial value (default: None)
    **kwargs :
        All keywords used to instantiate your `Model`
    
    Returns
    -------
    out.x : numpy.array
            Parameters at optimum
       .dxdy : tuple of 2 floats
           PSF shift at optimum
       .flux_bck : tuple of two floats
           Estimated image flux and background
       .psf : numpy.ndarray (dim=2)
           Image of the PSF model at optimum
       .success : bool
           Minimization success
       .status : int
           Minimization status (see scipy doc)
       .message : string
           Human readable minimization status
       .active_mask : numpy.array
           Saturated bounds
       .nfev : int
           Number of function evaluations
       .cost : float
           Value of cost function at optimum
    """
    Model_inst = Model(np.shape(psf),**kwargs)
    if weights is None:
        weights = np.ones_like(psf)
    elif len(psf)!=len(weights):
        raise ValueError("Keyword `weights` must have same number of elements as `psf`")
    sqW = np.sqrt(weights)
    
    class CostClass(object):
        def __init__(self):
            self.iter = 0
        def __call__(self,y):
            if (self.iter%3) == 0:
                print("-",end="")
            self.iter += 1
            
            x, dxdy = mini2input(y)
            m = Model_inst(x,dx=dxdy[0],dy=dxdy[1])
            amp, bck = lsq_flux_bck(m, psf, weights, background=flux_bck[1], positive_bck=positive_bck)
            return np.reshape(sqW * (amp * m + bck - psf), np.size(psf))
    
    cost = CostClass()
    
    if fixed is not None:
        if len(fixed)!=len(x0):
            raise ValueError("When defined, `fixed` must be same size as `x0`")
        FREE = [not fixed[i] for i in range(len(fixed))]
        INDFREE = np.where(FREE)[0]
    
    def input2mini(x,dxdy):
        # Transform user parameters to minimizer parameters
        if fixed is None:
            xfree = x
        else:
            xfree = np.take(x,INDFREE)
        return np.concatenate((xfree,dxdy))
    
    def mini2input(y):
        # Transform minimizer parameters to user parameters
        if fixed is None:
            xall = y[0:-2]
        else:
            xall = np.copy(x0)
            for i in range(len(y)-2):
                xall[INDFREE[i]] = y[i]
        return (xall,y[-2:])
    
    def get_bound(inst):
        b_low = inst.bounds[0]
        if fixed is not None:
            b_low = np.take(b_low,INDFREE)
        b_low = np.concatenate((b_low,[-np.inf,-np.inf]))
        b_up = inst.bounds[1]
        if fixed is not None:
            b_up = np.take(b_up,INDFREE)
        b_up = np.concatenate((b_up,[np.inf,np.inf]))
        return (b_low,b_up)
    
    result = least_squares(cost, input2mini(x0,dxdy), bounds=get_bound(Model_inst))
    
    print("") #finish line of "-"
    
    result.x, result.dxdy = mini2input(result.x) #split output between x and dxdy

    m = Model_inst(result.x,dx=result.dxdy[0],dy=result.dxdy[1])
    amp, bck = lsq_flux_bck(m, psf, weights, background=flux_bck[1], positive_bck=positive_bck)
    
    result.flux_bck = (amp,bck)
    result.psf = m    
    return result

#%% CLASS PARAMETRIC PSF AND ITS SUBCLASSES
class ParametricPSF(object):
    """Super-class defining parametric PSFs
    Not to be instantiated, only serves as a referent for subclasses
    """
    
    def __init__(self,Npix):
        """
        Parameters
        ----------
        Npix : tuple of two elements
            Model X and Y pixel size when called
        """
        if type(Npix)!=tuple:
            raise TypeError("Argument `Npix` must be a tuple")
        self.Npix = Npix
        self.bounds = (-np.inf,np.inf)
    
    def __repr__(self):
        return "ParametricPSF of size (%u,%u)"%self.Npix
    
    def __call__(self,*args,**kwargs):
        raise ValueError("ParametricPSF is not made to be instantiated. Better use its subclasses")
    
    
    def otf(self,*args,**kwargs):
        """Return the Optical Transfer Function (OTF)"""
        psf = self.__call__(args,kwargs)
        return fftshift(fft2(psf))
    
    def mtf(self,*args,**kwargs):
        """Return the Modulation Transfer Function (MTF)"""
        return np.abs(self.otf(args,kwargs))
    
    def tofits(self,param,filename,*args,keys=None,keys_comment=None,**kwargs):
        psf = self.__call__(param,*args,**kwargs)
        hdr = self._getfitshdr(param,keys=keys,keys_comment=keys_comment)
        hdu = fits.PrimaryHDU(psf, hdr)
        hdu.writeto(filename, overwrite=True)
        
    def _getfitshdr(self,param,keys=None,keys_comment=None):
        if keys is None:
            keys = ["PARAM %u"%(i+1) for i in range(len(param))]
        if len(keys)!=len(param):
            raise ValueError("When defined, `keys` must be same size as `param`")
        if keys_comment is not None:
            if len(keys_comment)!=len(param):
                raise ValueError("When defined, `keys_comment` must be same size as `param`")
        hdr = fits.Header()
        
        hdr["HIERARCH ORIGIN"] = "PAOMPY automatic header"
        hdr["HIERARCH CREATION"] = (time.ctime(),"Date of file creation")
        for i in range(len(param)):
            if keys_comment is None:
                hdr["HIERARCH PARAM "+keys[i]] = param[i]
            else:
                hdr["HIERARCH PARAM "+keys[i]] = (param[i],keys_comment[i])
        return hdr


class ConstantPSF(ParametricPSF):
    """Create a constant PSF, given as a 2D image, using ParametricPSF formalism
    With such a formalism, a constant PSF is just a particular case of a parametric PSF
    """
    def __init__(self,image_psf):
        super().__init__(np.shape(image_psf))
        self.image_psf = image_psf
        self.bounds = ()
        
    def __call__(self,*args,**kwargs):
        return self.image_psf
    


def moffat(XY, param, norm=None):
    """
    Compute a Moffat function on a meshgrid
    moff = A * Enorm * (1+u)^(-beta)
    with `u` the quadratic coordinates in the shifted and rotated frame
    
    Parameters
    ----------
    XY : numpy.ndarray (dim=2)
        The (X,Y) meshgrid with X = XY[0] and Y = XY[1]
    param : list, tuple, numpy.ndarray (len=7)
        param[0] - Amplitude
        param[1] - Alpha X
        param[2] - Alpha Y
        param[3] - Theta
        param[4] - Beta
        param[5] - Center X
        param[6] - Center Y
        
    Keywords
    --------
    norm : None, np.inf, float (>0), int (>0)
        Radius for energy normalization
        None      - No energy normalization
                    Enorm = 1.0
        np.inf    - Total energy normalization (on the whole X-Y plane)
                    Enorm = (beta-1)/(pi*ax*ay)
        float,int - Energy normalization up to the radius defined by this value
                    Enorm = (beta-1)/(pi*ax*ay)*(1-(1+(R**2)/(ax*ay))**(1-beta))
    
    Returns
    -------
    The 2D Moffat array
    
    """
    if len(param)!=7:
        raise ValueError("Parameter `param` must contain exactly 7 elements, but input has %u elements"%len(param))
    
    c = np.cos(param[3])
    s = np.sin(param[3])
    s2 = np.sin(2.0 * param[3])

    Rxx = (c / param[1]) ** 2 + (s / param[2]) ** 2
    Ryy = (c / param[2]) ** 2 + (s / param[1]) ** 2
    Rxy = s2 / param[2] ** 2 - s2 / param[1] ** 2
    
    u = Rxx * (XY[0]-param[5])**2 + Rxy * (XY[0]-param[5]) * (XY[1]-param[6]) + Ryy * (XY[1]-param[6])**2
    
    if norm is None:
        Enorm = 1
    elif norm == np.inf:
        if param[4]<=1:
            raise ValueError("Cannot compute Moffat energy for param[4]<=1")
        Enorm = (param[4]-1) / (np.pi*param[1]*param[2])
    else:
        if param[4]==1:
            raise ValueError("Energy computation for param[4]=1.0 not implemented yet. Sorry!")
        Enorm = (param[4]-1) / (np.pi*param[1]*param[2])
        Enorm = Enorm / (1 - (1 + (norm**2) / (param[1]*param[2]))**(1-param[4]))
    
    return Enorm * param[0] * (1. + u) ** (-param[4])


class Moffat(ParametricPSF):
    
    def __init__(self,Npix,norm=np.inf):
        #super(ParametricPSF,self).__init__(Npix)
        self.Npix = Npix
        self.norm = norm
        bounds_down = [_EPSILON,_EPSILON,-np.inf,1+_EPSILON]
        bounds_up = [np.inf for i in range(4)]
        self.bounds = (bounds_down,bounds_up)
    
    @lru_cache(maxsize=5)
    def _XY(self,Npix):
        YX = np.mgrid[0:Npix[0],0:Npix[1]]
        YX[1] = YX[1] - Npix[0]/2
        YX[0] = YX[0] - Npix[1]/2
        return YX
    
    def __call__(self,x,dx=0,dy=0):
        """
        Parameters
        ----------
        x : list, tuple, numpy.ndarray (len=4)
            x[0] - Alpha X
            x[1] - Alpha Y
            x[2] - Theta
            x[3] - Beta
        """
        y = np.concatenate(([1],x,[dx,dy]))
        return moffat(self._XY(self.Npix),y,norm=self.norm)
    
    def tofits(self,param,filename,*args,keys=["ALPHA_X","ALPHA_Y","THETA","BETA"],**kwargs):
        super(Moffat,self).tofits(param,filename,*args,keys=keys,**kwargs)


class Psfao(ParametricPSF):
    """PSF model based on a parametrization of the phase PSD
    See documentation of methods "__init__" and "__call__"
    
    """
    
    def __init__(self,Npix,system=None,Lext=10.,samp=None,symmetric=False,diffotf=True):
        """
        Parameters
        ----------
        Npix : tuple
            Size of output PSF
        system : OpticalSystem
            Optical system for this PSF
        samp : float
            Sampling at the observation wavelength
        Lext : float
            Von-Karman external scale (default = 10 m)
            Useless if Fao >> 1/Lext
        diffotf : bool
            Enable/disable diffraction OTF for PSF computation
            (default=True)
        """
        #super(ParametricPSF,self).__init__(Npix)
        if not (type(Npix) in [tuple,list,np.ndarray]):
            raise ValueError("Npix must be a tuple, list or numpy.ndarray")
        if len(Npix)!=2:
            raise ValueError("Npix must be of length = 2")
        if (Npix[0]%2) or (Npix[1]%2):
            raise ValueError("Each Npix component must be even")
        self.Npix = Npix
        if system is None:
            raise ValueError("Keyword `system` must be defined")
        if samp is None:
            raise ValueError("Keyword `samp` must be defined")
        self.system = system
        self.Lext = Lext
        self.samp = samp
        self.symmetric = symmetric
        self.diffotf = diffotf
    
    @property
    def symmetric(self):
        return self._symmetric
    
    @symmetric.setter
    def symmetric(self,value):
        self._symmetric = value
        if not value:
            bounds_down = [_EPSILON,0,_EPSILON,_EPSILON,_EPSILON,-np.inf,1+_EPSILON]
            bounds_up = [np.inf for i in range(7)]
        else:
            bounds_down = [_EPSILON,0,_EPSILON,_EPSILON,1+_EPSILON]
            bounds_up = [np.inf for i in range(5)]
        self.bounds = (bounds_down,bounds_up)
    
    @property
    def samp(self):
        return self._samp
    
    @samp.setter
    def samp(self,value):
        # Manage cases of undersampling
        self._samp = value
        if value >=2:
            self._samp_num = value
            self._k = 1
        else:
            self._k = int(np.ceil(2.0/value))
            self._samp_num = self._k * value
    
    @lru_cache(maxsize=2)
    def _freq_array(self,Nx,Ny,samp,D):
        """
        Returns
        -------
        tab - numpy.array (dim=3)
            3D array of frequencies [1/m]
        """
        pix2freq = 1.0/(D*samp)
        f2D = np.mgrid[0:Nx, 0:Ny].astype(float)
        #null frequency at [Nx//2,Ny//2] according to numpy fft convention
        f2D[0] -= Nx//2
        f2D[1] -= Ny//2
        return f2D * pix2freq
    
    @lru_cache(maxsize=2)
    def _shift_array(self,Nx,Ny):
        Y, X = np.mgrid[0:Nx,0:Ny].astype(float)
        X = (X-Nx/2) * 2*np.pi*1j/Nx
        Y = (Y-Ny/2) * 2*np.pi*1j/Ny
        return X, Y
    
    @lru_cache(maxsize=5)
    def _dlFTO(self,Nx,Ny,pupfct,samp):
        # samp as a tuple is not ready yet, don't use it
        if type(samp)==tuple:
            dlFTO = np.zeros((Nx,Ny,len(samp)))
            for i in range(len(samp)):
                dlFTO[...,i] = self._dlFTO(Nx,Ny,pupfct, samp[i])
            return np.mean(dlFTO,axis=2)
        
        NpupX = np.ceil(Nx/samp)
        NpupY = np.ceil(Ny/samp)
        tab = np.zeros((Nx, Ny), dtype=np.complex)
        tab[0:int(NpupX), 0:int(NpupY)] = pupfct((NpupX,NpupY),samp=samp)
        return fftshift(abs(ifft2(abs(fft2(tab)) ** 2)) / np.sum(tab))
    
    def psd(self,x0):
        """Compute the PSD model from parameters
        PSD is given in [rad²/f²] = [rad² m²]
        
        Parameters
        ----------
        x0 : numpy.array (dim=1), tuple, list
            See __call__ for more details
            
        Returns
        -------
        psd : numpy.array (dim=2)
            
        """
        if len(x0)==7:
            x = x0
        elif len(x0)==5:
            x = np.concatenate((x0[0:4],[x0[3],0],x0[4:]))
        else:
            raise ValueError("Wrong size of x0")
        
        f2D = self._freq_array(self.Npix[0]*self._k,self.Npix[1]*self._k,self._samp_num,self.system.D)
        F2 = f2D[0] ** 2. + f2D[1] ** 2.
        Fao = self.system.Nact/(2.0*self.system.D)
        
        PSD = 0.0229* x[0]**(-5./3.) * ((1. / self.Lext**2.) + F2)**(-11./6.)
        PSD *= (F2 >= Fao**2.)
        
        param = np.concatenate((x[2:],[0,0]))
        PSD += (F2 < Fao**2.) * np.abs(x[1] + moffat(f2D,param,norm=Fao))
        # Set PSD = 0 at null frequency (according to numpy fft convention)
        PSD[self.Npix[0]//2,self.Npix[1]//2] = 0.0
        return PSD
    
    def otf(self,x0,dx=0,dy=0,_caller='user'):
        """
        See __call__ for input arguments
        Warning: result of otf will be unconsistent if undersampled!!!
        This issue is solved with oversampling + binning in __call__ but not here
        For the moment, the `_caller` keyword prevents user to misuse otf
        """
        
        if (self._k > 1) and (_caller != 'self'):
            raise ValueError("Cannot call `Psfao.otf(...)` when undersampled (functionality not implemented yet)")
        
        PSD = self.psd(x0)
        
        L = self.system.D * self._samp_num
        Bg = fft2(fftshift(PSD)) / L ** 2
        Dphi = fftshift(np.real(2 * (Bg[0, 0] - Bg)))
        
        if self.diffotf:
            dlFTO = self._dlFTO(self.Npix[0]*self._k,self.Npix[1]*self._k, 
                            self.system.pupil, self._samp_num)      
        else:
            dlFTO = 1.
        X, Y = self._shift_array(self.Npix[0]*self._k,self.Npix[1]*self._k)
        return np.exp(-Dphi/2.)*dlFTO*np.exp(X*dx + Y*dy)
    
    def __call__(self,x0,dx=0,dy=0):
        """
        Parameters
        ----------
        x0 : numpy.array (dim=1), tuple, list
            x[0] - Fried parameter r0 [m]
            x[1] - PSD corrected area background C [rad² m²]
            x[2] - PSD corrected area phase variance A [rad²]
            x[3] - PSD alpha X [1/m]
            x[4] - PSD alpha Y [1/m]   (not defined in symmetric case)
            x[5] - PSD theta   [rad]   (not defined in symmetric case)
            x[6] - PSD beta power law  (becomes x[4] in symmetric case)
        dx : float
            PSF X shifting [pix] (default = 0)
        dy : float
            PSF Y shifting [pix] (default = 0)
            
        Returns
        -------
        tab : numpy.ndarray (dim=2)
            The PSF computed for the given parameters
            
        Note
        ----
        The PSD integral on the corrected area is x[2]+x[1]*PI*fao²
        """
        out = np.real(fftshift(ifft2(fftshift(self.otf(x0,dx=dx,dy=dy,_caller='self')))))
        out = out/out.sum() # ensure unit energy on the field of view
        
        if self._k==1:
            return out
        else:
            return binning(out,int(self._k))
        
    def tofits(self,param,filename,*args,keys=None,**kwargs):
        if keys is None:
            if len(param)==5:
                keys = ["R0","CST","SIGMA2","ALPHA","BETA"]
                keys_comment = ["Fried parameter [m]",
                                "PSD AO area constant C [rad2]",
                                "PSD AO area Moffat variance A [rad2]",
                                "PSD AO area Moffat alpha [1/m]",
                                "PSD AO area Moffat beta"]
            else: # if not 5, then equals 7
                keys = ["R0","CST","SIGMA2","ALPHA_X","ALPHA_Y","THETA","BETA"]
                keys_comment = ["Fried parameter [m]",
                                "PSD AO area constant C [rad2]",
                                "PSD AO area Moffat variance A [rad2]",
                                "PSD AO area Moffat alpha X [1/m]",
                                "PSD AO area Moffat alpha Y [1/m]",
                                "PSD AO area Moffat theta [rad]",
                                "PSD AO area Moffat beta"]
        
        # redefine tofits() because extra hdr is required
        psf = self.__call__(param,*args,**kwargs)
        hdr = self._getfitshdr(param,keys=keys,keys_comment=keys_comment)
        
        hdr["HIERARCH SYSTEM"] = (self.system.name,"System name")
        hdr["HIERARCH SAMP"] = (self.samp,"Sampling (eg. 2 for Shannon)")
        hdr["HIERARCH LEXT"] = (self.Lext,"Von-Karman outer scale")
        hdr["HIERARCH DIFFOTF"] = (self.diffotf,"Is diffraction OTF enabled")
        
        hdu = fits.PrimaryHDU(psf, hdr)
        hdu.writeto(filename, overwrite=True)