Commit eba8c7d1 authored by Clément Haëck's avatar Clément Haëck
Browse files

Regrouped hist scripts for 1thr and 2thr

parent 8317d67d
......@@ -21,11 +21,11 @@ import xarray as xr
import lib
import lib.data.globcolour
import lib.data.hi
import lib.data.hi as lhi
import lib.data.hists as lh
import lib.data.ostia
import lib.data.p_frt_mask
import lib.data.SN_separation
import lib.data.SN_separation as lsep
import lib.zones
# Width of SST discretization step
......@@ -41,10 +41,10 @@ VAR_RANGE = dict(
)
VARS = ['CHL', 'sst']
MASKS = ['frt', 'bkg']
MASKS_VAR = ['mask_' + m for m in MASKS]
INFILE_PARAMS = ['threshold', 'scale', 'number', 'coef']
MASKS = {
'1thr': ['frt', 'bkg'],
'2thr': ['low', 'mid', 'hi']
}
def main(args):
......@@ -53,11 +53,11 @@ def main(args):
lgr.msg("Smoothing SN separation temperature")
# Smooth SN separation
ds['threshold'] = lib.data.SN_separation.smooth(ds, time_step=8)
ds['threshold'] = lsep.smooth(ds, time_step=8)
lgr.msg("Computing HI")
# Compute HI
ds['HI'] = lib.data.hi.apply_coef(ds, lib.data.hi.get_coef(args))
ds['HI'] = lhi.apply_coef(ds, lhi.get_coef(args))
ds = ds.drop_vars(['S', 'V', 'B'])
lgr.msg("Applying static masks")
......@@ -70,9 +70,16 @@ def main(args):
lgr.msg("Computing HI masks")
# Masks
ds['mask_frt'] = ds.HI > args['threshold']
ds['mask_bkg'] = ds.HI < args['threshold']
masks = MASKS[args['kind']]
masks_var = ['mask_' + m for m in masks]
if args['kind'] == '1thr':
ds['mask_frt'] = ds.HI > args['threshold']
ds['mask_bkg'] = ds.HI < args['threshold']
elif args['kind'] == '2thr':
ds['mask_low'] = ds.HI < args['thr_lo']
ds['mask_hi'] = ds.HI > args['thr_hi']
ds['mask_mid'] = (ds.HI > args['thr_lo']) * (ds.HI < args['thr_hi'])
ds = ds.drop_vars(['HI'])
lgr.msg("Computing zones datasets")
......@@ -85,7 +92,7 @@ def main(args):
if zone not in args['zones']:
continue
up = ds.sel(lat=slice(None, 32))
for mask in MASKS_VAR:
for mask in masks_var:
up[mask] = up[mask] * op(ds.sst, ds.threshold)
zones['GS3_'+zone] = up
......@@ -94,9 +101,9 @@ def main(args):
for var in VARS:
hists = []
bins = get_bins(var)
bins_name = 'bins_' + var
bins_name = lh.var_name('bins', var)
for zone, zone_ds in zones.items():
for m, mask in zip(MASKS, MASKS_VAR):
for m, mask in zip(masks, masks_var):
lgr.progress()
h = zone_ds[var].where(zone_ds[mask]).groupby('time').map(
hist_grp, shortcut=True, args=[bins, bins_name]
......@@ -118,7 +125,13 @@ def main(args):
hist.attrs['VARS'] = VARS
hist.attrs.pop('VAR')
hist = hist.expand_dims({d: [args[d]] for d in INFILE_PARAMS})
infile_params = ['scale', 'number', 'coef']
if args['kind'] == '1thr':
infile_params.append('threshold')
elif args['kind'] == '2thr':
infile_params += ['thr_lo', 'thr_hi']
hist = hist.expand_dims({d: [args[d]] for d in infile_params})
# We have at most every pixel of an image in one bin,
# so approx 1000*1000 = 1e6 values. Uint32 stores up to 2**32~4e9.
......@@ -127,7 +140,7 @@ def main(args):
for v in VARS}
lgr.msg("Executing computations / Writing to disk")
ofile = lib.data.hists.get_filename(args)
ofile = lh.get_filename(args)
lib.check_output_dir(ofile, file=True)
lib.setup_metadata(hist, args)
hist.to_netcdf(ofile, encoding=encoding)
......@@ -152,7 +165,7 @@ def get_data(args):
data = []
data.append(lib.data.ostia.get_data(args))
data.append(lib.data.globcolour.get_data(args))
data.append(lib.data.hi.get_data(args))
data.append(lhi.get_data(args))
data.append(lib.data.SN_separation.get_data(args))
data = [lib.fix_time_daily(ds) for ds in data]
......@@ -202,9 +215,13 @@ def add_args(parser):
if __name__ == '__main__':
args = lib.get_args(['region', 'year', 'days', 'scale', 'number',
'coef', 'fixes', 'threshold', 'mask'], add_args)
args['kind'] = '1thr'
'coef', 'fixes', 'mask',
'kind', 'threshold', 'thr_lo', 'thr_hi'],
add_args)
args['fixes']['Y'] = args['year']
args['Y'] = args['year']
if (kind := args['kind']) not in ['1thr', '2thr']:
raise KeyError("kind '{}' not supported.".format(kind))
hist = main(args)
"""Compute histograms of variables for ostia data in the Gulf-Stream.
Separated in 3 zones: North of jet, South of jet, South of 32N.
GS3_N, GS3_I, GS3_S (North, Intermediate, South)
All results (for differents zones, masks and variable) are aggregated in one
output file. This ensures dask maximizes reuse of loaded data (for instance,
the same loaded bytes of HI are used to mask both variables).
This should be achievable by regrouping different delayed `to_netcdf` calls,
but this way is kind of simpler. It also generates less files (at the
cost of complexity of variables names).
It is possible to select only some of I, N, and S zones for computation,
that can be useful if they should have different different thresholds (note
that would necessitate some arrangements to lib.data.hists).
"""
import operator
import dask_histogram as dh
import xarray as xr
import lib
import lib.data.globcolour
import lib.data.hi
import lib.data.hists as lh
import lib.data.ostia
import lib.data.p_frt_mask
import lib.data.SN_separation
import lib.zones
# Width of SST bins
SST_STEP = 1e-2
# Number of Chl bins
CHL_NBINS = 500
# Bins extent
VAR_RANGE = dict(
# move bounds by half a discretisation step
sst=[b - SST_STEP/2. for b in [-5., 40.]],
CHL=[5e-3, 20.]
)
VARS = ['CHL', 'sst']
MASKS = ['low', 'mid', 'hi']
MASKS_VAR = ['mask_' + m for m in MASKS]
INFILE_PARAMS = ['thr_lo', 'thr_hi', 'scale', 'number', 'coef']
def main(args):
lgr = lib.Logger("Loading data")
ds = get_data(args)
lgr.msg("Smoothing SN separation temperature")
# Smooth SN separation
ds['threshold'] = lib.data.SN_separation.smooth(ds, time_step=8)
lgr.msg("Computing HI")
# Compute HI
ds['HI'] = lib.data.hi.apply_coef(ds, lib.data.hi.get_coef(args))
ds = ds.drop_vars(['S', 'V', 'B'])
lgr.msg("Applying static masks")
# Apply masks: land (enlarged), total zone, min front proba
static = ~ds.land_large * ds.total
if args['mask']:
static = static * ds.p_frt
ds['HI'] = ds.HI.where(static)
ds = ds.drop_vars(['land_large', 'total', 'p_frt'])
lgr.msg("Computing HI masks")
# Masks
ds['mask_low'] = ds.HI < args['thr_lo']
ds['mask_hi'] = ds.HI > args['thr_hi']
ds['mask_mid'] = (ds.HI > args['thr_lo']) * (ds.HI < args['thr_hi'])
ds = ds.drop_vars(['HI'])
lgr.msg("Computing zones datasets")
# Datasets for each zone
zones = dict()
if 'S' in args['zones']:
zones['GS3_S'] = ds.sel(lat=slice(32, None))
# For N and I, slice, then restrict masks according to SST
for zone, op in zip('IN', [operator.gt, operator.lt]):
if zone not in args['zones']:
continue
up = ds.sel(lat=slice(None, 32))
for mask in MASKS_VAR:
up[mask] = up[mask] * op(ds.sst, ds.threshold)
zones['GS3_'+zone] = up
lgr.msg("Setting up histogram computations")
hists_var = []
for var in VARS:
hists = []
bins = get_bins(var)
bins_name = lh.var_name('bins', var)
for zone, zone_ds in zones.items():
for m, mask in zip(MASKS, MASKS_VAR):
lgr.progress()
h = zone_ds[var].where(zone_ds[mask]).groupby('time').map(
hist_grp, shortcut=True, args=[bins, bins_name]
)
h = h.expand_dims(zone=[zone], mask=[m])
hists.append(h)
hist = xr.combine_by_coords(hists)
# Get a DataArray with the correct name
hist = hist[var].rename(lh.var_name('hist', var))
hist = hist.assign_coords({bins_name: bins.edges[:-1]})
hist[bins_name].attrs['right_edge'] = VAR_RANGE[var][1]
hist[bins_name].attrs['VAR'] = var
hist.attrs['VAR'] = var
hists_var.append(hist)
lgr.msg("Merging results")
hist = xr.merge(hists_var)
hist.attrs['VARS'] = VARS
hist.attrs.pop('VAR')
hist = hist.expand_dims({d: [args[d]] for d in INFILE_PARAMS})
# We have at most every pixel of an image in one bin,
# so approx 1000*1000 = 1e6 values. Uint32 stores up to 2**32~4e9.
encoding = {v: {'dtype': 'uint32', '_FillValue': 2**30-1}
for v in ['hist_' + v for v in VARS]}
lgr.msg("Executing computations / Writing to disk")
ofile = lib.data.hists.get_filename(args)
lib.check_output_dir(ofile, file=True)
lib.setup_metadata(hist, args)
hist.to_netcdf(ofile, encoding=encoding)
lgr.end()
return hist
def hist_grp(da, bins: dh.axis.Axis, bins_name, **kwargs):
"""Compute histogram of an array.
Flatten the array completely.
Uses dask_histogram.
"""
h = dh.factory(da.data.ravel(), axes=[bins], **kwargs)
h, _ = h.to_dask_array()
return xr.DataArray(h, dims=[bins_name])
def get_data(args):
"""Load all data in a single dataset."""
data = []
data.append(lib.data.ostia.get_data(args))
data.append(lib.data.globcolour.get_data(args))
data.append(lib.data.hi.get_data(args))
data.append(lib.data.SN_separation.get_data(args))
data = [lib.fix_time_daily(ds) for ds in data]
args['grid'] = lib.data.ostia.grid
data.append(lib.zones.get_data(args).total)
data.append(lib.zones.get_land(args).land_large)
data.append(lib.data.p_frt_mask.get_data(
args, fixes=dict(
threshold=lib.data.p_frt_mask.default_threshold
)).rename(mask='p_frt'))
ds = xr.merge(data, join='inner')
ds = ds.drop_vars(['CHL_error', 'CHL_flags'])
# Check we don't have incompatibility between datasets
if any(ds.sizes[(dim := d)] == 0 for d in ds.dims):
raise IndexError(f"'{dim}' empty after merging.")
# Check if SST scale factor match expected step
if (sf := ds['sst'].encoding['scale_factor']) != SST_STEP:
raise ValueError("Scale factor ({}) different from expected ({})"
.format(sf, SST_STEP))
return ds
def get_bins(variable):
"""Return bins axis."""
if variable == 'sst':
bounds = VAR_RANGE[variable]
n_bins = int((bounds[1]-bounds[0]) / (SST_STEP*10))
bins = dh.axis.Regular(n_bins, *bounds)
elif variable == 'CHL':
bins = dh.axis.Regular(CHL_NBINS, *VAR_RANGE[variable],
transform=dh.axis.transform.log)
else:
raise ValueError(f"'{variable}' variable not supported.")
return bins
if __name__ == '__main__':
def add_args(parser):
# Zone·s to compute
parser.add_argument('-zones', type=str, default='INS')
args = lib.get_args(['region', 'year', 'days', 'scale', 'number',
'coef', 'fixes', 'thr_lo', 'thr_hi', 'mask'],
add_args)
args['kind'] = '2thr'
args['fixes']['Y'] = args['year']
args['Y'] = args['year']
hist = main(args)
......@@ -62,11 +62,17 @@ coef: int, 0
# 0 is standard. Above are variations (like twice as much bimodality
# for instance).
threshold: float, 15.0
kind: str, 1thr
# How we compute histograms
# 1thr: One threshold, bkg and frt masks
# 2thr: Two thresholds, low, mid and hi masks
# 2d: A 2D histogram (against HI values)
threshold: float, 5.0
# HI threshold to discriminate fronts.
thr_lo: float, 6.0
thr_hi: float, 15.0
thr_lo: float, 5.0
thr_hi: float, 10.0
# HI thresholds to discriminates fronts.
# Under lo: no front
# Above hi: intense mesoscale front
......
......@@ -12,7 +12,6 @@ import lib.data
ARGS = {'region', 'days', 'kind', 'mask'}
defaults = dict(kind='2thr', mask=True)
"""
kind:
......@@ -48,7 +47,7 @@ def ROOT(args):
return lib.data.get_default_directory(args, 'Hists')
lib.data.create_dataset(__name__, PREGEX, ROOT, ARGS, defaults)
lib.data.create_dataset(__name__, PREGEX, ROOT, ARGS)
def var_name(name, var):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment