Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • wassim/jaxDecomp
1 result
Show changes
Commits on Source (70)
Showing
with 2027 additions and 1566 deletions
name: Build and upload to PyPI
on:
workflow_dispatch:
pull_request:
push:
branches:
- main
release:
types:
- published
types: [published]
jobs:
build_wheels:
......@@ -25,7 +19,7 @@ jobs:
- name: Build wheels
uses: pypa/cibuildwheel@v2.21.3
env:
CIBW_BUILD: "cp310-* cp311-* cp312-*"
CIBW_BUILD: "cp310-manylinux_x86_64 cp311-manylinux_x86_64 cp312-manylinux_x86_64"
CIBW_BUILD_VERBOSITY: 2
- uses: actions/upload-artifact@v4
with:
......
......@@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.10.4]
python-version: ["3.10" , "3.11" , "3.12"]
steps:
- name: Checkout Source
......
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
rev: 'v5.0.0'
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/google/yapf
rev: v0.40.2
hooks:
- id: yapf
args: ['--parallel', '--in-place']
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.2
hooks:
- id: ruff-format
- id: ruff
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.4
hooks:
......
......@@ -94,4 +94,4 @@ else()
endif()
set_target_properties(_jaxdecomp PROPERTIES INSTALL_RPATH "$ORIGIN/lib")
install(TARGETS _jaxdecomp LIBRARY DESTINATION . PUBLIC_HEADER DESTINATION .)
install(TARGETS _jaxdecomp LIBRARY DESTINATION jaxdecomplib PUBLIC_HEADER DESTINATION jaxdecomplib)
# jaxDecomp: JAX Library for 3D Domain Decomposition and Parallel FFTs
[![Build](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/actions/workflows/github-deploy.yml/badge.svg)](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/actions/workflows/github-deploy.yml)
[![Code Formatting](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/actions/workflows/formatting.yml/badge.svg)](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/actions/workflows/formatting.yml)
[![Tests](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/actions/workflows/tests.yml/badge.svg)](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/actions/workflows/tests.yml/badge.svg)
[![Tests](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/actions/workflows/tests.yml/badge.svg)](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/actions/workflows/tests.yml)
[![MIT License](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
> **Important**
> Version `0.2.0` includes a **pure JAX backend** that **no longer requires MPI**. For multi-node runs, MPI and NCCL backends are still available through **cuDecomp**.
> [!IMPORTANT]
> Version `0.2.0` has a pure JAX backend and no longer requires MPI .. MPI and NCCL backends are still available through cuDecomp
`jaxDecomp` provides JAX bindings for NVIDIA's [cuDecomp](https://nvidia.github.io/cuDecomp/index.html) library [(Romero et al. 2022)](https://dl.acm.org/doi/abs/10.1145/3539781.3539797), enabling **multi-node parallel FFTs and halo exchanges** directly in low-level NCCL/CUDA-Aware MPI from your JAX code.
JAX bindings for NVIDIA's [cuDecomp](https://nvidia.github.io/cuDecomp/index.html) library [(Romero et al. 2022)](https://dl.acm.org/doi/abs/10.1145/3539781.3539797), allowing for efficient **multi-node parallel FFTs and halo exchanges** directly in low level NCCL/CUDA-Aware MPI from your JAX code :tada:
---
## Usage
Here is an example of how to use `jaxDecomp` to perform a 3D FFT on a 3D array distributed across multiple GPUs. This example also includes a halo exchange operation, which is a common operation in many scientific computing applications.
Below is a simple code snippet illustrating how to perform a **3D FFT** on a distributed 3D array, followed by a halo exchange. For demonstration purposes, we force 8 CPU devices via environment variables:
```python
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
os.environ["JAX_PLATFORM_NAME"] = "cpu"
import jax
from jax.experimental import mesh_utils, multihost_utils
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax import numpy as jnp
import jaxdecomp
from functools import partial
# Initialize jax distributed to instruct jax local process which GPU to use
jax.distributed.initialize()
rank = jax.process_index()
# Setup a processor mesh (should be same size as "size")
# Create a 2x4 mesh of devices on CPU
pdims = (2, 4)
global_shape = (1024, 1024, 1024)
# Initialize an array with the expected gobal size
local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0],
global_shape[2])
# Remap to the global array from the local slice
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y'))
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
global_array = jax.make_array_from_callback(
global_shape,
sharding,
data_callback=lambda _: jax.random.normal(
jax.random.PRNGKey(rank), local_shape))
padding_width = ((32, 32), (32, 32), (0, 0)) # Has to a tuple of tuples
@partial(
shard_map, mesh=mesh, in_specs=(P('x', 'y'), P()), out_specs=P('x', 'y'))
def pad(arr, padding):
return jnp.pad(arr, padding)
@partial(
shard_map, mesh=mesh, in_specs=(P('x', 'y'), P()), out_specs=P('x', 'y'))
def reduce_halo(x, pad_width):
halo_x , _ = pad_width[0]
halo_y , _ = pad_width[1]
# Apply corrections along x
x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2])
x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:])
# Apply corrections along y
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
return x[halo_x:-halo_x, halo_y:-halo_y]
@jax.jit
def modify_array(array):
return 2 * array + 1
# Forward FFT
karray = jaxdecomp.fft.pfft3d(global_array)
# Do some operation on your array
karray = modify_array(karray)
kvec = jaxdecomp.fft.fftfreq3d(karray)
# Do a gradient in the X axis
karray_gradient = 1j * kvec[0] * karray
# Reverse FFT
recarray = jaxdecomp.fft.pifft3d(karray_gradient).real
# Add halo regions to our array
padded_array = pad(recarray, padding_width)
# Perform a halo exchange
exchanged_array = jaxdecomp.halo_exchange(
padded_array, halo_extents=(16, 16), halo_periods=(True, True))
# Reduce the halo regions and remove the padding
reduced_array = reduce_halo(exchanged_array, padding_width)
# Gather the results (only if it fits on CPU memory)
gathered_array = multihost_utils.process_allgather(recarray, tiled=True)
# Create a random 3D array and enforce sharding
a = jax.random.normal(jax.random.PRNGKey(0), (1024, 1024, 1024))
a = jax.lax.with_sharding_constraint(a, sharding)
# Parallel FFTs
k_array = jaxdecomp.fft.pfft3d(a)
rec_array = jaxdecomp.fft.pifft3d(a)
# Finalize the distributed JAX
jax.distributed.shutdown()
# Parallel halo exchange
exchanged = jaxdecomp.halo_exchange(a, halo_extents=(16, 16), halo_periods=(True, True))
```
**Note**: All these functions are jittable and have well defined derivatives!
This script would have to be run on 8 GPUs in total with something like
All these functions are **JIT**-compatible and support **automatic differentiation** (with [some caveats](docs/02_caveats.md)).
See also:
- [Basic Usage](docs/01-basic_usage.md)
- [Distributed LPT Example](examples/lpt_nbody_demo.py)
---
## Running on an HPC Cluster
On HPC clusters (e.g., Jean Zay, Perlmutter), you typically launch your script with:
```bash
mpirun -n 8 python demo.py
srun python demo.py
```
or on a slurm cluster like Jean jean-zay
or
```bash
srun -n 8 python demo.py
mpirun -n 8 python demo.py
```
See the Slurm [README](slurms/README.md) and [template script](slurms/template.slurm) for more details.
---
## Using cuDecomp (MPI and NCCL)
You can also use the cuDecomp backend by compiling the library with the right flag (check the [installation instructions](#install)) and setting the backend to use MPI or NCCL. Here is how you can do it:
For **multi-node** or advanced features, compile and install with cuDecomp enabled:
```python
import jaxdecomp
# Initialise the library, and optionally selects a communication backend (defaults to NCCL)
# Optionally select communication backends (defaults to NCCL)
jaxdecomp.config.update('halo_comm_backend', jaxdecomp.HALO_COMM_MPI)
jaxdecomp.config.update('transpose_comm_backend', jaxdecomp.TRANSPOSE_COMM_MPI_A2A)
# and then call the functions with the cuDecomp backends
karray = jaxdecomp.fft.pfft3d(global_array , backend='cudecomp')
recarray = jaxdecomp.fft.pifft3d(karray , backend='cudecomp')
# Then specify 'backend="cudecomp"' in your FFT or halo calls:
karray = jaxdecomp.fft.pfft3d(global_array, backend='cudecomp')
recarray = jaxdecomp.fft.pifft3d(karray, backend='cudecomp')
exchanged_array = jaxdecomp.halo_exchange(
padded_array, halo_extents=(16, 16), halo_periods=(True, True), backend='cudecomp')
padded_array, halo_extents=(16, 16), halo_periods=(True, True), backend='cudecomp'
)
```
please check the tests in [tests](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/tree/main/tests) folder for more examples.
On an HPC cluster like Jean Zay you should do this
## Install
```bash
$ srun python demo.py
```
### 1. Pure JAX Version (Easy / Recommended)
Check the slurm [README](slurms/README.md) and [template](slurms/template.slurm) for more information on how to run on a Jean Zay.
`jaxDecomp` is on PyPI:
## Install
1. **Install the appropriate JAX wheel**:
- **GPU**:
```bash
pip install --upgrade "jax[cuda]"
```
- **CPU**:
```bash
pip install --upgrade "jax[cpu]"
```
2. **Install `jaxdecomp`**:
```bash
pip install jaxdecomp
```
### Installing the pure JAX version (Easy)
This setup uses the pure-JAX backend—**no** MPI required.
jaxDecomp is available on pypi and can be installed via pip:
### 2. JAX + cuDecomp Backend (Advanced)
First install desired JAX version
If you need **multi-node** support, you can build from GitHub with cuDecomp enabled. This requires the [NVIDIA HPC SDK](https://developer.nvidia.com/hpc-sdk) or a similar environment providing a CUDA-aware MPI toolchain.
For GPU
```bash
pip install -U jax[cuda12]
pip install -U pip
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -Ccmake.define.JD_CUDECOMP_BACKEND=ON
```
For CPU
```bash
pip install -U jax[cpu]
```
Then you can pip install jaxdecomp
```bash
pip install jaxdecomp
```
- If CMake cannot find NVHPC, set:
```bash
export CMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH:$NVCOMPILERS/$NVARCH/22.9/cmake
```
and then install again.
### Installing JAX and cuDecomp (Advanced)
---
You need to install from this github after installing or loading the correct modules
## Machine-Specific Notes
This install procedure assumes that the [NVIDIA HPC SDK](https://developer.nvidia.com/hpc-sdk) is available in your environment. You can either install it from the NVIDIA website, or better yet, it may be available as a module on your cluster.
### IDRIS Jean Zay (HPE SGI 8600)
### Building jaxDecomp
As of October 2024, loading modules **in this exact order** works:
```bash
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -CCmake.define.JD_CUDECOMP_BACKEND=ON
```
If CMake complains of not finding the NVHPC SDK, you can manually specify the location
of the sdk's cmake files like so:
```
$ export CMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH:$NVCOMPILERS/$NVARCH/22.9/cmake
$ pip install --user .
```
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
### Specific Install Notes for Specific Machines
# Install JAX
pip install --upgrade "jax[cuda]"
#### IDRIS [Jean Zay](http://www.idris.fr/eng/jean-zay/cpu/jean-zay-cpu-hw-eng.html) HPE SGI 8600 supercomputer
As of October. 2024, the following works:
You need to load modules in that order exactly.
```bash
# Load NVHPC 23.9 because it has cuda 12.2
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
# Installing jax
pip install --upgrade "jax[cuda12]"
# Installing jaxdecomp
export CMAKE_PREFIX_PATH=$NVHPC_ROOT/cmake # Not always needed
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -CCmake.define.JD_CUDECOMP_BACKEND=ON
# Install jaxDecomp with cuDecomp
export CMAKE_PREFIX_PATH=$NVHPC_ROOT/cmake # sometimes needed
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -Ccmake.define.JD_CUDECOMP_BACKEND=ON
```
__Note__: This is needed **only** if you want to use the cuDecomp backend. If you are using the pure JAX backend, you can skip the NVHPC SDK installation and just `pip install jaxdecomp` **after** installing the correct JAX version for your hardware.
**Note**: If using only the pure-JAX backend, you do not need NVHPC.
### NERSC Perlmutter (HPE Cray EX)
#### NERSC [Perlmutter](https://docs.nersc.gov/systems/perlmutter/architecture/) HPE Cray EX supercomputer
As of November 2022:
As of Nov. 2022, the following works:
```bash
module load PrgEnv-nvhpc python
export CRAY_ACCEL_TARGET=nvidia80
# Installing jax
pip install --upgrade "jax[cuda12]"
# Installing jaxdecomp
# Install JAX
pip install --upgrade "jax[cuda]"
# Install jaxDecomp w/ cuDecomp
export CMAKE_PREFIX_PATH=/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cmake
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -CCmake.define.JD_CUDECOMP_BACKEND=ON
```
## Backend configuration (Only for cuDecomp)
__Note__: For the JAX backend, only NCCL is available.
---
## Backend Configuration (cuDecomp Only)
We can set the default communication backend to use for cuDecomp operations either through a `config` module, or environment variables. This will allow the users to choose at startup (although can be changed afterwards) the communication backend, making it possible to use CUDA-aware MPI or NVSHMEM as preferred.
By default, cuDecomp uses NCCL for inter-device communication. You can customize this at runtime:
Here is how it would like:
```python
jaxdecomp.config.update('transpose_comm_backend', 'NCCL')
# We could for instance time how long it takes to execute in this mode
%timeit pfft3(y)
# And then update the backend
jaxdecomp.config.update('transpose_comm_backend', 'MPI')
# And measure again
%timeit pfft3(y)
import jaxdecomp
# Choose MPI or NVSHMEM for halo and transpose ops
jaxdecomp.config.update('transpose_comm_backend', jaxdecomp.TRANSPOSE_COMM_MPI_A2A)
jaxdecomp.config.update('halo_comm_backend', jaxdecomp.HALO_COMM_MPI)
```
## Autotune computational mesh (Only for cuDecomp)
This can also be managed via environment variables, as described in the [docs](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/tree/main/docs).
---
## Autotune Computational Mesh (cuDecomp Only)
The cuDecomp library can **autotune** the partition layout to maximize performance:
We can also make things fancier, since cuDecomp is able to autotune, we could use it to tell us what is the best way to partition the data given the available GPUs, something like this:
```python
automesh = jaxdecomp.autotune(shape=[512,512,512])
# This is a JAX Sharding spec object, optimized for the given GPUs
# and shape of the tensor
# 'automesh' is an optimized partition layout.
# You can then create a JAX Sharding spec from this:
from jax.sharding import PositionalSharding
sharding = PositionalSharding(automesh)
```
---
**License**: This project is licensed under the [MIT License](https://opensource.org/licenses/MIT).
For more details, see the [examples](examples/) directory and the [documentation](docs/). Contributions and issues are welcome!
# Basic Usage
This example demonstrates how to run a JAX-based script in a **distributed** setup with `jaxDecomp`.
You can launch this script with **8 processes** (and thus 8 GPUs) via:
```bash
mpirun -n 8 python demo.py
```
or on a Slurm-based cluster (e.g., Jean Zay) using:
```bash
srun -n 8 python demo.py
```
Below is a full example script illustrating:
1. **Initializing JAX distributed** across multiple GPUs
2. **Creating a globally sharded 3D array**
3. **Performing a parallel FFT**
4. **Applying a halo exchange**
5. **Gathering results** back to a single process
```python
import jax
from jax.experimental import mesh_utils, multihost_utils
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax import numpy as jnp
import jaxdecomp
from functools import partial
# -----------------------------
# 1. Initialize JAX distributed
# -----------------------------
# This instructs JAX which GPU to use per process.
jax.distributed.initialize()
rank = jax.process_index()
# -----------------------------
# 2. Create a globally sharded array
# -----------------------------
# Suppose we have 8 total processes. We'll create a processor mesh
# of shape (2,4). Adjust these as needed for your environment.
pdims = (2, 4)
global_shape = (1024, 1024, 1024)
# Compute local slice sizes
local_shape = (
global_shape[0] // pdims[1],
global_shape[1] // pdims[0],
global_shape[2]
)
# Create a mesh of devices based on pdims
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y'))
# Define the sharding spec
sharding = NamedSharding(mesh, P('x', 'y'))
# Create a distributed global array
global_array = jax.make_array_from_callback(
global_shape,
sharding,
data_callback=lambda _: jax.random.normal(
jax.random.PRNGKey(rank), local_shape)
)
# -----------------------------
# 3. Perform a parallel FFT
# -----------------------------
# We will also demonstrate applying a halo exchange afterwards.
padding_width = ((32, 32), (32, 32), (0, 0)) # must be a tuple of tuples
# Shard-map helper to pad an array
@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P()), out_specs=P('x', 'y'))
def pad(arr, padding):
return jnp.pad(arr, padding)
# Shard-map helper to remove the padded halo
@partial(shard_map, mesh=mesh, in_specs=(P('x', 'y'), P()), out_specs=P('x', 'y'))
def reduce_halo(x, pad_width):
halo_x, _ = pad_width[0]
halo_y, _ = pad_width[1]
# Apply corrections along x
x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2])
x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:])
# Apply corrections along y
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
return x[halo_x:-halo_x, halo_y:-halo_y]
# A simple JITed function to modify an array
@jax.jit
def modify_array(array):
return 2 * array + 1
# Forward FFT
karray = jaxdecomp.fft.pfft3d(global_array)
# Apply some operation (e.g., scale + offset)
karray = modify_array(karray)
# Obtain frequency grid
kvec = jaxdecomp.fft.fftfreq3d(karray)
# Demonstration: compute a gradient in the x-axis in Fourier space
karray_gradient = 1j * kvec[0] * karray
# Inverse FFT
recarray = jaxdecomp.fft.pifft3d(karray_gradient).real
# -----------------------------
# 4. Perform a halo exchange
# -----------------------------
# Example: pad the array, exchange halos, then remove the padding
padded_array = pad(recarray, padding_width)
# Exchange halo across processes
exchanged_array = jaxdecomp.halo_exchange(
padded_array,
halo_extents=(16, 16),
halo_periods=(True, True)
)
# Remove the halo paddings after exchange
reduced_array = reduce_halo(exchanged_array, padding_width)
# -----------------------------
# 5. Gather results (optional)
# -----------------------------
# Only do this if the final array can fit in CPU memory.
gathered_array = multihost_utils.process_allgather(recarray, tiled=True)
# -----------------------------
# Finalize distributed JAX
# -----------------------------
jax.distributed.shutdown()
```
When you run this script, each MPI process (or Slurm task) will create its local slice of the global array. The FFT and halo operations are orchestrated in parallel using JAX and `jaxDecomp`.
# Caveats and Workarounds: Autodiff + SPMD Sharding with `jaxDecomp`
This page explains some **known caveats** when using JAX’s automatic differentiation (AD) with the distributed FFT routines in `jaxDecomp`. Specifically, you may encounter errors when combining **SPMD sharding** and **AD transforms** such as `jax.grad`, `jax.jacfwd`, or `jax.jacrev`. Below, we show how to annotate your code to avoid these issues.
---
## 1. Background
- **SPMD Sharding in JAX**: When you run JAX on multiple devices (e.g., multiple GPUs or CPU devices), you can specify how arrays should be partitioned across those devices using a mesh and a sharding specification (`NamedSharding`, `PartitionSpec`, etc.).
- **AD Transforms**: JAX’s `jax.grad`, `jax.jacfwd`, and `jax.jacrev` automatically compute derivatives of your functions. Under the hood, JAX sometimes rewrites your function into a new function that can cause changes to sharding or lead to “unsharded” arrays.
In certain scenarios, JAX’s AD transformations might **lose** the sharding specification if the function’s first operation is a parallel operation (like `pfft3d`). This can trigger errors like:
```
Input sharding was found to be None while lowering the SPMD rule.
You are likely calling jacfwd with pfft as the first function.
```
---
## 2. `jacfwd` with Parallel FFT
### Problem
Consider the following function, which calls `pfft3d` immediately:
```python
def forward(a):
return jaxdecomp.fft.pfft3d(a).real
```
If we attempt:
```python
jax.jacfwd(forward)(a)
```
we will encouter this error:
```
Input sharding was found to be None while lowering the SPMD rule.
You are likely calling jacfwd with pfft as the first function.
due to a bug in JAX, the sharding is not correctly passed to the SPMD rule.
```
### Workaround
By **annotating** the input array’s sharding *explicitly* within the function we differentiate, we ensure JAX does not lose the sharding information. For instance:
```python
import jax
import jax.numpy as jnp
from jax import lax
import jaxdecomp
# Suppose we have a sharding object named `sharding`.
# In your real code, you might do something like:
# mesh = jax.make_mesh((1, 8), axis_names=('x','y'))
# sharding = NamedSharding(mesh, P('x', 'y'))
def annotated_forward(a):
# explicitly ensure 'a' is recognized as sharded
a = lax.with_sharding_constraint(a, sharding)
return jaxdecomp.fft.pfft3d(a).real
# Now jacfwd works without losing the sharding:
jax.jacfwd(annotated_forward)(a)
```
---
## 3. `jacrev` with Parallel FFT
### Problem
When computing reverse-mode Jacobians (`jax.jacrev`), a similar issue can arise. If our function is:
```python
def forward(a):
return jaxdecomp.fft.pfft3d(a).real
```
Then:
```python
jax.jacrev(forward)(a)
```
can cause JAX to replicate the array or fail the sharding constraint. We might see an unexpected result like a fully replicated array (`SingleDeviceSharding`), or an error about “Input sharding was found to be None ...”.
### Workaround
Again, we can **annotate** the function:
```python
def annotated_forward(a):
a = lax.with_sharding_constraint(a, sharding)
return jaxdecomp.fft.pfft3d(a).real
# Now jacrev retains correct sharding
rev_jac = jax.jacrev(annotated_forward)(a)
```
You can verify the resulting array’s sharding with:
```python
print(rev_jac.sharding)
```
---
## 4. `grad` of a Scalar-Reduced FFT
### Problem
When your function returns a scalar (e.g., via `jnp.sum` of the FFT output), the gradient pipeline might fail with the same “Input sharding was found to be None” error. For example:
```python
def fft_reduce(a):
return jaxdecomp.fft.pfft3d(a).real.sum()
jax.grad(fft_reduce)(a)
```
can fail for the same reason: the initial pfft step is ambiguous to JAX’s SPMD rule.
### Workaround
1. **Perform `pfft3d`**,
2. **Annotate** the output array’s new sharding,
3. Then reduce.
Example:
```python
def fft_reduce_with_annotation(a):
# Perform FFT
res = jaxdecomp.fft.pfft3d(a).real
# Annotate the resulting array with the sharding that pfft3d produces:
out_sharding = jaxdecomp.get_fft_output_sharding(sharding)
res = lax.with_sharding_constraint(res, out_sharding)
# Now reduce to scalar
return res.sum()
# This will now run successfully
grad_val = jax.grad(fft_reduce_with_annotation)(a)
```
---
## 5. Summary of Best Practices
1. **Annotate Inputs**
If your function starts with `pfft3d(...)`, insert a `lax.with_sharding_constraint(input_array, sharding)` to ensure JAX retains the correct distribution info during AD transforms.
2. **Annotate Outputs**
For scalar-reduction patterns (`.sum()`, `.mean()`, etc.), or any time the output shape differs significantly from the input, use `lax.with_sharding_constraint(output_array, new_sharding)` to ensure the partial derivatives keep correct partitioning.
3. **Check Sharding**
Inspect the `.sharding` attribute of returned arrays after `jax.jacrev`, `jax.jacfwd`, or `jax.grad` to confirm that the output is still sharded the way you intend.
---
## 6. Conclusion
Due to a **bug** in how JAX’s AD transforms currently interact with SPMD partitioning, you may need to explicitly annotate sharding constraints around FFT calls. By applying `lax.with_sharding_constraint` or by retrieving the FFT’s “expected” output sharding (via `jaxdecomp.get_fft_output_sharding`), you can ensure that your distributed computations remain partitioned as expected.
Feel free to open an issue on GitHub if you encounter other scenarios where sharding + AD transforms produce unexpected results!
import argparse
import os
from functools import partial
from typing import Any, Callable, Hashable
from typing import Any, Callable
from collections.abc import Hashable
Specs = Any
AxisName = Hashable
import jax
jax.config.update('jax_enable_x64', False)
jax.config.update("jax_enable_x64", False)
import jax.numpy as jnp
import jax_cosmo as jc
......@@ -23,26 +24,27 @@ from scatter import scatter
import jaxdecomp
def shmap(f: Callable,
in_specs: Specs,
out_specs: Specs,
check_rep: bool = True,
auto: frozenset[AxisName] = frozenset()):
"""Helper function to create a shard_map function that extracts the mesh from the
def shmap(
f: Callable,
in_specs: Specs,
out_specs: Specs,
check_rep: bool = True,
auto: frozenset[AxisName] = frozenset(),
):
"""Helper function to create a shard_map function that extracts the mesh from the
context."""
mesh = mesh_lib.thread_resources.env.physical_mesh
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
mesh = mesh_lib.thread_resources.env.physical_mesh
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
def _global_to_local_size(nc: int):
""" Helper function to get the local size of a mesh given the global size.
"""
pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape
return [nc // pdims[0], nc // pdims[1], nc]
"""Helper function to get the local size of a mesh given the global size."""
pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape
return [nc // pdims[0], nc // pdims[1], nc]
def fttk(nc: int):
"""
"""
Generate Fourier transform wave numbers for a given mesh.
Args:
......@@ -51,25 +53,27 @@ def fttk(nc: int):
Returns:
list: List of wave number arrays for each dimension in
the order [kx, ky, kz].
"""
kd = np.fft.fftfreq(nc) * 2 * np.pi
@partial(
shmap,
in_specs=(P('x'), P('y'), P(None)),
out_specs=(P('x'), P(None, 'y'), P(None)))
def get_kvec(ky, kz, kx):
return (ky.reshape([-1, 1, 1]),
kz.reshape([1, -1, 1]),
kx.reshape([1, 1, -1])) # yapf: disable
ky, kz, kx = get_kvec(kd, kd, kd) # The order of the output
# corresponds to the order of dimensions in the transposed FFT
# output
return kx, ky, kz
"""
kd = np.fft.fftfreq(nc) * 2 * np.pi
@partial(
shmap,
in_specs=(P("x"), P("y"), P(None)),
out_specs=(P("x"), P(None, "y"), P(None)),
)
def get_kvec(ky, kz, kx):
return (ky.reshape([-1, 1, 1]),
kz.reshape([1, -1, 1]),
kx.reshape([1, 1, -1])) # yapf: disable
ky, kz, kx = get_kvec(kd, kd, kd) # The order of the output
# corresponds to the order of dimensions in the transposed FFT
# output
return kx, ky, kz
def gravity_kernel(kx, ky, kz):
""" Computes a Fourier kernel combining laplace and derivative
"""Computes a Fourier kernel combining laplace and derivative
operators to compute gravitational forces.
Args:
......@@ -77,18 +81,18 @@ def gravity_kernel(kx, ky, kz):
Returns:
tuple of jnp.ndarray: kernels for each dimension.
"""
kk = kx**2 + ky**2 + kz**2
laplace_kernel = jnp.where(kk == 0, 1., 1. / kk)
"""
kk = kx**2 + ky**2 + kz**2
laplace_kernel = jnp.where(kk == 0, 1.0, 1.0 / kk)
grav_kernel = (laplace_kernel * 1j * kx,
laplace_kernel * 1j * ky,
laplace_kernel * 1j * kz) # yapf: disable
return grav_kernel
grav_kernel = (laplace_kernel * 1j * kx,
laplace_kernel * 1j * ky,
laplace_kernel * 1j * kz) # yapf: disable
return grav_kernel
def gaussian_field_and_forces(key, nc, box_size, power_spectrum):
"""
"""
Generate a Gaussian field with a given power spectrum, along with gravitational forces.
Args:
......@@ -99,40 +103,40 @@ def gaussian_field_and_forces(key, nc, box_size, power_spectrum):
Returns:
tuple of jnp.ndarray: The generated Gaussian field and the gravitational forces.
"""
local_mesh_shape = _global_to_local_size(nc)
"""
local_mesh_shape = _global_to_local_size(nc)
# Create a distributed field drawn from a Gaussian distribution in real space
delta = shmap(
partial(jax.random.normal, shape=local_mesh_shape, dtype='float32'),
in_specs=P(None),
out_specs=P('x', 'y'))(key) # yapf: disable
# Create a distributed field drawn from a Gaussian distribution in real space
delta = shmap(
partial(jax.random.normal, shape=local_mesh_shape, dtype='float32'),
in_specs=P(None),
out_specs=P('x', 'y'))(key) # yapf: disable
# Compute the Fourier transform of the field
delta_k = jaxdecomp.fft.pfft3d(delta.astype(jnp.complex64))
# Compute the Fourier transform of the field
delta_k = jaxdecomp.fft.pfft3d(delta.astype(jnp.complex64))
# Compute the Fourier wavenumbers of the field
kx, ky, kz = fttk(nc)
kk = jnp.sqrt(kx**2 + ky**2 + kz**2) * (nc / box_size)
# Compute the Fourier wavenumbers of the field
kx, ky, kz = fttk(nc)
kk = jnp.sqrt(kx**2 + ky**2 + kz**2) * (nc / box_size)
# Apply power spectrum to Fourier modes
delta_k *= (power_spectrum(kk) * (nc / box_size)**3)**0.5
# Apply power spectrum to Fourier modes
delta_k *= (power_spectrum(kk) * (nc / box_size) ** 3) ** 0.5
# Compute inverse Fourier transform to recover the initial conditions in real space
delta = jaxdecomp.fft.pifft3d(delta_k).real
# Compute inverse Fourier transform to recover the initial conditions in real space
delta = jaxdecomp.fft.pifft3d(delta_k).real
# Compute gravitational forces associated with this field
grav_kernel = gravity_kernel(kx, ky, kz)
forces_k = [g * delta_k for g in grav_kernel]
# Compute gravitational forces associated with this field
grav_kernel = gravity_kernel(kx, ky, kz)
forces_k = [g * delta_k for g in grav_kernel]
# Retrieve the forces in real space by inverse Fourier transforming
forces = jnp.stack([jaxdecomp.fft.pifft3d(f).real for f in forces_k], axis=-1)
# Retrieve the forces in real space by inverse Fourier transforming
forces = jnp.stack([jaxdecomp.fft.pifft3d(f).real for f in forces_k], axis=-1)
return delta, forces
return delta, forces
def cic_paint(displacement, halo_size):
""" Paints particles on a mesh using Cloud-In-Cell interpolation.
"""Paints particles on a mesh using Cloud-In-Cell interpolation.
Args:
displacement (jnp.ndarray): Displacement of each particle.
......@@ -140,60 +144,58 @@ def cic_paint(displacement, halo_size):
Returns:
jnp.ndarray: Density field.
"""
local_mesh_shape = _global_to_local_size(displacement.shape[0])
hs = halo_size
@partial(shmap, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))
def cic_op(disp):
""" CiC operation on each local slice of the mesh."""
# Create a mesh to paint the particles on for the local slice
mesh = jnp.zeros(disp.shape[:-1], dtype='float32')
# Padding the mesh along the two first dimensions
mesh = jnp.pad(mesh, [[hs, hs], [hs, hs], [0, 0]])
# Compute the position of the particles on a regular grid
pos_x, pos_y, pos_z = jnp.meshgrid(
jnp.arange(local_mesh_shape[0]),
jnp.arange(local_mesh_shape[1]),
jnp.arange(local_mesh_shape[2]),
indexing='ij')
# adding an offset of size halo size
pos = jnp.stack([pos_x + hs, pos_y + hs, pos_z], axis=-1)
# Apply scatter operation to paint the particles on the local mesh
field = scatter(pos.reshape([-1, 3]), disp.reshape([-1, 3]), mesh)
"""
local_mesh_shape = _global_to_local_size(displacement.shape[0])
hs = halo_size
@partial(shmap, in_specs=(P("x", "y"),), out_specs=P("x", "y"))
def cic_op(disp):
"""CiC operation on each local slice of the mesh."""
# Create a mesh to paint the particles on for the local slice
mesh = jnp.zeros(disp.shape[:-1], dtype="float32")
# Padding the mesh along the two first dimensions
mesh = jnp.pad(mesh, [[hs, hs], [hs, hs], [0, 0]])
# Compute the position of the particles on a regular grid
pos_x, pos_y, pos_z = jnp.meshgrid(
jnp.arange(local_mesh_shape[0]),
jnp.arange(local_mesh_shape[1]),
jnp.arange(local_mesh_shape[2]),
indexing="ij",
)
# adding an offset of size halo size
pos = jnp.stack([pos_x + hs, pos_y + hs, pos_z], axis=-1)
# Apply scatter operation to paint the particles on the local mesh
field = scatter(pos.reshape([-1, 3]), disp.reshape([-1, 3]), mesh)
return field
# Performs painting on a padded mesh, with halos on the two first dimensions
field = cic_op(displacement)
# Run halo exchange to get the correct values at the boundaries
field = jaxdecomp.halo_exchange(field, halo_extents=(hs // 2, hs // 2, 0), halo_periods=(True, True, True))
@partial(shmap, in_specs=(P("x", "y"),), out_specs=P("x", "y"))
def unpad(x):
"""Removes the padding and reduce the halo regions"""
x = x.at[hs : hs + hs // 2].add(x[: hs // 2])
x = x.at[-(hs + hs // 2) : -hs].add(x[-hs // 2 :])
x = x.at[:, hs : hs + hs // 2].add(x[:, : hs // 2])
x = x.at[:, -(hs + hs // 2) : -hs].add(x[:, -hs // 2 :])
return x[hs:-hs, hs:-hs, :]
# Unpad the output array
field = unpad(field)
return field
# Performs painting on a padded mesh, with halos on the two first dimensions
field = cic_op(displacement)
# Run halo exchange to get the correct values at the boundaries
field = jaxdecomp.halo_exchange(
field,
halo_extents=(hs // 2, hs // 2, 0),
halo_periods=(True, True, True))
@partial(shmap, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))
def unpad(x):
""" Removes the padding and reduce the halo regions"""
x = x.at[hs:hs + hs // 2].add(x[:hs // 2])
x = x.at[-(hs + hs // 2):-hs].add(x[-hs // 2:])
x = x.at[:, hs:hs + hs // 2].add(x[:, :hs // 2])
x = x.at[:, -(hs + hs // 2):-hs].add(x[:, -hs // 2:])
return x[hs:-hs, hs:-hs, :]
# Unpad the output array
field = unpad(field)
return field
@partial(jax.jit, static_argnames=('nc', 'box_size', 'halo_size'))
@partial(jax.jit, static_argnames=("nc", "box_size", "halo_size"))
def simulation_fn(key, nc, box_size, halo_size, a=1.0):
"""
"""
Run a simulation to generate initial conditions and density field using LPT.
Args:
......@@ -205,76 +207,67 @@ def simulation_fn(key, nc, box_size, halo_size, a=1.0):
Returns:
tuple of jnp.ndarray: Initial conditions and final density field.
"""
# Build a default cosmology
cosmo = jc.Planck15()
"""
# Build a default cosmology
cosmo = jc.Planck15()
# Create a small function to generate the linear matter power spectrum at arbitrary k
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(cosmo, k)
pk_fn = jax.jit(lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).
reshape(x.shape))
# Create a small function to generate the linear matter power spectrum at arbitrary k
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(cosmo, k)
pk_fn = jax.jit(lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).reshape(x.shape))
# Generate a Gaussian field and gravitational forces from a power spectrum
intial_conditions, initial_forces = gaussian_field_and_forces(
key=key, nc=nc, box_size=box_size, power_spectrum=pk_fn)
# Generate a Gaussian field and gravitational forces from a power spectrum
intial_conditions, initial_forces = gaussian_field_and_forces(key=key, nc=nc, box_size=box_size, power_spectrum=pk_fn)
# Compute the LPT displacement that particles initialy placed on a regular grid
# would experience at scale factor a, by simple Zeldovich approximation
initial_displacement = jc.background.growth_factor(
cosmo, jnp.atleast_1d(a)) * initial_forces
# Compute the LPT displacement that particles initialy placed on a regular grid
# would experience at scale factor a, by simple Zeldovich approximation
initial_displacement = jc.background.growth_factor(cosmo, jnp.atleast_1d(a)) * initial_forces
# Paints the displaced particles on a mesh to obtain the density field
final_field = cic_paint(initial_displacement, halo_size)
# Paints the displaced particles on a mesh to obtain the density field
final_field = cic_paint(initial_displacement, halo_size)
return intial_conditions, final_field
return intial_conditions, final_field
def main(args):
print(f"Running with arguments {args}")
# Setting up distributed jax
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()
# Setting up distributed random numbers
master_key = jax.random.PRNGKey(42)
key = jax.random.split(master_key, size)[rank]
# Create computing mesh and sharding information
pdims = tuple(map(int, args.pdims.split('x')))
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y'))
# Run the simulation on the compute mesh
with mesh:
initial_conds, final_field = simulation_fn(
key=key, nc=args.nc, box_size=args.box_size, halo_size=args.halo_size)
# Create output directory to save the results
output_dir = args.output
os.makedirs(output_dir, exist_ok=True)
np.save(f'{output_dir}/initial_conditions_{rank}.npy',
initial_conds.addressable_data(0))
np.save(f'{output_dir}/field_{rank}.npy', final_field.addressable_data(0))
print(f"Finished saved to {output_dir}")
# Closing distributed jax
jax.distributed.shutdown()
if __name__ == '__main__':
parser = argparse.ArgumentParser("Distributed LPT N-body simulation.")
parser.add_argument(
'--pdims', type=str, default='1x1', help="Processor grid dimensions")
parser.add_argument(
'--nc', type=int, default=256, help="Number of cells in the mesh")
parser.add_argument(
'--box_size', type=float, default=512., help="Box size in Mpc/h")
parser.add_argument(
'--halo_size', type=int, default=32, help="Halo size for painting")
parser.add_argument('--output', type=str, default='out')
args = parser.parse_args()
main(args)
print(f"Running with arguments {args}")
# Setting up distributed jax
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()
# Setting up distributed random numbers
master_key = jax.random.PRNGKey(42)
key = jax.random.split(master_key, size)[rank]
# Create computing mesh and sharding information
pdims = tuple(map(int, args.pdims.split("x")))
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=("x", "y"))
# Run the simulation on the compute mesh
with mesh:
initial_conds, final_field = simulation_fn(key=key, nc=args.nc, box_size=args.box_size, halo_size=args.halo_size)
# Create output directory to save the results
output_dir = args.output
os.makedirs(output_dir, exist_ok=True)
np.save(f"{output_dir}/initial_conditions_{rank}.npy", initial_conds.addressable_data(0))
np.save(f"{output_dir}/field_{rank}.npy", final_field.addressable_data(0))
print(f"Finished saved to {output_dir}")
# Closing distributed jax
jax.distributed.shutdown()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Distributed LPT N-body simulation.")
parser.add_argument("--pdims", type=str, default="1x1", help="Processor grid dimensions")
parser.add_argument("--nc", type=int, default=256, help="Number of cells in the mesh")
parser.add_argument("--box_size", type=float, default=512.0, help="Box size in Mpc/h")
parser.add_argument("--halo_size", type=int, default=32, help="Halo size for painting")
parser.add_argument("--output", type=str, default="out")
args = parser.parse_args()
main(args)
......@@ -38,121 +38,116 @@ from jax.lax import scan
def _chunk_split(ptcl_num, chunk_size, *arrays):
"""Split and reshape particle arrays into chunks and remainders, with the remainders
"""Split and reshape particle arrays into chunks and remainders, with the remainders
preceding the chunks. 0D ones are duplicated as full arrays in the chunks."""
chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num)
remainder_size = ptcl_num % chunk_size
chunk_num = ptcl_num // chunk_size
chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num)
remainder_size = ptcl_num % chunk_size
chunk_num = ptcl_num // chunk_size
remainder = None
chunks = arrays
if remainder_size:
remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays]
chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays]
remainder = None
chunks = arrays
if remainder_size:
remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays]
chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays]
# `scan` triggers errors in scatter and gather without the `full`
chunks = [
x.reshape(chunk_num, chunk_size, *x.shape[1:])
if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks
]
# `scan` triggers errors in scatter and gather without the `full`
chunks = [x.reshape(chunk_num, chunk_size, *x.shape[1:]) if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks]
return remainder, chunks
return remainder, chunks
def enmesh(i1, d1, a1, s1, b12, a2, s2):
"""Multilinear enmeshing."""
i1 = jnp.asarray(i1)
d1 = jnp.asarray(d1)
a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1, dtype=d1.dtype)
if s1 is not None:
s1 = jnp.array(s1, dtype=i1.dtype)
b12 = jnp.float64(b12)
if a2 is not None:
a2 = jnp.float64(a2)
if s2 is not None:
s2 = jnp.array(s2, dtype=i1.dtype)
dim = i1.shape[1]
neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] >> jnp.arange(
dim, dtype=i1.dtype)) & 1
if a2 is not None:
P = i1 * a1 + d1 - b12
P = P[:, jnp.newaxis] # insert neighbor axis
i2 = P + neighbors * a2 # multilinear
"""Multilinear enmeshing."""
i1 = jnp.asarray(i1)
d1 = jnp.asarray(d1)
a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1, dtype=d1.dtype)
if s1 is not None:
L = s1 * a1
i2 %= L
s1 = jnp.array(s1, dtype=i1.dtype)
b12 = jnp.float64(b12)
if a2 is not None:
a2 = jnp.float64(a2)
if s2 is not None:
s2 = jnp.array(s2, dtype=i1.dtype)
i2 //= a2
d2 = P - i2 * a2
dim = i1.shape[1]
neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] >> jnp.arange(dim, dtype=i1.dtype)) & 1
if s1 is not None:
d2 -= jnp.rint(d2 / L) * L # also abs(d2) < a2 is expected
if a2 is not None:
P = i1 * a1 + d1 - b12
P = P[:, jnp.newaxis] # insert neighbor axis
i2 = P + neighbors * a2 # multilinear
i2 = i2.astype(i1.dtype)
d2 = d2.astype(d1.dtype)
a2 = a2.astype(d1.dtype)
if s1 is not None:
L = s1 * a1
i2 %= L
d2 /= a2
else:
i12, d12 = jnp.divmod(b12, a1)
i1 -= i12.astype(i1.dtype)
d1 -= d12.astype(d1.dtype)
i2 //= a2
d2 = P - i2 * a2
# insert neighbor axis
i1 = i1[:, jnp.newaxis]
d1 = d1[:, jnp.newaxis]
if s1 is not None:
d2 -= jnp.rint(d2 / L) * L # also abs(d2) < a2 is expected
# multilinear
d1 /= a1
i2 = jnp.floor(d1).astype(i1.dtype)
i2 += neighbors
d2 = d1 - i2
i2 += i1
i2 = i2.astype(i1.dtype)
d2 = d2.astype(d1.dtype)
a2 = a2.astype(d1.dtype)
if s1 is not None:
i2 %= s1
d2 /= a2
else:
i12, d12 = jnp.divmod(b12, a1)
i1 -= i12.astype(i1.dtype)
d1 -= d12.astype(d1.dtype)
# insert neighbor axis
i1 = i1[:, jnp.newaxis]
d1 = d1[:, jnp.newaxis]
# multilinear
d1 /= a1
i2 = jnp.floor(d1).astype(i1.dtype)
i2 += neighbors
d2 = d1 - i2
i2 += i1
if s1 is not None:
i2 %= s1
f2 = 1 - jnp.abs(d2)
f2 = 1 - jnp.abs(d2)
if s1 is None and s2 is not None: # all i2 >= 0 if s1 is not None
i2 = jnp.where(i2 < 0, s2, i2)
if s1 is None and s2 is not None: # all i2 >= 0 if s1 is not None
i2 = jnp.where(i2 < 0, s2, i2)
f2 = f2.prod(axis=-1)
f2 = f2.prod(axis=-1)
return i2, f2
return i2, f2
def scatter(pmid, disp, mesh, chunk_size=2**24, val=1., offset=0, cell_size=1.):
ptcl_num, spatial_ndim = pmid.shape
val = jnp.asarray(val)
mesh = jnp.asarray(mesh)
def scatter(pmid, disp, mesh, chunk_size=2**24, val=1.0, offset=0, cell_size=1.0):
ptcl_num, spatial_ndim = pmid.shape
val = jnp.asarray(val)
mesh = jnp.asarray(mesh)
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
carry = mesh, offset, cell_size
if remainder is not None:
carry = _scatter_chunk(carry, remainder)[0]
carry = scan(_scatter_chunk, carry, chunks)[0]
mesh = carry[0]
carry = mesh, offset, cell_size
if remainder is not None:
carry = _scatter_chunk(carry, remainder)[0]
carry = scan(_scatter_chunk, carry, chunks)[0]
mesh = carry[0]
return mesh
return mesh
def _scatter_chunk(carry, chunk):
mesh, offset, cell_size = carry
pmid, disp, val = chunk
spatial_ndim = pmid.shape[1]
spatial_shape = mesh.shape
# multilinear mesh indices and fractions
ind, frac = enmesh(pmid, disp, cell_size, spatial_shape, offset, cell_size,
spatial_shape)
# scatter
ind = tuple(ind[..., i] for i in range(spatial_ndim))
mesh = mesh.at[ind].add(val * frac)
carry = mesh, offset, cell_size
return carry, None
mesh, offset, cell_size = carry
pmid, disp, val = chunk
spatial_ndim = pmid.shape[1]
spatial_shape = mesh.shape
# multilinear mesh indices and fractions
ind, frac = enmesh(pmid, disp, cell_size, spatial_shape, offset, cell_size, spatial_shape)
# scatter
ind = tuple(ind[..., i] for i in range(spatial_ndim))
mesh = mesh.at[ind].add(val * frac)
carry = mesh, offset, cell_size
return carry, None
......@@ -70,8 +70,8 @@
}
],
"source": [
"folder = '../out'\n",
"pdims=(4,4)\n",
"folder = \"../out\"\n",
"pdims = (4, 4)\n",
"\n",
"init_field_slices = []\n",
"field_slices = []\n",
......@@ -79,11 +79,11 @@
"for i in range(pdims[0]):\n",
" row_init_field = []\n",
" row_field = []\n",
" \n",
"\n",
" for j in range(pdims[1]):\n",
" slice_index = i * pdims[1] + j \n",
" row_field.append(np.load(f'{folder}/field_{slice_index}.npy'))\n",
" row_init_field.append(np.load(f'{folder}/initial_conditions_{slice_index}.npy'))\n",
" slice_index = i * pdims[1] + j\n",
" row_field.append(np.load(f\"{folder}/field_{slice_index}.npy\"))\n",
" row_init_field.append(np.load(f\"{folder}/initial_conditions_{slice_index}.npy\"))\n",
"\n",
" field_slices.append(np.vstack(row_field))\n",
" init_field_slices.append(np.vstack(row_init_field))\n",
......@@ -148,23 +148,31 @@
" slicing[proj_axis] = slice(None, sum_over)\n",
" slicing = tuple(slicing)\n",
"\n",
"\n",
" fig, axes = plt.subplots(1, 2, figsize=(12, 6))\n",
"\n",
" # Flatten axes for easy indexing\n",
" axes = axes.flatten()\n",
"\n",
" # Plot initial conditions\n",
" axes[0].imshow(initial_conditions[slicing].sum(axis=proj_axis), cmap='magma', extent=[0, box_size + 5, 0, box_size + 5])\n",
" axes[0].set_xlabel('Mpc/h')\n",
" axes[0].set_ylabel('Mpc/h')\n",
" axes[0].set_title('Initial conditions')\n",
" axes[0].imshow(\n",
" initial_conditions[slicing].sum(axis=proj_axis),\n",
" cmap=\"magma\",\n",
" extent=[0, box_size + 5, 0, box_size + 5],\n",
" )\n",
" axes[0].set_xlabel(\"Mpc/h\")\n",
" axes[0].set_ylabel(\"Mpc/h\")\n",
" axes[0].set_title(\"Initial conditions\")\n",
"\n",
" # Plot LPT density field at z=0\n",
" axes[1].imshow(field[slicing].sum(axis=proj_axis), cmap='magma', extent=[0, box_size + 5, 0, box_size + 5])\n",
" axes[1].set_xlabel('Mpc/h')\n",
" axes[1].set_ylabel('Mpc/h')\n",
" axes[1].set_title('LPT density field at z=0')\n",
" axes[1].imshow(\n",
" field[slicing].sum(axis=proj_axis),\n",
" cmap=\"magma\",\n",
" extent=[0, box_size + 5, 0, box_size + 5],\n",
" )\n",
" axes[1].set_xlabel(\"Mpc/h\")\n",
" axes[1].set_ylabel(\"Mpc/h\")\n",
" axes[1].set_title(\"LPT density field at z=0\")\n",
"\n",
"\n",
"for i in range(3):\n",
" plot(i)"
......@@ -201,10 +209,14 @@
"box_size = initial_conditions.shape[0] + 5\n",
"plt.figure(figsize=(20, 20))\n",
"# Generate the plot\n",
"plt.imshow(np.log10(field[:16].sum(axis=proj_axis) + 1), cmap='magma', extent=[0, box_size, 0, box_size])\n",
"plt.xlabel('Mpc/h')\n",
"plt.ylabel('Mpc/h')\n",
"plt.title('LPT density field at z=0')\n",
"plt.imshow(\n",
" np.log10(field[:16].sum(axis=proj_axis) + 1),\n",
" cmap=\"magma\",\n",
" extent=[0, box_size, 0, box_size],\n",
")\n",
"plt.xlabel(\"Mpc/h\")\n",
"plt.ylabel(\"Mpc/h\")\n",
"plt.title(\"LPT density field at z=0\")\n",
"\n",
"# Display the plot\n",
"plt.show()"
......@@ -3,7 +3,7 @@ requires = ["scikit-build-core>=0.4.0", "pybind11>=2.9.0"]
build-backend = "scikit_build_core.build"
[project]
name = "jaxdecomp"
version = "0.2.0"
version = "0.2.4"
description = "JAX bindings for the cuDecomp library"
authors = [
{ name = "Wassim Kabalan" },
......@@ -23,24 +23,49 @@ classifiers = [
"Intended Audience :: Science/Research",
]
dependencies = [
"jaxtyping>=0.2.33",
"jax>=0.4.30",
"jaxtyping>=0.2.0",
"jax>=0.4.35",
"typing-extensions; python_version < '3.11'",
]
[project.optional-dependencies]
test = ["pytest>=8.0.0" , "jax[cpu]>=0.4.30"]
test = ["pytest>=8.0.0"]
[tool.scikit-build]
minimum-version = "0.8"
cmake.version = ">=3.25"
build-dir = "build/{wheel_tag}"
wheel.py-api = "py3"
cmake.build-type = "Release"
wheel.install-dir = "jaxdecomp/_src"
[tool.scikit-build.cmake.define]
CMAKE_LIBRARY_OUTPUT_DIRECTORY = ""
#[tool.cibuildwheel]
#test-extras = "test"
#test-command = "pytest {project}/tests"
[tool.cibuildwheel]
test-extras = "test"
test-command = "pytest {project}/tests"
[tool.ruff]
line-length = 150
src = ["src"]
exclude = ["third_party"]
[tool.ruff.lint]
select = [
# pycodestyle
'E',
# pyflakes
'F',
# pyupgrade
'UP',
# flake8-debugger
'T10',
]
ignore = [
'E402', # module level import not at top of file
'E203',
'E731',
'E701',
'E741',
'E722',
'UP037', # conflicts with jaxtyping Array annotations
]
......@@ -17,11 +17,10 @@ config.halo_comm_backend = jaxdecomp.config.halo_comm_backend
config.transpose_comm_backend = jaxdecomp.config.transpose_comm_backend
# Run autotune
tuned_config = jaxdecomp.get_autotuned_config(config, False, False, True, True,
(32, 32, 32), (True, True, True))
tuned_config = jaxdecomp.get_autotuned_config(config, False, False, True, True, (32, 32, 32), (True, True, True))
if rank == 0:
print(rank, "*** Results of optimization ***")
print(rank, "pdims", tuned_config.pdims)
print(rank, "halo_comm_backend", tuned_config.halo_comm_backend)
print(rank, "transpose_comm_backend", tuned_config.transpose_comm_backend)
print(rank, "*** Results of optimization ***")
print(rank, "pdims", tuned_config.pdims)
print(rank, "halo_comm_backend", tuned_config.halo_comm_backend)
print(rank, "transpose_comm_backend", tuned_config.transpose_comm_backend)
import time
import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
......@@ -13,45 +12,42 @@ jaxdecomp.init()
jax.distributed.initialize()
rank = jax.process_index()
#jaxdecomp.config.update('transpose_comm_backend', jaxdecomp.TRANSPOSE_COMM_MPI_P2P_PL)
# jaxdecomp.config.update('transpose_comm_backend', jaxdecomp.TRANSPOSE_COMM_MPI_P2P_PL)
pdims = (2, 2)
global_shape = (1024, 1024, 1024)
# Initialize a local slice of the global array
array = jax.random.normal(
shape=[
global_shape[0] // pdims[1], global_shape[1] // pdims[0],
global_shape[2]
],
key=jax.random.PRNGKey(0))
shape=[global_shape[0] // pdims[1], global_shape[1] // pdims[0], global_shape[2]],
key=jax.random.PRNGKey(0),
)
# Remap to the global array from the local slice
devices = mesh_utils.create_device_mesh(pdims[::-1])
mesh = Mesh(devices, axis_names=('z', 'y'))
global_array = multihost_utils.host_local_array_to_global_array(
array, mesh, P('z', 'y'))
mesh = Mesh(devices, axis_names=("z", "y"))
global_array = multihost_utils.host_local_array_to_global_array(array, mesh, P("z", "y"))
@jax.jit
def do_fft(x):
return jaxdecomp.fft.pfft3d(x)
return jaxdecomp.fft.pfft3d(x)
with mesh:
do_fft(global_array)
before = time.time()
karray = do_fft(global_array).block_until_ready()
after = time.time()
print(rank, 'took', after - before, 's')
do_fft(global_array)
before = time.time()
karray = do_fft(global_array).block_until_ready()
after = time.time()
print(rank, "took", after - before, "s")
# And now, let's do the inverse FFT
rec_array = jaxdecomp.fft.pifft3d(karray)
# And now, let's do the inverse FFT
rec_array = jaxdecomp.fft.pifft3d(karray)
diff = jax.jit(lambda x, y: abs(x - y).max())(rec_array, global_array)
diff = jax.jit(lambda x, y: abs(x - y).max())(rec_array, global_array)
# Let's test if things are like we expect
if rank == 0:
print('maximum reconstruction difference', diff)
print("maximum reconstruction difference", diff)
jaxdecomp.finalize()
jax.distributed.shutdown()
......@@ -7,7 +7,7 @@ A quick guide on how to run jaxDecomp (this was only tested on Idris's Jean-Zay)
```bash
# if on A100
module load cpuarch/amd
module load arch/a100
# then do this in this exact order
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
# Then create your python env
......@@ -18,14 +18,10 @@ conda deactivate
# Install dependencies
source venv/bin/activate
pip cache purge
# Installing mpi4py
CFLAGS=-noswitcherror pip install --no-cache-dir mpi4py
# Installing jax
pip install --no-cache-dir --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# or
pip install --no-cache-dir --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install jax[cuda]
# Then install jaxDecomp
pip install .
pip install jaxdecomp
```
For an interactive use
......@@ -53,7 +49,7 @@ Make sure to load the exact modules you used when you installed jaxDecomp
```bash
# if on A100
module load cpuarch/amd
module load arch/a100
# then
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
module load python/3.10.4 && conda deactivate
......
from dataclasses import dataclass
from importlib.metadata import PackageNotFoundError, version
from typing import Tuple
from jaxdecomp._src.pencil_utils import get_output_specs
from jaxdecomp._src.pencil_utils import get_output_specs, get_fft_output_sharding
from jaxdecomp.fft import fftfreq3d, pfft3d, pifft3d, rfftfreq3d
from jaxdecomp.halo import halo_exchange
from jaxdecomp.transpose import (transposeXtoY, transposeXtoZ, transposeYtoX,
transposeYtoZ, transposeZtoX, transposeZtoY)
from jaxdecomp.transpose import (
transposeXtoY,
transposeXtoZ,
transposeYtoX,
transposeYtoZ,
transposeZtoX,
transposeZtoY,
)
from ._src import (HALO_COMM_MPI, HALO_COMM_MPI_BLOCKING, HALO_COMM_NCCL,
HALO_COMM_NVSHMEM, HALO_COMM_NVSHMEM_BLOCKING, NO_DECOMP,
PENCILS, SLAB_XY, SLAB_YZ, TRANSPOSE_COMM_MPI_A2A,
TRANSPOSE_COMM_MPI_P2P, TRANSPOSE_COMM_MPI_P2P_PL,
TRANSPOSE_COMM_NCCL, TRANSPOSE_COMM_NCCL_PL,
TRANSPOSE_COMM_NVSHMEM, TRANSPOSE_COMM_NVSHMEM_PL,
TRANSPOSE_XY, TRANSPOSE_YX, TRANSPOSE_YZ, TRANSPOSE_ZY,
HaloCommBackend, TransposeCommBackend, finalize,
get_autotuned_config, get_pencil_info, init, make_config)
from ._src import (
HALO_COMM_MPI,
HALO_COMM_MPI_BLOCKING,
HALO_COMM_NCCL,
HALO_COMM_NVSHMEM,
HALO_COMM_NVSHMEM_BLOCKING,
NO_DECOMP,
PENCILS,
SLAB_XY,
SLAB_YZ,
TRANSPOSE_COMM_MPI_A2A,
TRANSPOSE_COMM_MPI_P2P,
TRANSPOSE_COMM_MPI_P2P_PL,
TRANSPOSE_COMM_NCCL,
TRANSPOSE_COMM_NCCL_PL,
TRANSPOSE_COMM_NVSHMEM,
TRANSPOSE_COMM_NVSHMEM_PL,
TRANSPOSE_XY,
TRANSPOSE_YX,
TRANSPOSE_YZ,
TRANSPOSE_ZY,
HaloCommBackend,
TransposeCommBackend,
finalize,
get_autotuned_config,
get_pencil_info,
init,
make_config,
)
try:
__version__ = version("jaxDecomp")
__version__ = version("jaxDecomp")
except PackageNotFoundError:
# package is not installed
pass
# package is not installed
pass
__all__ = [
"config",
......@@ -48,25 +73,36 @@ __all__ = [
"SLAB_YZ",
"PENCILS",
"NO_DECOMP",
"get_fft_output_sharding",
"get_output_specs",
"fftfreq3d",
"rfftfreq3d",
"HALO_COMM_MPI",
"HALO_COMM_MPI_BLOCKING",
"HALO_COMM_NVSHMEM",
"HALO_COMM_NVSHMEM_BLOCKING",
"TRANSPOSE_COMM_MPI_A2A",
"TRANSPOSE_COMM_MPI_P2P",
"TRANSPOSE_COMM_MPI_P2P_PL",
"TRANSPOSE_COMM_NCCL_PL",
"TRANSPOSE_COMM_NVSHMEM",
"TRANSPOSE_COMM_NVSHMEM_PL",
]
@dataclass
class JAXDecompConfig:
"""Class for storing the configuration state of the library."""
halo_comm_backend: HaloCommBackend = HALO_COMM_NCCL
transpose_comm_backend: TransposeCommBackend = TRANSPOSE_COMM_NCCL
transpose_axis_contiguous: bool = True
transpose_axis_contiguous_2: bool = True
"""Class for storing the configuration state of the library."""
def update(self, key, value):
if hasattr(self, key):
setattr(self, key, value)
else:
raise ValueError("key %s is not a valid configuration key" % key)
halo_comm_backend: HaloCommBackend = HALO_COMM_NCCL
transpose_comm_backend: TransposeCommBackend = TRANSPOSE_COMM_NCCL
transpose_axis_contiguous: bool = True
def update(self, key, value):
if hasattr(self, key):
setattr(self, key, value)
else:
raise ValueError(f"key {key} is not a valid configuration key")
# Declare the global configuration object
......
from jax.lib import xla_client
from . import _jaxdecomp
from jaxdecomplib import _jaxdecomp
init = _jaxdecomp.init
finalize = _jaxdecomp.finalize
......@@ -9,19 +8,34 @@ get_autotuned_config = _jaxdecomp.get_autotuned_config
make_config = _jaxdecomp.GridConfig
# Loading the comm configuration flags at the top level
from ._jaxdecomp import TransposeCommBackend # yapf: disable
from ._jaxdecomp import (HALO_COMM_MPI, HALO_COMM_MPI_BLOCKING, HALO_COMM_NCCL,
HALO_COMM_NVSHMEM, HALO_COMM_NVSHMEM_BLOCKING,
NO_DECOMP, PENCILS, SLAB_XY, SLAB_YZ,
TRANSPOSE_COMM_MPI_A2A, TRANSPOSE_COMM_MPI_P2P,
TRANSPOSE_COMM_MPI_P2P_PL, TRANSPOSE_COMM_NCCL,
TRANSPOSE_COMM_NCCL_PL, TRANSPOSE_COMM_NVSHMEM,
TRANSPOSE_COMM_NVSHMEM_PL, TRANSPOSE_XY, TRANSPOSE_YX,
TRANSPOSE_YZ, TRANSPOSE_ZY, HaloCommBackend)
from jaxdecomplib._jaxdecomp import HaloCommBackend # yapf: disable
from jaxdecomplib._jaxdecomp import TransposeCommBackend # yapf: disable
from jaxdecomplib._jaxdecomp import ( # dummy line to avoid yapf reformatting
HALO_COMM_MPI,
HALO_COMM_MPI_BLOCKING,
HALO_COMM_NCCL,
HALO_COMM_NVSHMEM,
HALO_COMM_NVSHMEM_BLOCKING,
NO_DECOMP,
PENCILS,
SLAB_XY,
SLAB_YZ,
TRANSPOSE_COMM_MPI_A2A,
TRANSPOSE_COMM_MPI_P2P,
TRANSPOSE_COMM_MPI_P2P_PL,
TRANSPOSE_COMM_NCCL,
TRANSPOSE_COMM_NCCL_PL,
TRANSPOSE_COMM_NVSHMEM,
TRANSPOSE_COMM_NVSHMEM_PL,
TRANSPOSE_XY,
TRANSPOSE_YX,
TRANSPOSE_YZ,
TRANSPOSE_ZY,
)
# Registering ops for XLA
for name, fn in _jaxdecomp.registrations().items():
xla_client.register_custom_call_target(name, fn, platform="gpu")
xla_client.register_custom_call_target(name, fn, platform="gpu")
__all__ = [
"init",
......@@ -37,4 +51,18 @@ __all__ = [
"SLAB_YZ",
"PENCILS",
"NO_DECOMP",
"HALO_COMM_MPI",
"HALO_COMM_MPI_BLOCKING",
"HALO_COMM_NCCL",
"HALO_COMM_NVSHMEM",
"HALO_COMM_NVSHMEM_BLOCKING",
"TRANSPOSE_COMM_MPI_A2A",
"TRANSPOSE_COMM_MPI_P2P",
"TRANSPOSE_COMM_MPI_P2P_PL",
"TRANSPOSE_COMM_NCCL",
"TRANSPOSE_COMM_NCCL_PL",
"TRANSPOSE_COMM_NVSHMEM",
"TRANSPOSE_COMM_NVSHMEM_PL",
"HaloCommBackend",
"TransposeCommBackend",
]
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
def error_during_jacfwd(function_name):
raise ValueError(f"""
Input sharding was found to be none while lowering the SPMD rule.
You are likely calling jacfwd with pfft as the first function.
due to a bug in JAX, the sharding is not correctly passed to the SPMD rule.
You need to annotate the sharding before calling {function_name}.
please check the caveat documentation, jacfwd section
""")
def error_during_jacrev(function_name):
raise ValueError(f"""
Input sharding was found to be none while lowering the SPMD rule.
You are likely calling jacrev with pfft as the first function.
due to a bug in JAX, the sharding is not correctly passed to the SPMD rule.
You need to annotate the sharding After calling {function_name}.
please check the caveat documentation, grad section
""")