Skip to content
Snippets Groups Projects
test_shift_inversion_numpy.py 3.02 KiB
Newer Older
"""Unit tests for module shift_inversion_numpy"""

import numpy as np
import pytest

from lisainstrument.shift_inversion_numpy import make_shift_inverse_lagrange_numpy


def test_shift_inversion_numpy():
    """Test basic functioning of shift_inversion_numpy on analytic coordinate transform"""
    order = 31
    nsamp = 3000
    fsample = 16.0
    dt = 1 / fsample

    f_mod = 0.005
    a_mod = 1e-2 / (2 * np.pi * f_mod)
    op_np = make_shift_inverse_lagrange_numpy(
        order=order,
        fsample=fsample,
        max_abs_shift=a_mod * 1.01,
        max_iter=max_it,
        tolerance=tol,
    )

    ai_ex = dx_from_x(xi - ai_np)

    valid_range = slice(op_np.margin_left, -op_np.margin_right)
    assert ai_np[valid_range] == pytest.approx(ai_ex[valid_range], abs=tol, rel=0)


def legacy_invert_scet_wrt_tps(
    scet_wrt_tps: np.ndarray,
    clockinv_tolerance: float,
    physics_fs: float,
    interpolation_order: int,
    clockinv_maxiter: int = 5,
):
    """Legacy shift inversion algorithm taken out of Instrument class for comparison

    Slight adaptions to make it work without Instrument instance, and logging removed.

    Args:
        scet_wrt_tps: array of SCETs with respect to TPS
        clockinv_tolerance: tolerance for result
        physics_fs: sample rate
        interpolation_order: Lagrange interpolation order
    """

    # Drop samples at the edges to compute error
    edge = min(100, len(scet_wrt_tps) // 2 - 1)
    error = 0

    niter = 0
    next_inverse = scet_wrt_tps
    while not niter or error > clockinv_tolerance:
        if niter >= clockinv_maxiter:
            msg = "Legacy fixed point iter did not converge"
            raise RuntimeError(msg)
        inverse = next_inverse

        next_inverse = dsp.timeshift(
            scet_wrt_tps, -inverse * physics_fs, interpolation_order
        )
        # originally,
        # next_inverse = self.interpolate(scet_wrt_tps, -inverse)

        error = np.max(np.abs((inverse - next_inverse)[edge:-edge]))
        niter += 1

    return inverse


def test_shift_inversion_legacy():
    """Compare shift_inversion_numpy to original algorithm"""
    order = 31
    nsamp = 3000
    fsample = 16.0
    dt = 1 / fsample

    f_mod = 0.005
    a_mod = 1e-2 / (2 * np.pi * f_mod)
    max_it = 5
    tol = 1e-10

    def dx_from_x(x):
        return np.sin(2 * np.pi * f_mod * x) * a_mod

    xi = np.arange(nsamp) * dt
    dxi = dx_from_x(xi)

    op_np = make_shift_inverse_lagrange_numpy(
        order=order,
        fsample=fsample,
        max_abs_shift=a_mod * 1.01,
        max_iter=max_it,
        tolerance=tol,
    )

    ai_np = op_np(dxi)

    ai_leg = legacy_invert_scet_wrt_tps(dxi, tol, fsample, order, max_it)

    valid_range = slice(op_np.margin_left, -op_np.margin_right)
    assert ai_np[valid_range] == pytest.approx(ai_leg[valid_range], abs=tol, rel=0)