From 9478ea8e4e88ddc2b9e040921c91cadd5a08755f Mon Sep 17 00:00:00 2001
From: Wolfgang Kastaun <wolfgang.kastaun@aei.mpg.de>
Date: Fri, 13 Dec 2024 15:23:43 +0100
Subject: [PATCH] Pulled fir filter modules from firfilters branch

---
 lisainstrument/fir_filters_dask.py  | 201 +++++++++++++++++
 lisainstrument/fir_filters_numpy.py | 336 ++++++++++++++++++++++++++++
 2 files changed, 537 insertions(+)
 create mode 100644 lisainstrument/fir_filters_dask.py
 create mode 100644 lisainstrument/fir_filters_numpy.py

diff --git a/lisainstrument/fir_filters_dask.py b/lisainstrument/fir_filters_dask.py
new file mode 100644
index 0000000..f751727
--- /dev/null
+++ b/lisainstrument/fir_filters_dask.py
@@ -0,0 +1,201 @@
+"""Functions for applying FIR filters to dask arrays.
+
+To create a fir filter operating on dask arrays, use the `make_filter_fir_dask` function.
+"""
+
+from __future__ import annotations
+
+from typing import Callable, NewType, TypeAlias
+
+import dask
+import dask.array as da
+import numpy as np
+
+from lisainstrument.fir_filters_numpy import (
+    DefFilterFIR,
+    EdgeHandling,
+    FIRCoreOp,
+    NumpyArray1D,
+    SemiLocalMapType,
+    make_numpy_array_1d,
+    semilocal_map_numpy,
+)
+
+DaskArray1D = NewType("DaskArray1D", da.Array)
+
+
+def make_dask_array_1d(x: da.Array) -> DaskArray1D:
+    """Check that dask array is 1D or raise exception
+
+    This improves static type checking by promoting array dimensionality to a type.
+    """
+    if (dims := len(x.shape)) != 1:
+        msg = f"Expected dask array with one dimension, got {dims}"
+        raise ValueError(msg)
+    return DaskArray1D(x)
+
+
+def merge_small_chunks(sizes: tuple[int, ...], threshold: int) -> tuple[int, ...]:
+    """Recursively merge chunks below threshold size with neighbors
+
+    Arguments:
+        sizes: list of chunk sizes
+        threshold: minimum required chunk size
+
+    Returns:
+        List of chunk sizes after recursive merge
+    """
+
+    if len(sizes) < 2:
+        return sizes
+    sz = list(sizes)
+
+    for i in range(len(sz) - 1):
+        if max(sz[i], sz[i + 1]) < threshold:
+            m = sz[:i] + [sz[i] + sz[i + 1]] + sz[i + 2 :]
+            return merge_small_chunks(tuple(m), threshold)
+    for i in range(len(sz) - 1):
+        if min(sz[i], sz[i + 1]) < threshold:
+            m = sz[:i] + [sz[i] + sz[i + 1]] + sz[i + 2 :]
+            return merge_small_chunks(tuple(m), threshold)
+
+    return sizes
+
+
+def semilocal_map_dask(
+    op: SemiLocalMapType,
+    bound_left: EdgeHandling,
+    bound_right: EdgeHandling,
+    data: DaskArray1D,
+) -> DaskArray1D:
+    """Apply a semi local map to dask array and employ boundary conditions.
+
+    Arguments:
+        op: the semi local mapping
+        bound_left: Boundary treatment on left side
+        bound_right: Boundary treatment on right side
+        data: the 1D array to be mapped
+
+    Returns:
+        The mapped data. The size is the same as the input if both boundary
+        conditions are ZEROPAD. A boundary condition VALID reduces the output
+        size by the corresponding margin size of the semilocal map.
+    """
+
+    if op.margin_left < 0 or op.margin_right < 0:
+        msg = (
+            f"semilocal_map_dask: mappings with negative"
+            f" margin not supported (got left={op.margin_left}, right={op.margin_right})"
+        )
+        raise RuntimeError(msg)
+
+    margin = max(op.margin_left, op.margin_right)
+
+    def whole_op(d: np.ndarray) -> np.ndarray:
+        return semilocal_map_numpy(op, bound_left, bound_right, make_numpy_array_1d(d))
+
+    if data.size < margin:
+        delay_op = dask.delayed(whole_op)
+        nsize = data.size
+        if bound_left == EdgeHandling.VALID:
+            nsize -= op.margin_left
+        if bound_right == EdgeHandling.VALID:
+            nsize -= op.margin_right
+
+        res = da.from_delayed(delay_op(data), (nsize,), data.dtype)
+        return make_dask_array_1d(res)
+
+    def ext_op(d: np.ndarray) -> np.ndarray:
+        r: np.ndarray = op(make_numpy_array_1d(d))
+        if op.margin_left < margin:
+            r = r[margin - op.margin_left :]
+        if op.margin_right < margin:
+            r = r[: -(margin - op.margin_right)]
+        return r
+
+    if min(data.chunks[0]) < margin:
+        newsizes = merge_small_chunks(data.chunks[0], margin)
+        data = data.rechunk(chunks={0: newsizes})
+
+    ext = da.map_overlap(ext_op, data, depth={0: margin}, boundary=0, trim=False)
+
+    if bound_left == EdgeHandling.VALID and op.margin_left > 0:
+        ext = ext[op.margin_left :]
+    if bound_right == EdgeHandling.VALID and op.margin_right > 0:
+        ext = ext[: -op.margin_right]
+
+    return make_dask_array_1d(ext)
+
+
+FilterFirDaskType: TypeAlias = Callable[[da.Array], DaskArray1D]
+
+
+def filter_fir_dask(
+    fdef: DefFilterFIR,
+    bound_left: EdgeHandling,
+    bound_right: EdgeHandling,
+    data: da.Array,
+) -> DaskArray1D:
+    """Apply FIR filter to 1D dask array
+
+    Arguments:
+        fdef: The definition of the FIR filter
+        bound_left: Boundary treatment on left side
+        bound_right: Boundary treatment on right side
+        data: The 1D dask array to be filtered
+
+    Returns:
+        Filtered dask array
+
+    """
+    fmap = FIRCoreOp(fdef)
+    return semilocal_map_dask(fmap, bound_left, bound_right, make_dask_array_1d(data))
+
+
+def make_filter_fir_dask(
+    fdef: DefFilterFIR, bound_left: EdgeHandling, bound_right: EdgeHandling
+) -> FilterFirDaskType:
+    """Create a function that applies a given FIR filter to dask arrays,
+    employing the specified boundary treatment.
+
+    Arguments:
+        fdef: The definition of the FIR filter
+        bound_left: Boundary treatment on left side
+        bound_right: Boundary treatment on right side
+
+    Returns:
+        Function which accepts a single 1D dask array as input and returns
+        the filtered dask array.
+    """
+
+    fmap = FIRCoreOp(fdef)
+
+    def op(data: da.Array) -> DaskArray1D:
+        return semilocal_map_dask(
+            fmap, bound_left, bound_right, make_dask_array_1d(data)
+        )
+
+    return op
+
+
+def numpyfy_dask(
+    dafunc: Callable[[da.Array], DaskArray1D], chunks: int
+) -> Callable[[np.ndarray], NumpyArray1D]:
+    """Convert function operating on dask arrays to work with numpy arrays.
+
+    The dask function will be called with a chunked dask array created from
+    the numpy array, and then evaluate the result.
+
+    Arguments:
+        dafunc: function operating on a dask array
+        chunks: chunk size to be used internally
+    Returns:
+        function operating on numpy arrays
+    """
+
+    def func(x):
+        d = da.from_array(x, chunks=chunks)
+        r = dafunc(make_dask_array_1d(d)).compute()
+        return make_numpy_array_1d(r)
+
+    return func
diff --git a/lisainstrument/fir_filters_numpy.py b/lisainstrument/fir_filters_numpy.py
new file mode 100644
index 0000000..03ce169
--- /dev/null
+++ b/lisainstrument/fir_filters_numpy.py
@@ -0,0 +1,336 @@
+"""Functions for applying FIR filters to numpy arrays
+
+To create a fir filter operating on numpy arrays, use the `make_filter_fir_numpy` function.
+"""
+
+from __future__ import annotations
+
+from enum import Enum
+from typing import TYPE_CHECKING, Callable, NewType, Protocol, TypeAlias
+
+import numpy as np
+from attrs import field, frozen
+from scipy.signal import convolve, firwin, kaiserord
+
+if TYPE_CHECKING:
+    from numpy.typing import ArrayLike
+
+NumpyArray1D = NewType("NumpyArray1D", np.ndarray)
+
+
+def make_numpy_array_1d(x: np.ndarray) -> NumpyArray1D:
+    """Check that numpy array is 1D or raise exception
+
+    This improves static type checking by promoting array dimensionality to a type.
+    """
+    if (dims := len(x.shape)) != 1:
+        msg = f"Expected numpy array with one dimension, got {dims}"
+        raise ValueError(msg)
+    return NumpyArray1D(x)
+
+
+@frozen
+class DefFilterFIR:
+    r"""This dataclass defines a FIR filter
+
+    The finite impulse response filter is given by
+
+    $$
+    y_a = \sum_{i=0}^{L-1} K_i x_{a + i + D}
+        = \sum_{k=a+D}^{a+D+L-1} x_{k} K_{k-a-D}
+    $$
+
+    Note that there are different conventions for the order of coefficients.
+    In standard convolution notation, the convolution kernel is given by the
+    coefficients above in reversed order.
+
+
+    Attributes:
+        filter_coeffs: Filter coefficients $K_i$
+        offset: Offset $D$
+    """
+
+    filter_coeffs: list[float] = field(converter=lambda lst: [float(e) for e in lst])
+
+    offset: int = field()
+
+    @filter_coeffs.validator
+    def check(self, _, value):
+        """Validate filter coefficients"""
+        if len(value) < 1:
+            msg = "FIR filter coefficients array needs at least one entry"
+            raise ValueError(msg)
+
+    @property
+    def gain(self) -> float:
+        r"""The gain factor for a constant signal
+
+        The gain factor is defined by
+        $$
+        \sum_{i=0}^{L-1} K_i
+        $$
+        """
+        return sum(self.filter_coeffs)
+
+    @property
+    def length(self) -> int:
+        """The length of the domain of dependence of a given output sample
+        on the input samples. This does not take into account zeros anywhere
+        in the coefficients. Thus the result is simply the number of
+        coefficients.
+        """
+        return len(self.filter_coeffs)
+
+    @property
+    def domain_of_dependence(self) -> tuple[int, int]:
+        r"""The domain of dependence
+
+        A point with index $a$ in the output sequence depends on
+        indices $a+D \ldots a+D+L-1$. This property provides the
+        domain of dependence for $a=0$.
+        """
+
+        return (self.offset, self.length - 1 + self.offset)
+
+    def __str__(self) -> str:
+        return (
+            f"{self.__class__.__name__} \n"
+            f"Length           = {self.length} \n"
+            f"Offset           = {self.offset} \n"
+            f"Gain             = {self.gain} \n"
+            f"Coefficients     = {self.filter_coeffs} \n"
+        )
+
+
+def make_fir_causal_kaiser(
+    fsamp: float, attenuation: float, freq1: float, freq2: float
+) -> DefFilterFIR:
+    """Create FIR filter definition for Kaiser window with given attenuation and transition band
+
+    This creates FIR coefficients from a Kaiser window specified by the
+    desired properties of attenuation and transition band. The filter offset
+    is set to describe a completely causal filter.
+
+    Arguments:
+        fsamp: Sampling rate [Hz]. Must be strictly positive.
+        attenuation: Required stop-band attenuation [dB]. Has to be greater zero.
+        freq1: Start of transition band [Hz]
+        freq2: End of transition band / start of stop band [Hz]. Has to be below Nyquist frequency.
+
+    Returns:
+        The FIR definition
+    """
+
+    if fsamp <= 0:
+        msg = f"make_fir_causal_kaiser: sample rate must be greater zero, got {fsamp=}"
+        raise ValueError(msg)
+
+    if freq2 <= freq1:
+        msg = f"make_fir_causal_kaiser: end of transition band cannot be below beginning ({freq1=}, {freq2=})"
+        raise ValueError(msg)
+
+    if freq1 <= 0:
+        msg = f"make_fir_causal_kaiser: start of transition band has to be greater zero, got {freq1=}"
+        raise ValueError(msg)
+
+    if freq2 >= fsamp / 2:
+        msg = f"make_fir_causal_kaiser: end of transition band has to be below Nyquist, got {freq1=}"
+        raise ValueError(msg)
+
+    if attenuation <= 0:
+        msg = f"make_fir_causal_kaiser: attenuation must be strictly positive, got {attenuation}"
+        raise ValueError(msg)
+
+    nyquist = fsamp / 2
+    numtaps, beta = kaiserord(attenuation, (freq2 - freq1) / nyquist)
+    taps = firwin(numtaps, (freq1 + freq2) / fsamp, window=("kaiser", beta))
+
+    return DefFilterFIR(filter_coeffs=taps, offset=1 - len(taps))
+
+
+def make_fir_causal_normal_from_coeffs(coeffs: ArrayLike) -> DefFilterFIR:
+    """Create causal, unity-gain FIR filter definition from coefficients.
+
+    This creates FIR coefficients from FIR coefficients. The latter are
+    normalized first such that the filter has unity gain. The filter offset
+    is set to describe a completely causal filter.
+
+    Arguments:
+        coeffs: filter coefficients
+
+    Returns:
+        The FIR definition
+    """
+    coeffs = np.array(coeffs)
+    norm = np.sum(coeffs)
+    if norm == 0:
+        msg = "make_fir_causal_normal_from_coeffs: cannot normalize zero-gain coefficients"
+        raise ValueError(msg)
+    normed_coeffs = coeffs / norm
+    return DefFilterFIR(filter_coeffs=normed_coeffs, offset=1 - len(coeffs))
+
+
+class SemiLocalMapType(Protocol):
+    """Protocol for semi-local maps of 1D numpy arrays
+
+    This is used to describe array operations which require boundary points
+    """
+
+    @property
+    def margin_left(self) -> int:
+        """How many points at the left boundary are missing in the output"""
+
+    @property
+    def margin_right(self) -> int:
+        """How many points at the right boundary are missing in the output"""
+
+    def __call__(self, data_in: NumpyArray1D) -> NumpyArray1D:
+        """Apply the array operation"""
+
+
+class FIRCoreOp(SemiLocalMapType):
+    """Function class applying FIR to numpy array
+
+    This does not include boundary tratment and only returns valid
+    points. It does provide the margin sizes corresponding to invalid
+    boundary points on each side.
+
+    """
+
+    def __init__(self, fdef: DefFilterFIR):
+        self._margin_left = -fdef.domain_of_dependence[0]
+        self._margin_right = fdef.domain_of_dependence[1]
+        self._convolution_kernel = np.array(fdef.filter_coeffs[::-1], dtype=np.float64)
+
+        if self._margin_left < 0 or self._margin_right < 0:
+            msg = f"FilterFirNumpy instantiated with unsupported domain of dependence {fdef.domain_of_dependence}"
+            raise ValueError(msg)
+
+    @property
+    def margin_left(self) -> int:
+        """How many points at the left boundary are missing in the output"""
+        return self._margin_left
+
+    @property
+    def margin_right(self) -> int:
+        """How many points at the right boundary are missing in the output"""
+        return self._margin_right
+
+    def __call__(self, data_in: NumpyArray1D) -> NumpyArray1D:
+        """Apply FIR filter using convolution
+
+        Only valid points are returned, i.e. points for which the filter stencil fully
+        overlaps with the data. No zero padding or similar is applied.
+
+        Arguments:
+            data_in: 1D numpy array to be filtered
+
+        Returns:
+            Filtered array. Its size is smaller than the input by `m̀argin_left+margin_right`.
+
+        """
+        return convolve(data_in, self._convolution_kernel, mode="valid")
+
+
+class EdgeHandling(Enum):
+    """Enum for various methods of handling boundaries in filters.
+
+    VALID:   Use only valid points that can be computed from the given data,
+             without zero padding or similar
+    ZEROPAD: Compute output for every input point, pad input with suitable
+             number of zeros before applying filter.
+    """
+
+    VALID = 1
+    ZEROPAD = 2
+
+
+def semilocal_map_numpy(
+    op: SemiLocalMapType,
+    bound_left: EdgeHandling,
+    bound_right: EdgeHandling,
+    data: NumpyArray1D,
+) -> NumpyArray1D:
+    """Apply a semi local map to numpy array and employ boundary conditions.
+
+    Arguments:
+        op: the semi local mapping
+        bound_left: Boundary treatment on left side
+        bound_right: Boundary treatment on right side
+        data: the 1D array to be mapped
+
+    Returns:
+        The mapped data. The size is the same as the input if both boundary
+        conditions are ZEROPAD. A boundary condition VALID reduces the output
+        size by the corresponding margin size of the semilocal map.
+    """
+
+    if op.margin_left < 0 or op.margin_right < 0:
+        msg = (
+            f"semilocal_map_numpy: mappings with negative"
+            f" margin not supported (got left={op.margin_left}, right={op.margin_right})"
+        )
+        raise RuntimeError(msg)
+
+    pad_left = op.margin_left if bound_left == EdgeHandling.ZEROPAD else 0
+    pad_right = op.margin_right if bound_right == EdgeHandling.ZEROPAD else 0
+    if pad_left != 0 or pad_right != 0:
+        data = make_numpy_array_1d(
+            np.pad(
+                data,
+                pad_width=(pad_left, pad_right),
+                mode="constant",
+                constant_values=0,
+            )
+        )
+    return op(data)
+
+
+FilterFirNumpyType: TypeAlias = Callable[[np.ndarray], NumpyArray1D]
+
+
+def filter_fir_numpy(
+    fdef: DefFilterFIR,
+    bound_left: EdgeHandling,
+    bound_right: EdgeHandling,
+    data: np.ndarray,
+) -> NumpyArray1D:
+    """Apply FIR filter to 1D numpy array
+
+    Arguments:
+        fdef: The definition of the FIR filter
+        bound_left: Boundary treatment on left side
+        bound_right: Boundary treatment on right side
+        data: The 1D numpy array to be filtered
+
+    Returns:
+        Filtered 1D numpy array.
+    """
+    fmap = FIRCoreOp(fdef)
+    return semilocal_map_numpy(fmap, bound_left, bound_right, make_numpy_array_1d(data))
+
+
+def make_filter_fir_numpy(
+    fdef: DefFilterFIR, bound_left: EdgeHandling, bound_right: EdgeHandling
+) -> FilterFirNumpyType:
+    """Create a function that applies a given FIR filter to numpy arrays,
+    employing the specified boundary treatment.
+
+    Arguments:
+        fdef: The definition of the FIR filter
+        bound_left: Boundary treatment on left side
+        bound_right: Boundary treatment on right side
+
+    Returns:
+        Function which accepts a single 1D numpy array as input and returns
+        the filtered array.
+    """
+
+    fmap = FIRCoreOp(fdef)
+
+    def op(data: np.ndarray) -> NumpyArray1D:
+        return semilocal_map_numpy(
+            fmap, bound_left, bound_right, make_numpy_array_1d(data)
+        )
+
+    return op
-- 
GitLab