From d05d568884564a407eba8195f34d70764469d685 Mon Sep 17 00:00:00 2001
From: Anthony <anthony.schrapffer@polytechnique.fr>
Date: Thu, 4 Jun 2020 15:36:33 +0200
Subject: [PATCH] Correct the outflows and inflows in the halos

---
 F90subroutines/routing_interface.f90 | 81 ++++++++++++++++++----------
 Interface.py                         | 48 ++++++++++++++---
 Partition.py                         | 69 +++++++++++++++---------
 RoutingPreProc.py                    | 10 ++--
 4 files changed, 145 insertions(+), 63 deletions(-)

diff --git a/F90subroutines/routing_interface.f90 b/F90subroutines/routing_interface.f90
index 5bbe204..a163c64 100644
--- a/F90subroutines/routing_interface.f90
+++ b/F90subroutines/routing_interface.f90
@@ -161,7 +161,7 @@ SUBROUTINE findbasins(nbpt, nb_htu, nbv, ijdimmax, nbi, nbj, trip_bx, basin_bx,
   !diaglalo(1,:) = (/ 39.6791, 2.6687 /)
   !
   WRITE(numout,*) "Memory Mgt findbasin : nbvmax, nb_htu, nbv = ", nbvmax, nb_htu, nbv
-  
+
   DO ib=1,nbpt
      CALL routing_reg_findbasins(nb_htu, nbv, ib, nbi(ib), nbj(ib), trip_bx(ib,:,:), &
           & basin_bx(ib,:,:), fac_bx(ib,:,:), hierarchy_bx(ib,:,:), &
@@ -169,7 +169,7 @@ SUBROUTINE findbasins(nbpt, nb_htu, nbv, ijdimmax, nbi, nbj, trip_bx, basin_bx,
           & nb_basin(ib), basin_inbxid(ib,:), basin_outlet(ib,:,:), basin_outtp(ib,:), basin_sz(ib,:), basin_bxout(ib,:), &
           & basin_bbout(ib,:), basin_pts(ib,:,:,:), basin_lshead(ib,:), coast_pts(ib,:), lontmp(ib,:,:), lattmp(ib,:,:))
   ENDDO
-  
+
 END SUBROUTINE findbasins
 
 SUBROUTINE globalize(nbpt, nb_htu, nbv, ijdimmax, area_bx, lon_bx, lat_bx, trip_bx, hierarchy_bx, orog_bx, floodp_bx, &
@@ -235,7 +235,7 @@ SUBROUTINE globalize(nbpt, nb_htu, nbv, ijdimmax, area_bx, lon_bx, lat_bx, trip_
   WRITE(numout,*) "Memory Mgt globalize : nbvmax, ijdimmax, nbv, nwbas, nb_htu = ", nbvmax, ijdimmax, nbv, nwbas, nb_htu
   !!
   DO ib=1,nbpt
-     CALL routing_reg_globalize(nbpt, nb_htu, nbv, ib, ijdimmax, neighbours, area_bx(ib,:,:),& 
+     CALL routing_reg_globalize(nbpt, nb_htu, nbv, ib, ijdimmax, neighbours, area_bx(ib,:,:),&
           & lon_bx(ib,:,:), lat_bx(ib,:,:), trip_bx(ib,:,:), &
           & hierarchy_bx(ib,:,:), orog_bx(ib,:,:), floodp_bx(ib,:,:), fac_bx(ib,:,:), &
           & topoind_bx(ib,:,:), min_topoind, nb_basin(ib), basin_inbxid(ib,:), basin_outlet(ib,:,:), basin_outtp(ib,:), &
@@ -284,7 +284,7 @@ SUBROUTINE linkup(nbpt, ijdimmax, nwbas, inflowmax, basin_count, basin_area, bas
   !
 
   WRITE(numout,*) "Memory Mgt Linkup : nbvmax, ijdimmax, nwbas, inflowmax = ", nbvmax, ijdimmax, nwbas, inflowmax
-  
+
   CALL routing_reg_linkup(nbpt, neighbours, nwbas, ijdimmax, inflowmax, basin_count, basin_area, basin_id, basin_flowdir, &
        & basin_lshead, basin_hierarchy, basin_fac, diaglalo, outflow_grid, outflow_basin, inflow_number, inflow_grid, &
        & inflow_basin, nbcoastal, coastal_basin, invented_basins)
@@ -595,7 +595,7 @@ SUBROUTINE killbas(nbpt, inflowmax, nbasmax, nwbas, ops, tokill, totakeover, num
   INTEGER(i_std), DIMENSION(nbpt,nwbas), INTENT(inout) :: outflow_grid !! Type of outflow on the grid box (unitless)
   INTEGER(i_std), DIMENSION(nbpt,nwbas), INTENT(inout) :: outflow_basin !!
   !
-  ! 
+  !
   !
   INTEGER(i_std), DIMENSION(nbpt,nwbas), INTENT(inout)          :: inflow_number !!
   INTEGER(i_std), DIMENSION(nbpt,nwbas,inflowmax), INTENT(inout) :: inflow_basin !!
@@ -715,8 +715,43 @@ SUBROUTINE checkrouting(nbpt, nwbas, outflow_grid, outflow_basin, basin_count)
 
 END SUBROUTINE checkrouting
 
+SUBROUTINE correct_outflows(nbpt, nwbas,nbhalo, outflow_grid, outflow_basin, &
+                &basin_count, hg, hb, halopts)
+  !
+  USE ioipsl
+  USE grid
+  USE routing_tools
+  USE routing_reg
+  !
+  !! INPUT VARIABLES
+  INTEGER(i_std), INTENT (in) :: nbpt !! Domain size (unitless)
+  INTEGER(i_std), INTENT (in) :: nwbas !!
+  INTEGER(i_std), INTENT (in) :: nbhalo !!
+  INTEGER(i_std), DIMENSION(nbhalo) :: halopts
+  INTEGER(i_std), DIMENSION(nbpt,nwbas), INTENT(inout)       :: outflow_grid !! Type of outflow on the grid box (unitless)
+  INTEGER(i_std), DIMENSION(nbpt,nwbas), INTENT(inout)       :: outflow_basin !!
+  INTEGER(i_std), DIMENSION(nbpt), INTENT(in)       :: basin_count
+  INTEGER(i_std), DIMENSION(nbpt,nwbas), INTENT(in)       :: hg !! Type of outflow on the grid box (unitless)
+  INTEGER(i_std), DIMENSION(nbpt,nwbas), INTENT(in)       :: hb !!
+  !! LOCAL
+  INTEGER(i_std) ::ih, ig, ib, nbas
+
+! AJOUT BASIN8COUNT
+  DO ih=1,nbhalo
+    ig = halopts(ih)
+    nbas = basin_count(ig)
+    DO ib=1,nbas
+      outflow_grid(ig,ib) = hg(ig,ib)
+      outflow_basin(ig,ib) = hb(ig,ib)
+    END DO
+  END DO
 
-SUBROUTINE check_inflows(nbpt, nwbas, inflowmax, outflow_grid, outflow_basin, basin_count, inflow_number, inflow_grid, inflow_basin)
+END SUBROUTINE correct_outflows
+
+
+SUBROUTINE correct_inflows(nbpt, nwbas, inflowmax, outflow_grid,&
+            & outflow_basin, basin_count, inflow_number, inflow_grid, &
+            & inflow_basin)
 
   !
   USE ioipsl
@@ -732,14 +767,18 @@ SUBROUTINE check_inflows(nbpt, nwbas, inflowmax, outflow_grid, outflow_basin, ba
   INTEGER(i_std), DIMENSION(nbpt,nwbas), INTENT(in)       :: outflow_grid !! Type of outflow on the grid box (unitless)
   INTEGER(i_std), DIMENSION(nbpt,nwbas), INTENT(in)       :: outflow_basin !!
   INTEGER(i_std), DIMENSION(nbpt),       INTENT(in)       :: basin_count !!
-  INTEGER(i_std), DIMENSION(nbpt,nwbas), INTENT(in)          :: inflow_number
-  INTEGER(i_std), DIMENSION(nbpt,nwbas,inflowmax), INTENT(in) :: inflow_basin
-  INTEGER(i_std), DIMENSION(nbpt,nwbas,inflowmax), INTENT(in) :: inflow_grid
+  INTEGER(i_std), DIMENSION(nbpt,nwbas), INTENT(inout)          :: inflow_number
+  INTEGER(i_std), DIMENSION(nbpt,nwbas,inflowmax), INTENT(inout) :: inflow_basin
+  INTEGER(i_std), DIMENSION(nbpt,nwbas,inflowmax), INTENT(inout) :: inflow_grid
 
   ! LOCAL
   INTEGER(i_std) :: ig, nbas, ib, og, ob, inf, found
 
-  WRITE(numout,*) "Checking if the HTUs are in the inflows of their outflow"  
+  WRITE(numout,*) "Checking if the HTUs are in the inflows of their outflow"
+
+  inflow_number(:,:) = 0
+  inflow_basin(:,:,:)=0
+  inflow_grid(:,:,:)=0
 
   DO ig=1,nbpt
     nbas = basin_count(ig)
@@ -747,28 +786,16 @@ SUBROUTINE check_inflows(nbpt, nwbas, inflowmax, outflow_grid, outflow_basin, ba
       og = outflow_grid(ig,ib)
       ob = outflow_basin(ig,ib)
       if (og .GT. 0) THEN
-        IF (inflow_number(og,ob) .EQ. 0) THEN
-          WRITE(numout,*) ig, ib, "Error : outflow has no inflow"
-          WRITE(numout,*) og, ob
-        ELSE
-          found = 0
-          DO inf = 1,inflow_number(og,ob)
-            IF ((inflow_grid(og,ob,inf) .EQ. ig) .AND. (inflow_basin(og,ob,inf) .EQ. ib)) THEN
-              found = 1
-            END IF
-          END DO
-          IF (found .EQ. 0) THEN
-            WRITE(numout,*) ig,ib, "Error, not found the inflows of its outflow"
-            WRITE(numout,*) og, ob
-          END IF
-        END IF
-      END IF 
+        inflow_number(og,ob) = inflow_number(og,ob) +1
+        inflow_basin(og,ob,inflow_number(og,ob)) = ib
+        inflow_grid(og,ob,inflow_number(og,ob)) = ig
+      END IF
     END DO
   END DO
 
 
 
-END SUBROUTINE check_inflows
+END SUBROUTINE correct_inflows
 
 SUBROUTINE checkfetch(nbpt, nwbas, fetch_basin, outflow_grid, outflow_basin, basin_count)
   !
diff --git a/Interface.py b/Interface.py
index 4d32f1c..bbb00f4 100644
--- a/Interface.py
+++ b/Interface.py
@@ -201,6 +201,10 @@ def finalfetch(part, routing_area, basin_count, route_togrid, route_tobasin, fet
         xtmp = np.hstack(part.comm.allgather(outflow_uparea[np.where(outflow_uparea > 0.0)]))
         # Precision in m^2 of the upstream areas when sorting.
         sorted_outareas = (np.unique(np.rint(np.array(xtmp)/prec))*prec)[::-1]
+        if sorted_outareas.shape[0]<largest_pos:
+           s = sorted_outareas[:]
+           sorted_outareas = np.zeros(largest_pos, dtype=np.float32, order='F')
+           sorted_outareas[:s.shape[0]] = s[:]
         # If mono-proc no need to iterate as fetch produces the full result.
         if part.size == 1 :
             maxdiff_sorted = 0.0
@@ -406,7 +410,7 @@ class HydroSuper :
             if part.size == 1 :
                 maxdiff_sorted = 0.0
             else :
-                maxdiff_sorted = np.max(np.abs(sorted_outareas[0:largest_pos]-old_sorted[0:l]))
+                maxdiff_sorted = np.max(np.abs(sorted_outareas[0:l]-old_sorted[0:l]))
                 old_sorted[:l] = sorted_outareas[0:largest_pos]
             iter_count += 1
 
@@ -437,11 +441,43 @@ class HydroSuper :
 
         return
 
-    def check_inflows(self):
+    def correct_outflows(self, part):
+        # Global index of the proc domain
+        nbpt_loc = np.zeros((self.nbpt,1)).astype(np.int32)
+        nbpt_loc[:,0] = np.arange(1, self.nbpt+1)
+        nbpt_glo = part.l2glandindex(nbpt_loc)
+        # Halo points
+        fhalo = np.array([pt+1 for pt in range(self.nbpt) if pt not in part.landcorelist], order = "F")
+        #
+        # Outflow grid in global index and send to halo
+        hg = part.l2glandindex(self.outflow_grid)
+        part.landsendtohalo(hg, order='F')
+        # Convert to local index
+        outflows = np.unique(hg)
+        outflows_out = [a for a in outflows if (a not in nbpt_glo and a>0)]
+        for a in outflows_out:
+          hg[hg == a] = 0
+        for i, b in enumerate(nbpt_glo):
+          hg[hg == b] = nbpt_loc[i]
+        # Send Outflow basin to the halo and adapt it
+        hb = np.copy(self.outflow_basin)
+        part.landsendtohalo(hb, order='F')
+        hb[hg <= 0] = 999999999
+        for ig in range(self.nbpt):
+            hb[ig,self.basin_count[ig]:] = 0
+        #
+        # Correct the routing graph in the halo
+        routing_interface.correct_outflows(nbpt = self.nbpt, nwbas = self.nwbas, nbhalo = fhalo.shape[0], \
+                    outflow_grid = self.outflow_grid, outflow_basin = self.outflow_basin, \
+                    basin_count = self.basin_count, hg = hg, hb = hb, halopts = fhalo)
+        #
+        # Correct the inflows
         nbxmax_tmp = self.inflow_grid.shape[2]
-        routing_interface.check_inflows(nbpt = self.nbpt, nwbas = self.nwbas, inflowmax = nbxmax_tmp, outflow_grid = self.outflow_grid, outflow_basin = self.outflow_basin, basin_count = self.basin_count, inflow_number = self.inflow_number, inflow_grid = self.inflow_grid, inflow_basin = self.inflow_basin)
-    #
-    #
+        routing_interface.correct_inflows(nbpt = self.nbpt, nwbas = self.nwbas, inflowmax = nbxmax_tmp, outflow_grid = self.outflow_grid, outflow_basin = self.outflow_basin, basin_count = self.basin_count, inflow_number = self.inflow_number, inflow_grid = self.inflow_grid, inflow_basin = self.inflow_basin)
+
+        return
+
+
     def killbas(self, tokill, totakeover, numops):
         ops = tokill.shape[1]
         #
@@ -483,7 +519,7 @@ class HydroSuper :
             outnf.createDimension('x', globalgrid.ni)
             outnf.createDimension('y', globalgrid.nj)
             outnf.createDimension('land', len(procgrid.area))
-            outnf.createDimension('htuext', self.nbhtuext)
+            outnf.createDimension('htuext', self.basin_id.shape[1])
             outnf.createDimension('htu', self.inflow_number.shape[1])
             outnf.createDimension('in',inflow_size )
             outnf.createDimension('bnd', nbcorners)
diff --git a/Partition.py b/Partition.py
index db43384..b262f07 100644
--- a/Partition.py
+++ b/Partition.py
@@ -1,5 +1,6 @@
 import numpy as np
 import sys
+from numba import jit
 #
 import getargs
 log_master, log_world = getargs.getLogger(__name__)
@@ -26,7 +27,7 @@ def halfpartition(partin, land) :
         new_dom = {"nbi":dom["nbi"]-hh[i],"nbj":dom["nbj"],  \
                    "istart":dom["istart"]+hh[i],"jstart":dom["jstart"]}
         dom["nbi"] = hh[i]
-        
+
     else :
         hh = [h for h in range(dom["nbj"])]
         nb = np.array([np.ma.sum(np.ma.filled(land[dom["jstart"]:dom["jstart"]+h,\
@@ -61,13 +62,13 @@ def fit_partition(partin, land):
             j0 = np.where(l2>0)[0][0]
             j1 = np.where(l2>0)[0][-1]
 
-            dom["jstart"] = j0 + dom["jstart"] 
+            dom["jstart"] = j0 + dom["jstart"]
             dom["nbj"] = j1-j0+1
-            dom["istart"] = i0 + dom["istart"] 
+            dom["istart"] = i0 + dom["istart"]
             dom["nbi"] = i1-i0+1
             dom["nbland"] = int(np.nansum(land[dom["jstart"]:dom["jstart"]+dom["nbj"],dom["istart"]:dom["istart"]+dom["nbi"]]))
 
-            partout.append(dom)        
+            partout.append(dom)
     return partout
 #
 #
@@ -240,7 +241,7 @@ def addhalo(nig, njg, part, procmap, nbh) :
             #
             if dom['ihstart'] != dom['istart'] :
                 sproc = procmap[jc,dom['ihstart']]
-                halosource[jc,dom['ihstart'],proc] = sproc 
+                halosource[jc,dom['ihstart'],proc] = sproc
                 coresend[jc,dom['ihstart'],proc] = proc
             if dom['ihstart']+dom['nbih']-1 != dom['istart']+dom['nbi']-1 :
                 sproc = procmap[jc,dom['ihstart']+dom['nbih']-1]
@@ -287,7 +288,7 @@ def haloreceivelist(halosource_map, rank) :
         halosource_g.append(np.where(halosource_map[:,:,rank]==ic))
     return receivefrom, halosource_g
 #
-# Get 1D indices for the land points 
+# Get 1D indices for the land points
 #
 def landindexmap(istart, ni, jstart, nj, land) :
     gnj,gni=land.shape
@@ -332,7 +333,31 @@ def toland_index(x,landmap) :
             xout.append([])
     return xout
 #
-# 
+@jit(nopython = True)
+def subl2glandindex2d(l2glandind, y, x, nl, nh):
+    LI = np.arange(nl, dtype=np.int32)
+    LJ = np.arange(nh, dtype=np.int32)
+    for i in LI:
+        for j in LJ:
+            # Land indices are in FORTRAN !!
+            if x[i,j] > 0:
+                y[i,j] = l2glandind[x[i,j]-1]+1
+            else:
+                y[i,j] = x[i,j]
+
+@jit(nopython = True)
+def subl2glandindex3d(l2glandind, y, x, nl, nh1, nh2):
+    LI = np.arange(nl, dtype=np.int32)
+    LJ = np.arange(nh1, dtype=np.int32)
+    LK = np.arange(nh2, dtype=np.int32)
+    for i in LI:
+        for j in LJ:
+            for k in LK:
+                if x[i,j,k] > 0:
+                    # Land indices are in Fortran
+                    y[i,j,k] = l2glandind[x[i,j,k]-1]+1
+                else:
+                    y[i,j,k] = x[i,j,k]
 #
 class partition :
     def __init__ (self, nig, njg, land, mpicomm, nbcore, halosz, rank, wunit="None") :
@@ -366,7 +391,7 @@ class partition :
         if self.size != len(part) :
             ERROR("There are too many processors for the size of the domain.")
             ERROR(str(self.size)+" processors. But partition could only achieve "+str(len(part))+" domain with land points")
-            sys.exit()        
+            sys.exit()
         #
         for i in range(self.size) :
             self.allistart.append(part[i]["istart"])
@@ -383,7 +408,7 @@ class partition :
         self.ihstart = part[rank]["ihstart"]
         self.jhstart = part[rank]["jhstart"]
         #
-        self.nbland, landindmap, self.l2glandind = landindexmap(self.ihstart, self.nih, self.jhstart, self.njh, land)        
+        self.nbland, landindmap, self.l2glandind = landindexmap(self.ihstart, self.nih, self.jhstart, self.njh, land)
         #
         if wunit != "None" :
             wunit.write("Offsets with halo (j-i) : "+str(self.jhstart)+"-"+str(self.ihstart)+'\n')
@@ -476,7 +501,7 @@ class partition :
         #
         # Working on matrices
         #
-        elif len(x.shape) == 2 :   
+        elif len(x.shape) == 2 :
             chksz = 100
             if order == 'C' :
                 chkdim = x.shape[0]
@@ -508,7 +533,7 @@ class partition :
         else :
             ERROR("Unforessen rank of the variable to be received in halo of land points")
             sys.exit()
-     
+
         return
     #
     # Gather all fields partitioned in the 2D domain onto the root proc
@@ -528,7 +553,7 @@ class partition :
             else :
                 ERROR("Unforessen rank of field to be gathered")
                 sys.exit()
-                
+
         else :
             xout = None
         #
@@ -571,31 +596,26 @@ class partition :
     #
     # Convert local index of land points to global index
     #
-    def l2glandindex(self, x) :
+    def l2glandindex(self, x, order = 'F') :
         if x.ndim == 2:
            nl,nh = x.shape
-           y = np.zeros(x.shape, dtype=x.dtype)
+           y = np.zeros(x.shape, dtype=x.dtype, order = order)
            if nl == self.nbland :
-               for i in range(nl) :
-                   for j in range(nh) :
-                       # Land indices are in FORTRAN !!
-                       y[i,j] = self.l2glandind[x[i,j]-1]+1
+               subl2glandindex2d(self.l2glandind, y, x, nl, nh)
            else :
                ERROR("The first dimension does not have the length of the number of land points")
                sys.exit()
         if x.ndim == 3:
            nl,nh1,nh2 = x.shape
-           y = np.zeros(x.shape,dtype = x.dtype)
+           y = np.zeros(x.shape,dtype = x.dtype, order = order)
            if nl == self.nbland:
-               for i in range(nl):
-                   for j in range(nh1):
-                       for k in range(nh2):
-                           # Land indices are in Fortran
-                           y[i,j,k] = self.l2glandind[x[i,j,k]-1]+1
+               subl2glandindex3d(self.l2glandind, y, x, nl, nh1, nh2)
            else:
                ERROR("The first dimension does not have the length of the number of land points")
                sys.exit()
         return y
+
+
     #
     # Set to zero all points in the core
     #
@@ -704,4 +724,3 @@ class partition :
             self.comm.send(np.pad(x, [0,maxlen-len(x)], mode='constant', constant_values=np.nan), dest=0)
         #
         return xout
-    
diff --git a/RoutingPreProc.py b/RoutingPreProc.py
index 2ae1564..b3062b8 100644
--- a/RoutingPreProc.py
+++ b/RoutingPreProc.py
@@ -96,23 +96,23 @@ hsuper.linkup(hydrodata)
 #
 
 del hoverlap
-gc.collect() 
+gc.collect()
 comm.Barrier()
-
+#
+hsuper.correct_outflows(part)
+#
 INFO("=================== Compute fetch ====================")
 t = time.time()
 hsuper.fetch(part)
 comm.Barrier()
 t1 = time.time()
-print("Time for fetch: {:0.2f} s.".format(t1-t)) 
+print("Time for fetch: {:0.2f} s.".format(t1-t))
 comm.Barrier()
 
 if DumpHydroSuper :
     INFO("Dumping HydroSuper")
     hsuper.dumpnetcdf(OutGraphFile.replace(".nc","_HydroSuper.nc"), gg, modelgrid, part)
 
-hsuper.check_inflows()
-
 INFO("=================== Truncate ====================")
 print("Rank : {0} - Basin_count Before Truncate : ".format(part.rank), hsuper.basin_count)
 hs = TR(hsuper, part, comm, modelgrid, numop = numop)
-- 
GitLab