Skip to content
Snippets Groups Projects
Commit 26b8f754 authored by Wolfgang Kastaun's avatar Wolfgang Kastaun
Browse files

Added dask version of shift inversion

parent 970ea945
No related branches found
No related tags found
No related merge requests found
"""Functions for applying dynamic real-valued shifts to dask arrays using Lagrange interpolation
Use make_dynamic_shift_lagrange_dask to create a Lagrange interpolator for dask arrays.
"""
from __future__ import annotations
from typing import Final
import dask
import dask.array as da
import numpy as np
from lisainstrument.fir_filters_dask import DaskArray1D, make_dask_array_1d
from lisainstrument.fir_filters_numpy import NumpyArray1D, make_numpy_array_1d
from lisainstrument.regular_interpolators import (
RegularInterpolator,
make_regular_interpolator_lagrange,
)
from lisainstrument.shift_inversion_numpy import fixed_point_iter
class ShiftInverseDask:
"""Invert coordinate transformation given as shift"""
def __init__(
self,
max_abs_shift: float,
interp: RegularInterpolator,
max_iter: int,
tolerance: float,
):
"""Set up interpolator.
Arguments:
max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space
interp: Interpolation method
max_iter: Maximum iterations before fail
tolerance: Maximum absolute error of result
"""
self._max_abs_shift: Final = int(np.ceil(max_abs_shift))
if self._max_abs_shift < 0:
msg = (
f"ShiftInverseDask: max_abs_shift must be positive, got {max_abs_shift}"
)
raise ValueError(msg)
self._interp_np: Final = interp
self._max_iter = int(max_iter)
self._tolerance = float(tolerance)
@property
def margin_left(self) -> int:
"""Left margin size.
Specifies how many samples on the left have to be added by boundary conditions.
"""
return self._interp_np.margin_left + self._max_abs_shift
@property
def margin_right(self) -> int:
"""Right margin size.
Specifies how many samples on the right have to be added by boundary conditions.
"""
return self._interp_np.margin_right + self._max_abs_shift
def _fixed_point_iter(self, dx_pad: np.ndarray, dx: np.ndarray) -> NumpyArray1D:
def f_iter(x: NumpyArray1D) -> NumpyArray1D:
return self._interp_np.apply_shift(
make_numpy_array_1d(dx_pad), -x, self.margin_left
)
def f_err(x1: NumpyArray1D, x2: NumpyArray1D) -> float:
return np.max(np.abs(x1 - x2)[self.margin_left : -self.margin_right])
return fixed_point_iter(
f_iter, f_err, make_numpy_array_1d(dx), self._tolerance, self._max_iter
)
def __call__(self, shift: da.Array, fsample: float) -> DaskArray1D:
"""Find the shift for the inverse transformation given by a shift.
Arguments:
shift: 1D dask array with shifts of the coordinate transform
fsample: sample rate of shift array
Returns:
1D dask array with shift at transformed coordinate
"""
shift_idx = shift * fsample
make_dask_array_1d(shift_idx)
dx_pad = da.pad(
shift_idx,
(self.margin_left, self.margin_right),
mode="edge",
# ~ constant_values=(shift_idx[0], shift_idx[-1]),
)
results = []
chunks = shift_idx.to_delayed()
delayed_op = dask.delayed(self._fixed_point_iter)
pos = 0
for chunk, chunk_shape in zip(chunks, shift.chunks[0], strict=True):
n_size = self.margin_left + chunk_shape + self.margin_right
n_first = pos
samples_needed = dx_pad[n_first : n_first + n_size]
samples_shifted = delayed_op(samples_needed, chunk)
delayed_chunk = da.from_delayed(
samples_shifted, (chunk_shape,), shift_idx.dtype
)
results.append(delayed_chunk)
pos += chunk_shape
shift_idx_inv = da.concatenate(results, axis=0)
shift_inv = shift_idx_inv / fsample
return make_dask_array_1d(shift_inv)
def make_shift_inverse_lagrange_dask(
order: int,
max_abs_shift: float,
max_iter: int,
tolerance: float,
) -> ShiftInverseDask:
"""Set up ShiftInverseDask instance with Lagrange interpolation method.
Arguments:
order: Order of the Lagrange polynomials
max_abs_shift: Upper limit for absolute difference between coordinate frames w.r.t index space
max_iter: Maximum iterations before fail
tolerance: Maximum absolute error of result
Returns:
Inversion function of type ShiftInverseNumpy
"""
interp = make_regular_interpolator_lagrange(order)
return ShiftInverseDask(max_abs_shift, interp, max_iter, tolerance)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment