Skip to content
Snippets Groups Projects
Commit e508f7d1 authored by Wassim Kabalan's avatar Wassim Kabalan
Browse files

fix

parent 92f477c9
No related tags found
No related merge requests found
......@@ -124,7 +124,8 @@ def get_transpose_order(fft_type: FftType, mesh: Optional[Mesh] = None) -> tuple
"""
if not jaxdecomp.config.transpose_axis_contiguous:
return (0, 1, 2)
else:
if mesh is None:
match fft_type:
case FftType.FFT | FftType.RFFT:
return (1, 2, 0)
......@@ -133,6 +134,31 @@ def get_transpose_order(fft_type: FftType, mesh: Optional[Mesh] = None) -> tuple
case _:
raise TypeError("Only complex FFTs are currently supported through pfft.")
pencil_type = get_pencil_type_from_mesh(mesh)
match fft_type:
case FftType.FFT:
match pencil_type:
case _jaxdecomp.SLAB_YZ:
return (1, 2, 0)
case _jaxdecomp.SLAB_XY | _jaxdecomp.PENCILS:
return (1, 2, 0)
case _jaxdecomp.NO_DECOMP:
return (0, 1, 2)
case _:
raise TypeError("Unknown pencil type")
case FftType.IFFT:
match pencil_type:
case _jaxdecomp.SLAB_YZ:
return (2, 0, 1)
case _jaxdecomp.SLAB_XY | _jaxdecomp.PENCILS:
return (2, 0, 1)
case _jaxdecomp.NO_DECOMP:
return (0, 1, 2)
case _:
raise TypeError("Unknown pencil type")
case _:
raise TypeError("Only complex FFTs are currently supported through pfft.")
def get_lowering_args(fft_type: FftType, global_shape: GdimsType, mesh: Mesh) -> tuple[PdimsType, GdimsType]:
"""Returns the lowering arguments based on FFT type, global shape, and mesh.
......
......@@ -77,7 +77,7 @@ class TestFFTs:
transpose_back = [1, 2, 0]
if not local_transpose:
transpose_back = [0, 1, 2]
elif jaxdecomp.config.transpose_axis_contiguous_2:
else:
transpose_back = [1, 2, 0]
# Check reconstructed array
......@@ -127,7 +127,7 @@ class TestFFTsGrad:
transpose_back = [1, 2, 0]
if not local_transpose:
transpose_back = [0, 1, 2]
elif jaxdecomp.config.transpose_axis_contiguous_2:
else:
transpose_back = [1, 2, 0]
print("*" * 80)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment