From 8da681e038b1da35236372bbcb25d30d76a1777f Mon Sep 17 00:00:00 2001
From: Alexis Lau <alexis.lau@lam.fr>
Date: Tue, 20 Jun 2023 17:12:05 +0200
Subject: [PATCH] Add jitter, phasemask and custom pupil from .ini

---
 maoppy/instrument.py | 71 +++++++++++++++++++++++++++++++++++---------
 1 file changed, 57 insertions(+), 14 deletions(-)

diff --git a/maoppy/instrument.py b/maoppy/instrument.py
index ee53877..2d32f8a 100644
--- a/maoppy/instrument.py
+++ b/maoppy/instrument.py
@@ -16,6 +16,8 @@ from scipy.interpolate import interp2d
 from maoppy.utils import circarr as _circarr
 from maoppy.utils import RAD2ARCSEC as _RAD2ARCSEC
 
+import maoppy.utils
+
 #%% INSTRUMENT CLASS
 class Instrument:
     """Represents an optical system (telescope to detector)
@@ -40,7 +42,7 @@ class Instrument:
         Pixel binning factor (default=1)
     """
 
-    def __init__(self, D=None, occ=0., res=None, Nact=0, gain=1., ron=1.):
+    def __init__(self, D=None, occ=0., res=None, Nact=0, gain=1., ron=1., custom_pupil = False, phasemask_enable = False, phasemask_path = None, jitter = None):
         
         if D is None:
             raise ValueError("Please enter keyword `D` to set Instrument's aperture diameter")
@@ -65,12 +67,14 @@ class Instrument:
         
         self.name = "default" # unique identifier name
         self.fullname = "MAOPPY Instrument" # human readable name
+        self.custom_pupil = custom_pupil
         
         # phasemask (not tested yet)
-        self.phasemask_enable = False
-        self.phasemask_path = None
+        self.phasemask_enable = phasemask_enable
+        self.phasemask_path = phasemask_path
         self.phasemask_shift = (0.0,0.0)
         self._phasemask = None
+        self.jitter = jitter
         
     def __str__(self):
         s  = "---------------------------------\n" 
@@ -102,13 +106,11 @@ class Instrument:
     
     def pupil(self,shape,wvl=None,samp=None):
         """Returns the 2D array of the pupil transmission function (complex data)"""
-        Dpix = min(shape)/2
-        pup = _circarr(shape)
         if self.phasemask_enable:
             if self._phasemask is None:
                 if self.phasemask_path is None:
                     raise ValueError('phasemask_path must be defined')
-                p = fits.open(self.phasemask_path)[0].data * 1e-9 # fits data in nm, converted here to meter
+                p = fits.getdata(self.phasemask_path)* 1e-9 # fits data in nm, converted here to meter
                 x = np.arange(p.shape[0])/p.shape[0]
                 y = np.arange(p.shape[1])/p.shape[1]
                 self._phasemask = interp2d(x,y,p)
@@ -117,11 +119,19 @@ class Instrument:
             y = np.arange(shape[1])/shape[1] - cy/shape[1]
             if wvl is None:
                 wvl = self.wvl(samp) # samp must be defined if wvl is None
-            wf = np.exp(2j*np.pi/wvl*self._phasemask(x,y))
+            wf = self._phasemask(x,y)
+            # wf = np.exp(2j*np.pi/wvl*self._phasemask(x,y))
         else:
             wf = 1.0 + 0j # complex type for output array, even if real data
-        return (pup < Dpix) * (pup >= Dpix*self.occ) * wf
-    
+            
+        if self.custom_pupil is False: 
+            Dpix = min(shape)/2
+            pup = _circarr(shape)
+            return (pup < Dpix) * (pup >= Dpix*self.occ) * wf
+        else: 
+            pup = fits.getdata(self.custom_pupil)
+            return pup, wf
+
     def samp(self,wvl):
         """Returns sampling value for the given wavelength"""
         return wvl/(self.resolution_rad*self.D)
@@ -130,22 +140,21 @@ class Instrument:
         """Returns wavelength for the given sampling"""
         return samp*(self.resolution_rad*self.D)
 
-
 #%% LOAD INSTRUMENT INSTANCES (make them attributes of this module)
 # this might be clumsy, should I make something like this:
 # from maoppy.instrument import load_instrument
 # zimpol = load_instrument("zimpol")
 
+# BUG - This would not work for pyinstaller because of the relative path issue
 def _get_data_folder():
-    folder = os.path.abspath(__file__)
+    folder = os.path.abspath(maoppy.utils.__file__)
     folder = os.sep.join(folder.split(os.sep)[0:-1])+os.sep+'data'+os.sep
+    print(folder)
     return folder
 
-
 def _get_all_ini(pth):
     return [f for f in os.listdir(pth) if f.endswith('.ini')]
 
-
 def load_ini(pth):
     """Create an Instrument instance from a path to a .ini file"""
     config = ConfigParser()
@@ -170,6 +179,7 @@ def load_ini(pth):
         pass
     # [camera]
     res_mas = float(config['camera']['res_mas'])
+
     # [filters]
     filters = {}
     if 'filters' in config.keys():
@@ -177,9 +187,38 @@ def load_ini(pth):
             s = config['filters'][filt]
             wvl_central,width = s[1:-1].split(',')
             filters[filt] = (float(wvl_central)*1e-9,float(width)*1e-9)
+
     # Make instrument
     res_rad = res_mas*1e-3 / _RAD2ARCSEC
-    instru = Instrument(D=d,occ=occ,res=res_rad,Nact=nact)
+
+    if 'pupil' in config.keys():
+        print('Loading a custom pupil')
+        custom_pupil = config['pupil']['pupil_path']
+    else: 
+        print('Normal pupil')
+        custom_pupil = False
+
+    # [phase_mask]
+    if 'phase_mask' in config.keys():
+        print('Loading a custom phase_mask')
+        phasemask_enable = True
+        phasemask_path = config['phase_mask']['path']
+    else: 
+        # phasemask (not tested yet)
+        phasemask_enable = False
+        phasemask_path = None
+
+    if 'jitter' in config.keys():
+        jitter = float(config['jitter']['jitter_mas']) *1e-3 / _RAD2ARCSEC
+    else: 
+        jitter = None
+
+    instru = Instrument(D=d,occ=occ,res=res_rad,Nact=nact, 
+        custom_pupil = custom_pupil,
+        phasemask_path = phasemask_path, 
+        phasemask_enable = phasemask_enable,
+        jitter= jitter)
+
     instru.name = tag
     instru.fullname = name
     instru.filters = filters
@@ -189,10 +228,14 @@ def load_ini(pth):
     
 
 
+# BUG - disable for now to make sure it would work for gui 
+
 _this_module = sys.modules[__name__]
+
 _d = _get_data_folder()
 for _f in _get_all_ini(_d):
     _instru = load_ini(_d+_f)
+    print(_d+_f)
     _n = _instru.name.lower().replace(" ","_") # format name
     setattr(_this_module, _n, _instru)
     
-- 
GitLab