Skip to content
Snippets Groups Projects
Commit 06f26bb6 authored by Jean-Baptiste Bayle's avatar Jean-Baptiste Bayle
Browse files

Add index transformation routines

parent 546a83a2
No related branches found
No related tags found
No related merge requests found
This commit is part of merge request !46. Comments created here will be created in the context of that merge request.
...@@ -7,6 +7,7 @@ Authors: ...@@ -7,6 +7,7 @@ Authors:
Jean-Baptiste Bayle <j2b.bayle@gmail.com> Jean-Baptiste Bayle <j2b.bayle@gmail.com>
""" """
import numpy as np
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -91,3 +92,80 @@ def sc2index(mosa): ...@@ -91,3 +92,80 @@ def sc2index(mosa):
SC1 = sc2index(1) SC1 = sc2index(1)
SC2 = sc2index(2) SC2 = sc2index(2)
SC3 = sc2index(3) SC3 = sc2index(3)
def transform(x, mapping):
"""Transform indices or arrays with a mapping of SC.
If `x` is an MOSA or SC index, it is transformed according to the mapping of SC indices.
If `x` is a MOSA or SC array, the first axis is re-ordered according to the mapping.
Args:
x: MOSA or SC index (as a string or integer), or MOSA or SC array
mapping: dictionary describing the mapping from SC indices to new SC indices
"""
# Check that we have a complete mapping
sc_set = set(SC)
mapping_str = {str(key): str(mapping[key]) for key in mapping}
if set(mapping_str.keys()) != sc_set:
raise ValueError(f"incomplete mapping '{mapping}', should have '{sc_set}' as keys")
if set(mapping_str.values()) != sc_set:
raise ValueError(f"incomplete mapping '{mapping}', should have '{sc_set}' as values")
# Transform MOSA array
if isinstance(x, np.ndarray) and x.shape[0] == 6:
transformed_mosas = [mosa2index(transform(mosa, mapping)) for mosa in MOSAS]
return x[transformed_mosas]
# Transform SC array
if isinstance(x, np.ndarray) and x.shape[0] == 3:
transformed_sc = [sc2index(transform(sc, mapping)) for sc in SC]
return x[transformed_sc]
# Raise an error if we have an array of invalid shape
if isinstance(x, np.ndarray):
raise TypeError(f"invalid MOSA or SC array shape '{x.shape}'")
# Check that index is correct
index = str(x)
if index not in MOSAS and index not in SC:
raise IndexError(f"invalid MOSA or SC index '{index}'")
# Transform index
mapped_chars = [mapping_str[char] for char in index]
return ''.join(mapped_chars)
def rotate(x, nrot=1):
"""Rotate SC indices in clockwise direction.
A rotation is a circular permutation of indices.
Args:
x: MOSA or SC index (as a string or integer), or MOSA or SC array
nrot: number of 120-degree rotations
"""
nrot = nrot % 3
if nrot == 0:
mapping = {1: 1, 2: 2, 3: 3}
elif nrot == 1:
mapping = {1: 2, 2: 3, 3: 1}
elif nrot == 2:
mapping = {1: 3, 2: 1, 3: 2}
return transform(x, mapping)
def reflect(x, axis):
"""Reflect SC indices around an axis.
A reflection leaves the axis unchanged, and swaps the others.
Args:
axis: SC index around which the reflection is performed
"""
if axis == 1:
mapping = {1: 1, 2: 3, 3: 2}
elif axis == 2:
mapping = {1: 3, 2: 2, 3: 1}
elif axis == 3:
mapping = {1: 2, 2: 1, 3: 3}
else:
raise ValueError(f"invalid SC index '{axis}'")
return transform(x, mapping)
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# pylint: disable=missing-module-docstring # pylint: disable=missing-module-docstring
import os
import numpy as np
from pytest import raises from pytest import raises
from lisainstrument.indexing import * from lisainstrument.indexing import *
...@@ -11,6 +9,7 @@ from lisainstrument.indexing import * ...@@ -11,6 +9,7 @@ from lisainstrument.indexing import *
def test_index2mosa(): def test_index2mosa():
"""Test that we return the correct MOSA, or raise an error.""" """Test that we return the correct MOSA, or raise an error."""
assert index2mosa(0) == '12' assert index2mosa(0) == '12'
assert index2mosa(1) == '23' assert index2mosa(1) == '23'
assert index2mosa(2) == '31' assert index2mosa(2) == '31'
...@@ -26,6 +25,7 @@ def test_index2mosa(): ...@@ -26,6 +25,7 @@ def test_index2mosa():
def test_mosa2index(): def test_mosa2index():
"""Test that we return the correct array index, or raise an error.""" """Test that we return the correct array index, or raise an error."""
assert mosa2index('12') == 0 assert mosa2index('12') == 0
assert mosa2index('23') == 1 assert mosa2index('23') == 1
assert mosa2index('31') == 2 assert mosa2index('31') == 2
...@@ -46,3 +46,100 @@ def test_mosa2index(): ...@@ -46,3 +46,100 @@ def test_mosa2index():
mosa2index('15') mosa2index('15')
with raises(ValueError): with raises(ValueError):
mosa2index(0) mosa2index(0)
def test_transform_mapping():
"""Test that we check that mapping is complete."""
transform(1, {1: 1, 2: 2, 3: 3})
transform(1, {1: 2, 2: 1, 3: 3})
with raises(ValueError):
transform(1, {1: 1, 2: 1, 3: 3})
with raises(ValueError):
transform(1, {1: 1, 2: 2})
with raises(ValueError):
transform(1, {1: 1, 2: 2, 3: 3, 4: 4})
def test_transform_check_index():
"""Test that we check that index is correct."""
transform(1, {1: 1, 2: 2, 3: 3})
transform('1', {1: 2, 2: 1, 3: 3})
transform(12, {1: 2, 2: 1, 3: 3})
transform('12', {1: 2, 2: 1, 3: 3})
with raises(IndexError):
transform(0, {1: 1, 2: 2, 3: 3})
with raises(IndexError):
transform(4, {1: 1, 2: 2, 3: 3})
with raises(IndexError):
transform('4', {1: 1, 2: 2, 3: 3})
with raises(IndexError):
transform('something', {1: 1, 2: 2, 3: 3})
with raises(IndexError):
transform(11, {1: 1, 2: 2, 3: 3})
with raises(IndexError):
transform(114, {1: 1, 2: 2, 3: 3})
with raises(IndexError):
transform('11', {1: 1, 2: 2, 3: 3})
def test_transform_check_array():
"""Test that we check that array has correct shape."""
transform(np.zeros((3, )), {1: 1, 2: 2, 3: 3})
transform(np.zeros((3, 10)), {1: 1, 2: 2, 3: 3})
transform(np.zeros((3, 10, 5)), {1: 1, 2: 2, 3: 3})
transform(np.zeros((6, )), {1: 1, 2: 2, 3: 3})
transform(np.zeros((6, 10)), {1: 1, 2: 2, 3: 3})
transform(np.zeros((6, 10, 5)), {1: 1, 2: 2, 3: 3})
with raises(TypeError):
transform(np.zeros((1, )), {1: 1, 2: 2, 3: 3})
with raises(TypeError):
transform(np.zeros((10, )), {1: 1, 2: 2, 3: 3})
with raises(TypeError):
transform(np.zeros((10, 3, )), {1: 1, 2: 2, 3: 3})
with raises(TypeError):
transform(np.zeros((10, 6)), {1: 1, 2: 2, 3: 3})
def test_rotate_index():
"""Test that we can rotate indices."""
assert rotate('1', nrot=0) == '1'
assert rotate('1', nrot=1) == '2'
assert rotate('1', nrot=2) == '3'
assert rotate('1', nrot=3) == '1'
assert rotate('2', nrot=0) == '2'
assert rotate('2', nrot=1) == '3'
assert rotate('2', nrot=2) == '1'
assert rotate('2', nrot=3) == '2'
assert rotate('3', nrot=0) == '3'
assert rotate('3', nrot=1) == '1'
assert rotate('3', nrot=2) == '2'
assert rotate('3', nrot=3) == '3'
assert rotate('12', nrot=0) == '12'
assert rotate('12', nrot=1) == '23'
assert rotate('12', nrot=2) == '31'
assert rotate('12', nrot=3) == '12'
assert rotate('23', nrot=1) == '31'
assert rotate('21', nrot=2) == '13'
def test_rotate_array():
"""Test that we can rotate arrays."""
assert np.all(rotate(np.array([1, 2, 3]), nrot=0) == [1, 2, 3])
assert np.all(rotate(np.array([1, 2, 3]), nrot=1) == [2, 3, 1])
assert np.all(rotate(np.array([1, 2, 3]), nrot=2) == [3, 1, 2])
assert np.all(rotate(np.array([1, 2, 3]), nrot=3) == [1, 2, 3])
assert np.all(rotate(np.array([12, 23, 31, 13, 32, 21]), nrot=0) == [12, 23, 31, 13, 32, 21])
assert np.all(rotate(np.array([12, 23, 31, 13, 32, 21]), nrot=1) == [23, 31, 12, 21, 13, 32])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment