test_py_unsio.py 6.32 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#!/usr/bin/python

#
# This program test unsio library by readind and saving same file in different output
# file format (gadget2, nemo) and comparing all arrays with original one
#
from __future__ import print_function    
import sys
import argparse
import numpy as np                # arrays are treated as numpy arrays

import os.path
#dirname, filename = os.path.split(os.path.abspath(__file__))
#sys.path.append(dirname+'../modules/')  # trick to find modules directory
LAMBERT Jean-charles's avatar
LAMBERT Jean-charles committed
15
from unsio import *
16 17
import copy
import tempfile
LAMBERT Jean-charles's avatar
LAMBERT Jean-charles committed
18 19 20 21
try:
    from IPython import embed
except:
    pass
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206

class snap:
    time   = None
    nbody  = None
    mass   = None
    pos    = None
    vel    = None
    id     = None
    age    = None
    hsml   = None
    rho    = None
    metal  = None
    interface = ""
 
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# commandLine, parse the command line 
def commandLine():
    # help
    parser = argparse.ArgumentParser(description="test unsio library",
                                    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    # options
    parser.add_argument('simname', help='Simulation name')
    parser.add_argument('component', help='component name')
    float_parser=parser.add_mutually_exclusive_group(required=True)
    float_parser.add_argument('--float',    help='floating format operation',action='store_true')
    float_parser.add_argument('--double', help='double format operation',  action='store_true')
    parser.add_argument('--out', help='save output file name ?',default="")
        # parse
    args = parser.parse_args()

    # start main funciton
    process(args)

# -----------------------------------------------------
def readSnap(simname, comp, single):
    components=comp 
    verbose=False

    
    #timef=float(times)
    
    # Create a UNSIO object
    if (single) :
        print("Float format object")
        uns = CunsIn(simname,components,"all",verbose)
    else:
        print("Double format object")
        uns = CunsInD(simname,components,"all",verbose)

    print ("simname=",simname,file=sys.stderr)

    mysnap=snap() # instantiate a snap object
    # load frame
    ok=uns.nextFrame("")
    #print ok

    if (ok) :
        #embed()
        mysnap.interface = uns.getInterfaceType()
        ok,mysnap.time = uns.getValueF("time")
        ok,mysnap.pos  = uns.getArrayF(comp,"pos")
        ok,mysnap.vel  = uns.getArrayF(comp,"vel")
        ok,mysnap.mass = uns.getArrayF(comp,"mass")
        ok,mysnap.hsml = uns.getArrayF(comp,"hsml")
        ok,mysnap.rho  = uns.getArrayF(comp,"rho")
        ok,mysnap.age  = uns.getArrayF(comp,"age")
        ok,mysnap.metal= uns.getArrayF(comp,"metal")
        if ok:
            mysnap.metal=fixMetal(mysnap.metal)
        ok,mysnap.id   = uns.getArrayI(comp,"id")
        uns.close()
        return True,copy.deepcopy(mysnap)
        
    else :
        print ("Didn't load anything....",file=sys.stderr)

    return False

# -----------------------------------------------------
def saveSnap(insnap,comp,unstype,single):
  if (unstype=="nemo"):
      comp="all"
  if (unstype=="gadget2" and insnap.interface=="Nemo"):
      comp="gas"

  # create a temporary name file
  f = tempfile.NamedTemporaryFile()
  myfile=f.name # get filename
  f.close()     # remode temporary file
  
  ## SAVE FILE
  # instantiate output object
  if (single):
      unso=CunsOut(myfile,unstype);    # output file
  else:
      unso=CunsOutD(myfile,unstype);    # output file
      
  print("\nSaving in ",unstype," format......")
  if (insnap.time) :
      unso.setValueF("time",insnap.time)      # save time
  # proceed on all real array
  for attr in ("pos","vel","mass","age","hsml","rho","metal","id"):
      name=getattr(insnap,attr)
      if (name.size and attr=="id"):
          print("save -> ",attr)
          unso.setArrayI(comp,attr,name) # save integer array
      else :
          if (name.size) :
              print("save -> ",attr)
              unso.setArrayF(comp,attr,name) # save real arrays
  unso.save()
  unso.close()

  ## READ FILE BACK
  mysnap = snap()
  ok,mysnap = readSnap(myfile,comp,single)
  #embed()
  os.remove(myfile) # rmove temporary file

  return True,copy.deepcopy(mysnap)
  
# -----------------------------------------------------
def compareArray(CA,CB,attr):
    #embed()
    A=getattr(CA,attr)
    B=getattr(CB,attr)
    ok=False
    disp=True
    if notCompare(CA,CB,attr):
        disp=False
    else:
        if attr=="time":
            ok = (A==B)
        else:
            ok=(A==B).all()
            if ok :
                if (A.size):
                    disp=True
                else:
                    disp=False

    if (disp):
        print("[",attr,"]",ok)
        if not ok:
            print("\tA:",A[0:2],"\n\tB:",B[0:2])

# -----------------------------------------------------
# do not compare in the following cases
def notCompare(CA,CB,attr):
    status=False
    if (CA.interface=="Nemo" or CB.interface=="Nemo"):
        if attr=="metal":
            status=True
        if attr=="age":
            status=True
        if status:
            print("<",attr,"> attribute not supported with NEMO format")
    A=getattr(CA,attr)
    if (attr != "time" and A.size==0):
        status=True
        print("In <",attr,"> attribute not tested")
    return status
        
# -----------------------------------------------------
def compare(CA,CB):
    print("-----------------------------------------------------")
    print("Comparing : [",CA.interface,"] vs [",CB.interface,"]\n")
    for attr in ("pos","vel","mass","age","hsml","rho","metal", "id","time"):
        compareArray(CA,CB,attr)

# -----------------------------------------------------
def fixMetal(metal):
    if (metal==-1.0).all() :
        print("fixing metal....")
        return np.empty(0)
    else:
        return metal

# -----------------------------------------------------
# process
def process(args):
  print("simname = ",args.simname)
  
  ok,insnap=readSnap(args.simname,args.component,args.float)
  ok,svsnap=saveSnap(insnap,args.component,"gadget2",args.float)
207
  ok,hdf5snap=saveSnap(insnap,args.component,"gadget3",args.float)
208 209
  ok,nemosnap=saveSnap(insnap,args.component,"nemo",args.float)
  compare(insnap,svsnap)
210
  compare(insnap,hdf5snap)
211 212 213 214 215 216
  compare(insnap,nemosnap)

# -----------------------------------------------------
# main program
commandLine()   # parse command line
#