Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • wassim/jaxDecomp
1 result
Show changes
Commits on Source (2)
......@@ -113,9 +113,9 @@ def _halo_pencils(operand, halo_extents: HaloExtentType,
permutations_x = slice(None, None) if periodic_x else slice(None, -1)
permutations_y = slice(None, None) if periodic_y else slice(None, -1)
reverse_indexing_z = [(j, (j + 1) % z_size) for j in range(z_size)
reverse_indexing_z = [((j + 1) % z_size, j) for j in range(z_size)
][permutations_x]
forward_indexing_z = [((j + 1) % z_size, j) for j in range(z_size)
forward_indexing_z = [(j, (j + 1) % z_size) for j in range(z_size)
][permutations_x]
reverse_indexing_y = [((j + 1) % y_size, j) for j in range(y_size)
][permutations_y]
......
......@@ -79,19 +79,18 @@ def test_halo_against_cudecomp(pdims):
def pad(arr):
return jnp.pad(arr, padding, mode='linear_ramp', end_values=20)
with mesh:
# perform halo exchange
updated_array = pad(global_array)
jax_exchanged = jaxdecomp.halo_exchange(
updated_array,
halo_extents=halo_extents,
halo_periods=periodic,
backend="JAX")
cudecomp_exchanged = jaxdecomp.halo_exchange(
updated_array,
halo_extents=halo_extents,
halo_periods=periodic,
backend="CUDECOMP")
# perform halo exchange
updated_array = pad(global_array)
jax_exchanged = jaxdecomp.halo_exchange(
updated_array,
halo_extents=halo_extents,
halo_periods=periodic,
backend="JAX")
cudecomp_exchanged = jaxdecomp.halo_exchange(
updated_array,
halo_extents=halo_extents,
halo_periods=periodic,
backend="CUDECOMP")
g_array = all_gather(updated_array)
g_jax_exchanged = all_gather(jax_exchanged)
......@@ -136,20 +135,19 @@ class TestHaloExchange:
return arr
with mesh:
# perform halo exchange
padded_array = multiply(global_array)
padded_array = pad(padded_array)
periodic_exchanged_array = jaxdecomp.halo_exchange(
padded_array,
halo_extents=halo_extents,
halo_periods=periodic,
backend=backend)
exchanged_array = jaxdecomp.halo_exchange(
padded_array,
halo_extents=halo_extents,
halo_periods=periodic,
backend=backend)
# perform halo exchange
padded_array = multiply(global_array)
padded_array = pad(padded_array)
periodic_exchanged_array = jaxdecomp.halo_exchange(
padded_array,
halo_extents=halo_extents,
halo_periods=periodic,
backend=backend)
exchanged_array = jaxdecomp.halo_exchange(
padded_array,
halo_extents=halo_extents,
halo_periods=periodic,
backend=backend)
# Gather array from all processes
# gathered_array = multihost_utils.process_allgather(global_array,tiled=True)
......@@ -193,46 +191,47 @@ class TestHaloExchange:
print(f"z {z_slice} y {y_slice}")
print(f"original slice \n{original_slice[:,:,0]}")
print(f"up slice \n{up_slice[:,:,0]}")
print(f"current slice \n{current_slice[:,:,0]}")
print(f"down slice \n{down_slice[:,:,0]}")
print("--" * 40)
continue
# if up down padding check the up down slices
if pdims[0] > 1:
# if up down padding check the up down slices
if pdims[1] > 1:
# Check the upper padding
assert_array_equal(current_slice[:halo_size], up_slice[-halo_size:])
assert_array_equal(current_slice[:halo_size], up_slice[-2 * halo_size:-halo_size])
# Check the lower padding
assert_array_equal(current_slice[-halo_size:], down_slice[:halo_size])
assert_array_equal(current_slice[-halo_size:], down_slice[halo_size:halo_size * 2])
# if left right padding check the left right slices
if pdims[1] > 1:
if pdims[0] > 1:
# Check the left padding
assert_array_equal(current_slice[:, :halo_size],
left_slice[:, -halo_size:])
left_slice[:, -2 * halo_size:-halo_size:])
# Check the right padding
assert_array_equal(current_slice[:, -halo_size:],
right_slice[:, :halo_size])
right_slice[:, halo_size:halo_size * 2])
# if both padded check the corners
if pdims[0] > 1 and pdims[1] > 1:
# Check the upper left corner
assert_array_equal(current_slice[:halo_size, :halo_size],
upper_left_corner[-halo_size:, -halo_size:])
upper_left_corner[-2 * halo_size:-halo_size,-2 * halo_size:-halo_size])
# Check the upper right corner
assert_array_equal(current_slice[:halo_size, -halo_size:],
upper_right_corner[-halo_size:, :halo_size])
upper_right_corner[-2 * halo_size:-halo_size, halo_size:halo_size * 2])
# Check the lower left corner
assert_array_equal(current_slice[-halo_size:, :halo_size],
lower_left_corner[:halo_size, -halo_size:])
lower_left_corner[halo_size:halo_size * 2, -2 * halo_size:-halo_size])
# Check the lower right corner
assert_array_equal(current_slice[-halo_size:, -halo_size:],
lower_right_corner[:halo_size, :halo_size])
lower_right_corner[halo_size:halo_size * 2, halo_size:halo_size * 2])
@pytest.mark.skipif(not is_on_cluster(), reason="Only run on cluster")
@pytest.mark.parametrize("pdims", pdims)
def test_cudecomp_halo(self, pdims):
self.run_test((16, 16, 16), pdims, "CUDECOMP")
self.run_test((32, 32, 32), pdims, "CUDECOMP")
@pytest.mark.parametrize("pdims", pdims)
def test_jax_halo(self, pdims):
......