From d34f3a55cbca2808368d11cfdad7d1a7e7e5e8a5 Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Bayle <j2b.bayle@gmail.com> Date: Mon, 26 Dec 2022 15:31:29 +0100 Subject: [PATCH] Allow multiplication of ForEachObject with arrays --- lisainstrument/containers.py | 4 ++++ tests/test_containers.py | 27 +++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/lisainstrument/containers.py b/lisainstrument/containers.py index 19856e5..2d11729 100644 --- a/lisainstrument/containers.py +++ b/lisainstrument/containers.py @@ -20,6 +20,10 @@ logger = logging.getLogger(__name__) class ForEachObject(abc.ABC): """Abstract class which represents a dictionary holding a value for each object.""" + # Used to bypass Numpy's vectorization and use + # this class `__rmul__()` method for multiplication from both sides + __array_priority__ = 10000 + def __init__(self, values, concurrent=False): """Initialize an object with a value or a function of the mosa index. diff --git a/tests/test_containers.py b/tests/test_containers.py index 32abe71..c033d05 100644 --- a/tests/test_containers.py +++ b/tests/test_containers.py @@ -195,6 +195,33 @@ def test_multiplication(): result = object_2 * object_3 +def test_multiplication_array(): + """Check that we can multiply a ForEachObject instance with an array.""" + + array = np.random.normal(size=10) + object_1 = ForEachAB({'A': 1, 'B': 2}) + object_2 = ForEachAB({ + 'A': np.random.normal(size=10), + 'B': np.random.normal(size=10), + }) + + result = object_1 * array + assert np.all(result['A'] == array) + assert np.all(result['B'] == 2 * array) + + result = array * object_1 + assert np.all(result['A'] == array) + assert np.all(result['B'] == 2 * array) + + result = object_2 * array + assert np.all(result['A'] == object_2['A'] * array) + assert np.all(result['B'] == object_2['B'] * array) + + result = array * object_2 + assert np.all(result['A'] == object_2['A'] * array) + assert np.all(result['B'] == object_2['B'] * array) + + def test_floor_division(): """Check that we can apply floor division on two ForEachObject subclasses, or a scalar.""" -- GitLab