Skip to content
Snippets Groups Projects
Commit ba3778f7 authored by Wassim KABALAN's avatar Wassim KABALAN
Browse files

add fft tests

parent e7c74097
No related branches found
No related tags found
No related merge requests found
......@@ -232,3 +232,81 @@ class TestFFTsGrad:
global_shapes) # Test cubes, non-cubes and primes
def test_jax_grad(self, pdims, global_shape, local_transpose):
self.run_test(pdims, global_shape, local_transpose, backend="jax")
class TestFFTFreq:
def run_test(self, pdims, global_shape, local_transpose, backend):
print("*" * 80)
print(
f"Testing with pdims {pdims} and global shape {global_shape} and local transpose {local_transpose}"
)
jaxdecomp.config.update('transpose_axis_contiguous', local_transpose)
if not local_transpose:
pytest.skip(reason="Not implemented yet")
global_array, mesh = create_spmd_array(global_shape, pdims)
# Perform distributed gradient kernel
karray = jaxdecomp.fft.pfft3d(global_array, backend=backend)
kvec = jaxdecomp.fftfreq3d(karray)
k_gradients = [k * karray for k in kvec]
gradients = [
jaxdecomp.fft.pifft3d(grad, backend=backend) for grad in k_gradients
]
gathered_gradients = [all_gather(grad) for grad in gradients]
# perform local gradient kernel
gathered_array = all_gather(global_array)
jax_karray = jnp.fft.fftn(gathered_array)
kz, ky, kx = [
jnp.fft.fftfreq(jax_karray.shape[i]) * 2 * jnp.pi for i in range(3)
]
kz = kz.reshape(-1, 1, 1)
ky = ky.reshape(1, -1, 1)
kx = kx.reshape(1, 1, -1)
kvec = [kz, ky, kx]
jax_k_gradients = [k * jax_karray for k in kvec]
jax_gradients = [jnp.fft.ifftn(grad) for grad in jax_k_gradients]
# Check the gradients
for i in range(3):
assert_allclose(
jax_gradients[i], gathered_gradients[i], rtol=1e-5, atol=1e-5)
print(f"Gradient check OK!")
# Trigger rejit in case local transpose is switched
jax.clear_caches()
@pytest.mark.skipif(not is_on_cluster(), reason="Only run on cluster")
# Cartesian product tests
@pytest.mark.parametrize(
"local_transpose",
local_transpose) # Test with and without local transpose
@pytest.mark.parametrize("pdims",
decomp) # Test with Slab and Pencil decompositions
@pytest.mark.parametrize("global_shape",
global_shapes) # Test cubes, non-cubes and primes
def test_cudecomp_fft(self, pdims, global_shape, local_transpose):
self.run_test(pdims, global_shape, local_transpose, backend="cuDecomp")
# Cartesian product tests
@pytest.mark.parametrize(
"local_transpose",
local_transpose) # Test with and without local transpose
@pytest.mark.parametrize("pdims",
decomp) # Test with Slab and Pencil decompositions
@pytest.mark.parametrize("global_shape",
global_shapes) # Test cubes, non-cubes and primes
def test_jax_fft(self, pdims, global_shape, local_transpose):
self.run_test(pdims, global_shape, local_transpose, backend="jax")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment