From d34f3a55cbca2808368d11cfdad7d1a7e7e5e8a5 Mon Sep 17 00:00:00 2001
From: Jean-Baptiste Bayle <j2b.bayle@gmail.com>
Date: Mon, 26 Dec 2022 15:31:29 +0100
Subject: [PATCH] Allow multiplication of ForEachObject with arrays

---
 lisainstrument/containers.py |  4 ++++
 tests/test_containers.py     | 27 +++++++++++++++++++++++++++
 2 files changed, 31 insertions(+)

diff --git a/lisainstrument/containers.py b/lisainstrument/containers.py
index 19856e5..2d11729 100644
--- a/lisainstrument/containers.py
+++ b/lisainstrument/containers.py
@@ -20,6 +20,10 @@ logger = logging.getLogger(__name__)
 class ForEachObject(abc.ABC):
     """Abstract class which represents a dictionary holding a value for each object."""
 
+    # Used to bypass Numpy's vectorization and use
+    # this class `__rmul__()` method for multiplication from both sides
+    __array_priority__ = 10000
+
     def __init__(self, values, concurrent=False):
         """Initialize an object with a value or a function of the mosa index.
 
diff --git a/tests/test_containers.py b/tests/test_containers.py
index 32abe71..c033d05 100644
--- a/tests/test_containers.py
+++ b/tests/test_containers.py
@@ -195,6 +195,33 @@ def test_multiplication():
         result = object_2 * object_3
 
 
+def test_multiplication_array():
+    """Check that we can multiply a ForEachObject instance with an array."""
+
+    array = np.random.normal(size=10)
+    object_1 = ForEachAB({'A': 1, 'B': 2})
+    object_2 = ForEachAB({
+        'A': np.random.normal(size=10),
+        'B': np.random.normal(size=10),
+    })
+
+    result = object_1 * array
+    assert np.all(result['A'] == array)
+    assert np.all(result['B'] == 2 * array)
+
+    result = array * object_1
+    assert np.all(result['A'] == array)
+    assert np.all(result['B'] == 2 * array)
+
+    result = object_2 * array
+    assert np.all(result['A'] == object_2['A'] * array)
+    assert np.all(result['B'] == object_2['B'] * array)
+
+    result = array * object_2
+    assert np.all(result['A'] == object_2['A'] * array)
+    assert np.all(result['B'] == object_2['B'] * array)
+
+
 def test_floor_division():
     """Check that we can apply floor division on two ForEachObject subclasses, or a scalar."""
 
-- 
GitLab