diff --git a/lisainstrument/containers.py b/lisainstrument/containers.py index 19856e5b6355b8263fc2ce32abbe66fef614bc03..2d11729aa4bd1d2dc09ea07921a3399bd3828ea8 100644 --- a/lisainstrument/containers.py +++ b/lisainstrument/containers.py @@ -20,6 +20,10 @@ logger = logging.getLogger(__name__) class ForEachObject(abc.ABC): """Abstract class which represents a dictionary holding a value for each object.""" + # Used to bypass Numpy's vectorization and use + # this class `__rmul__()` method for multiplication from both sides + __array_priority__ = 10000 + def __init__(self, values, concurrent=False): """Initialize an object with a value or a function of the mosa index. diff --git a/tests/test_containers.py b/tests/test_containers.py index 32abe7113158def261505539379e4f00e23c29bd..c033d057d2bfd621bf39f6486f276f196f4b7fbe 100644 --- a/tests/test_containers.py +++ b/tests/test_containers.py @@ -195,6 +195,33 @@ def test_multiplication(): result = object_2 * object_3 +def test_multiplication_array(): + """Check that we can multiply a ForEachObject instance with an array.""" + + array = np.random.normal(size=10) + object_1 = ForEachAB({'A': 1, 'B': 2}) + object_2 = ForEachAB({ + 'A': np.random.normal(size=10), + 'B': np.random.normal(size=10), + }) + + result = object_1 * array + assert np.all(result['A'] == array) + assert np.all(result['B'] == 2 * array) + + result = array * object_1 + assert np.all(result['A'] == array) + assert np.all(result['B'] == 2 * array) + + result = object_2 * array + assert np.all(result['A'] == object_2['A'] * array) + assert np.all(result['B'] == object_2['B'] * array) + + result = array * object_2 + assert np.all(result['A'] == object_2['A'] * array) + assert np.all(result['B'] == object_2['B'] * array) + + def test_floor_division(): """Check that we can apply floor division on two ForEachObject subclasses, or a scalar."""