From 236eef8c5f9709de2aa862d696ea6fe4e39ab062 Mon Sep 17 00:00:00 2001
From: Anthony <anthony.schrapffer@polytechnique.fr>
Date: Tue, 28 Apr 2020 14:27:52 +0200
Subject: [PATCH] Optimization reading Weights.nc

---
 RPPtools.py | 68 +++++++++++++++++++++++++++++++++--------------------
 1 file changed, 43 insertions(+), 25 deletions(-)

diff --git a/RPPtools.py b/RPPtools.py
index 63ff806..d5678e1 100644
--- a/RPPtools.py
+++ b/RPPtools.py
@@ -11,6 +11,7 @@ vinttyp=np.int32
 from netCDF4 import Dataset
 import pickle
 from spherical_geometry import vector
+import time
 #
 # Configuration
 #
@@ -174,12 +175,17 @@ class compweights :
         else :
             #
             INFO("Reading weights from "+wfile)
+            t = time.time() 
             self.fetchweight(wfile, part, modelgrid, hydrogrid)
             self.hpts = [l.shape[0] for l in self.lon]
             self.maxhpts = max(self.hpts)
+            t1 = time.time()
+            INFO("Proc {0}, time : {1}".format(part.rank,round(t1-t,2 ))) 
+            """
             for icell in range(len(self.lon)) :
                 INFO(str(icell)+" Sum of overlap "+str(np.sum(self.area[icell])/modelgrid.area[icell])+
                      " Nb of Hydro grid overlaping : "+str(self.hpts[icell]))
+            """
             #
             ##self.printtest([15,19], part, modelgrid)
             #
@@ -189,10 +195,7 @@ class compweights :
     #
     def fetchweight(self, weightfile, part, modelgrid, hydrogrid) :
         #
-        self.area = []
-        self.lon = []
-        self.lat = []
-        self.hpts = []
+        debug_time = False
         #
         innf=Dataset(weightfile, 'r')
         #
@@ -200,34 +203,44 @@ class compweights :
         indexfill = innf.variables["global_index"]._FillValue
         landind = part.toglobal_index(modelgrid.indP)
         #
-        for pts in landind :
-            i = self.findinfile(pts, gindex)
-            maxhpts = np.array(np.where(innf.variables["hydro_index"][0,:,i] >= 0)).shape[1]
-            self.area.append(innf.variables["hydro_area"][0:maxhpts,i])
-            self.lon.append(innf.variables["hydro_lon"][0:maxhpts,i])
-            self.lat.append(innf.variables["hydro_lat"][0:maxhpts,i])
-            self.hpts.append(maxhpts)
+        t = time.time()
+        locinfile = [int(self.findinfile(pts, gindex)) for pts in landind]
+        var = innf.variables["hydro_index"]
+        self.hpts = [int(np.array(np.where(var[0,:,i]>= 0)).shape[1]) for i in locinfile ]
+        t1 = time.time()
+
+        if debug_time:INFO("Proc {0}, step 1, time : {1}".format(part.rank,round(t1-t,2 )))
+        #
+        var = innf.variables["hydro_area"]
+        self.area = [var[0:maxhpts,i] for i,maxhpts in zip(locinfile, self.hpts)]
+        var = innf.variables["hydro_lon"]
+        self.lon = [var[0:maxhpts,i] for i,maxhpts in zip(locinfile, self.hpts)]
+        var = innf.variables["hydro_lat"]
+        self.lat = [var[0:maxhpts,i] for i,maxhpts in zip(locinfile, self.hpts)]
+        t2 = time.time()
+        if debug_time: INFO("Proc {0}, step 2, time : {1}".format(part.rank,round(t2-t1,2 )))
         #
         innf.close()
         # Find the indicis of the points on the HydroGrid using the coordinates.
-        self.index = self.findhindex(self.lon, self.lat, hydrogrid)
+        self.index = self.findhindex(hydrogrid)
+        t3 = time.time()
+        if debug_time: INFO("Proc {0}, step 3, time : {1}".format(part.rank,round(t3-t2,2 )))
         #
         return
     #
     # Function to find the indicis of the points on the hydrological grid for the current
     # partition of the domain.
     #
-    def findhindex(self, lon, lat, hydrogrid) :
-        ind=[]
-        for im in range(len(lon)) :
-            maxhpts = lon[im].shape[0]
-            hind=np.zeros((2,maxhpts))
-            for ih in range(maxhpts) :
-                dist=np.sqrt((hydrogrid.lon-lon[im][ih])**2+(hydrogrid.lat-lat[im][ih])**2)
-                hind[0,ih],hind[1,ih] = np.unravel_index(np.argmin(dist), dist.shape)
-            #
-            ind.append(np.array(hind, dtype=np.int32))
+    def findhindex(self,  hydrogrid): 
+        ind=[self.subhindex(im, hydrogrid) for im in range(len(self.lon))]
         return ind
+
+    def subhindex(self, im, hydrogrid ):
+        lon_loc = [np.where(hydrogrid.lon[0,:] == self.lon[im][ih])[0][0] for ih in range(self.hpts[im])]
+        lat_loc = [np.where(hydrogrid.lat[:,0] == self.lat[im][ih])[0][0] for ih in range(self.hpts[im])]
+        pts_loc = [[a, b] for a,b in zip(lat_loc,lon_loc)]
+        hind = np.transpose(pts_loc).astype(np.int32)
+        return hind
     #
     # TEST function to be deleted later
     #
@@ -250,9 +263,14 @@ class compweights :
     #
     def findinfile(self, pts, indlist) :
         ib = IntFillValue
-        for i in range(indlist.shape[1]) :
-                if all(indlist[:,i] == pts) :
-                    ib = i
+        A = np.array([indlist[0,i] for i in range(indlist.shape[1])])
+        B = np.array([indlist[1,i] for i in range(indlist.shape[1])])
+        
+        C = np.where((A==pts[0])*(B==pts[1]))[0]
+        if len(C) ==1:
+            ib = C[0]
+        else:
+            INFO("ERREUR C {0}".format(len(C)))
         return ib
     #
     # Function to dump all the weights into a netCDF file
-- 
GitLab