Commit debcb618 authored by Clément Haëck's avatar Clément Haëck
Browse files

Readapt for multi threshold

parent 1c6f2eb7
......@@ -42,7 +42,7 @@ VAR_RANGE = dict(
)
VARS = ['CHL', 'sst']
MASKS = ['frt', 'bkg']
MASKS = ['low', 'mid', 'hi']
MASKS_VAR = ['mask_' + m for m in MASKS]
INFILE_PARAMS = ['threshold', 'scale', 'number', 'coef']
......@@ -68,8 +68,9 @@ def main(args):
m_next("Computing HI masks")
# Masks
ds['mask_frt'] = ds.HI > args['threshold']
ds['mask_bkg'] = ds.HI < args['threshold']
ds['mask_low'] = ds.HI < args['thr_lo']
ds['mask_mid'] = ds.HI > args['thr_lo'] * ds.HI < args['thr_hi']
ds['mask_hi'] = ds.HI > args['thr_hi']
ds = ds.drop_vars(['HI'])
m_next("Computing zones datasets")
......@@ -185,8 +186,8 @@ def get_bins(variable):
if __name__ == '__main__':
def add_args(parser):
# Frt/Bkg HI threshold
parser.add_argument('-threshold', type=float, default=5.)
parser.add_argument('-thr_lo', type=float, default=5.)
parser.add_argument('-thr_hi', type=float, default=10.)
# Zone·s to compute
parser.add_argument('-zones', type=str, default='INS')
......
......@@ -2,12 +2,21 @@
Separated in 3 zones: North of jet, South of jet, South of 32N.
GS3_N, GS3_I, GS3_S (North, Intermediate, South)
All results (for differents zones, masks and variable) are aggregated in one
output file. This ensures dask maximizes reuse of loaded data (for instance,
the same loaded bytes of HI are used to mask both variables).
This should be achievable by regrouping different delayed `to_netcdf` calls,
but this way is kind of simpler. It also generates less files (at the
cost of complexity of variables names).
It is possible to select only some of I, N, and S zones for computation,
that can be useful if they should have different different thresholds (note
that would necessitate some arrangements to lib.data.hists).
"""
import operator
import boost_histogram as bh
import dask_histogram as dh
import numpy as np
import xarray as xr
import lib
......@@ -18,75 +27,107 @@ import lib.data.ostia
import lib.data.p_frt_mask
import lib.data.SN_separation
import lib.zones
from lib import m_start, m_end, m_next
# Width of SST bins
SST_STEP = 1e-2
# Number of Chl bins
CHL_NBINS = 500
# Bins extent
VAR_RANGE = dict(
# move bounds by half a discretisation step
sst=[b - SST_STEP/2. for b in [-5., 40.]],
CHL=[5e-3, 20.]
)
VARS = ['CHL', 'sst']
MASKS = ['low', 'mid', 'hi']
MASKS_VAR = ['mask_' + m for m in MASKS]
INFILE_PARAMS = ['threshold', 'scale', 'number', 'coef']
def main(args):
m_start("Loading data")
ds = get_data(args)
m_next("Smoothing SN separation temperature")
# Smooth SN separation
ds['threshold'] = lib.data.SN_separation.smooth(ds, time_step=8)
# Slice at LAT_IS_SEPARATION
# Remove as much data as possible early on
ds, = slice_zone(args, ds)
m_next("Computing HI")
# Compute HI
ds['HI'] = lib.data.hi.apply_coef(ds, lib.data.hi.get_coef(args))
ds = ds.drop_vars(['S', 'V', 'B'])
m_next("Applying static masks")
# Apply masks: land (enlarged), total zone, min front proba
ds['HI'] = ds.HI.where(~ds.land_large * ds.total * ds.p_frt)
ds = ds.drop_vars(['land_large', 'total', 'p_frt'])
# To keep bins.right_edge attribute
xr.set_options(keep_attrs=True)
# Select N/I zone (mask using SST values)
if args['zone'] in 'NI':
op = dict(I=operator.gt, N=operator.lt)[args['zone']]
ds['HI'] = ds.HI.where(op(ds.sst, ds.threshold))
args['zone'] = 'GS3_{}'.format(args['zone'])
m_next("Computing HI masks")
# Masks
ds['mask_low'] = ds.HI < args['thr_lo']
ds['mask_mid'] = ds.HI > args['thr_lo'] * ds.HI < args['thr_hi']
ds['mask_hi'] = ds.HI > args['thr_hi']
ds = ds.drop_vars(['HI'])
m_next("Computing zones datasets")
# Datasets for each zone
zones = dict()
if 'S' in args['zones']:
zones['GS3_S'] = ds.sel(lat=slice(32, None))
# For N and I, slice, then restrict masks according to SST
for zone, op in zip('IN', [operator.gt, operator.lt]):
if zone not in args['zones']:
continue
up = ds.sel(lat=slice(None, 32))
for mask in MASKS_VAR:
up[mask] = up[mask] * op(ds.sst, ds.threshold)
zones['GS3_'+zone] = up
m_next("Setting up histogram computations")
hists_var = []
for var in VARS:
hists = []
bins = get_bins(var)
bins_name = 'bins_' + var
for zone, zone_ds in zones.items():
for m, mask in zip(MASKS, MASKS_VAR):
h = zone_ds[var].where(zone_ds[mask]).groupby('time').map(
hist_grp, shortcut=True, args=[bins, bins_name]
)
h = h.expand_dims(zone=[zone], mask=[m])
hists.append(h)
hist = xr.combine_by_coords(hists)
hist = hist.rename({var: 'hist_' + var})
hist = hist.assign_coords({bins_name: bins.edges[:-1]})
hist[bins_name].attrs['right_edge'] = VAR_RANGE[var][1]
hists_var.append(hist)
m_next("Merging results")
hist = xr.merge(hists_var)
hist.attrs['VARS'] = VARS
hist = hist.expand_dims({d: [args[d]] for d in INFILE_PARAMS})
# We have at most every pixel of an image in one bin,
# so approx 1000*1000 = 1e6 values. Uint32 stores up to 2**32~4e9.
encoding = dict(hist={'dtype': 'uint32', '_FillValue': 2**30-1})
encoding = {v: {'dtype': 'uint32', '_FillValue': 2**30-1}
for v in ['hist_' + v for v in VARS]}
# Masks
masks = dict(
low=ds.HI < args['thr_lo'],
hi=ds.HI > args['thr_hi'],
mid=(ds.HI > args['thr_lo']) * (ds.HI < args['thr_hi'])
)
for var in ['CHL', 'sst']:
print('variable: {}'.format(var))
hist, bins = prepare_hist(args, ds, var)
for mask, mask_arr in masks.items():
h = ds[var].where(mask_arr).groupby('time').map(
hist_grp, shortcut=True, args=[bins])
# .data gets the dask array (no eager loading normally)
# avoids troubles as h does not have a defined bins dim
# which clashes with hist dataset
hist['hist'] = xr.where(hist.mask == mask, h.data, hist['hist'])
args['var'] = var
ofile = lib.data.hists.get_filename(args)
lib.check_output_dir(ofile, file=True)
hist.to_netcdf(ofile, encoding=encoding)
m_next("Executing computations / Writing to disk")
ofile = lib.data.hists.get_filename(args, var=var, zone=zone)
lib.check_output_dir(ofile, file=True)
hist.to_netcdf(ofile, encoding=encoding)
m_end()
return hist
def hist_grp(da, bins: bh.axis.Axis, **kwargs):
def hist_grp(da, bins: dh.axis.Axis, bins_name, **kwargs):
"""Compute histogram of an array.
Flatten the array completely.
......@@ -94,7 +135,7 @@ def hist_grp(da, bins: bh.axis.Axis, **kwargs):
"""
h = dh.factory(da.data.ravel(), axes=[bins], **kwargs)
h, _ = h.to_dask_array()
return xr.DataArray(h, dims=['bins'], name='hist')
return xr.DataArray(h, dims=[bins_name])
def get_data(args):
......@@ -114,6 +155,8 @@ def get_data(args):
ds = xr.merge(data, join='inner')
ds = ds.drop_vars(['CHL_error', 'CHL_flags'])
# Check we don't have incompatibility between datasets
if any(ds.sizes[(dim := d)] == 0 for d in ds.dims):
raise IndexError(f"'{dim}' empty after merging.")
......@@ -126,68 +169,31 @@ def get_data(args):
return ds
def slice_zone(args, *datasets):
"""Select latitudinal slab."""
if args['zone'] in 'NI':
bounds = [lib.data.SN_separation.LAT_IS_SEPARATION, None]
else:
bounds = [None, lib.data.SN_separation.LAT_IS_SEPARATION]
out = []
for ds in datasets:
if 'lat' in ds.coords:
slc = bounds.copy()
if ds.lat.get_index('lat').is_monotonic_decreasing:
slc.reverse()
out.append(ds.sel(lat=slice(*slc)))
else:
out.append(ds)
return out
def prepare_hist(args, ds, variable):
"""Set dimensions and coordinates for parameters and histograms."""
def get_bins(variable):
"""Return bins axis."""
if variable == 'sst':
bounds = VAR_RANGE[variable]
n_bins = int((bounds[1]-bounds[0]) / (SST_STEP*10))
bins = bh.axis.Regular(n_bins, *bounds)
bins = dh.axis.Regular(n_bins, *bounds)
elif variable == 'CHL':
bins = bh.axis.Regular(CHL_NBINS, *VAR_RANGE[variable],
transform=bh.axis.transform.log)
bins = dh.axis.Regular(CHL_NBINS, *VAR_RANGE[variable],
transform=dh.axis.transform.log)
else:
raise ValueError(f"'{variable}' variable not supported.")
args_dims = ['thr_lo', 'thr_hi', 'zone', 'scale', 'number', 'coef']
hist = xr.Dataset(
coords=dict(**{c: [args[c]] for c in args_dims},
mask=['low', 'mid', 'hi'],
time=ds.time.copy(),
bins=bins.edges[:-1])
)
da = xr.DataArray(
np.zeros(list(hist.sizes.values()), 'uint32'),
dims=hist.dims,
)
hist['hist'] = da
hist.bins.attrs['right_edge'] = VAR_RANGE[variable][1]
return hist, bins
return bins
if __name__ == '__main__':
def add_args(parser):
parser.add_argument('-thr_lo', type=float, default=5.)
parser.add_argument('-thr_hi', type=float, default=10.)
# Frt/Bkg HI threshold
parser.add_argument('-threshold', type=float, default=5.)
# Zone·s to compute
parser.add_argument('-zones', type=str, default='INS')
args = lib.get_args(['region', 'year', 'days',
'scale', 'number', 'coef',
'fixes', 'zone'], add_args)
args = lib.get_args(['region', 'year', 'days', 'scale', 'number',
'coef', 'fixes'], add_args)
args['fixes']['Y'] = args['year']
args['Y'] = args['year']
if args['zone'] not in 'NIS':
raise ValueError("Zone must be one of {N, I, S}.")
hist = main(args)
......@@ -19,8 +19,9 @@ ARGS_DIR = {'region', 'days'}
pregex = ("number_%(number:fmt=d:rgx=%I)/"
"scale_%(scale:fmt=.1f)/"
"coef_%(coef:fmt=d)/"
"hist_GS3_%(time:Y)"
"_thr_%(threshold:fmt=.2f)"
"hist_GS3"
"_thr_%(thr_lo:fmt=.2f)_%(thr_hi:fmt=.2f)"
"_%(time:Y)"
".nc")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment