Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • wassim/jaxDecomp
1 result
Show changes
Commits on Source (9)
name: Tests
on:
on:
push:
branches:
- main
......@@ -19,22 +19,20 @@ jobs:
steps:
- name: Checkout Source
uses: actions/checkout@v2.3.1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[tests]
pip install .[test]
- name: Run tests
run: |
cd tests
export JAX_PLATFORM_NAME=cpu
export XLA_FLAGS='--xla_force_host_platform_device_count=8'
pytest test_fft.py
pytest test_halo.py
pytest test_transpose.py
pytest -v
[submodule "third_party/cuDecomp"]
path = third_party/cuDecomp
url = https://github.com/NVIDIA/cuDecomp.git
[submodule "pybind11"]
path = pybind11
url = https://github.com/pybind/pybind11.git
......@@ -15,6 +15,9 @@ if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
endif()
set(PYBIND11_FINDPYTHON ON)
find_package(pybind11 CONFIG REQUIRED)
# Check for CUDA
include(CheckLanguage)
check_language(CUDA)
......@@ -37,10 +40,6 @@ if(CMAKE_CUDA_COMPILER AND JD_CUDECOMP_BACKEND)
# 70: Volta, 80: Ampere, 89: RTX 4060
set(CUDECOMP_CUDA_CC_LIST "70;80;89" CACHE STRING "List of CUDA compute capabilities to build cuDecomp for.")
# Add pybind11 and cuDecomp subdirectories
add_subdirectory(pybind11)
find_package(NVHPC REQUIRED COMPONENTS MATH MPI NCCL)
string(REPLACE "/lib64" "/include" NVHPC_MATH_INCLUDE_DIR ${NVHPC_MATH_LIBRARY_DIR})
......@@ -89,7 +88,6 @@ if(CMAKE_CUDA_COMPILER AND JD_CUDECOMP_BACKEND)
set_target_properties(_jaxdecomp PROPERTIES LINKER_LANGUAGE CXX)
target_compile_definitions(_jaxdecomp PRIVATE JD_CUDECOMP_BACKEND)
else()
add_subdirectory(pybind11)
pybind11_add_module(_jaxdecomp src/jaxdecomp.cc)
target_include_directories(_jaxdecomp PRIVATE ${CMAKE_CURRENT_LIST_DIR}/include)
target_compile_definitions(_jaxdecomp PRIVATE JD_JAX_BACKEND)
......
......@@ -3,7 +3,7 @@ from importlib.metadata import PackageNotFoundError, version
from typing import Tuple
from jaxdecomp._src.pencil_utils import get_output_specs
from jaxdecomp.fft import fftfreq3d_shard, pfft3d, pifft3d
from jaxdecomp.fft import fftfreq3d, pfft3d, pifft3d, rfftfreq3d
from jaxdecomp.halo import halo_exchange
from jaxdecomp.transpose import (transposeXtoY, transposeXtoZ, transposeYtoX,
transposeYtoZ, transposeZtoX, transposeZtoY)
......@@ -52,7 +52,8 @@ __all__ = [
"PENCILS",
"NO_DECOMP",
"get_output_specs",
"fftfreq3d_shard",
"fftfreq3d",
"rfftfreq3d",
]
......
from functools import partial
from typing import List, Tuple
import jax
from jax import lax
from jax import numpy as jnp
from jax._src.api import ShapeDtypeStruct
from jax._src.core import ShapedArray
from jax._src.typing import Array
from jax.experimental.shard_alike import shard_alike
from jax.lib import xla_client
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxdecomp._src import _jaxdecomp
from jaxdecomp._src.pencil_utils import get_transpose_order
from jaxdecomp._src.spmd_ops import (CustomParPrimitive, get_pencil_type,
register_primitive)
import jaxdecomp
FftType = xla_client.FftType
# TODO Custom partionning for FFTFreq is to be removed
# the only working implementation is fftfreq3d_shard
class FFTFreqPrimitive(CustomParPrimitive):
name = 'jax_fftfreq'
multiple_results = True
impl_static_args: Tuple[int, ...] = ()
outer_primitive = None
@staticmethod
def impl(array, d) -> Tuple[Array, Array, Array]:
assert array.ndim == 3, "Only 3D FFTFreq is supported"
kx = jnp.fft.fftfreq(array.shape[0], d=d, dtype=array.dtype)
ky = jnp.fft.fftfreq(array.shape[1], d=d, dtype=array.dtype)
kz = jnp.fft.fftfreq(array.shape[2], d=d, dtype=array.dtype)
assert len(kx) == array.shape[0], "kx must have the same size as array"
assert len(ky) == array.shape[1], "ky must have the same size as array"
assert len(kz) == array.shape[2], "kz must have the same size as array"
kvec = (kx, ky, kz)
transpose_order = get_transpose_order(FftType.FFT)
kx, ky, kz = kvec[transpose_order[0]], kvec[transpose_order[1]], kvec[
transpose_order[2]]
kx , ky , kz = (kx.reshape([-1, 1, 1]),
ky.reshape([1, -1, 1]),
kz.reshape([1, 1, -1])) # yapf: disable
print(f"IMPL shape of kx {kx.shape} ky {ky.shape} kz {kz.shape}")
return kx, ky, kz
@staticmethod
def per_shard_impl(a: Array, kx: Array, ky: Array, kz: Array,
mesh) -> Tuple[Array, Array, Array]:
x_axis_name, y_axis_name = mesh.axis_names
assert x_axis_name is not None and y_axis_name is not None
transpose_order = get_transpose_order(FftType.FFT, mesh)
kvec = (kx, ky, kz)
kx = kvec[transpose_order[0]]
ky = kvec[transpose_order[1]]
kz = kvec[transpose_order[2]]
return kx, ky, kz
@staticmethod
def infer_sharding_from_operands(mesh: Mesh,
arg_infos: Tuple[ShapeDtypeStruct],
result_infos: Tuple[ShapedArray]):
del mesh
input_mesh = arg_infos[0].sharding.mesh
x_axis_name, y_axis_name = input_mesh.axis_names
pencil_type = get_pencil_type(input_mesh)
match pencil_type:
case _jaxdecomp.SLAB_XY | _jaxdecomp.PENCILS:
kx_sharding = NamedSharding(input_mesh, P(x_axis_name))
ky_sharding = NamedSharding(input_mesh, P(None, y_axis_name))
kz_sharding = NamedSharding(input_mesh, P(None, None, None))
case _jaxdecomp.SLAB_YZ:
kx_sharding = NamedSharding(input_mesh, P(y_axis_name))
ky_sharding = NamedSharding(input_mesh, P(None, x_axis_name))
kz_sharding = NamedSharding(input_mesh, P(None, None, None))
case _:
raise ValueError(f"Unsupported pencil type {pencil_type}")
@partial(jax.jit, static_argnums=(1,))
def fftfreq3d(k_array, d=1.0):
return (kx_sharding, ky_sharding, kz_sharding)
# in frequency space, the order is Z pencil
# X pencil is Z Y X
# Z pencil is Y X Z
@staticmethod
def partition(mesh: Mesh, arg_infos: Tuple[ShapeDtypeStruct],
result_infos: Tuple[ShapedArray]):
if jnp.iscomplexobj(k_array):
dtype = jnp.float32 if k_array.dtype == jnp.complex64 else jnp.float64
else:
dtype = k_array.dtype
# assert isinstance(arg_infos, tuple) and len(
# arg_infos) == 2, "Arg info must be a tuple of two sharding"
# assert isinstance(result_infos, tuple) and len(
# result_infos) == 3, "Result info must be a tuple of three sharding"
ky = jnp.fft.fftfreq(k_array.shape[0], d=d, dtype=dtype) * 2 * jnp.pi
kx = jnp.fft.fftfreq(k_array.shape[1], d=d, dtype=dtype) * 2 * jnp.pi
kz = jnp.fft.fftfreq(k_array.shape[2], d=d, dtype=dtype) * 2 * jnp.pi
print(f"arg_infos size {len(arg_infos)} and {arg_infos}")
ky, _ = shard_alike(ky, k_array[:, 0, 0])
kx, _ = shard_alike(kx, k_array[0, :, 0])
kz, _ = shard_alike(kz, k_array[0, 0, :])
input_sharding = arg_infos[0].sharding
print(f"args_infos[1] sharding {arg_infos[1].sharding}")
input_mesh = input_sharding.mesh
kvec_sharding = (NamedSharding(input_mesh, P(None)),) * 3
input_shardings = (input_sharding, *kvec_sharding)
print(f"input_shardings {input_shardings}")
print(f"len(input_shardings) {len(input_shardings)}")
print(f"input_mesh {input_mesh}")
ky = ky.reshape([-1, 1, 1])
kx = kx.reshape([1, -1, 1])
kz = kz.reshape([1, 1, -1])
output_sharding = tuple([
NamedSharding(input_mesh, P(*result_infos[i].sharding.spec))
for i in range(3)
])
return kz, ky, kx
impl = partial(FFTFreqPrimitive.per_shard_impl, mesh=input_mesh)
return mesh, impl, output_sharding, input_shardings
@partial(jax.jit, static_argnums=(1,))
def rfftfreq3d(k_array, d=1.0):
# in frequency space, the order is Z pencil
# X pencil is Z Y X
# Z pencil is Y X Z
# the first axis to be FFT'd is X so it is the real one
register_primitive(FFTFreqPrimitive)
@jax.jit
def fftfreq_impl(array, kx, ky, kz) -> Array:
return FFTFreqPrimitive.outer_lowering(array, kx, ky, kz)
@partial(jax.jit, static_argnums=(1, 2))
def fftfreq3d(array, d=1.0, dtype=None):
kx = jnp.fft.fftfreq(array.shape[0], d=d, dtype=dtype)
ky = jnp.fft.fftfreq(array.shape[1], d=d, dtype=dtype)
kz = jnp.fft.fftfreq(array.shape[2], d=d, dtype=dtype)
print(f"global shape of kx {kx.shape} ky {ky.shape} kz {kz.shape}")
return fftfreq_impl(array, kx, ky, kz)
@partial(jax.jit, static_argnums=(1, 2))
def rfftfreq3d(array, d=1.0, dtype=None):
kx = jnp.fft.fftfreq(array.shape[0], d=d, dtype=dtype)
ky = jnp.fft.fftfreq(array.shape[1], d=d, dtype=dtype)
kz = jnp.fft.rfftfreq(array.shape[2], d=d, dtype=dtype)
return fftfreq_impl(array, kx, ky, kz)
from jax.experimental.shard_alike import shard_alike
if jnp.iscomplexobj(k_array):
dtype = jnp.float32 if k_array.dtype == jnp.complex64 else jnp.float64
else:
dtype = k_array.dtype
@partial(jax.jit, static_argnums=(1, 2))
def fftfreq3d_shard(array, d=1.0, dtype=None):
ky = jnp.fft.fftfreq(k_array.shape[0], d=d, dtype=dtype) * 2 * jnp.pi
kx = jnp.fft.rfftfreq(k_array.shape[1], d=d, dtype=dtype) * 2 * jnp.pi
kz = jnp.fft.fftfreq(k_array.shape[2], d=d, dtype=dtype) * 2 * jnp.pi
kx = jnp.fft.fftfreq(array.shape[0], d=d, dtype=dtype)
ky = jnp.fft.fftfreq(array.shape[1], d=d, dtype=dtype)
kz = jnp.fft.fftfreq(array.shape[2], d=d, dtype=dtype)
ky, _ = shard_alike(ky, k_array[:, 0, 0])
kx, _ = shard_alike(kx, k_array[0, :, 0])
kz, _ = shard_alike(kz, k_array[0, 0, :])
kx, _ = shard_alike(kx, array[:, 0, 0])
ky, _ = shard_alike(ky, array[0, :, 0])
kz, _ = shard_alike(kz, array[0, 0, :])
ky = ky.reshape([-1, 1, 1])
kx = kx.reshape([1, -1, 1])
kz = kz.reshape([1, 1, -1])
return kx, ky, kz
return kz, ky, kx
......@@ -174,13 +174,9 @@ def pifft3d(a: ArrayLike,
"ifft", xla_client.FftType.IFFT, a, norm=norm, backend=backend)
def fftfreq3d(array, d=1.0, dtype=None):
return _fftfreq.fftfreq3d(array, d=d, dtype=dtype)
def fftfreq3d(array, d=1.0):
return _fftfreq.fftfreq3d(array, d=d)
def rfftfreq3d(array, d=1.0, dtype=None):
return _fftfreq.rfftfreq3d(array, d=d, dtype=dtype)
def fftfreq3d_shard(array, d=1.0, dtype=None):
return _fftfreq.fftfreq3d_shard(array, d=d, dtype=dtype)
def rfftfreq3d(array, d=1.0):
return _fftfreq.rfftfreq3d(array, d=d)
Subproject commit 8b48ff878c168b51fe5ef7b8c728815b9e1a9857
from functools import partial
import jax
jax.config.update("jax_enable_x64", True)
from math import prod
import pytest
from conftest import initialize_distributed
from jax.experimental import mesh_utils, multihost_utils
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from numpy.testing import assert_array_equal
# Initialize jax distributed to instruct jax local process which GPU to use
initialize_distributed()
rank = jax.process_index()
size = jax.process_count()
# Initialize cuDecomp
############################################################################################################
# This test is just to make sure that multihost_utils.process_allgather works as expected
############################################################################################################
# Helper function to create a 3D array and remap it to the global array
def create_spmd_array(global_shape, pdims):
assert (len(global_shape) == 3)
assert (len(pdims) == 2)
assert (prod(pdims) == size
), "The product of pdims must be equal to the number of MPI processes"
local_array = jax.random.normal(
shape=[global_shape[0] // pdims[1], global_shape[1] // pdims[0]],
key=jax.random.PRNGKey(rank))
# Remap to the global array from the local slice
devices = mesh_utils.create_device_mesh(pdims[::-1])
mesh = Mesh(devices, axis_names=('z', 'y'))
global_array = multihost_utils.host_local_array_to_global_array(
local_array, mesh, P('z', 'y'))
return global_array, mesh
@pytest.mark.parametrize("pdims", [(1, size), (size, 1),
(size // 2, size // 2)]) # Slabs and Pencils
def test_empty_halo(pdims):
pdims = (2, 2)
global_shape = (29 * size, 19 * size, 17 * size
) # These sizes are prime numbers x size of the pmesh
global_array, mesh = create_spmd_array(global_shape, pdims)
# All gather function
@partial(
shard_map,
mesh=mesh,
in_specs=P('z', 'y'),
out_specs=P(),
check_rep=False)
def sharded_allgather(arr):
gathered_z_axis = jax.lax.all_gather(arr, axis_name='z', axis=0, tiled=True)
gathered = jax.lax.all_gather(
gathered_z_axis, axis_name='y', axis=1, tiled=True)
return gathered
gathered = sharded_allgather(global_array)
process_allgather = multihost_utils.process_allgather(
global_array, tiled=True)
print(f"Shape of original array {global_array.shape}")
print(f"Shape of gathered array using double lax gather {gathered.shape}")
print(
f"Shape of gathered array using process_allgather {process_allgather.shape}"
)
assert_array_equal(gathered, process_allgather)
......@@ -8,14 +8,18 @@ size = jax.device_count()
jax.config.update("jax_enable_x64", True)
from functools import partial
import jax.numpy as jnp
import pytest
from jax.experimental import multihost_utils
from jax.experimental.multihost_utils import process_allgather
from numpy.testing import assert_allclose
import jaxdecomp
from jaxdecomp._src import PENCILS, SLAB_XY, SLAB_YZ
all_gather = partial(process_allgather, tiled=True)
pencil_1 = (size // 2, size // (size // 2)) # 2x2 for V100 and 4x2 for A100
pencil_2 = (size // (size // 2), size // 2) # 2x2 for V100 and 2x4 for A100
......@@ -59,10 +63,9 @@ class TestFFTs:
# assert compare_sharding(global_array.sharding, dist_jax_rec_array.sharding)
# Check the forward FFT
gathered_array = multihost_utils.process_allgather(global_array, tiled=True)
gathered_karray = multihost_utils.process_allgather(karray, tiled=True)
gathered_rec_array = multihost_utils.process_allgather(
rec_array, tiled=True)
gathered_array = all_gather(global_array)
gathered_karray = all_gather(karray)
gathered_rec_array = all_gather(rec_array)
jax_karray = jnp.fft.fftn(gathered_array)
jax_rec_array = jnp.fft.ifftn(jax_karray, norm='forward')
......@@ -167,8 +170,8 @@ class TestFFTsGrad:
array_grad = jax.grad(spmd_grad)(global_array)
print("Here is the gradient I'm getting", array_grad.shape)
gathered_array = multihost_utils.process_allgather(global_array, tiled=True)
gathered_grads = multihost_utils.process_allgather(array_grad, tiled=True)
gathered_array = all_gather(global_array)
gathered_grads = all_gather(array_grad)
jax_grad = jax.grad(local_grad)(gathered_array)
print(f"Shape of JAX array {jax_grad.shape}")
......@@ -194,8 +197,7 @@ class TestFFTsGrad:
ifft_array_grad = jax.grad(inv_spmd_grad)(karray)
print("Here is the gradient I'm getting", array_grad.shape)
ifft_gathered_grads = multihost_utils.process_allgather(
ifft_array_grad, tiled=True)
ifft_gathered_grads = all_gather(ifft_array_grad)
jax_karray = jnp.fft.fftn(gathered_array).transpose(transpose_back)
ifft_jax_grad = jax.grad(inv_local_grad)(jax_karray)
......@@ -230,3 +232,81 @@ class TestFFTsGrad:
global_shapes) # Test cubes, non-cubes and primes
def test_jax_grad(self, pdims, global_shape, local_transpose):
self.run_test(pdims, global_shape, local_transpose, backend="jax")
class TestFFTFreq:
def run_test(self, pdims, global_shape, local_transpose, backend):
print("*" * 80)
print(
f"Testing with pdims {pdims} and global shape {global_shape} and local transpose {local_transpose}"
)
jaxdecomp.config.update('transpose_axis_contiguous', local_transpose)
if not local_transpose:
pytest.skip(reason="Not implemented yet")
global_array, mesh = create_spmd_array(global_shape, pdims)
# Perform distributed gradient kernel
karray = jaxdecomp.fft.pfft3d(global_array, backend=backend)
kvec = jaxdecomp.fftfreq3d(karray)
k_gradients = [k * karray for k in kvec]
gradients = [
jaxdecomp.fft.pifft3d(grad, backend=backend) for grad in k_gradients
]
gathered_gradients = [all_gather(grad) for grad in gradients]
# perform local gradient kernel
gathered_array = all_gather(global_array)
jax_karray = jnp.fft.fftn(gathered_array)
kz, ky, kx = [
jnp.fft.fftfreq(jax_karray.shape[i]) * 2 * jnp.pi for i in range(3)
]
kz = kz.reshape(-1, 1, 1)
ky = ky.reshape(1, -1, 1)
kx = kx.reshape(1, 1, -1)
kvec = [kz, ky, kx]
jax_k_gradients = [k * jax_karray for k in kvec]
jax_gradients = [jnp.fft.ifftn(grad) for grad in jax_k_gradients]
# Check the gradients
for i in range(3):
assert_allclose(
jax_gradients[i], gathered_gradients[i], rtol=1e-5, atol=1e-5)
print(f"Gradient check OK!")
# Trigger rejit in case local transpose is switched
jax.clear_caches()
@pytest.mark.skipif(not is_on_cluster(), reason="Only run on cluster")
# Cartesian product tests
@pytest.mark.parametrize(
"local_transpose",
local_transpose) # Test with and without local transpose
@pytest.mark.parametrize("pdims",
decomp) # Test with Slab and Pencil decompositions
@pytest.mark.parametrize("global_shape",
global_shapes) # Test cubes, non-cubes and primes
def test_cudecomp_fft(self, pdims, global_shape, local_transpose):
self.run_test(pdims, global_shape, local_transpose, backend="cuDecomp")
# Cartesian product tests
@pytest.mark.parametrize(
"local_transpose",
local_transpose) # Test with and without local transpose
@pytest.mark.parametrize("pdims",
decomp) # Test with Slab and Pencil decompositions
@pytest.mark.parametrize("global_shape",
global_shapes) # Test cubes, non-cubes and primes
def test_jax_fft(self, pdims, global_shape, local_transpose):
self.run_test(pdims, global_shape, local_transpose, backend="jax")
......@@ -103,9 +103,6 @@ def test_halo_against_cudecomp(pdims):
assert_array_equal(g_jax_exchanged, g_cudecomp_exchanged)
pdims = [pdims[0]]
class TestHaloExchange:
def run_test(self, global_shape, pdims, backend):
......
from functools import partial
from math import prod
import jax
import pytest
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from conftest import initialize_distributed
from jax import lax
from jax.experimental import mesh_utils, multihost_utils
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from numpy.testing import assert_array_equal
import jaxdecomp
from jaxdecomp._src.padding import slice_pad, slice_unpad
# Initialize jax distributed to instruct jax local process which GPU to use
initialize_distributed()
rank = jax.process_index()
size = jax.process_count()
# Initialize cuDecomp
# Helper function to create a 3D array and remap it to the global array
def create_spmd_array(global_shape, pdims, complex=False):
assert (len(global_shape) == 3)
assert (len(pdims) == 2)
assert (prod(pdims) == size
), "The product of pdims must be equal to the number of MPI processes"
if complex:
local_array = jax.random.normal(
shape=[
global_shape[0] // pdims[1], global_shape[1] // pdims[0],
global_shape[2]
],
key=jax.random.PRNGKey(rank)) + 1j * jax.random.normal(
shape=[
global_shape[0] // pdims[1], global_shape[1] // pdims[0],
global_shape[2]
],
key=jax.random.PRNGKey(rank + 1))
else:
local_array = jax.random.normal(
shape=[
global_shape[0] // pdims[1], global_shape[1] // pdims[0],
global_shape[2]
],
key=jax.random.PRNGKey(rank))
# Remap to the global array from the local slice
devices = mesh_utils.create_device_mesh(pdims[::-1])
mesh = Mesh(devices, axis_names=('z', 'y'))
global_array = multihost_utils.host_local_array_to_global_array(
local_array, mesh, P('z', 'y'))
return global_array, mesh
pencil_1 = (size // 2, size // (size // 2))
pencil_2 = (size // (size // 2), size // 2)
decomp = [(size, 1), (1, size), pencil_1, pencil_2]
global_shapes = [(32, 32, 32), (29 * size, 19 * size, 17 * size)]
@pytest.mark.parametrize("pdims",
decomp) # Test with Slab and Pencil decompositions
@pytest.mark.parametrize("global_shape",
global_shapes) # Test cubes, non-cubes and primes
def test_padding(pdims, global_shape):
print("*" * 80)
print(f"Testing with pdims {pdims} and global_shape {global_shape}")
global_array, mesh = create_spmd_array(global_shape, pdims)
padding = ((32, 32), (32, 32), (0, 0))
# Reference implementation of per shard slicing
@partial(
shard_map, mesh=mesh, in_specs=(P('z', 'y'), P()), out_specs=P('z', 'y'))
def sharded_pad(arr, padding):
padded = jnp.pad(arr, pad_width=padding)
return padded
@partial(shard_map, mesh=mesh, in_specs=(P('z', 'y')), out_specs=P('z', 'y'))
def sharded_unpad(arr):
x_unpading, y_unpading, z_unpading = padding[0], padding[1], padding[2]
first_x, first_y, first_z = -x_unpading[0], -y_unpading[0], -z_unpading[0]
last_x, last_y, last_z = -x_unpading[1], -y_unpading[1], -z_unpading[1]
return lax.pad(
arr,
padding_value=0.0,
padding_config=[(first_x, last_x, 0), (first_y, last_y, 0),
(first_z, last_z, 0)])
# Test padding
print("-" * 40)
print(f"Testing padding")
with mesh:
padded_array = jnp.pad(global_array, padding)
padding_width = jnp.array(padding)
sharded_padded = sharded_pad(global_array, padding_width)
jaxdecomp_padded = slice_pad(global_array, padding, pdims)
first_x, last_x = padding[0]
first_y, last_y = padding[1]
# using just jnp pad will pad the entire global array and not the slices
expected_padded_shape = (global_shape[0] + first_x + last_x,
global_shape[1] + first_y + last_y, global_shape[2])
# Using a sharded jnp pad will pad the slices
expected_sharded_pad_shape = (global_shape[0] + (first_x + last_x) * pdims[1],
global_shape[1] + (first_y + last_y) * pdims[0],
global_shape[2])
print(f"Shape of global_array {global_array.shape}")
print(
f"Shape of padded_array {padded_array.shape} it should be {expected_padded_shape}"
)
print(
f"Shape of sharded_padded {sharded_padded.shape} it should be {expected_sharded_pad_shape}"
)
print(
f"Shape of jaxdecomp_padded {jaxdecomp_padded.shape} it should be {expected_sharded_pad_shape}"
)
# Using pad on a global array will pad the array (uses communication)
assert_array_equal(padded_array.shape, expected_padded_shape)
# Test slice_pad agains reference sharded pad
assert_array_equal(sharded_padded.shape, expected_sharded_pad_shape)
assert_array_equal(jaxdecomp_padded.shape, expected_sharded_pad_shape)
# Test unpadding
print("-" * 40)
print(f"Testing unpadding")
with mesh:
unpadded_array = lax.pad(
padded_array,
padding_value=0.0,
padding_config=((-32, -32, 0), (-32, -32, 0), (0, 0, 0)))
sharded_unpadded = sharded_unpad(sharded_padded)
jaxdecomp_unpadded = slice_unpad(jaxdecomp_padded, padding, pdims)
print(
f"Shape of unpadded_array {unpadded_array.shape} should be {global_shape}"
)
print(
f"Shape of sharded_unpadded {sharded_unpadded.shape} should be {global_shape}"
)
print(
f"Shape of jaxdecomp_unpadded {jaxdecomp_unpadded.shape} should be {global_shape}"
)
first_x, last_x = padding[0]
first_y, last_y = padding[1]
# Using pad on a global array will unpad the array (uses communication)
assert_array_equal(unpadded_array.shape, global_shape)
# Test slice_pad agains reference sharded pad
assert_array_equal(sharded_unpadded.shape, global_shape)
assert_array_equal(jaxdecomp_unpadded.shape, global_shape)
gathered_original = multihost_utils.process_allgather(
global_array, tiled=True)
gathered_unpadded = multihost_utils.process_allgather(
unpadded_array, tiled=True)
# Make sure the unpadded arrays is equal to the original array
assert_array_equal(gathered_original, gathered_unpadded)
@pytest.mark.parametrize("pdims",
decomp) # Test with Slab and Pencil decompositions
@pytest.mark.parametrize("global_shape",
global_shapes) # Test cubes, non-cubes and primes
def test_complex_unpad(pdims, global_shape):
print("*" * 80)
print(f"Testing with pdims {pdims} and global_shape {global_shape}")
global_array, mesh = create_spmd_array(global_shape, pdims, complex=True)
padding = ((32, 32), (32, 32), (0, 0))
with mesh:
jaxdecomp_padded = slice_pad(global_array, padding, pdims)
jaxdecomp_unpadded = slice_unpad(jaxdecomp_padded, padding, pdims)
gathered_original = multihost_utils.process_allgather(
global_array, tiled=True)
gathered_unpadded = multihost_utils.process_allgather(
jaxdecomp_unpadded, tiled=True)
# Make sure the unpadded arrays is equal to the original array
assert_array_equal(gathered_original, gathered_unpadded)
......@@ -9,17 +9,19 @@ size = jax.device_count()
import pytest
jax.config.update("jax_enable_x64", False)
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
from numpy.testing import assert_allclose, assert_array_equal
from functools import partial
import jaxdecomp
from jaxdecomp import (transposeXtoY, transposeYtoX, transposeYtoZ,
transposeZtoY)
all_gather = partial(process_allgather, tiled=True)
pencil_1 = (size // 2, size // (size // 2)) # 2x2 for V100 and 4x2 for A100
pencil_2 = (size // (size // 2), size // 2) # 2x2 for V100 and 2x4 for A100
......@@ -77,12 +79,12 @@ class TestTransposes:
assert compare_sharding(jd_tranposed_zy.sharding, y_pencil_sharding)
assert compare_sharding(jd_tranposed_yx.sharding, original_sharding)
gathered_array = process_allgather(global_array, tiled=True)
gathered_array = all_gather(global_array)
gathered_jd_xy = process_allgather(jd_tranposed_xy, tiled=True)
gathered_jd_yz = process_allgather(jd_tranposed_yz, tiled=True)
gathered_jd_zy = process_allgather(jd_tranposed_zy, tiled=True)
gathered_jd_yx = process_allgather(jd_tranposed_yx, tiled=True)
gathered_jd_xy = all_gather(jd_tranposed_xy)
gathered_jd_yz = all_gather(jd_tranposed_yz)
gathered_jd_zy = all_gather(jd_tranposed_zy)
gathered_jd_yx = all_gather(jd_tranposed_yx)
# Explanation :
# Tranposing forward is a shift axis to the right so ZYX to XZY to YXZ (2 0 1)
......@@ -192,8 +194,8 @@ class TestTransposesGrad:
array_grad = jax.grad(jaxdecomp_transpose)(global_array)
print("Here is the gradient I'm getting", array_grad.shape)
gathered_array = process_allgather(global_array, tiled=True)
gathered_grads = process_allgather(array_grad, tiled=True)
gathered_array = all_gather(global_array)
gathered_grads = all_gather(array_grad)
jax_grad = jax.grad(jax_transpose)(gathered_array)
print(f"Shape of JAX array {jax_grad.shape}")
......