"""Functions for applying fixed real-valued shifts to dask arrays using Lagrange interpolation Use make_fixed_shift_lagrange_dask to create a Lagrange interpolator for dask arrays. """ from __future__ import annotations from typing import Final import dask import dask.array as da import numpy as np from typing_extensions import assert_never from lisainstrument.dynamic_delay_numpy import DynShiftBC from lisainstrument.fir_filters_dask import DaskArray1D, make_dask_array_1d from lisainstrument.fixed_shift_numpy import FixedShiftFactory, FixedShiftLagrange class FixedShiftDask: # pylint: disable=too-few-public-methods """Interpolate Dask arrays to locations given by a const shift. This allows to interpolate samples in a Dask array to locations specified by a fixed shift. The shift is specified in units of the array index, i.e. there is no separate coordinate array. A positive shift refers to values left of a given sample, negative shifts to values on the right. The boundary treatment can be specified for each boundary in terms of DynShiftBC enums. The interpolation method is not fixed but provided via an interpolator instance implementing the RegularInterpMethod protocol. """ def __init__( self, left_bound: DynShiftBC, right_bound: DynShiftBC, interp_fac: FixedShiftFactory, ): """Not intended for direct use, employ named constructors instead Arguments: left_bound: boundary treatment on the left right_bound: boundary treatment on the right interp_fac: Function to create interpolator engine for given shift """ self._left_bound: Final = left_bound self._right_bound: Final = right_bound self._interp_fac: Final = interp_fac def _padding_left(self, samples: da.Array, margin_left: int) -> da.Array: if margin_left > 0: match self._left_bound: case DynShiftBC.ZEROPAD: samples = da.concatenate([da.zeros(margin_left), samples]) case DynShiftBC.FLAT: samples = da.concatenate( [da.ones(margin_left) * samples[0], samples] ) case DynShiftBC.EXCEPTION: msg = ( f"FixedShiftDask: left edge handling {self._left_bound.name} " f"impossible for given delay, would need margin of {margin_left}." ) raise RuntimeError(msg) case _ as unreachable: assert_never(unreachable) elif margin_left < 0: samples = samples[-margin_left:] return samples def _padding_right(self, samples: da.Array, margin_right: int) -> da.Array: if margin_right > 0: match self._right_bound: case DynShiftBC.ZEROPAD: samples = da.concatenate([samples, da.zeros(margin_right)]) case DynShiftBC.FLAT: samples = da.concatenate( [samples, da.ones(margin_right) * samples[-1]] ) case DynShiftBC.EXCEPTION: msg = ( f"FixedShiftDask: right edge handling {self._right_bound.name} " f"impossible for given delay, would need margin of {margin_right}." ) raise RuntimeError(msg) case _ as unreachable: assert_never(unreachable) elif margin_right < 0: samples = samples[:margin_right] return samples def __call__(self, samples: da.Array, shift: float) -> DaskArray1D: r"""Apply shift $s$ to samples in Dask array. Denoting the input data as $y_i$ with $i=0 \ldots N-1$, and the interpolated input data as $y(t)$, such that $y(i)=y_i$, the output $z_k$ is given by $z_k = y(k-s), k=0 \ ldots N - 1$. Required data outside the provided samples is created if the specified boundary condition allows it, or an exception is raised if the BC disallows it. The output has same length and chunking as the input. Arguments: samples: 1D Dask array with data samples shift: The shift $s$ Returns: Dask array with interpolated samples """ loc = -shift loc_int = int(np.floor(loc)) loc_frac = loc - loc_int interp = self._interp_fac(-loc_frac) margin_left = interp.margin_left - loc_int margin_right = interp.margin_right + loc_int samples_checked = make_dask_array_1d(samples) samples_padleft = self._padding_left(samples_checked, margin_left) samples_padded = self._padding_right(samples_padleft, margin_right) delayed_op = dask.delayed(interp.apply) results = [] pos = 0 for chunk_shape in samples.chunks[0]: n_size = interp.margin_left + chunk_shape + interp.margin_right samples_shifted = delayed_op(samples_padded[pos : pos + n_size]) delayed_chunk = da.from_delayed( samples_shifted, (chunk_shape,), samples.dtype ) results.append(delayed_chunk) pos += chunk_shape return make_dask_array_1d(da.concatenate(results, axis=0)) def make_fixed_shift_lagrange_dask( left_bound: DynShiftBC, right_bound: DynShiftBC, length: int ) -> FixedShiftDask: """Create a FixedShiftDask instance that uses Lagrange interpolator Arguments: left_bound: boundary treatment on the left right_bound: boundary treatment on the right length: number of Lagrange plolynomials (=order + 1) Returns: Fixed shift interpolator """ fac = FixedShiftLagrange.factory(length) return FixedShiftDask(left_bound, right_bound, fac)