Commits on Source (9)
name: Tests
- main
......@@ -19,22 +19,20 @@ jobs:
- name: Checkout Source
uses: actions/checkout@v2.3.1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[tests]
pip install .[test]
- name: Run tests
run: |
cd tests
export XLA_FLAGS='--xla_force_host_platform_device_count=8'
pytest -v
[submodule "third_party/cuDecomp"]
path = third_party/cuDecomp
url =
[submodule "pybind11"]
path = pybind11
url =
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
find_package(pybind11 CONFIG REQUIRED)
# Check for CUDA
# 70: Volta, 80: Ampere, 89: RTX 4060
set(CUDECOMP_CUDA_CC_LIST "70;80;89" CACHE STRING "List of CUDA compute capabilities to build cuDecomp for.")
# Add pybind11 and cuDecomp subdirectories
set_target_properties(_jaxdecomp PROPERTIES LINKER_LANGUAGE CXX)
target_compile_definitions(_jaxdecomp PRIVATE JD_CUDECOMP_BACKEND)
pybind11_add_module(_jaxdecomp src/
target_include_directories(_jaxdecomp PRIVATE ${CMAKE_CURRENT_LIST_DIR}/include)
target_compile_definitions(_jaxdecomp PRIVATE JD_JAX_BACKEND)
......@@ -3,7 +3,7 @@ from importlib.metadata import PackageNotFoundError, version
from typing import Tuple
from jaxdecomp._src.pencil_utils import get_output_specs
from jaxdecomp.fft import fftfreq3d_shard, pfft3d, pifft3d
from jaxdecomp.fft import fftfreq3d, pfft3d, pifft3d, rfftfreq3d
from jaxdecomp.halo import halo_exchange
from jaxdecomp.transpose import (transposeXtoY, transposeXtoZ, transposeYtoX,
transposeYtoZ, transposeZtoX, transposeZtoY)
......@@ -52,7 +52,8 @@ __all__ = [
from functools import partial
from typing import List, Tuple
import jax
from jax import lax
from jax import numpy as jnp
from jax._src.api import ShapeDtypeStruct
from jax._src.core import ShapedArray
from jax._src.typing import Array
from jax.experimental.shard_alike import shard_alike
from jax.lib import xla_client
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxdecomp._src import _jaxdecomp
from jaxdecomp._src.pencil_utils import get_transpose_order
from jaxdecomp._src.spmd_ops import (CustomParPrimitive, get_pencil_type,
import jaxdecomp
FftType = xla_client.FftType
# TODO Custom partionning for FFTFreq is to be removed
# the only working implementation is fftfreq3d_shard
class FFTFreqPrimitive(CustomParPrimitive):
name = 'jax_fftfreq'
multiple_results = True
impl_static_args: Tuple[int, ...] = ()
outer_primitive = None
def impl(array, d) -> Tuple[Array, Array, Array]:
assert array.ndim == 3, "Only 3D FFTFreq is supported"
kx = jnp.fft.fftfreq(array.shape[0], d=d, dtype=array.dtype)
ky = jnp.fft.fftfreq(array.shape[1], d=d, dtype=array.dtype)
kz = jnp.fft.fftfreq(array.shape[2], d=d, dtype=array.dtype)
assert len(kx) == array.shape[0], "kx must have the same size as array"
assert len(ky) == array.shape[1], "ky must have the same size as array"
assert len(kz) == array.shape[2], "kz must have the same size as array"
kvec = (kx, ky, kz)
transpose_order = get_transpose_order(FftType.FFT)
kx, ky, kz = kvec[transpose_order[0]], kvec[transpose_order[1]], kvec[
kx , ky , kz = (kx.reshape([-1, 1, 1]),
ky.reshape([1, -1, 1]),
kz.reshape([1, 1, -1])) # yapf: disable
print(f"IMPL shape of kx {kx.shape} ky {ky.shape} kz {kz.shape}")
return kx, ky, kz
def per_shard_impl(a: Array, kx: Array, ky: Array, kz: Array,
mesh) -> Tuple[Array, Array, Array]:
x_axis_name, y_axis_name = mesh.axis_names
assert x_axis_name is not None and y_axis_name is not None
transpose_order = get_transpose_order(FftType.FFT, mesh)
kvec = (kx, ky, kz)
kx = kvec[transpose_order[0]]
ky = kvec[transpose_order[1]]
kz = kvec[transpose_order[2]]
return kx, ky, kz
def infer_sharding_from_operands(mesh: Mesh,
arg_infos: Tuple[ShapeDtypeStruct],
result_infos: Tuple[ShapedArray]):
del mesh
input_mesh = arg_infos[0].sharding.mesh
x_axis_name, y_axis_name = input_mesh.axis_names
pencil_type = get_pencil_type(input_mesh)
match pencil_type:
case _jaxdecomp.SLAB_XY | _jaxdecomp.PENCILS:
kx_sharding = NamedSharding(input_mesh, P(x_axis_name))
ky_sharding = NamedSharding(input_mesh, P(None, y_axis_name))
kz_sharding = NamedSharding(input_mesh, P(None, None, None))
case _jaxdecomp.SLAB_YZ:
kx_sharding = NamedSharding(input_mesh, P(y_axis_name))
ky_sharding = NamedSharding(input_mesh, P(None, x_axis_name))
kz_sharding = NamedSharding(input_mesh, P(None, None, None))
case _:
raise ValueError(f"Unsupported pencil type {pencil_type}")
@partial(jax.jit, static_argnums=(1,))
def fftfreq3d(k_array, d=1.0):
return (kx_sharding, ky_sharding, kz_sharding)
# in frequency space, the order is Z pencil
# X pencil is Z Y X
# Z pencil is Y X Z
def partition(mesh: Mesh, arg_infos: Tuple[ShapeDtypeStruct],
result_infos: Tuple[ShapedArray]):
if jnp.iscomplexobj(k_array):
dtype = jnp.float32 if k_array.dtype == jnp.complex64 else jnp.float64
dtype = k_array.dtype
# assert isinstance(arg_infos, tuple) and len(
# arg_infos) == 2, "Arg info must be a tuple of two sharding"
# assert isinstance(result_infos, tuple) and len(
# result_infos) == 3, "Result info must be a tuple of three sharding"
ky = jnp.fft.fftfreq(k_array.shape[0], d=d, dtype=dtype) * 2 * jnp.pi
kx = jnp.fft.fftfreq(k_array.shape[1], d=d, dtype=dtype) * 2 * jnp.pi
kz = jnp.fft.fftfreq(k_array.shape[2], d=d, dtype=dtype) * 2 * jnp.pi
print(f"arg_infos size {len(arg_infos)} and {arg_infos}")
ky, _ = shard_alike(ky, k_array[:, 0, 0])
kx, _ = shard_alike(kx, k_array[0, :, 0])
kz, _ = shard_alike(kz, k_array[0, 0, :])
input_sharding = arg_infos[0].sharding
print(f"args_infos[1] sharding {arg_infos[1].sharding}")
input_mesh = input_sharding.mesh
kvec_sharding = (NamedSharding(input_mesh, P(None)),) * 3
input_shardings = (input_sharding, *kvec_sharding)
print(f"input_shardings {input_shardings}")
print(f"len(input_shardings) {len(input_shardings)}")
print(f"input_mesh {input_mesh}")
ky = ky.reshape([-1, 1, 1])
kx = kx.reshape([1, -1, 1])
kz = kz.reshape([1, 1, -1])
output_sharding = tuple([
NamedSharding(input_mesh, P(*result_infos[i].sharding.spec))
for i in range(3)
return kz, ky, kx
impl = partial(FFTFreqPrimitive.per_shard_impl, mesh=input_mesh)
return mesh, impl, output_sharding, input_shardings
@partial(jax.jit, static_argnums=(1,))
def rfftfreq3d(k_array, d=1.0):
# in frequency space, the order is Z pencil
# X pencil is Z Y X
# Z pencil is Y X Z
# the first axis to be FFT'd is X so it is the real one
def fftfreq_impl(array, kx, ky, kz) -> Array:
return FFTFreqPrimitive.outer_lowering(array, kx, ky, kz)
@partial(jax.jit, static_argnums=(1, 2))
def fftfreq3d(array, d=1.0, dtype=None):
kx = jnp.fft.fftfreq(array.shape[0], d=d, dtype=dtype)
ky = jnp.fft.fftfreq(array.shape[1], d=d, dtype=dtype)
kz = jnp.fft.fftfreq(array.shape[2], d=d, dtype=dtype)
print(f"global shape of kx {kx.shape} ky {ky.shape} kz {kz.shape}")
return fftfreq_impl(array, kx, ky, kz)
@partial(jax.jit, static_argnums=(1, 2))
def rfftfreq3d(array, d=1.0, dtype=None):
kx = jnp.fft.fftfreq(array.shape[0], d=d, dtype=dtype)
ky = jnp.fft.fftfreq(array.shape[1], d=d, dtype=dtype)
kz = jnp.fft.rfftfreq(array.shape[2], d=d, dtype=dtype)
return fftfreq_impl(array, kx, ky, kz)
from jax.experimental.shard_alike import shard_alike
if jnp.iscomplexobj(k_array):
dtype = jnp.float32 if k_array.dtype == jnp.complex64 else jnp.float64
dtype = k_array.dtype
@partial(jax.jit, static_argnums=(1, 2))
def fftfreq3d_shard(array, d=1.0, dtype=None):
ky = jnp.fft.fftfreq(k_array.shape[0], d=d, dtype=dtype) * 2 * jnp.pi
kx = jnp.fft.rfftfreq(k_array.shape[1], d=d, dtype=dtype) * 2 * jnp.pi
kz = jnp.fft.fftfreq(k_array.shape[2], d=d, dtype=dtype) * 2 * jnp.pi
kx = jnp.fft.fftfreq(array.shape[0], d=d, dtype=dtype)
ky = jnp.fft.fftfreq(array.shape[1], d=d, dtype=dtype)
kz = jnp.fft.fftfreq(array.shape[2], d=d, dtype=dtype)
ky, _ = shard_alike(ky, k_array[:, 0, 0])
kx, _ = shard_alike(kx, k_array[0, :, 0])
kz, _ = shard_alike(kz, k_array[0, 0, :])
kx, _ = shard_alike(kx, array[:, 0, 0])
ky, _ = shard_alike(ky, array[0, :, 0])
kz, _ = shard_alike(kz, array[0, 0, :])
ky = ky.reshape([-1, 1, 1])
kx = kx.reshape([1, -1, 1])
kz = kz.reshape([1, 1, -1])
return kx, ky, kz
return kz, ky, kx
......@@ -174,13 +174,9 @@ def pifft3d(a: ArrayLike,
"ifft", xla_client.FftType.IFFT, a, norm=norm, backend=backend)
def fftfreq3d(array, d=1.0, dtype=None):
return _fftfreq.fftfreq3d(array, d=d, dtype=dtype)
def fftfreq3d(array, d=1.0):
return _fftfreq.fftfreq3d(array, d=d)
def rfftfreq3d(array, d=1.0, dtype=None):
return _fftfreq.rfftfreq3d(array, d=d, dtype=dtype)
def fftfreq3d_shard(array, d=1.0, dtype=None):
return _fftfreq.fftfreq3d_shard(array, d=d, dtype=dtype)
def rfftfreq3d(array, d=1.0):
return _fftfreq.rfftfreq3d(array, d=d)
Subproject commit 8b48ff878c168b51fe5ef7b8c728815b9e1a9857
from functools import partial
import jax
jax.config.update("jax_enable_x64", True)
from math import prod
import pytest
from conftest import initialize_distributed
from jax.experimental import mesh_utils, multihost_utils
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from numpy.testing import assert_array_equal
# Initialize jax distributed to instruct jax local process which GPU to use
rank = jax.process_index()
size = jax.process_count()
# Initialize cuDecomp
# This test is just to make sure that multihost_utils.process_allgather works as expected
# Helper function to create a 3D array and remap it to the global array
def create_spmd_array(global_shape, pdims):
assert (len(global_shape) == 3)
assert (len(pdims) == 2)
assert (prod(pdims) == size
), "The product of pdims must be equal to the number of MPI processes"
local_array = jax.random.normal(
shape=[global_shape[0] // pdims[1], global_shape[1] // pdims[0]],
# Remap to the global array from the local slice
devices = mesh_utils.create_device_mesh(pdims[::-1])
mesh = Mesh(devices, axis_names=('z', 'y'))
global_array = multihost_utils.host_local_array_to_global_array(
local_array, mesh, P('z', 'y'))
return global_array, mesh
@pytest.mark.parametrize("pdims", [(1, size), (size, 1),
(size // 2, size // 2)]) # Slabs and Pencils
def test_empty_halo(pdims):
pdims = (2, 2)
global_shape = (29 * size, 19 * size, 17 * size
) # These sizes are prime numbers x size of the pmesh
global_array, mesh = create_spmd_array(global_shape, pdims)
# All gather function
in_specs=P('z', 'y'),
def sharded_allgather(arr):
gathered_z_axis = jax.lax.all_gather(arr, axis_name='z', axis=0, tiled=True)
gathered = jax.lax.all_gather(
gathered_z_axis, axis_name='y', axis=1, tiled=True)
return gathered
gathered = sharded_allgather(global_array)
process_allgather = multihost_utils.process_allgather(
global_array, tiled=True)
print(f"Shape of original array {global_array.shape}")
print(f"Shape of gathered array using double lax gather {gathered.shape}")
f"Shape of gathered array using process_allgather {process_allgather.shape}"
assert_array_equal(gathered, process_allgather)
......@@ -8,14 +8,18 @@ size = jax.device_count()
jax.config.update("jax_enable_x64", True)
from functools import partial
import jax.numpy as jnp
import pytest
from jax.experimental import multihost_utils
from jax.experimental.multihost_utils import process_allgather
from numpy.testing import assert_allclose
import jaxdecomp
from jaxdecomp._src import PENCILS, SLAB_XY, SLAB_YZ
all_gather = partial(process_allgather, tiled=True)
pencil_1 = (size // 2, size // (size // 2)) # 2x2 for V100 and 4x2 for A100
pencil_2 = (size // (size // 2), size // 2) # 2x2 for V100 and 2x4 for A100
......@@ -59,10 +63,9 @@ class TestFFTs:
# assert compare_sharding(global_array.sharding, dist_jax_rec_array.sharding)
# Check the forward FFT
gathered_array = multihost_utils.process_allgather(global_array, tiled=True)
gathered_karray = multihost_utils.process_allgather(karray, tiled=True)
gathered_rec_array = multihost_utils.process_allgather(
rec_array, tiled=True)
gathered_array = all_gather(global_array)
gathered_karray = all_gather(karray)
gathered_rec_array = all_gather(rec_array)
jax_karray = jnp.fft.fftn(gathered_array)
jax_rec_array = jnp.fft.ifftn(jax_karray, norm='forward')
......@@ -167,8 +170,8 @@ class TestFFTsGrad:
array_grad = jax.grad(spmd_grad)(global_array)
print("Here is the gradient I'm getting", array_grad.shape)
gathered_array = multihost_utils.process_allgather(global_array, tiled=True)
gathered_grads = multihost_utils.process_allgather(array_grad, tiled=True)
gathered_array = all_gather(global_array)
gathered_grads = all_gather(array_grad)
jax_grad = jax.grad(local_grad)(gathered_array)
print(f"Shape of JAX array {jax_grad.shape}")
......@@ -194,8 +197,7 @@ class TestFFTsGrad:
ifft_array_grad = jax.grad(inv_spmd_grad)(karray)
print("Here is the gradient I'm getting", array_grad.shape)
ifft_gathered_grads = multihost_utils.process_allgather(
ifft_array_grad, tiled=True)
ifft_gathered_grads = all_gather(ifft_array_grad)
jax_karray = jnp.fft.fftn(gathered_array).transpose(transpose_back)
ifft_jax_grad = jax.grad(inv_local_grad)(jax_karray)
......@@ -230,3 +232,81 @@ class TestFFTsGrad:
global_shapes) # Test cubes, non-cubes and primes
def test_jax_grad(self, pdims, global_shape, local_transpose):
self.run_test(pdims, global_shape, local_transpose, backend="jax")
class TestFFTFreq:
def run_test(self, pdims, global_shape, local_transpose, backend):
print("*" * 80)
f"Testing with pdims {pdims} and global shape {global_shape} and local transpose {local_transpose}"
jaxdecomp.config.update('transpose_axis_contiguous', local_transpose)
if not local_transpose:
pytest.skip(reason="Not implemented yet")
global_array, mesh = create_spmd_array(global_shape, pdims)
# Perform distributed gradient kernel
karray = jaxdecomp.fft.pfft3d(global_array, backend=backend)
kvec = jaxdecomp.fftfreq3d(karray)
k_gradients = [k * karray for k in kvec]
gradients = [
jaxdecomp.fft.pifft3d(grad, backend=backend) for grad in k_gradients
gathered_gradients = [all_gather(grad) for grad in gradients]
# perform local gradient kernel
gathered_array = all_gather(global_array)
jax_karray = jnp.fft.fftn(gathered_array)
kz, ky, kx = [
jnp.fft.fftfreq(jax_karray.shape[i]) * 2 * jnp.pi for i in range(3)
kz = kz.reshape(-1, 1, 1)
ky = ky.reshape(1, -1, 1)
kx = kx.reshape(1, 1, -1)
kvec = [kz, ky, kx]
jax_k_gradients = [k * jax_karray for k in kvec]
jax_gradients = [jnp.fft.ifftn(grad) for grad in jax_k_gradients]
# Check the gradients
for i in range(3):
jax_gradients[i], gathered_gradients[i], rtol=1e-5, atol=1e-5)
print(f"Gradient check OK!")
# Trigger rejit in case local transpose is switched
@pytest.mark.skipif(not is_on_cluster(), reason="Only run on cluster")
# Cartesian product tests
local_transpose) # Test with and without local transpose
decomp) # Test with Slab and Pencil decompositions
global_shapes) # Test cubes, non-cubes and primes
def test_cudecomp_fft(self, pdims, global_shape, local_transpose):
self.run_test(pdims, global_shape, local_transpose, backend="cuDecomp")
# Cartesian product tests
local_transpose) # Test with and without local transpose
decomp) # Test with Slab and Pencil decompositions
global_shapes) # Test cubes, non-cubes and primes
def test_jax_fft(self, pdims, global_shape, local_transpose):
self.run_test(pdims, global_shape, local_transpose, backend="jax")
......@@ -103,9 +103,6 @@ def test_halo_against_cudecomp(pdims):
assert_array_equal(g_jax_exchanged, g_cudecomp_exchanged)
pdims = [pdims[0]]
class TestHaloExchange:
def run_test(self, global_shape, pdims, backend):
from functools import partial
from math import prod
import jax
import pytest
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from conftest import initialize_distributed
from jax import lax
from jax.experimental import mesh_utils, multihost_utils
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from numpy.testing import assert_array_equal
import jaxdecomp
from jaxdecomp._src.padding import slice_pad, slice_unpad
# Initialize jax distributed to instruct jax local process which GPU to use
rank = jax.process_index()
size = jax.process_count()
# Initialize cuDecomp
# Helper function to create a 3D array and remap it to the global array
def create_spmd_array(global_shape, pdims, complex=False):
assert (len(global_shape) == 3)
assert (len(pdims) == 2)
assert (prod(pdims) == size
), "The product of pdims must be equal to the number of MPI processes"
if complex:
local_array = jax.random.normal(
global_shape[0] // pdims[1], global_shape[1] // pdims[0],
key=jax.random.PRNGKey(rank)) + 1j * jax.random.normal(
global_shape[0] // pdims[1], global_shape[1] // pdims[0],
key=jax.random.PRNGKey(rank + 1))
local_array = jax.random.normal(
global_shape[0] // pdims[1], global_shape[1] // pdims[0],
# Remap to the global array from the local slice
devices = mesh_utils.create_device_mesh(pdims[::-1])
mesh = Mesh(devices, axis_names=('z', 'y'))
global_array = multihost_utils.host_local_array_to_global_array(
local_array, mesh, P('z', 'y'))
return global_array, mesh
pencil_1 = (size // 2, size // (size // 2))
pencil_2 = (size // (size // 2), size // 2)
decomp = [(size, 1), (1, size), pencil_1, pencil_2]
global_shapes = [(32, 32, 32), (29 * size, 19 * size, 17 * size)]
decomp) # Test with Slab and Pencil decompositions
global_shapes) # Test cubes, non-cubes and primes
def test_padding(pdims, global_shape):
print("*" * 80)
print(f"Testing with pdims {pdims} and global_shape {global_shape}")
global_array, mesh = create_spmd_array(global_shape, pdims)
padding = ((32, 32), (32, 32), (0, 0))
# Reference implementation of per shard slicing
shard_map, mesh=mesh, in_specs=(P('z', 'y'), P()), out_specs=P('z', 'y'))
def sharded_pad(arr, padding):
padded = jnp.pad(arr, pad_width=padding)
return padded
@partial(shard_map, mesh=mesh, in_specs=(P('z', 'y')), out_specs=P('z', 'y'))
def sharded_unpad(arr):
x_unpading, y_unpading, z_unpading = padding[0], padding[1], padding[2]
first_x, first_y, first_z = -x_unpading[0], -y_unpading[0], -z_unpading[0]
last_x, last_y, last_z = -x_unpading[1], -y_unpading[1], -z_unpading[1]
return lax.pad(
padding_config=[(first_x, last_x, 0), (first_y, last_y, 0),
(first_z, last_z, 0)])
# Test padding
print("-" * 40)
print(f"Testing padding")
with mesh:
padded_array = jnp.pad(global_array, padding)
padding_width = jnp.array(padding)
sharded_padded = sharded_pad(global_array, padding_width)
jaxdecomp_padded = slice_pad(global_array, padding, pdims)
first_x, last_x = padding[0]
first_y, last_y = padding[1]
# using just jnp pad will pad the entire global array and not the slices
expected_padded_shape = (global_shape[0] + first_x + last_x,
global_shape[1] + first_y + last_y, global_shape[2])
# Using a sharded jnp pad will pad the slices
expected_sharded_pad_shape = (global_shape[0] + (first_x + last_x) * pdims[1],
global_shape[1] + (first_y + last_y) * pdims[0],
print(f"Shape of global_array {global_array.shape}")
f"Shape of padded_array {padded_array.shape} it should be {expected_padded_shape}"
f"Shape of sharded_padded {sharded_padded.shape} it should be {expected_sharded_pad_shape}"
f"Shape of jaxdecomp_padded {jaxdecomp_padded.shape} it should be {expected_sharded_pad_shape}"
# Using pad on a global array will pad the array (uses communication)
assert_array_equal(padded_array.shape, expected_padded_shape)
# Test slice_pad agains reference sharded pad
assert_array_equal(sharded_padded.shape, expected_sharded_pad_shape)
assert_array_equal(jaxdecomp_padded.shape, expected_sharded_pad_shape)
# Test unpadding
print("-" * 40)
print(f"Testing unpadding")
with mesh:
unpadded_array = lax.pad(
padding_config=((-32, -32, 0), (-32, -32, 0), (0, 0, 0)))
sharded_unpadded = sharded_unpad(sharded_padded)
jaxdecomp_unpadded = slice_unpad(jaxdecomp_padded, padding, pdims)
f"Shape of unpadded_array {unpadded_array.shape} should be {global_shape}"
f"Shape of sharded_unpadded {sharded_unpadded.shape} should be {global_shape}"
f"Shape of jaxdecomp_unpadded {jaxdecomp_unpadded.shape} should be {global_shape}"
first_x, last_x = padding[0]
first_y, last_y = padding[1]
# Using pad on a global array will unpad the array (uses communication)
assert_array_equal(unpadded_array.shape, global_shape)
# Test slice_pad agains reference sharded pad
assert_array_equal(sharded_unpadded.shape, global_shape)
assert_array_equal(jaxdecomp_unpadded.shape, global_shape)
gathered_original = multihost_utils.process_allgather(
global_array, tiled=True)
gathered_unpadded = multihost_utils.process_allgather(
unpadded_array, tiled=True)
# Make sure the unpadded arrays is equal to the original array
assert_array_equal(gathered_original, gathered_unpadded)
decomp) # Test with Slab and Pencil decompositions
global_shapes) # Test cubes, non-cubes and primes
def test_complex_unpad(pdims, global_shape):
print("*" * 80)
print(f"Testing with pdims {pdims} and global_shape {global_shape}")
global_array, mesh = create_spmd_array(global_shape, pdims, complex=True)
padding = ((32, 32), (32, 32), (0, 0))
with mesh:
jaxdecomp_padded = slice_pad(global_array, padding, pdims)
jaxdecomp_unpadded = slice_unpad(jaxdecomp_padded, padding, pdims)
gathered_original = multihost_utils.process_allgather(
global_array, tiled=True)
gathered_unpadded = multihost_utils.process_allgather(
jaxdecomp_unpadded, tiled=True)
# Make sure the unpadded arrays is equal to the original array
assert_array_equal(gathered_original, gathered_unpadded)
......@@ -9,17 +9,19 @@ size = jax.device_count()
import pytest
jax.config.update("jax_enable_x64", False)
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
from numpy.testing import assert_allclose, assert_array_equal
from functools import partial
import jaxdecomp
from jaxdecomp import (transposeXtoY, transposeYtoX, transposeYtoZ,
all_gather = partial(process_allgather, tiled=True)
pencil_1 = (size // 2, size // (size // 2)) # 2x2 for V100 and 4x2 for A100
pencil_2 = (size // (size // 2), size // 2) # 2x2 for V100 and 2x4 for A100
......@@ -77,12 +79,12 @@ class TestTransposes:
assert compare_sharding(jd_tranposed_zy.sharding, y_pencil_sharding)
assert compare_sharding(jd_tranposed_yx.sharding, original_sharding)
gathered_array = process_allgather(global_array, tiled=True)
gathered_array = all_gather(global_array)
gathered_jd_xy = process_allgather(jd_tranposed_xy, tiled=True)
gathered_jd_yz = process_allgather(jd_tranposed_yz, tiled=True)
gathered_jd_zy = process_allgather(jd_tranposed_zy, tiled=True)
gathered_jd_yx = process_allgather(jd_tranposed_yx, tiled=True)
gathered_jd_xy = all_gather(jd_tranposed_xy)
gathered_jd_yz = all_gather(jd_tranposed_yz)
gathered_jd_zy = all_gather(jd_tranposed_zy)
gathered_jd_yx = all_gather(jd_tranposed_yx)
# Explanation :
# Tranposing forward is a shift axis to the right so ZYX to XZY to YXZ (2 0 1)
......@@ -192,8 +194,8 @@ class TestTransposesGrad:
array_grad = jax.grad(jaxdecomp_transpose)(global_array)
print("Here is the gradient I'm getting", array_grad.shape)
gathered_array = process_allgather(global_array, tiled=True)
gathered_grads = process_allgather(array_grad, tiled=True)
gathered_array = all_gather(global_array)
gathered_grads = all_gather(array_grad)
jax_grad = jax.grad(jax_transpose)(gathered_array)
print(f"Shape of JAX array {jax_grad.shape}")