Skip to content
Snippets Groups Projects
Commit 94a1b710 authored by Wassim KABALAN's avatar Wassim KABALAN
Browse files

quick fixes

parent f9e598a9
No related branches found
No related tags found
No related merge requests found
......@@ -41,7 +41,7 @@ def _global_to_local_size(nc: int):
return [nc // pdims[0], nc // pdims[1], nc]
def fttk(nc: int) -> list:
def fttk(nc: int):
"""
Generate Fourier transform wave numbers for a given mesh.
......
......@@ -23,16 +23,12 @@ from ._jaxdecomp import (HALO_COMM_MPI, HALO_COMM_MPI_BLOCKING, HALO_COMM_NCCL,
for name, fn in _jaxdecomp.registrations().items():
xla_client.register_custom_call_target(name, fn, platform="gpu")
from .padding import slice_pad, slice_unpad
__all__ = [
"init",
"finalize",
"get_pencil_info",
"get_autotuned_config",
"make_config",
"slice_pad",
"slice_unpad",
"TRANSPOSE_XY",
"TRANSPOSE_YX",
"TRANSPOSE_YZ",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment