Skip to content
Snippets Groups Projects
psfmodel.py 19.7 KiB
Newer Older
  • Learn to ignore specific revisions
  • #!/usr/bin/env python3
    # -*- coding: utf-8 -*-
    """
    Created on Mon May 27 17:31:18 2019
    
    @author: rfetick
    """
    
    
    FETICK Romain's avatar
    FETICK Romain committed
    9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578
    import numpy as np
    from scipy.optimize import least_squares
    from astropy.io import fits
    import time
    from numpy.fft import fft2, fftshift, ifft2
    from functools import lru_cache
    from paompy.config import _EPSILON
    from paompy.utils import binning
    
    #%% FITTING FUNCTION
    def lsq_flux_bck(model, data, weights, background=True, positive_bck=False):
        """Compute the analytical least-square solution for flux and background
        LS = SUM_pixels { weights*(flux*model + bck - data)² }
        
        Parameters
        ----------
        model: numpy.ndarray
        data: numpy.ndarray
        weights: numpy.ndarray
        
        Keywords
        --------
        background: bool
            Activate/inactivate background (activated by default:True)
        positive_bck : bool
            Makes background positive (default:False)
        """
        ws = np.sum(weights)
        mws = np.sum(model * weights)
        mwds = np.sum(model * weights * data)
        m2ws = np.sum(weights * (model ** 2))
        wds = np.sum(weights * data)
    
        if background:
            delta = mws ** 2 - ws * m2ws
            amp = 1. / delta * (mws * wds - ws * mwds)
            bck = 1. / delta * (-m2ws * wds + mws * mwds)
        else:
            amp = mwds / m2ws
            bck = 0.0
            
        if bck<0 and positive_bck: #re-implement above equation
            amp = mwds / m2ws
            bck = 0.0
    
        return amp, bck
    
    def psffit(psf,Model,x0,weights=None,dxdy=(0.,0.),flux_bck=(True,True),
               positive_bck=False,fixed=None,**kwargs):
        """Fit a PSF with a parametric model solving the least-square problem
           epsilon(x) = SUM_pixel { weights * (amp * Model(x) + bck - psf)² }
        
        Parameters
        ----------
        psf : numpy.ndarray
            The experimental image to be fitted
        Model : class
            The class representing the fitting model
        x0 : tuple, list, numpy.ndarray
            Initial guess for parameters
        weights : numpy.ndarray
            Least-square weighting matrix (same size as `psf`)
            Default: uniform weighting
        dxdy : tuple of two floats
            Eventual guess on PSF shifting
        flux_bck : tuple of two bool
            Only background can be activate/inactivated
            Flux is always activated (sorry!)
        positive_bck : bool
            Force background to be positive or null
        fixed : numpy.ndarray
            Fix some parameters to their initial value (default: None)
        **kwargs :
            All keywords used to instantiate your `Model`
        
        Returns
        -------
        out.x : numpy.array
                Parameters at optimum
           .dxdy : tuple of 2 floats
               PSF shift at optimum
           .flux_bck : tuple of two floats
               Estimated image flux and background
           .psf : numpy.ndarray (dim=2)
               Image of the PSF model at optimum
           .success : bool
               Minimization success
           .status : int
               Minimization status (see scipy doc)
           .message : string
               Human readable minimization status
           .active_mask : numpy.array
               Saturated bounds
           .nfev : int
               Number of function evaluations
           .cost : float
               Value of cost function at optimum
        """
        Model_inst = Model(np.shape(psf),**kwargs)
        if weights is None:
            weights = np.ones_like(psf)
        elif len(psf)!=len(weights):
            raise ValueError("Keyword `weights` must have same number of elements as `psf`")
        sqW = np.sqrt(weights)
        
        class CostClass(object):
            def __init__(self):
                self.iter = 0
            def __call__(self,y):
                if (self.iter%3) == 0:
                    print("-",end="")
                self.iter += 1
                
                x, dxdy = mini2input(y)
                m = Model_inst(x,dx=dxdy[0],dy=dxdy[1])
                amp, bck = lsq_flux_bck(m, psf, weights, background=flux_bck[1], positive_bck=positive_bck)
                return np.reshape(sqW * (amp * m + bck - psf), np.size(psf))
        
        cost = CostClass()
        
        if fixed is not None:
            if len(fixed)!=len(x0):
                raise ValueError("When defined, `fixed` must be same size as `x0`")
            FREE = [not fixed[i] for i in range(len(fixed))]
            INDFREE = np.where(FREE)[0]
        
        def input2mini(x,dxdy):
            # Transform user parameters to minimizer parameters
            if fixed is None:
                xfree = x
            else:
                xfree = np.take(x,INDFREE)
            return np.concatenate((xfree,dxdy))
        
        def mini2input(y):
            # Transform minimizer parameters to user parameters
            if fixed is None:
                xall = y[0:-2]
            else:
                xall = np.copy(x0)
                for i in range(len(y)-2):
                    xall[INDFREE[i]] = y[i]
            return (xall,y[-2:])
        
        def get_bound(inst):
            b_low = inst.bounds[0]
            if fixed is not None:
                b_low = np.take(b_low,INDFREE)
            b_low = np.concatenate((b_low,[-np.inf,-np.inf]))
            b_up = inst.bounds[1]
            if fixed is not None:
                b_up = np.take(b_up,INDFREE)
            b_up = np.concatenate((b_up,[np.inf,np.inf]))
            return (b_low,b_up)
        
        result = least_squares(cost, input2mini(x0,dxdy), bounds=get_bound(Model_inst))
        
        print("") #finish line of "-"
        
        result.x, result.dxdy = mini2input(result.x) #split output between x and dxdy
    
        m = Model_inst(result.x,dx=result.dxdy[0],dy=result.dxdy[1])
        amp, bck = lsq_flux_bck(m, psf, weights, background=flux_bck[1], positive_bck=positive_bck)
        
        result.flux_bck = (amp,bck)
        result.psf = m    
        return result
    
    #%% CLASS PARAMETRIC PSF AND ITS SUBCLASSES
    class ParametricPSF(object):
        """Super-class defining parametric PSFs
        Not to be instantiated, only serves as a referent for subclasses
        """
        
        def __init__(self,Npix):
            """
            Parameters
            ----------
            Npix : tuple of two elements
                Model X and Y pixel size when called
            """
            if type(Npix)!=tuple:
                raise TypeError("Argument `Npix` must be a tuple")
            self.Npix = Npix
            self.bounds = (-np.inf,np.inf)
        
        def __repr__(self):
            return "ParametricPSF of size (%u,%u)"%self.Npix
        
        def __call__(self,*args,**kwargs):
            raise ValueError("ParametricPSF is not made to be instantiated. Better use its subclasses")
        
        
        def otf(self,*args,**kwargs):
            """Return the Optical Transfer Function (OTF)"""
            psf = self.__call__(args,kwargs)
            return fftshift(fft2(psf))
        
        def mtf(self,*args,**kwargs):
            """Return the Modulation Transfer Function (MTF)"""
            return np.abs(self.otf(args,kwargs))
        
        def tofits(self,param,filename,*args,keys=None,keys_comment=None,**kwargs):
            psf = self.__call__(param,*args,**kwargs)
            hdr = self._getfitshdr(param,keys=keys,keys_comment=keys_comment)
            hdu = fits.PrimaryHDU(psf, hdr)
            hdu.writeto(filename, overwrite=True)
            
        def _getfitshdr(self,param,keys=None,keys_comment=None):
            if keys is None:
                keys = ["PARAM %u"%(i+1) for i in range(len(param))]
            if len(keys)!=len(param):
                raise ValueError("When defined, `keys` must be same size as `param`")
            if keys_comment is not None:
                if len(keys_comment)!=len(param):
                    raise ValueError("When defined, `keys_comment` must be same size as `param`")
            hdr = fits.Header()
            
            hdr["HIERARCH ORIGIN"] = "PAOMPY automatic header"
            hdr["HIERARCH CREATION"] = (time.ctime(),"Date of file creation")
            for i in range(len(param)):
                if keys_comment is None:
                    hdr["HIERARCH PARAM "+keys[i]] = param[i]
                else:
                    hdr["HIERARCH PARAM "+keys[i]] = (param[i],keys_comment[i])
            return hdr
    
    
    class ConstantPSF(ParametricPSF):
        """Create a constant PSF, given as a 2D image, using ParametricPSF formalism
        With such a formalism, a constant PSF is just a particular case of a parametric PSF
        """
        def __init__(self,image_psf):
            super().__init__(np.shape(image_psf))
            self.image_psf = image_psf
            self.bounds = ()
            
        def __call__(self,*args,**kwargs):
            return self.image_psf
        
    
    
    def moffat(XY, param, norm=None):
        """
        Compute a Moffat function on a meshgrid
        moff = A * Enorm * (1+u)^(-beta)
        with `u` the quadratic coordinates in the shifted and rotated frame
        
        Parameters
        ----------
        XY : numpy.ndarray (dim=2)
            The (X,Y) meshgrid with X = XY[0] and Y = XY[1]
        param : list, tuple, numpy.ndarray (len=7)
            param[0] - Amplitude
            param[1] - Alpha X
            param[2] - Alpha Y
            param[3] - Theta
            param[4] - Beta
            param[5] - Center X
            param[6] - Center Y
            
        Keywords
        --------
        norm : None, np.inf, float (>0), int (>0)
            Radius for energy normalization
            None      - No energy normalization
                        Enorm = 1.0
            np.inf    - Total energy normalization (on the whole X-Y plane)
                        Enorm = (beta-1)/(pi*ax*ay)
            float,int - Energy normalization up to the radius defined by this value
                        Enorm = (beta-1)/(pi*ax*ay)*(1-(1+(R**2)/(ax*ay))**(1-beta))
        
        Returns
        -------
        The 2D Moffat array
        
        """
        if len(param)!=7:
            raise ValueError("Parameter `param` must contain exactly 7 elements, but input has %u elements"%len(param))
        
        c = np.cos(param[3])
        s = np.sin(param[3])
        s2 = np.sin(2.0 * param[3])
    
        Rxx = (c / param[1]) ** 2 + (s / param[2]) ** 2
        Ryy = (c / param[2]) ** 2 + (s / param[1]) ** 2
        Rxy = s2 / param[2] ** 2 - s2 / param[1] ** 2
        
        u = Rxx * (XY[0]-param[5])**2 + Rxy * (XY[0]-param[5]) * (XY[1]-param[6]) + Ryy * (XY[1]-param[6])**2
        
        if norm is None:
            Enorm = 1
        elif norm == np.inf:
            if param[4]<=1:
                raise ValueError("Cannot compute Moffat energy for param[4]<=1")
            Enorm = (param[4]-1) / (np.pi*param[1]*param[2])
        else:
            if param[4]==1:
                raise ValueError("Energy computation for param[4]=1.0 not implemented yet. Sorry!")
            Enorm = (param[4]-1) / (np.pi*param[1]*param[2])
            Enorm = Enorm / (1 - (1 + (norm**2) / (param[1]*param[2]))**(1-param[4]))
        
        return Enorm * param[0] * (1. + u) ** (-param[4])
    
    
    class Moffat(ParametricPSF):
        
        def __init__(self,Npix,norm=np.inf):
            #super(ParametricPSF,self).__init__(Npix)
            self.Npix = Npix
            self.norm = norm
            bounds_down = [_EPSILON,_EPSILON,-np.inf,1+_EPSILON]
            bounds_up = [np.inf for i in range(4)]
            self.bounds = (bounds_down,bounds_up)
        
        @lru_cache(maxsize=5)
        def _XY(self,Npix):
            YX = np.mgrid[0:Npix[0],0:Npix[1]]
            YX[1] = YX[1] - Npix[0]/2
            YX[0] = YX[0] - Npix[1]/2
            return YX
        
        def __call__(self,x,dx=0,dy=0):
            """
            Parameters
            ----------
            x : list, tuple, numpy.ndarray (len=4)
                x[0] - Alpha X
                x[1] - Alpha Y
                x[2] - Theta
                x[3] - Beta
            """
            y = np.concatenate(([1],x,[dx,dy]))
            return moffat(self._XY(self.Npix),y,norm=self.norm)
        
        def tofits(self,param,filename,*args,keys=["ALPHA_X","ALPHA_Y","THETA","BETA"],**kwargs):
            super(Moffat,self).tofits(param,filename,*args,keys=keys,**kwargs)
    
    
    class Psfao(ParametricPSF):
        """PSF model based on a parametrization of the phase PSD
        See documentation of methods "__init__" and "__call__"
        
        """
        
        def __init__(self,Npix,system=None,Lext=10.,samp=None,symmetric=False,diffotf=True):
            """
            Parameters
            ----------
            Npix : tuple
                Size of output PSF
            system : OpticalSystem
                Optical system for this PSF
            samp : float
                Sampling at the observation wavelength
            Lext : float
                Von-Karman external scale (default = 10 m)
                Useless if Fao >> 1/Lext
            diffotf : bool
                Enable/disable diffraction OTF for PSF computation
                (default=True)
            """
            #super(ParametricPSF,self).__init__(Npix)
            if not (type(Npix) in [tuple,list,np.ndarray]):
                raise ValueError("Npix must be a tuple, list or numpy.ndarray")
            if len(Npix)!=2:
                raise ValueError("Npix must be of length = 2")
            if (Npix[0]%2) or (Npix[1]%2):
                raise ValueError("Each Npix component must be even")
            self.Npix = Npix
            if system is None:
                raise ValueError("Keyword `system` must be defined")
            if samp is None:
                raise ValueError("Keyword `samp` must be defined")
            self.system = system
            self.Lext = Lext
            self.samp = samp
            self.symmetric = symmetric
            self.diffotf = diffotf
        
        @property
        def symmetric(self):
            return self._symmetric
        
        @symmetric.setter
        def symmetric(self,value):
            self._symmetric = value
            if not value:
                bounds_down = [_EPSILON,0,_EPSILON,_EPSILON,_EPSILON,-np.inf,1+_EPSILON]
                bounds_up = [np.inf for i in range(7)]
            else:
                bounds_down = [_EPSILON,0,_EPSILON,_EPSILON,1+_EPSILON]
                bounds_up = [np.inf for i in range(5)]
            self.bounds = (bounds_down,bounds_up)
        
        @property
        def samp(self):
            return self._samp
        
        @samp.setter
        def samp(self,value):
            # Manage cases of undersampling
            self._samp = value
            if value >=2:
                self._samp_num = value
                self._k = 1
            else:
                self._k = int(np.ceil(2.0/value))
                self._samp_num = self._k * value
        
        @lru_cache(maxsize=2)
        def _freq_array(self,Nx,Ny,samp,D):
            """
            Returns
            -------
            tab - numpy.array (dim=3)
                3D array of frequencies [1/m]
            """
            pix2freq = 1.0/(D*samp)
            f2D = np.mgrid[0:Nx, 0:Ny].astype(float)
            #null frequency at [Nx//2,Ny//2] according to numpy fft convention
            f2D[0] -= Nx//2
            f2D[1] -= Ny//2
            return f2D * pix2freq
        
        @lru_cache(maxsize=2)
        def _shift_array(self,Nx,Ny):
            Y, X = np.mgrid[0:Nx,0:Ny].astype(float)
            X = (X-Nx/2) * 2*np.pi*1j/Nx
            Y = (Y-Ny/2) * 2*np.pi*1j/Ny
            return X, Y
        
        @lru_cache(maxsize=5)
        def _dlFTO(self,Nx,Ny,pupfct,samp):
            # samp as a tuple is not ready yet, don't use it
            if type(samp)==tuple:
                dlFTO = np.zeros((Nx,Ny,len(samp)))
                for i in range(len(samp)):
                    dlFTO[...,i] = self._dlFTO(Nx,Ny,pupfct, samp[i])
                return np.mean(dlFTO,axis=2)
            
            NpupX = np.ceil(Nx/samp)
            NpupY = np.ceil(Ny/samp)
            tab = np.zeros((Nx, Ny), dtype=np.complex)
            tab[0:int(NpupX), 0:int(NpupY)] = pupfct((NpupX,NpupY),samp=samp)
            return fftshift(abs(ifft2(abs(fft2(tab)) ** 2)) / np.sum(tab))
        
        def psd(self,x0):
            """Compute the PSD model from parameters
            PSD is given in [rad²/f²] = [rad² m²]
            
            Parameters
            ----------
            x0 : numpy.array (dim=1), tuple, list
                See __call__ for more details
                
            Returns
            -------
            psd : numpy.array (dim=2)
                
            """
            if len(x0)==7:
                x = x0
            elif len(x0)==5:
                x = np.concatenate((x0[0:4],[x0[3],0],x0[4:]))
            else:
                raise ValueError("Wrong size of x0")
            
            f2D = self._freq_array(self.Npix[0]*self._k,self.Npix[1]*self._k,self._samp_num,self.system.D)
            F2 = f2D[0] ** 2. + f2D[1] ** 2.
            Fao = self.system.Nact/(2.0*self.system.D)
            
            PSD = 0.0229* x[0]**(-5./3.) * ((1. / self.Lext**2.) + F2)**(-11./6.)
            PSD *= (F2 >= Fao**2.)
            
            param = np.concatenate((x[2:],[0,0]))
            PSD += (F2 < Fao**2.) * np.abs(x[1] + moffat(f2D,param,norm=Fao))
            # Set PSD = 0 at null frequency (according to numpy fft convention)
            PSD[self.Npix[0]//2,self.Npix[1]//2] = 0.0
            return PSD
        
        def otf(self,x0,dx=0,dy=0,_caller='user'):
            """
            See __call__ for input arguments
            Warning: result of otf will be unconsistent if undersampled!!!
            This issue is solved with oversampling + binning in __call__ but not here
            For the moment, the `_caller` keyword prevents user to misuse otf
            """
            
            if (self._k > 1) and (_caller != 'self'):
                raise ValueError("Cannot call `Psfao.otf(...)` when undersampled (functionality not implemented yet)")
            
            PSD = self.psd(x0)
            
            L = self.system.D * self._samp_num
            Bg = fft2(fftshift(PSD)) / L ** 2
            Dphi = fftshift(np.real(2 * (Bg[0, 0] - Bg)))
            
            if self.diffotf:
                dlFTO = self._dlFTO(self.Npix[0]*self._k,self.Npix[1]*self._k, 
                                self.system.pupil, self._samp_num)      
            else:
                dlFTO = 1.
            X, Y = self._shift_array(self.Npix[0]*self._k,self.Npix[1]*self._k)
            return np.exp(-Dphi/2.)*dlFTO*np.exp(X*dx + Y*dy)
        
        def __call__(self,x0,dx=0,dy=0):
            """
            Parameters
            ----------
            x0 : numpy.array (dim=1), tuple, list
                x[0] - Fried parameter r0 [m]
                x[1] - PSD corrected area background C [rad² m²]
                x[2] - PSD corrected area phase variance A [rad²]
                x[3] - PSD alpha X [1/m]
                x[4] - PSD alpha Y [1/m]   (not defined in symmetric case)
                x[5] - PSD theta   [rad]   (not defined in symmetric case)
                x[6] - PSD beta power law  (becomes x[4] in symmetric case)
            dx : float
                PSF X shifting [pix] (default = 0)
            dy : float
                PSF Y shifting [pix] (default = 0)
                
            Returns
            -------
            tab : numpy.ndarray (dim=2)
                The PSF computed for the given parameters
                
            Note
            ----
            The PSD integral on the corrected area is x[2]+x[1]*PI*fao²
            """
            out = np.real(fftshift(ifft2(fftshift(self.otf(x0,dx=dx,dy=dy,_caller='self')))))
            out = out/out.sum() # ensure unit energy on the field of view
            
            if self._k==1:
                return out
            else:
                return binning(out,int(self._k))
            
        def tofits(self,param,filename,*args,keys=None,**kwargs):
            if keys is None:
                if len(param)==5:
                    keys = ["R0","CST","SIGMA2","ALPHA","BETA"]
                    keys_comment = ["Fried parameter [m]",
                                    "PSD AO area constant C [rad2]",
                                    "PSD AO area Moffat variance A [rad2]",
                                    "PSD AO area Moffat alpha [1/m]",
                                    "PSD AO area Moffat beta"]
                else: # if not 5, then equals 7
                    keys = ["R0","CST","SIGMA2","ALPHA_X","ALPHA_Y","THETA","BETA"]
                    keys_comment = ["Fried parameter [m]",
                                    "PSD AO area constant C [rad2]",
                                    "PSD AO area Moffat variance A [rad2]",
                                    "PSD AO area Moffat alpha X [1/m]",
                                    "PSD AO area Moffat alpha Y [1/m]",
                                    "PSD AO area Moffat theta [rad]",
                                    "PSD AO area Moffat beta"]
            
            # redefine tofits() because extra hdr is required
            psf = self.__call__(param,*args,**kwargs)
            hdr = self._getfitshdr(param,keys=keys,keys_comment=keys_comment)
            
            hdr["HIERARCH SYSTEM"] = (self.system.name,"System name")
            hdr["HIERARCH SAMP"] = (self.samp,"Sampling (eg. 2 for Shannon)")
            hdr["HIERARCH LEXT"] = (self.Lext,"Von-Karman outer scale")
            hdr["HIERARCH DIFFOTF"] = (self.diffotf,"Is diffraction OTF enabled")
            
            hdu = fits.PrimaryHDU(psf, hdr)
            hdu.writeto(filename, overwrite=True)