Commit 5534b76e authored by Clément Haëck's avatar Clément Haëck
Browse files

Faster hist box computation

parent 39b19c22
"""Compute histograms of variables for ostia data"""
from os import path
from dask.diagnostics import ProgressBar
import numpy as np
import xarray as xr
......@@ -61,23 +61,26 @@ def main():
variables_gc = ['CHL']
boxes = lib.box.IndexBox.ReadGridText(lib.box.IndexBox.GetGridFile(args))
for i, b in enumerate(boxes):
print(str(b) + ' | {}/{}'.format(i+1, len(boxes)))
args['boxes'] = boxes
args['grid'] = [b.idx_str for b in boxes]
args['zone'] = args['grid_file']
args['zone'] = path.join(args['grid_file'], b.idx_str)
h_st = prepare_ds(st, args, variables_hi)
h_gc = prepare_ds(gc, args, variables_gc)
h_st = prepare_ds(st, args, variables_hi)
h_gc = prepare_ds(gc, args, variables_gc)
compute_hist(st, h_st, variables_hi)
compute_hist(gc, h_gc, variables_gc)
compute_hist(st, h_st, variables_hi, args)
compute_hist(gc, h_gc, variables_gc, args)
# Merge all histograms in one dataset and write to disk
ds = xr.concat([h_st, h_gc], 'variable')
# Merge all histograms in one dataset and write to disk
ds = xr.concat([h_st, h_gc], 'variable')
# Write to disk
args['Y'] = args['year']
ofile = lib.data.hists.get_filename(args)
lib.check_output_dir(ofile, file=True)
ds = ds.assign_coords(zone=[idx for idx in args['grid']])
# Write to disk
args['Y'] = args['year']
ofile = lib.data.hists.get_filename(args)
lib.check_output_dir(ofile, file=True)
with ProgressBar():
ds.to_netcdf(ofile, encoding={v: {'zlib': True} for v in ds.data_vars})
return ds
......@@ -85,8 +88,7 @@ def main():
def prepare_ds(ds, args, variables):
"""Set dimensions and coordinates for parameters and histograms."""
args_dims = ['threshold', 'zone',
'scale', 'number', 'coef']
args_dims = ['threshold', 'scale', 'number', 'coef']
h = xr.Dataset(
coords=dict(**{c: [args[c]] for c in args_dims},
variable=variables,
......@@ -104,26 +106,32 @@ def prepare_ds(ds, args, variables):
# Add hist variable
hist_dims = args_dims + ['variable', 'mask', 'time']
hist = xr.DataArray(
np.zeros((*[h[c].size for c in hist_dims], args['nbins'])),
dims=hist_dims + ['nbins'])
np.zeros((*[h[c].size for c in hist_dims],
args['nbins'], len(args['grid']))),
dims=hist_dims + ['nbins', 'zone'])
h['hist'] = hist
return h
def compute_hist(ds, hist, variables):
def compute_hist(ds, hist, variables, args):
"""Set up an apply_ufunc."""
ds['frt'] = ds['HI'] > hist['threshold'][0].values
slices = [b.get_isel(ds) for b in args['boxes']]
slices = [(slice(None), slc['lat'], slc['lon']) for slc in slices]
for var in variables:
hists = xr.apply_ufunc(
u_comp_hist, ds['frt'], ds[var],
input_core_dims=[['lat', 'lon'],
['lat', 'lon']],
output_core_dims=[['nbins'], ['nbins']],
output_core_dims=[['zone', 'nbins'], ['zone', 'nbins']],
dask='parallelized',
dask_gufunc_kwargs={'output_sizes': {'nbins': hist.nbins.size}},
kwargs=dict(range=VAR_RANGE[var],
dask_gufunc_kwargs={'output_sizes': {'nbins': hist.nbins.size,
'zone': hist.zone.size}},
kwargs=dict(slices=slices,
range=VAR_RANGE[var],
bins=hist['nbins'].size)
)
for h, mask in zip(hists, ['frt', 'bkg']):
......@@ -136,15 +144,25 @@ def compute_hist(ds, hist, variables):
return hist
def u_comp_hist(mask, values, **kwargs):
def u_comp_hist(mask, values, slices=None, **kwargs):
"""Compute histogram."""
def comp(values, mask):
sel = values[mask] # Apply mask
val = np.delete(sel, np.where(~np.isfinite(sel))) # Remove NaN
h, _ = np.histogram(val, **kwargs) # Compute hist
return np.expand_dims(h, 0) # Add dimension for vectorization
return h # Add dimension for vectorization
res_frt = []
res_bkg = []
for slc in slices:
res_frt.append(comp(values[slc], mask[slc]))
res_bkg.append(comp(values[slc], ~mask[slc]))
res_frt = np.expand_dims(np.array(res_frt), 0)
res_bkg = np.expand_dims(np.array(res_bkg), 0)
print(res_frt.shape)
return comp(values, mask), comp(values, ~mask)
return res_frt, res_bkg
if __name__ == '__main__':
......
......@@ -10,20 +10,19 @@ fixes = dict(
coef=0,
scale=30,
threshold=15.0,
zone='box1'
zone='boxgrid_5.0_5.0'
)
variable = ['chl_ocx', 'CHL'][fixes['number'] - 1]
# variable = 'dtom'
ds = lib.data.hists.get_data(fixes=fixes)
ds = ds.sel(variable=variable, zone='0503').squeeze().load()
hist = ds['hist'].resample(time='8D').sum().to_dataset()
hist['bins'] = ds['bins']
ds = hist
ds = hist.load()
# ds = ds.squeeze()
ds = ds.sel(variable=variable).squeeze()
ds = lib.data.hists.normalize_hist(ds)
med = lib.data.hists.get_percentile(ds, .5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment