Docker-in-Docker (DinD) capabilities of public runners deactivated. More info

Commit 250d140a authored by Clément Haëck's avatar Clément Haëck
Browse files

Adapt to xarray

Parallelisation not fully working yet
parent aa8c6a35
"""Compute stats. """
from os import path
import xarray as xr
import numpy as np
import pandas as pd
from scipy import ndimage
from lib import root_data, get_args
from lib.dask_client import make_client
import lib.zones
import lib.data.mask
import lib.data.sst
import lib.data.chl
def get_data():
mask_lo = lib.data.mask.get_data(region, days, year, scale, number, chl=True)
mask_lo = mask_lo.assign_coords(time=mask_lo.time.dt.floor("D"))
chl = lib.data.chl.get_data(region, days, year)
zone_lo = lib.zones.get_data(region, 'lo')
ds_lo = xr.merge([mask_lo, chl, zone_lo], join='inner')
mask_hi = lib.data.mask.get_data(region, days, year, scale, number, chl=False)
sst = lib.data.sst.get_data(region, days, year)
sst = sst.rename(mask='sst_mask')
zone_hi = lib.zones.get_data(region, 'hi')
ds_hi = xr.merge([mask_hi, sst, zone_hi], join='inner')
ds_hi = ds_hi.assign_coords(time=ds_hi.time.dt.floor("D"))
ds_lo = add_land('lo', ds_lo)
ds_hi = add_land('hi', ds_hi)
ds_lo, ds_hi = xr.align(ds_lo, ds_hi, join='inner',
exclude=['lat', 'lon'])
ds_lo = to_pd_datetime(ds_lo)
ds_hi = to_pd_datetime(ds_hi)
return ds_lo, ds_hi
def get_stats(time):
variables = ['mean', 'q10', 'q25', 'q50', 'q75', 'q90', 'n']
coords = {
'variable': ['analysed_sst', 'CHL'],
'zone': zones_list,
'mask': ['front', 'background'],
'time': time
}
stats = xr.Dataset(data_vars={v: xr.DataArray(np.nan, coords, coords.keys())
for v in variables},
coords=coords)
return stats
def add_land(kind, ds):
filename = path.join(root_data, 'land_mask_{}.nc'.format(kind))
land = xr.open_dataset(filename)['land']
land, _ = xr.align(land, ds, join='right')
land = land.astype('bool')
bermudes_lat = (32.20, 32.45)
bermudes_lon = (-64.95, -64.60)
land.loc[dict(lat=slice(*bermudes_lat),
lon=slice(*bermudes_lon))] = False
ds['land'] = land
return ds
def to_pd_datetime(ds):
ds = ds.assign_coords(time=pd.to_datetime(ds.time.values))
return ds
def extend_mask(mask, neighbors, repeat):
n = 2*neighbors+1
kernel = np.zeros((n, n))
for i in range(n):
for j in range(n):
kernel[i, j] = (i-(n-1)/2)**2 + (j-(n-1)/2)**2 <= (n/2)**2
for _ in range(repeat):
mask = np.clip(ndimage.convolve(mask, kernel), 0, 1)
return mask
def compute_stats(stats, loc, da, mask):
def add(var, func, *args):
stats[var].loc[loc] = getattr(da.where(mask), func)(*args, dim=['lat', 'lon'])
add('n', 'count')
add('mean', 'mean')
add('q10', 'quantile', 0.10)
add('q25', 'quantile', 0.25)
add('q50', 'quantile', 0.50)
add('q75', 'quantile', 0.75)
add('q90', 'quantile', 0.90)
return stats
def compute(ds, var, stats):
print(var)
for m in ['front', 'background']:
print(m)
mask = ds[m] * ~ds['land']
for zone in zones_list:
print(zone)
if zone == 'total':
mask_ = mask
else:
mask_ = mask * ds[zone]
loc = dict(zone=zone, mask=m, variable=var)
stats = compute_stats(stats, loc, ds[var], mask_)
if __name__ == '__main__':
# client = make_client()
args = get_args(['region', 'days', 'year',
'scale', 'number', 'target'])
region = str(args['region'])
days = int(args['days'])
year = int(args['year'])
scale = float(args['scale'])
number = int(args['number'])
target = int(args['target'])
zones_list = ['1', '4', 'total']
ds_lo, ds_hi = get_data()
stats = get_stats(ds_lo.time)
compute(ds_lo, 'CHL', stats)
compute(ds_hi, 'analysed_sst', stats)
stats.to_netcdf(path.join(root_data, region,
'HI/HI_{:.1f}_{:d}'.format(scale, target),
'stats/{:d}days/{:d}'.format(days, year),
'stats.nc'))
#!/usr/bin/env bash
export PYTHONPATH="/home/chaeck/Fronts/:$PYTHONPATH"
. "$HOME/.setup-shell.sh"
conda activate py38
python "$HOME/Fronts/Compute/compute_deltas.py" \
--kwargs region:"'GS'" days:8 year:2007 \
scale:10. number:2 target:2
......@@ -5,7 +5,6 @@ export PYTHONPATH="/home/chaeck/Fronts/:$PYTHONPATH"
. "$HOME/.setup-shell.sh"
conda activate py38
python "$HOME/Fronts/Compute/compute_deltas_daily.py" \
python "$HOME/Fronts/Compute/compute_deltas.py" \
--kwargs region:"'GS'" days:8 year:2003 \
scale:10. number:1 target:1 \
level:4 step:1
"""Compute HI components."""
from lib import compute_hi, get_args
from os import path
import numpy as np
import xarray as xr
import lib.data.sst
import lib.data.hi
from lib import compute_hi_f, root_data, get_args
from lib.dask_client import make_client
from lib.xarray_utils import save_chunks_by_date
def get_scale(hi, scale):
def d2s(d):
"""Turn deg step in pixels."""
return int(np.ceil((scale_deg/d-1) / 2))
scale_deg = scale / 117.
dx = hi.lon.diff('lon').mean()
dy = hi.lat.diff('lat').mean()
sx = d2s(dx)
sy = d2s(dy)
hi.attrs['scale_km'] = scale
hi.attrs['scale_deg'] = scale_deg
hi.attrs['scale_deg'] = scale_deg
hi.attrs['sx'] = sx
hi.attrs['sy'] = sy
return sx, sy
def get_components(sst, mask, sx=1, sy=1):
values = compute_hi_f.compute_hi.get_hi(sst, mask, sx, sy)
values[~np.isifinite(values)] = -1.
values[values < 0] = np.nan
return values[0], values[1], values[2]
if __name__ == '__main__':
client = make_client(n_workers=1)
args = get_args(['region', 'days', 'year', 'scale', 'number'])
region = str(args['region'])
days = int(args['days'])
year = int(args['year'])
args = get_args(['region', 'days', 'year', 'scale', 'number'])
region = str(args['region'])
days = int(args['days'])
year = int(args['year'])
scale = float(args['scale'])
number = int(args['number'])
scale = float(args['scale'])
number = int(args['number'])
db = lib.data.sst.get_data(region, days, year)
db_hi = lib.data.hi.get_data_from_db(db, region, days, year, scale, number)
sst = lib.data.sst.get_data(region, days, year)
hi = sst.drop_vars(list(sst.data_vars))
hi.attrs = {}
sx, sy = get_scale(hi, scale)
## Compute
values = xr.apply_ufunc(get_components,
sst.analysed_sst, sst.analysed_sst.isnull(),
kwargs=dict(sx=sx, sy=sy),
dask='parallelized',
input_core_dims=[['lat', 'lon'],
['lat', 'lon']],
output_core_dims=[['lat', 'lon'],
['lat', 'lon'],
['lat', 'lon']])
hi['S'] = values[0]
hi['V'] = values[1]
hi['B'] = values[2]
hi = hi.isel(lon=slice(sx, -sx), lat=slice(sy, -sy))
for time_slice in db.avail.iter_slices('time', 3):
db.load(time=time_slice)
db.remove_loaded_variables(['SST_error', 'SST_mask'])
compute_hi.compute_components(db, db_hi, sst='SST')
filename = path.join(
root_data,
'{:s}/HI/HI_{:.1f}_{:d}/HI/{:d}days/{:d}'.format(
region, scale, number, days, year),
'HI_{:s}.nc'
)
encoding = {'_all': {'zlib': True}}
save_chunks_by_date(hi, filename=filename,
encoding=encoding)
for i, d in enumerate(db_hi.loaded.time.index2date()):
filename = 'HI_{:s}.nc'.format(d.strftime('%F'))
db_hi.write(filename, time=[i], var_kw={'_all': {'zlib': True}})
client.close()
......@@ -6,5 +6,5 @@ export PYTHONPATH="/home/chaeck/Fronts/:$PYTHONPATH"
conda activate py38
python "$HOME/Fronts/Compute/compute_hi.py" \
--kwargs region:"'GS'" days:8 year:2003 \
scale:10. number:1
--kwargs region:"'GS'" days:1 year:2007 \
scale:10. number:2
"""Compute masks.
Week by week, find H-Index threshold for
impacted and non-impacted areas.
"""
"""Compute masks. """
from os import path
import xarray as xr
import numpy as np
from matplotlib import pyplot as plt
from dask.distributed import get_task_stream
import tomate.var_types as vt
from lib.data.hi import get_data
import lib.data.hi
from lib import root_data, get_args
plt.style.use('common')
args = get_args(['region', 'days', 'year',
'scale', 'number', 'var',
'threshold_min', 'threshold_max'])
print('Parameters: ', args)
region = str(args['region'])
days = int(args['days'])
year = int(args['year'])
scale = float(args['scale'])
number = int(args['number'])
var = str(args['var'])
threshold_min = float(args['threshold_min'])
threshold_max = float(args['threshold_max'])
db = get_data(region, days, year, scale, number)
kwargs = {'dims': ['time', 'lat', 'lon'],
'var_class': vt.VariableMasked}
db.vi.set_attribute_variables('datatype', front='i1', background='i1')
db.add_variable('front', **kwargs)
db.add_variable('background', **kwargs)
db.vi.set_attribute_variables('threshold', front=threshold_max,
background=threshold_min)
db.select_by_value(scope='avail', var=var)
# db.selected.slice(time=slice(None, None, 2))
wd = path.join(root_data,
region,
'HI/HI_{:.1f}_{:d}/masks'.format(scale, number),
'{:d}days/{:d}'.format(days, year))
from lib.dask_client import make_client
from lib.xarray_utils import save_chunks_by_date
## Compute
def plot_thresholds(db, HI, threshold_min, threshold_max, s_date):
def plot_threshold(hi, threshold_min, threshold_max):
fig, ax = plt.subplots(figsize=(7, 5))
vmin = np.amin(HI.compressed())
vmax = np.amax(HI.compressed())
bins = int(100)
hist, bins = np.histogram(HI.compressed(), bins=bins,
range=[vmin, vmax], density=True)
hist, bins = np.histogram(hi.to_masked_array().compressed(), bins=100, density=True)
ax.plot(bins[:-1], hist, ds='steps-post', color='k')
ax.set_ylim(0, None)
......@@ -72,22 +27,57 @@ def plot_thresholds(db, HI, threshold_min, threshold_max, s_date):
ax.axvspan(bins[0], threshold_min, color='royalblue', alpha=.3)
ax.axvspan(threshold_max, bins[-1], color='red', alpha=.3)
s_date = hi.time[0].dt.strftime('%F').values
fig.savefig(path.join(wd, f'thresholds_{s_date}.png'), dpi=100)
plt.close(fig)
return hi
if __name__ == '__main__':
client = make_client()
with get_task_stream(plot='save', filename='/home/chaeck/task-stream.html') as ts:
plt.style.use('common')
plt.switch_backend('agg')
args = get_args(['region', 'days', 'year',
'scale', 'number', 'var',
'threshold_min', 'threshold_max'])
print('Parameters: ', args)
region = str(args['region'])
days = int(args['days'])
year = int(args['year'])
scale = float(args['scale'])
number = int(args['number'])
var = str(args['var'])
threshold_min = float(args['threshold_min'])
threshold_max = float(args['threshold_max'])
wd = path.join(root_data,
region,
'HI/HI_{:.1f}_{:d}/masks'.format(scale, number),
'{:d}days/{:d}'.format(days, year))
for time_slice in db.selected.iter_slices('time', 8):
db.load_selected(time=time_slice)
ds = lib.data.hi.get_data(region, days, year, scale, number)
hi = ds[var].fillna(0.)
ds = ds.drop_vars(ds.data_vars)
s_date = db.loaded.time.index2date(0).strftime('%F')
# hi = hi.map_blocks(plot_threshold, args=[threshold_min, threshold_max],
# template=hi)
HI = db.view(var)
plot_thresholds(db, HI, threshold_min, threshold_max, s_date)
ds['front'] = hi > threshold_max
ds['background'] = hi < threshold_min
ds['mask'] = hi.isnull()
ds['front'] = ds.front.where(~ds.mask, False)
ds['background'] = ds.background.where(~ds.mask, False)
db['front'] = HI > threshold_max
db['background'] = HI < threshold_max
save_chunks_by_date(ds, path.join(wd, 'masks_{:s}.nc'),
encoding={'_all': {'zlib': True}})
db.filegroups[0].write(f'masks_{s_date}.nc', directory=wd,
var=['front', 'background'],
var_kw={'_all': {'zlib': True,
'datatype': 'i1'}})
client.profile(filename='profile.html')
client.close()
......@@ -6,6 +6,6 @@ export PYTHONPATH="/home/chaeck/Fronts/:$PYTHONPATH"
conda activate py38
python "$HOME/Fronts/Compute/compute_masks.py" \
--kwargs region:"'GS'" days:8 year:2018 \
scale:10. number:1 \
--kwargs region:"'GS'" days:1 year:2007 \
scale:10. number:2 \
var:"'V'" threshold_min:0.06 threshold_max:0.13
import numpy as np
from os import path
from tomate.var_types import VariableMasked
import xarray as xr
import pandas as pd
from dask.distributed import Client
import lib.data.pig
import lib.data.pft
import lib.data.chl
from lib import root_data, get_args
from lib.xarray_utils import split_by_chunks, save_mfdataset
from lib import get_args
args = get_args(['region', 'days', 'year'])
region = str(args['region'])
days = int(args['days'])
year = int(args['year'])
cstr = lib.data.pig.get_cstr(region, days, year)
lib.data.chl.add_chl(cstr, region, days, year, level=4)
db = cstr.make_data()
cstr = lib.data.pft.get_cstr(region, days, year, variables='PFT', concentration=False)
pft = cstr.make_data(scan=False)
pft_var = ['dtom', 'dinflg', 'galgae', 'prokrt',
'prymnsio', 'prchlcus', 'picoeuk', 'pico']
## Compute pft
def compute_pft(db, pft):
tot = (1.41*db['Fuco'] + 1.41*db['Perid']
+ 1.27*db['19HF'] + 0.35*db['19BF']
+ 1.01*db['Chlb'] + 0.60*db['Allo'] + 0.86*db['Zea'])
def compute_pft(ds):
pft = ds.drop_vars(list(ds.data_vars))
pft['dtom'] = 1.41 * db['Fuco'] / tot
pft['dinflg'] = 1.41 * db['Perid'] / tot
pft['galgae'] = 1.01 * db['Chlb'] / tot
pft['prokrt'] = 0.86 * db['Zea'] / tot
pft['prchlcus'] = 0.74 * db['DVChla'] / tot
pft['prchlcus'].data = np.clip(pft['prchlcus'].data, 0, .9)
tot = (1.41*ds['Fuco'] + 1.41*ds['Perid']
+ 1.27*ds['19HF'] + 0.35*ds['19BF']
+ 1.01*ds['Chlb'] + 0.60*ds['Allo'] + 0.86*ds['Zea'])
X = np.clip(db['CHL']/0.08, 0, 1)
pft['prymnsio'] = (1.27*X*db['19HF'] + .35*db['19BF'] + .6*db['Allo']) / tot
pft['picoeuk'] = 1.27*(1-X)*db['19HF'] / tot
pft['pico'] = pft['picoeuk'] + pft['prokrt']
pft['dtom'] = 1.41 * ds['Fuco'] / tot
pft['dinflg'] = 1.41 * ds['Perid'] / tot
pft['galgae'] = 1.01 * ds['Chlb'] / tot
pft['prokrt'] = 0.86 * ds['Zea'] / tot
pft['prchlcus'] = 0.74 * ds['DVChla'] / ds['Chla_SOM']
# Set mask
for var in pft.loaded.var:
pft[var] = pft[var].data.clip(0, 1)
X = (ds['CHL']/0.08).clip(0, 1)
pft['prymnsio'] = (1.27*X*ds['19HF'] + .35*ds['19BF'] + .6*ds['Allo']) / tot
pft['picoeuk'] = 1.27*(1-X)*ds['19HF'] / tot
pft['pico'] = pft['picoeuk'] + pft['prokrt']
for var in pft.data_vars:
pft[var] = pft[var].clip(0, 1)
def compute(db, pft):
return pft
for var in pft_var:
pft.add_variable(var, dims=['time', 'lat', 'lon'], datatype='f8',
var_class=VariableMasked)
pft.filegroups[0].add_variable(var, dimensions=('time', 'lat', 'lon'))
for time_slice in db.avail.iter_slices('time', 6):
db.load(time=time_slice)
pft.unload()
pft.avail = db.loaded.copy()
pft.avail.var.update_values(pft_var)
if __name__ == '__main__':
client = Client()
compute_pft(db, pft)
pig = lib.data.pig.get_data(region, days, year)
chl = lib.data.chl.get_data(region, days, year)
pig = pig.assign_coords(time=pig.time.dt.floor("D"))
for var in pft.loaded.var:
for i, d in enumerate(pft.loaded.time.index2date()):
filename = '{0}/{0}_{1}.nc'.format(var, d.strftime('%Y%m%d'))
pft.write(filename, var=var, time=[i],
var_kw={'_all': {'zlib': True}})
ds = xr.merge([pig, chl], join='inner')
ds = ds.assign_coords(time=pd.to_datetime(ds.time.values))
pft = compute_pft(ds)
compute(db, pft)
datasets = list(split_by_chunks(pft))
dates = [d.isel(time=0).time.dt.strftime('%F').values
for d in datasets]
filenames = [path.join(root_data,
'{:s}/SOM/PFT/{:d}days/{:d}/pft_{:s}.nc'.format(
region, days, year, d))
for d in dates]
encoding = {var: dict(zlib=True) for var in pft.data_vars}
save_mfdataset(datasets, filenames, encoding=encoding)
......@@ -2,65 +2,24 @@
from os import path
import xarray as xr
import numpy as np
from tomate import Lat, Lon
from tomate.constructor import create_data_class
from tomate.filegroup import FilegroupNetCDF
import tomate.db_types as dt
import tomate.var_types as vt
import lib.data.chl
import lib.data.mask
from lib import root_data, get_args
from lib.dask_client import make_client
from lib.xarray_utils import save_chunks_by_date
from lib.downsample_f import downsample
args = get_args(['region', 'days', 'year', 'scale', 'number'])
region = str(args['region'])
days = int(args['days'])
year = int(args['year'])
scale = float(args['scale'])
number = int(args['number'])
db_chl = lib.data.chl.get_data(region, days, year)
db_mask = lib.data.mask.get_data(region, days, year, scale, number)
wd = path.join(root_data,
region,
'HI/HI_{:.1f}_{:d}/masks_chl'.format(scale, number),
'{:d}days/{:d}'.format(days, year))
def get_db_mask_chl(db_mask, db_chl):
db_chl.select_by_value(lat=slice(*db_mask.selected.lat.get_extent()),
lon=slice