Skip to content
Snippets Groups Projects
ModelGrid.py 16.5 KiB
Newer Older
#
# Reads ans manages the model grid
#
import numpy as np
import Projections as PS
from netCDF4 import Dataset
# Doc : https://spacetelescope.github.io/spherical_geometry/api/spherical_geometry.polygon.SphericalPolygon.html
from spherical_geometry import polygon
import RPPtools as RPP
import sys
#
import configparser
config=configparser.ConfigParser({"WEST_EAST":"-180., 180", "SOUTH_NORTH":"-90., 90."})
config.read("run.def")
#
import getargs
log_master, log_world = getargs.getLogger(__name__)
INFO, DEBUG, ERROR = log_master.info, log_master.debug, log_world.error
INFO_ALL, DEBUG_ALL = log_world.info, log_world.debug
#
#
EarthRadius=config.getfloat("OverAll", "EarthRadius")
rose=[[-1,-1],[-1,0],[-1,+1],[0,+1],[+1,+1],[+1,0],[+1,-1],[0,-1]]
epsilon=0.00001
#
###################################################################################
#
# Find in the list coord the points closes to the point given by lon,lat.
#
def mindist(coord,lon,lat) :
    d=[]
    for c in coord :
        d.append(np.sqrt((c[0]-lon)**2+(c[1]-lat)**2))
    return np.argmin(d)
#
# Function to gather all land points but while keeping also the neighbour information.
#
def gatherland(lon, lat, land, indP, indFi, indFj) :
    nj,ni=lon.shape
    coord=[]
    indF_land=[]
    indP_land=[]
    for i in range(ni) :
        for j in range(nj) :
            if (land[j,i] > 0 ) :
                coord.append([lon[j,i],lat[j,i]])
                indF_land.append([indFi[j,i],indFj[j,i]])
                indP_land.append([j,i])
    nbland=len(coord)
    #
    # Get the neighbours in the coord list. The same order is used as in ORCHIDEE
    #
    neighbours=[]
    for i in range(ni) :
        for j in range(nj) :
           if (land[j,i] > 0 ) :
               ntmp=[]
               #
               # Indices of neighbouring points are in C and thus +1 will be performed
               # For FORTRAN.
               # -1 will become 0 and indicate point is outside of domain.
               # -2 will become -1 and indicate ocean.
               #
               for r in rose :
                   nnj=j+r[0]
                   nni=i+r[1]
                   if ( nni >= 0 and nni < ni and nnj >= 0 and nnj < nj) :
                       if land[nnj,nni] > 0 :
                           ntmp.append(mindist(coord,lon[nnj,nni],lat[nnj,nni]))
                       else : 
                           ntmp.append(-2)
                   else :
               neighbours.append(ntmp)
               
    return nbland, coord,neighbours,indP_land,indF_land
#
#
#
def corners(indF, proj, istart, jstart, lon, lat) :
    cornersll=[]
    cornerspoly=[]
    areas=[]
    allon=[]
    allat=[]
    #
    dlon=int((lon[0,-1]-lon[0,0])/np.abs(lon[0,-1]-lon[0,0]))
    dlat=int((lat[0,0]-lat[-1,0])/np.abs(lat[0,0]-lat[-1,0]))
    #
    for ij in indF :
        #
        # Get the corners and mid-points of segments to completely describe the polygone
        #
        polyg = RPP.boxit([istart+ij[0],jstart+ij[1]], dlon, dlat)
        polyll=proj.ijll(polyg)
        #
        allon.append([p[0] for p in polyll])
        allat.append([p[1] for p in polyll])
        #
        sphpoly=polygon.SphericalPolygon.from_lonlat([p[0] for p in polyll], [p[1] for p in polyll], center=centll)
        #
        areas.append((sphpoly.area())*EarthRadius**2)
        cornerspoly.append(sphpoly)
        cornersll.append(polyll)

    box=[[np.min(np.array(allon)),np.max(np.array(allon))],[np.min(np.array(allat)),np.max(np.array(allat))]]
        
    return cornersll, cornerspoly, areas, box
#
# Extract the coordinates of the grid file. If ni < 0 and nj < 0 then the
# full grid is read.
# 
def getcoordinates(geo, istart, ni, jstart, nj) :
    #
    # Guess the type of grid
    #
    griddesc = {}
    if "DX" in geo.ncattrs() :
        griddesc['type'] = "RegXY"
    elif len(geo.variables["lon"].shape) == 1 :
        griddesc['type'] = "RegLonLat"
    else :
        ERROR("We could not guess the grid type")
        sys.exit()
    #
    # Extract grid information
    #
    if griddesc['type'] == "RegXY"  :
        # We have a geogrid file from WRF
        griddesc['dx'] = geo.DX
        griddesc['known_lon'] = geo.corner_lons[0]
        griddesc['known_lat'] = geo.corner_lats[0]
        griddesc['truelat1'] = geo.TRUELAT1
        griddesc['truelat2'] = geo.TRUELAT2
        griddesc['stdlon'] = geo.STAND_LON
        #
        # Verify the region chosen
        #
        nbt, nbj_g, nbi_g = geo.variables["XLONG_M"].shape
        if ni < 0 and nj < 0 :
            istart = 0
            ni = nbi_g
            jstart = 0
            nj = nbj_g
        #
        if istart > nbi_g or istart+ni > nbi_g :
            ERROR("The sub-domain is not possible in longitude : Total size = "+str(nbi_g)+" istart = "+str(istart)+" ni = "+str(ni))
            sys.exit()
        if jstart > nbj_g or jstart+nj > nbj_g :
            ERROR("The sub-domain is not possible in latitude : Total size = "+str(nbj_g)+" jstart = "+str(jstart)+" nj = "+str(nj))
            sys.exit()
        #
        # Extract grid
        #
        lon_full=np.copy(geo.variables["XLONG_M"][0,jstart:jstart+nj,istart:istart+ni])
        lat_full=np.copy(geo.variables["XLAT_M"][0,jstart:jstart+nj,istart:istart+ni])
    elif griddesc['type'] == "RegLonLat"  :
        # We have a regulat lat/lon grid
        nbi_g = geo.variables["lon"].shape[0]
        nbj_g = geo.variables["lat"].shape[0]
        #
        if ni < 0 and nj < 0 :
            istart = 0
            ni = nbi_g
            jstart = 0
            nj = nbj_g
        #
        if istart > nbi_g or istart+ni > nbi_g :
            ERROR("The sub-domain is not possible in longitude : Total size = "+str(nbi_g)+" istart = "+str(istart)+" ni = "+str(ni))
            sys.exit()
        if jstart > nbj_g or jstart+nj > nbj_g :
            ERROR("The sub-domain is not possible in latitude : Total size = "+str(nbj_g)+" jstart = "+str(jstart)+" nj = "+str(nj))
            sys.exit()
        #
        # Extract grid
        #
        lon_full=np.tile(np.copy(geo.variables["lon"][istart:istart+ni]),(nj,1))
        griddesc['inilon'] = np.copy(geo.variables["lon"][0])
        lat_full=np.transpose(np.tile(np.copy(geo.variables["lat"][jstart:jstart+nj]),(ni,1)))
        griddesc['inilat'] = np.copy(geo.variables["lat"][0])
    else :
        ERROR("Unknown grid type")
        sys.exit()
    #
    return griddesc, lon_full, lat_full
#
#
#
def getland (geo, ist, ni, jst, nj) :
    vn=list(v.name for v in geo.variables.values())
    if "LANDMASK" in vn :
        land=geo.variables["LANDMASK"][0,jst:jst+nj,ist:ist+ni]
    elif "elevation" in vn :
        land=geo.variables["elevation"][jst:jst+nj,ist:ist+ni]
        if "missing_value" in geo.variables["elevation"].ncattrs() :
            missing = geo.variables["elevation"].missing_value
        elif "_FillValue" in geo.variables["elevation"].ncattrs() :
            missing = geo.variables["elevation"]._FillValue
        else :
            ERROR("Could not find a flag for ocean points (i.e. missing)")
            sys.exit()
        #
        # Complete the mask with the missing flag
        #
        if missing < 0 :
            land[land > missing] = 1.0
            land[land <= missing] = 0.0
        else :
            land[land < missing] = 1.0
            land[land >= missing] = 0.0
    else :
        ERROR("We could not find a variable for computing the land mask")
        sys.exit()
    return land
#
#
#
class ModelGrid :
    def __init__(self, istart, ni, jstart, nj) :
        #
        if ni < 2 or nj < 2 :
            INFO("Found impossibleDomain too small for ModelGrid to work : "+str(ni)+str(nj))
            ERROR("Domain too small")
            sys.exit()
        #
        filename=config.get("OverAll", "ModelGridFile")
        geo=Dataset(filename,'r')
        #
        # Get the coordinates from the grid file.
        #
        griddesc, self.lon_full, self.lat_full = getcoordinates(geo, istart, ni, jstart, nj)
        #
        # Extract the land/ea mask.
        #
        self.land = getland(geo, istart, ni, jstart, nj)
        ind=np.reshape(np.array(range(self.land.shape[0]*self.land.shape[1])),self.land.shape)
        indFi=[]
        indFj=[]
        for j in range(nj) :
            for i in range(ni) :
                indFi.append([i+1])
                indFj.append([j+1])
        #
        # Define some grid variables.
        #
        self.res_lon = np.mean(np.gradient(self.lon_full, axis=1))
        self.res_lat = np.mean(np.gradient(self.lat_full, axis=0))
        self.nj,self.ni = self.lon_full.shape
        #
        # Gather all land points
        #
        self.nbland, self.coordll,self.neighbours,self.indP,indF = gatherland(self.lon_full,self.lat_full,self.land,ind,\
                                                                              np.reshape(indFi[:],self.lon_full.shape),\
                                                                              np.reshape(indFj[:],self.lon_full.shape))
        INFO("Shape of region :"+str(ni)+" x "+str(nj)+" with nbland="+str(self.nbland))
        if self.nbland < 1 or self.nbland > self.nj*self.ni :
            INFO("Found impossible number of land points : "+str(self.nbland))
            ERROR("Problem with number of land points")
            sys.exit()
        #
        # Gather some of the variables from the full grid.
        #
        self.contfrac=self.landgather(self.land)
        #
        for ip in self.neighbours[0][:] :
            if ip >= 0 :
                DEBUG("Neighbour : "+str(ip)+" P index : "+str(self.indP[ip])+" F index : "+str(indF[ip][:]))
            else :
                DEBUG("Neighbour : "+str(ip)+" Not Land")
        #
        # Define projection
        #
        if griddesc['type'] == "RegXY"  :
            proj=PS.LambertC(griddesc['dx'], griddesc['known_lon'], griddesc['known_lat'], griddesc['truelat1'], griddesc['truelat2'], griddesc['stdlon'])
        elif griddesc['type'] == "RegLonLat"  :
            proj=PS.RegLonLat(self.res_lon, self.res_lat, griddesc['inilon'], griddesc['inilat'])
        else :
           ERROR("Unknown grid type")
           sys.exit()
        #
        # Get the bounds of the grid boxes and region.
        #
        self.polyll, self.polylist, self.area, self.box_land = corners(indF, proj, istart, jstart, self.lon_full, self.lat_full)
        #
        self.box=[[np.min(self.lon_full),np.max(self.lon_full)],[np.min(self.lat_full),np.max(self.lat_full)]]
        #
        geo.close()
#
# Function to scatter variables onto the full grid.
#
    def landscatter(self, var, order='C') :
        #
        # Some arrays can be in FORTRAN convention and thus the dimension to be scattered is not the last but the first.
        #
        dims = var.shape
        transpose=False
        if len(dims) == 1 :
            if dims[0] == self.nbland :
                newdims = (self.nj,self.ni)
            else :
                ERROR("The attempt to scatter cannot succeed as the last dimension is not the number of land points")
                sys.exit()
        else :
            if order == 'C' :
                if dims[-1] == self.nbland :
                    newdims = dims[:-1]+(self.nj,self.ni)
                else :
                    ERROR("The attempt to scatter cannot succeed as the last dimension is not the number of land points")
                    sys.exit()
            elif order == 'F' :
                if dims[0] == self.nbland :
                    transpose = True
                    newdims = (dims[::-1])[:-1]+(self.nj,self.ni)
                else :
                    ERROR("The attempt to scatter cannot succeed as the last dimension is not the number of land points")
                    sys.exit()
        #
        # Actual work
        #
        varscat = np.zeros(newdims, dtype=var.dtype)
        if len(dims) == 1 :
            if str(var.dtype).find('float') > -1 :
                varscat[:,:] = RPP.FillValue
            else :
                varscat[:,:] = RPP.IntFillValue
            for i in range(self.nbland) :
                varscat[self.indP[i][0],self.indP[i][1]]=var[i]
        elif len(dims) == 2 :
            if str(var.dtype).find('float') > -1 :
                varscat[:,:,:] = RPP.FillValue
            else :
                varscat[:,:,:] = RPP.IntFillValue
            if transpose :
                for i in range(self.nbland) :
                    varscat[:,self.indP[i][0],self.indP[i][1]]=np.transpose(var)[:,i]
            else :
                for i in range(self.nbland) :
                    varscat[:,self.indP[i][0],self.indP[i][1]]=var[:,i]
        elif len(dims) == 3 :
            if str(var.dtype).find('float') > -1 :
                varscat[:,:,:,:] = RPP.FillValue
            else :
                varscat[:,:,:,:] = RPP.IntFillValue
            if transpose :
                for i in range(self.nbland) :
                    varscat[:,:,self.indP[i][0],self.indP[i][1]]=np.transpose(var)[:,:,i]
            else :
                for i in range(self.nbland) :
                    varscat[:,:,self.indP[i][0],self.indP[i][1]]=var[:,:,i]
        elif len(dims) == 4 :
            if str(var.dtype).find('float') > -1 :
                varscat[:,:,:,:,:] = RPP.FillValue
            else :
                varscat[:,:,:,:,:] = RPP.IntFillValue
            if transpose :
                for i in range(self.nbland) :
                    varscat[:,:,:,self.indP[i][0],self.indP[i][1]]=np.transpose(var)[:,:,:,i]
            else :
                for i in range(self.nbland) :
                    varscat[:,:,:,self.indP[i][0],self.indP[i][1]]=var[:,:,:,i]
        else :
            ERROR("Unforessen rank of the variable to be scattered")
            sys.exit()

        return varscat
#
# Function to gather land points
#
    def landgather(self, var) :
        dims = var.shape
        if dims[-1] == self.ni and dims[-2] == self.nj :
            newdims = dims[:-2]+(self.nbland,)
        else :
            ERROR("The attempt to gather cannot succeed as the last dimensions do not correspond to the grid size")
            sys.exit()
        #
        if len(dims) == 1 :
            ERROR("Unforessen rank of the variable to be gathered")
            sys.exit()
        elif len(dims) == 2 :
            vargat = np.zeros(newdims, dtype=var.dtype)
            for i in range(self.nbland) :
                vargat[i] = var[self.indP[i][0],self.indP[i][1]]
        elif len(dims) == 3 :
            vargat = np.zeros(newdims, dtype=var.dtype)
            for i in range(self.nbland) :
                vargat[:,i] = var[:,self.indP[i][0],self.indP[i][1]]
        elif len(dims) == 4 :
            vargat = np.zeros(newdims, dtype=var.dtype)
            for i in range(self.nbland) :
                vargat[:,:,i] = var[:,:,self.indP[i][0],self.indP[i][1]]
        else :
            ERROR("Gathering for this rank not yet possible")
            sys.exit()

        return vargat
#
# A class for extracting the basic information of the full grid
#
class GlobalGrid :
    def __init__(self) :

        lonrange=np.array(config.get("OverAll", "WEST_EAST").split(","),dtype=float)
        latrange=np.array(config.get("OverAll", "SOUTH_NORTH").split(","),dtype=float)
        
        filename=config.get("OverAll", "ModelGridFile")
        INFO("Opening :"+filename)
        geo=Dataset(filename,'r')
        #
        griddesc, lon_full, lat_full = getcoordinates(geo, 0, -1, 0, -1)
        #
        # Default behaviour if global is requested in the configuration file.
        #
        if np.abs(min(lonrange)) == np.abs(max(lonrange)) and np.abs(max(lonrange)-180) < epsilon and \
           np.abs(min(latrange)) == np.abs(max(latrange)) and np.abs(max(latrange)-90) < epsilon :
            self.nj, self.ni = lon_full.shape
            self.jgstart = 0
            self.igstart = 0
        else :
            dist=np.sqrt((lon_full-min(lonrange))**2 + (lat_full-min(latrange))**2)
            jmin,imin=np.unravel_index(np.argmin(dist, axis=None), dist.shape)
            dist=np.sqrt((lon_full-max(lonrange))**2 + (lat_full-max(latrange))**2)
            jmax,imax=np.unravel_index(np.argmin(dist, axis=None), dist.shape)
            self.nj = jmax-jmin+1
            self.jgstart = jmin
            self.ni = imax-imin+1
            self.igstart = imin
            print("Shape : ", lon_full.shape, "N = ", self.nj, self.ni, " Start :", self.jgstart, self.igstart)
            
        self.land = getland(geo, self.igstart, self.ni, self.jgstart, self.nj)
        self.nbland = int(np.sum(self.land))
        #
        geo.close()