Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • wassim/jaxDecomp
1 result
Show changes
Commits on Source (90)
Showing
with 580 additions and 945 deletions
name: Build and upload to PyPI
on:
workflow_dispatch:
pull_request:
push:
branches:
- main
release:
types:
- published
jobs:
build_wheels:
name: Build wheels on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
strategy:
matrix:
# macos-13 is an intel runner, macos-14 is apple silicon
os: [ubuntu-latest]
steps:
- uses: actions/checkout@v4
- name: Build wheels
uses: pypa/cibuildwheel@v2.21.3
env:
CIBW_BUILD: "cp310-* cp311-* cp312-*"
CIBW_BUILD_VERBOSITY: 2
- uses: actions/upload-artifact@v4
with:
name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }}
path: ./wheelhouse/*.whl
build_sdist:
name: Build source distribution
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Build sdist
run: pipx run build --sdist
- uses: actions/upload-artifact@v4
with:
name: cibw-sdist
path: dist/*.tar.gz
upload_pypi:
needs: [build_wheels, build_sdist]
runs-on: ubuntu-latest
environment: pypi
permissions:
id-token: write
if: github.event_name == 'release' && github.event.action == 'published'
# or, alternatively, upload to PyPI on every tag starting with 'v' (remove on: release above to use this)
# if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')
steps:
- uses: actions/download-artifact@v4
with:
# unpacks all CIBW artifacts into dist/
pattern: cibw-*
path: dist
merge-multiple: true
- uses: pypa/gh-action-pypi-publish@release/v1
#with:
# repository-url: https://test.pypi.org/legacy/
name: Tests
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.10.4]
steps:
- name: Checkout Source
uses: actions/checkout@v2.3.1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install jax[cpu]
pip install .[test]
- name: Run tests
run: |
cd tests
export JAX_PLATFORM_NAME=cpu
export XLA_FLAGS='--xla_force_host_platform_device_count=8'
pytest -v
[submodule "third_party/cuDecomp"]
path = third_party/cuDecomp
url = https://github.com/NVIDIA/cuDecomp.git
[submodule "pybind11"]
path = pybind11
url = https://github.com/pybind/pybind11.git
......@@ -22,3 +22,7 @@ repos:
files: '\.(c|cc|cpp|h|hpp|cxx|hh|cu|cuh)$'
exclude: '^third_party/|/pybind11/'
name: clang-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.13.0
hooks:
- id: mypy
# Change log
## jaxdecomp 0.2.0
* Changes
* jaxDecomp works without MPI and using only JAX as backend
* with mesh is no longer required (will be deprecated by JAX)
* Added support for fftfreq
* Added testing for all functions
* Added static typing and checked with mypy
## jaxdecomp 0.1.0
* Changes
......
cmake_minimum_required(VERSION 3.19...3.25)
find_program(NVHPC_CXX_BIN "nvc++" REQUIRED)
set(CMAKE_CXX_COMPILER ${NVHPC_CXX_BIN})
find_program(NVHPC_C_BIN "nvc" REQUIRED)
set(CMAKE_C_COMPILER ${NVHPC_C_BIN})
project(jaxdecomp LANGUAGES CXX CUDA)
project(jaxdecomp LANGUAGES CXX)
# NVCC 12 does not support C++20
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
# Latest JAX v0.4.26 no longer supports cuda 11.8
find_package(CUDAToolkit REQUIRED VERSION 12)
set(NVHPC_CUDA_VERSION ${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR})
option(JD_CUDECOMP_BACKEND "Use cuDecomp backend" OFF)
# Set default build type to Release
if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." FORCE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
endif()
message(STATUS "Using CUDA ${NVHPC_CUDA_VERSION}")
# Build Release by default
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build.")
set(PYBIND11_FINDPYTHON ON)
find_package(pybind11 CONFIG REQUIRED)
add_subdirectory(third_party/cuDecomp)
# Check for CUDA
include(CheckLanguage)
check_language(CUDA)
option(CUDECOMP_BUILD_FORTRAN "Build Fortran bindings" OFF)
option(CUDECOMP_ENABLE_NVSHMEM "Enable NVSHMEM" OFF)
option(CUDECOMP_BUILD_EXTRAS "Build benchmark, examples, and tests" OFF)
if(CMAKE_CUDA_COMPILER AND JD_CUDECOMP_BACKEND)
enable_language(CUDA)
# 70: Volta, 80: Ampere, 89: RTX 4060
set(CUDECOMP_CUDA_CC_LIST "70;80;89" CACHE STRING "List of CUDA compute capabilities to build cuDecomp for.")
# Latest JAX v0.4.26 no longer supports cuda 11.8
find_package(CUDAToolkit REQUIRED VERSION 12)
set(NVHPC_CUDA_VERSION ${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR})
# Add pybind11 and cuDecomp subdirectories
add_subdirectory(pybind11)
message(STATUS "Using CUDA ${NVHPC_CUDA_VERSION}")
add_subdirectory(third_party/cuDecomp)
find_package(NVHPC REQUIRED COMPONENTS MATH MPI NCCL)
option(CUDECOMP_BUILD_FORTRAN "Build Fortran bindings" OFF)
option(CUDECOMP_ENABLE_NVSHMEM "Enable NVSHMEM" OFF)
option(CUDECOMP_BUILD_EXTRAS "Build benchmark, examples, and tests" OFF)
string(REPLACE "/lib64" "/include" NVHPC_MATH_INCLUDE_DIR ${NVHPC_MATH_LIBRARY_DIR})
string(REPLACE "/lib64" "/include" NVHPC_CUDA_INCLUDE_DIR ${NVHPC_CUDA_LIBRARY_DIR})
# 70: Volta, 80: Ampere, 89: RTX 4060
set(CUDECOMP_CUDA_CC_LIST "70;80;89" CACHE STRING "List of CUDA compute capabilities to build cuDecomp for.")
find_package(NVHPC REQUIRED COMPONENTS MATH MPI NCCL)
find_library(NCCL_LIBRARY
NAMES nccl
HINTS ${NVHPC_NCCL_LIBRARY_DIR}
string(REPLACE "/lib64" "/include" NVHPC_MATH_INCLUDE_DIR ${NVHPC_MATH_LIBRARY_DIR})
string(REPLACE "/lib64" "/include" NVHPC_CUDA_INCLUDE_DIR ${NVHPC_CUDA_LIBRARY_DIR})
find_library(NCCL_LIBRARY
NAMES nccl
HINTS ${NVHPC_NCCL_LIBRARY_DIR}
)
string(REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR})
message(STATUS "Using NCCL library: ${NCCL_LIBRARY}")
message(STATUS "NVHPC NCCL lib dir: ${NVHPC_NCCL_LIBRARY_DIR}")
message(STATUS "NCCL include dir: ${NCCL_INCLUDE_DIR}")
# Add _jaxdecomp modulei
pybind11_add_module(_jaxdecomp
src/csrc/halo.cu
src/csrc/jaxdecomp.cc
src/csrc/grid_descriptor_mgr.cc
src/csrc/fft.cu
src/csrc/transpose.cu
)
string(REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR})
message(STATUS "Using NCCL library: ${NCCL_LIBRARY}")
message(STATUS "NVHPC NCCL lib dir: ${NVHPC_NCCL_LIBRARY_DIR}")
message(STATUS "NCCL include dir: ${NCCL_INCLUDE_DIR}")
# Add _jaxdecomp modulei
pybind11_add_module(_jaxdecomp
src/halo.cu
src/jaxdecomp.cc
src/grid_descriptor_mgr.cc
src/fft.cu
src/transpose.cu
)
set_target_properties(_jaxdecomp PROPERTIES CUDA_ARCHITECTURES "${CUDECOMP_CUDA_CC_LIST}")
target_include_directories(_jaxdecomp
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/third_party/cuDecomp/include
${NVHPC_CUDA_INCLUDE_DIR}
${MPI_CXX_INCLUDE_DIRS}
${NVHPC_MATH_INCLUDE_DIR}
${NCCL_INCLUDE_DIR}
)
target_link_libraries(_jaxdecomp PRIVATE MPI::MPI_CXX)
target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUFFT)
target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUTENSOR)
target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUDA)
target_link_libraries(_jaxdecomp PRIVATE ${NCCL_LIBRARY})
target_link_libraries(_jaxdecomp PRIVATE cudecomp)
target_link_libraries(_jaxdecomp PRIVATE stdc++fs)
set_target_properties(_jaxdecomp PROPERTIES LINKER_LANGUAGE CXX)
set_target_properties(_jaxdecomp PROPERTIES INSTALL_RPATH "$ORIGIN/lib")
set_target_properties(_jaxdecomp PROPERTIES CUDA_ARCHITECTURES "${CUDECOMP_CUDA_CC_LIST}")
target_include_directories(_jaxdecomp
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/src/csrc/include
${CMAKE_CURRENT_SOURCE_DIR}/third_party/cuDecomp/include
${NVHPC_CUDA_INCLUDE_DIR}
${MPI_CXX_INCLUDE_DIRS}
${NVHPC_MATH_INCLUDE_DIR}
${NCCL_INCLUDE_DIR}
)
target_link_libraries(_jaxdecomp PRIVATE MPI::MPI_CXX)
target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUFFT)
target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUTENSOR)
target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUDA)
target_link_libraries(_jaxdecomp PRIVATE ${NCCL_LIBRARY})
target_link_libraries(_jaxdecomp PRIVATE cudecomp)
target_link_libraries(_jaxdecomp PRIVATE stdc++fs)
set_target_properties(_jaxdecomp PROPERTIES LINKER_LANGUAGE CXX)
target_compile_definitions(_jaxdecomp PRIVATE JD_CUDECOMP_BACKEND)
else()
pybind11_add_module(_jaxdecomp src/csrc/jaxdecomp.cc)
target_include_directories(_jaxdecomp PRIVATE ${CMAKE_CURRENT_LIST_DIR}/src/csrc/include)
target_compile_definitions(_jaxdecomp PRIVATE JD_JAX_BACKEND)
endif()
set_target_properties(_jaxdecomp PROPERTIES INSTALL_RPATH "$ORIGIN/lib")
install(TARGETS _jaxdecomp LIBRARY DESTINATION . PUBLIC_HEADER DESTINATION .)
# jaxDecomp: JAX Library for 3D Domain Decomposition and Parallel FFTs
[![Code Formatting](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/actions/workflows/formatting.yml/badge.svg)](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/actions/workflows/formatting.yml)
[![Tests](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/actions/workflows/tests.yml/badge.svg)](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/actions/workflows/tests.yml/badge.svg)
[![MIT License](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
> [!IMPORTANT]
> Version `0.2.0` has a pure JAX backend and no longer requires MPI .. MPI and NCCL backends are still available through cuDecomp
JAX bindings for NVIDIA's [cuDecomp](https://nvidia.github.io/cuDecomp/index.html) library [(Romero et al. 2022)](https://dl.acm.org/doi/abs/10.1145/3539781.3539797), allowing for efficient **multi-node parallel FFTs and halo exchanges** directly in low level NCCL/CUDA-Aware MPI from your JAX code :tada:
......@@ -10,61 +17,83 @@ Here is an example of how to use `jaxDecomp` to perform a 3D FFT on a 3D array d
```python
import jax
from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax import numpy as jnp
import jaxdecomp
# Initialise the library, and optionally selects a communication backend (defaults to NCCL)
jaxdecomp.config.update('halo_comm_backend', jaxdecomp.HALO_COMM_MPI)
jaxdecomp.config.update('transpose_comm_backend', jaxdecomp.TRANSPOSE_COMM_MPI_A2A)
from functools import partial
# Initialize jax distributed to instruct jax local process which GPU to use
jax.distributed.initialize()
rank = jax.process_index()
# Setup a processor mesh (should be same size as "size")
pdims= (1,4)
global_shape=[1024,1024,1024]
pdims = (2, 4)
global_shape = (1024, 1024, 1024)
# Initialize an array with the expected gobal size
local_array = jax.random.normal(
shape=[
global_shape[0] // pdims[1], global_shape[1] // pdims[0],
global_shape[2]
],
key=jax.random.PRNGKey(rank))
# Remap to the global array from the local slice
devices = mesh_utils.create_device_mesh(pdims[::-1])
mesh = Mesh(devices, axis_names=('z', 'y'))
global_array = multihost_utils.host_local_array_to_global_array(
local_array, mesh, P('z', 'y'))
# Forward FFT, note that the output FFT is transposed
local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0],
global_shape[2])
# Remap to the global array from the local slice
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
global_array = jax.make_array_from_callback(
global_shape,
sharding,
data_callback=lambda _: jax.random.normal(
jax.random.PRNGKey(rank), local_shape))
padding_width = ((32, 32), (32, 32), (0, 0)) # Has to a tuple of tuples
@partial(
shard_map, mesh=mesh, in_specs=(P('x', 'y'), P()), out_specs=P('x', 'y'))
def pad(arr, padding):
return jnp.pad(arr, padding)
@partial(
shard_map, mesh=mesh, in_specs=(P('x', 'y'), P()), out_specs=P('x', 'y'))
def reduce_halo(x, pad_width):
halo_x , _ = pad_width[0]
halo_y , _ = pad_width[1]
# Apply corrections along x
x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2])
x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:])
# Apply corrections along y
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
return x[halo_x:-halo_x, halo_y:-halo_y]
@jax.jit
def modify_array(array):
return 2 * array + 1
with mesh:
# Forward FFT
karray = jaxdecomp.fft.pfft3d(global_array)
# Do some operation on your array
karray = modify_array(karray)
# Reverse FFT
recarray = jaxdecomp.fft.pifft3d(karray).astype('float32')
# Add halo regions to our array
padding_width = ((32,32),(32,32),(32,32)) # Has to a tuple of tuples
padded_array = jaxdecomp.slice_pad(recarray, padding_width , pdims)
# Perform a halo exchange
exchanged_array = jaxdecomp.halo_exchange(padded_array,
halo_extents=(32,32,32),
halo_periods=(True,True,True))
# Remove the halo regions
recarray = jaxdecomp.slice_unpad(exchanged_array, padding_width, pdims)
# Gather the results (only if it fits on CPU memory)
gathered_array = multihost_utils.process_allgather(recarray, tiled=True)
# Finalize the library
return 2 * array + 1
# Forward FFT
karray = jaxdecomp.fft.pfft3d(global_array)
# Do some operation on your array
karray = modify_array(karray)
kvec = jaxdecomp.fft.fftfreq3d(karray)
# Do a gradient in the X axis
karray_gradient = 1j * kvec[0] * karray
# Reverse FFT
recarray = jaxdecomp.fft.pifft3d(karray_gradient).real
# Add halo regions to our array
padded_array = pad(recarray, padding_width)
# Perform a halo exchange
exchanged_array = jaxdecomp.halo_exchange(
padded_array, halo_extents=(16, 16), halo_periods=(True, True))
# Reduce the halo regions and remove the padding
reduced_array = reduce_halo(exchanged_array, padding_width)
# Gather the results (only if it fits on CPU memory)
gathered_array = multihost_utils.process_allgather(recarray, tiled=True)
# Finalize the distributed JAX
jax.distributed.shutdown()
```
**Note**: All these functions are jittable and have well defined derivatives!
......@@ -72,8 +101,33 @@ jax.distributed.shutdown()
This script would have to be run on 8 GPUs in total with something like
```bash
$ mpirun -n 8 python demo.py
mpirun -n 8 python demo.py
```
or on a slurm cluster like Jean jean-zay
```bash
srun -n 8 python demo.py
```
## Using cuDecomp (MPI and NCCL)
You can also use the cuDecomp backend by compiling the library with the right flag (check the [installation instructions](#install)) and setting the backend to use MPI or NCCL. Here is how you can do it:
```python
import jaxdecomp
# Initialise the library, and optionally selects a communication backend (defaults to NCCL)
jaxdecomp.config.update('halo_comm_backend', jaxdecomp.HALO_COMM_MPI)
jaxdecomp.config.update('transpose_comm_backend', jaxdecomp.TRANSPOSE_COMM_MPI_A2A)
# and then call the functions with the cuDecomp backends
karray = jaxdecomp.fft.pfft3d(global_array , backend='cudecomp')
recarray = jaxdecomp.fft.pifft3d(karray , backend='cudecomp')
exchanged_array = jaxdecomp.halo_exchange(
padded_array, halo_extents=(16, 16), halo_periods=(True, True), backend='cudecomp')
```
please check the tests in [tests](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/tree/main/tests) folder for more examples.
On an HPC cluster like Jean Zay you should do this
......@@ -83,29 +137,40 @@ $ srun python demo.py
Check the slurm [README](slurms/README.md) and [template](slurms/template.slurm) for more information on how to run on a Jean Zay.
### Caveats
## Install
The code presented above should work, but there are a few caveats mentioned in [this issue](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/issues/1). If you need a functionality that is not currently implemented, feel free to mention it on that issue.
### Installing the pure JAX version (Easy)
## Install
jaxDecomp is available on pypi and can be installed via pip:
Start by cloning this repository locally on your cluster:
First install desired JAX version
For GPU
```bash
pip install -U jax[cuda12]
```
For CPU
```bash
$ git clone --recurse-submodules https://github.com/DifferentiableUniverseInitiative/jaxDecomp
pip install -U jax[cpu]
```
Then you can pip install jaxdecomp
#### Requirements
```bash
pip install jaxdecomp
```
This install procedure assumes that the [NVIDIA HPC SDK](https://developer.nvidia.com/hpc-sdk) is available in your environment. You can either install it from the NVIDIA website, or better yet, it may be available as a module on your cluster.
### Installing JAX and cuDecomp (Advanced)
You need to install from this github after installing or loading the correct modules
Make sure all environment variables relative to the SDK are properly set.
This install procedure assumes that the [NVIDIA HPC SDK](https://developer.nvidia.com/hpc-sdk) is available in your environment. You can either install it from the NVIDIA website, or better yet, it may be available as a module on your cluster.
### Building jaxDecomp
From this directory, install & build jaxDecomp via pip
```bash
$ pip install --user .
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -CCmake.define.JD_CUDECOMP_BACKEND=ON
```
If CMake complains of not finding the NVHPC SDK, you can manually specify the location
of the sdk's cmake files like so:
```
......@@ -117,7 +182,7 @@ $ pip install --user .
#### IDRIS [Jean Zay](http://www.idris.fr/eng/jean-zay/cpu/jean-zay-cpu-hw-eng.html) HPE SGI 8600 supercomputer
As of September. 2024, the following works:
As of October. 2024, the following works:
You need to load modules in that order exactly.
```bash
......@@ -127,69 +192,26 @@ module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5
pip install --upgrade "jax[cuda12]"
# Installing jaxdecomp
export CMAKE_PREFIX_PATH=$NVHPC_ROOT/cmake # Not always needed
pip install .
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -CCmake.define.JD_CUDECOMP_BACKEND=ON
```
__Note__: This is needed **only** if you want to use the cuDecomp backend. If you are using the pure JAX backend, you can skip the NVHPC SDK installation and just `pip install jaxdecomp` **after** installing the correct JAX version for your hardware.
#### NERSC [Perlmutter](https://docs.nersc.gov/systems/perlmutter/architecture/) HPE Cray EX supercomputer
As of Nov. 2022, the following works:
```bash
module load PrgEnv-nvhpc python
export CRAY_ACCEL_TARGET=nvidia80
# Installing mpi4py
MPICC="cc -target-accel=nvidia80 -shared" CC=nvc CFLAGS="-noswitcherror" pip install --force --no-cache-dir --no-binary=mpi4py mpi4py
# Installing jax
pip install --upgrade "jax[cuda12]"
# Installing jaxdecomp
export CMAKE_PREFIX_PATH=/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cmake
pip install .
```
## Design
Here is what works now :
```python
from jaxdecomp.fft import pfft3, ipfft3
import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils, multihost_utils
# Initialize jax distributed to instruct jax local process which GPU to use
jax.distributed.initialize()
pdims = (2 , 4)
global_shape = (512 , 512 , 512 )
local_array = jax.random.normal(shape=[global_shape[0]//pdims[0],
global_shape[1]//pdims[1],
global_shape[2]], key=jax.random.PRNGKey(0))
# remap to global array (this is a free call no communications are happening)
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('z', 'y'))
global_array = multihost_utils.host_local_array_to_global_array(
array, mesh, P('z', 'y'))
with mesh
z = pfft3(global_array)
# If we could inspect the distribution of y, we would see that it is sliced in 2 along x, and 4 along y
# This could also be part of a jitted function, no problem
z_rec = ipfft3(z)
# And z remains at all times distributed.
jaxdecomp.finalize()
jax.distributed.shutdown()
pip install git+https://github.com/DifferentiableUniverseInitiative/jaxDecomp -CCmake.define.JD_CUDECOMP_BACKEND=ON
```
## Backend configuration (Only for cuDecomp)
#### Backend configuration
__Note__: For the JAX backend, only NCCL is available.
We can set the default communication backend to use for cuDecomp operations either through a `config` module, or environment variables. This will allow the users to choose at startup (although can be changed afterwards) the communication backend, making it possible to use CUDA-aware MPI or NVSHMEM as preferred.
......@@ -205,7 +227,7 @@ jaxdecomp.config.update('transpose_comm_backend', 'MPI')
%timeit pfft3(y)
```
#### Autotune computational mesh
## Autotune computational mesh (Only for cuDecomp)
We can also make things fancier, since cuDecomp is able to autotune, we could use it to tell us what is the best way to partition the data given the available GPUs, something like this:
```python
......
......@@ -41,7 +41,7 @@ def _global_to_local_size(nc: int):
return [nc // pdims[0], nc // pdims[1], nc]
def fttk(nc: int) -> list:
def fttk(nc: int):
"""
Generate Fourier transform wave numbers for a given mesh.
......
from functools import partial
from typing import Tuple
import jax.numpy as jnp
from jax import jit
from jax._src.api import ShapeDtypeStruct
from jax._src.core import ShapedArray
from jax._src.typing import Array, ArrayLike
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxdecomp._src.spmd_ops import CustomParPrimitive, register_primitive
class SlicePaddingPrimitive(CustomParPrimitive):
"""
Custom primitive for slice padding operation.
Attributes
----------
name : str
The name of the primitive operation.
multiple_results : bool
Whether the operation produces multiple results.
impl_static_args : tuple
Static arguments for the implementation.
outer_pritimive : object
Outer primitive used for the operation.
"""
name: str = "slice_pad"
multiple_results: bool = False
impl_static_args: Tuple[int, int, int] = (1, 2, 3)
outer_pritimive: object = None
# Global array implementation is used purely for its abstract eval
# at jit time, the shape of the global array output is infered from this function
# it is padding_width * pdims since we pad each slice
# the pdim is needing for the global array and not the slice
# This function is lowered, but never executed
@staticmethod
def impl(arr: ArrayLike,
padding_width: int | tuple[int],
pdims: tuple[int],
mode: str = 'constant') -> Array:
"""
Implementation of the slice padding operation.
Parameters
----------
arr : ArrayLike
Input array to be padded.
padding_width : int | tuple[int]
Width of padding to apply.
pdims : tuple[int]
Dimensions for padding.
mode : str, optional
Padding mode ('constant' by default).
Returns
-------
Array
Padded array.
"""
assert arr.ndim == 3, "Only 3D arrays are supported"
assert len(pdims) == 2, "Only 2D pdims are supported"
# If padding width is an integer then unpad the entire array
if isinstance(padding_width, int):
padding_width = ((padding_width, padding_width),) * arr.ndim
elif isinstance(padding_width, tuple):
# TODO(wassim) : support single value padding width (low and high are the same)
# Padding width (if more than one) has to be equal to the number of dimensions
assert len(padding_width) == arr.ndim
padding_width = padding_width
for dim, local_padding in enumerate(padding_width):
if isinstance(local_padding, int):
first, last = local_padding, local_padding
elif isinstance(local_padding, tuple):
first, last = local_padding
else:
raise ValueError(
"Padding width must be an integer or a tuple of integers")
if first == 0 and last == 0:
continue
match dim:
# X dimension
case 0:
slices = jnp.array_split(arr, pdims[1], axis=0)
arr = jnp.concatenate([
jnp.pad(s, ((first, last), (0, 0), (0, 0)), mode=mode)
for s in slices
],
axis=0)
case 1:
slices = jnp.array_split(arr, pdims[0], axis=1)
arr = jnp.concatenate([
jnp.pad(s, ((0, 0), (first, last), (0, 0)), mode=mode)
for s in slices
],
axis=1)
case 2:
# no distributed padding in the z dimension
arr = jnp.pad(arr, ((0, 0), (0, 0), (first, last)), mode=mode)
case _:
raise ValueError("Only 3D arrays are supported")
return arr
@staticmethod
def per_shard_impl(arr: ArrayLike,
padding_width: int | tuple[int],
mode: str = 'constant') -> Array:
"""
Per-shard implementation of the slice padding operation.
Parameters
----------
arr : ArrayLike
Input array to be padded.
padding_width : int | tuple[int]
Width of padding to apply.
mode : str, optional
Padding mode ('constant' by default).
Returns
-------
Array
Padded array.
"""
return jnp.pad(arr, padding_width, mode=mode)
@staticmethod
def infer_sharding_from_operands(padding_width: int | tuple[int],
pdims: tuple[int], mode: str, mesh: Mesh,
arg_infos: Tuple[ShapeDtypeStruct],
result_infos: Tuple[ShapedArray]):
"""
Infers sharding information from operands for the slice padding operation.
Parameters
----------
padding_width : int | tuple[int]
Width of padding to apply.
pdims : tuple[int]
Dimensions for padding.
mode : str
Padding mode.
mesh : Mesh
Computational mesh.
arg_infos : Tuple[ShapeDtypeStruct]
Information about operands.
result_infos : Tuple[ShapedArray]
Information about results.
Returns
-------
NamedSharding
Sharding information.
"""
input_sharding = arg_infos[0].sharding
return NamedSharding(input_sharding.mesh, P(*input_sharding.spec))
@staticmethod
def partition(padding_width: int | tuple[int], pdims: tuple[int], mode: str,
mesh: Mesh, arg_infos: Tuple[ShapeDtypeStruct],
result_infos: Tuple[ShapedArray]):
"""
Partitions the slice padding operation across a computational mesh.
Parameters
----------
padding_width : int | tuple[int]
Width of padding to apply.
pdims : tuple[int]
Dimensions for padding.
mode : str
Padding mode.
mesh : Mesh
Computational mesh.
arg_infos : Tuple[ShapeDtypeStruct]
Information about operands.
result_infos : Tuple[ShapedArray]
Information about results.
Returns
-------
Tuple
Mesh, implementation, output sharding, and input sharding.
"""
input_sharding = NamedSharding(mesh, P(*arg_infos[0].sharding.spec))
output_sharding = NamedSharding(mesh, P(*result_infos.sharding.spec))
impl = partial(
SlicePaddingPrimitive.per_shard_impl,
padding_width=padding_width,
mode=mode)
return mesh, impl, output_sharding, (input_sharding,)
register_primitive(SlicePaddingPrimitive)
@partial(jit, static_argnums=(1, 2, 3))
def slice_pad(x: ArrayLike,
padding_width: int | tuple[int],
pdims: tuple[int],
mode: str = 'constant') -> Array:
"""
JIT-compiled function for slice padding operation.
Parameters
----------
x : ArrayLike
Input array to be padded.
padding_width : int | tuple[int]
Width of padding to apply.
pdims : tuple[int]
Dimensions for padding.
mode : str, optional
Padding mode ('constant' by default).
Returns
-------
Array
Padded array.
"""
return SlicePaddingPrimitive.outer_lowering(x, padding_width, pdims, mode)
class SliceUnPaddingPrimitive(CustomParPrimitive):
"""
Custom primitive for slice unpading operation.
Attributes
----------
name : str
The name of the primitive operation.
multiple_results : bool
Whether the operation produces multiple results.
impl_static_args : tuple
Static arguments for the implementation.
outer_pritimive : object
Outer primitive used for the operation.
"""
name: str = "slice_unpad"
multiple_results: bool = False
impl_static_args: Tuple[int, int] = (1, 2)
outer_pritimive: object = None
@staticmethod
def impl(arr: ArrayLike, padding_width: int | tuple[int],
pdims: tuple[int]) -> Array:
"""
Implementation of the slice unpading operation.
Parameters
----------
arr : ArrayLike
Input array to be unpadded.
padding_width : int | tuple[int]
Width of padding to remove.
pdims : tuple[int]
Dimensions for unpadding.
Returns
-------
Array
Unpadded array.
"""
if isinstance(padding_width, int):
unpadding_width = ((padding_width, padding_width),) * arr.ndim
elif isinstance(padding_width, tuple):
# Unpadding width (if more than one) has to be equal to the number of dimensions
assert len(padding_width) == arr.ndim
unpadding_width = padding_width
for dim, local_padding in enumerate(padding_width):
if isinstance(local_padding, int):
first, last = local_padding, local_padding
elif isinstance(local_padding, tuple):
first, last = local_padding
else:
raise ValueError(
"Padding width must be an integer or a tuple of integers")
if first == 0 and last == 0:
continue
match dim:
# X dimension
case 0:
slices = jnp.array_split(arr, pdims[1], axis=0)
arr = jnp.concatenate([arr[first:-last] for arr in slices], axis=0)
case 1:
slices = jnp.array_split(arr, pdims[0], axis=1)
arr = jnp.concatenate([arr[:, first:-last] for arr in slices], axis=1)
case 2:
# no distributed padding in the z dimension
arr = arr[:, :, first:-last]
case _:
raise ValueError("Only 3D arrays are supported")
return arr
@staticmethod
def per_shard_impl(arr: ArrayLike, padding_width: int | tuple[int]) -> Array:
"""
Per-shard implementation of the slice unpading operation.
Parameters
----------
arr : ArrayLike
Input array to be unpadded.
padding_width : int | tuple[int]
Width of padding to remove.
Returns
-------
Array
Unpadded array.
"""
if isinstance(padding_width, int):
unpadding_width = ((padding_width, padding_width),) * arr.ndim
elif isinstance(padding_width, tuple):
# Unpadding width (if more than one) has to be equal to the number of dimensions
assert len(padding_width) == arr.ndim
unpadding_width = padding_width
first_x, last_x = unpadding_width[0]
first_y, last_y = unpadding_width[1]
first_z, last_z = unpadding_width[2]
last_x = arr.shape[0] - last_x
last_y = arr.shape[1] - last_y
last_z = arr.shape[2] - last_z
return arr[first_x:last_x, first_y:last_y, first_z:last_z]
@staticmethod
def infer_sharding_from_operands(padding_width: int | tuple[int],
pdims: tuple[int], mesh: Mesh,
arg_infos: Tuple[ShapeDtypeStruct],
result_infos: Tuple[ShapedArray]):
"""
Infers sharding information from operands for the slice unpading operation.
Parameters
----------
padding_width : int | tuple[int]
Width of padding to remove.
pdims : tuple[int]
Dimensions for unpadding.
mesh : Mesh
Computational mesh.
arg_infos : Tuple[ShapeDtypeStruct]
Information about operands.
result_infos : Tuple[ShapedArray]
Information about results.
Returns
-------
NamedSharding
Sharding information.
"""
input_sharding = arg_infos[0].sharding
return NamedSharding(input_sharding.mesh, P(*input_sharding.spec))
@staticmethod
def partition(padding_width: int | tuple[int], pdims: tuple[int], mesh: Mesh,
arg_infos: Tuple[ShapeDtypeStruct],
result_infos: Tuple[ShapedArray]):
"""
Partitions the slice unpading operation across a computational mesh.
Parameters
----------
padding_width : int | tuple[int]
Width of padding to remove.
pdims : tuple[int]
Dimensions for unpadding.
mesh : Mesh
Computational mesh.
arg_infos : Tuple[ShapeDtypeStruct]
Information about operands.
result_infos : Tuple[ShapedArray]
Information about results.
Returns
-------
Tuple
Mesh, implementation, output sharding, and input sharding.
"""
input_sharding = NamedSharding(mesh, P(*arg_infos[0].sharding.spec))
output_sharding = NamedSharding(mesh, P(*result_infos.sharding.spec))
impl = partial(
SliceUnPaddingPrimitive.per_shard_impl, padding_width=padding_width)
return mesh, impl, output_sharding, (input_sharding,)
register_primitive(SliceUnPaddingPrimitive)
@partial(jit, static_argnums=(1, 2))
def slice_unpad(arr: ArrayLike, padding_width: int | tuple[int],
pdims: tuple[int]) -> Array:
"""
JIT-compiled function for slice unpading operation.
Parameters
----------
arr : ArrayLike
Input array to be unpadded.
padding_width : int | tuple[int]
Width of padding to remove.
pdims : tuple[int]
Dimensions for unpadding.
Returns
-------
Array
Unpadded array.
"""
return SliceUnPaddingPrimitive.outer_lowering(arr, padding_width, pdims)
from functools import partial
from typing import Optional, Sequence
import jax.numpy as jnp
from jax import jit
from jax._src.typing import Array, ArrayLike
from jax.lib import xla_client
from jaxdecomp._src import pfft as _pfft
Shape = Sequence[int]
__all__ = [
"pfft3d",
"pifft3d",
]
def _fft_norm(s: Array, func_name: str, norm: str) -> Array:
"""
Compute the normalization factor for FFT operations.
Parameters
----------
s : Array
Shape of the input array.
func_name : str
Name of the FFT function ("fft" or "ifft").
norm : str
Type of normalization ("backward", "ortho", or "forward").
Returns
-------
Array
Normalization factor.
Raises
------
ValueError
If an invalid norm value is provided.
"""
if norm == "backward":
return 1 / jnp.prod(s) if func_name.startswith("i") else jnp.array(1)
elif norm == "ortho":
return (1 / jnp.sqrt(jnp.prod(s)) if func_name.startswith("i") else 1 /
jnp.sqrt(jnp.prod(s)))
elif norm == "forward":
return jnp.prod(s) if func_name.startswith("i") else 1 / jnp.prod(s)**2
raise ValueError(f'Invalid norm value {norm}; should be "backward",'
'"ortho" or "forward".')
# Has to be jitted here because _fft_norm will act on non fully addressable global array
# Which means this should be jit wrapped
@partial(jit, static_argnums=(0, 1, 3))
def _do_pfft(
func_name: str,
fft_type: xla_client.FftType,
arr: ArrayLike,
norm: Optional[str],
) -> Array:
"""
Perform 3D FFT or inverse 3D FFT on the input array.
Parameters
----------
func_name : str
Name of the FFT function ("fft" or "ifft").
fft_type : xla_client.FftType
Type of FFT operation.
arr : ArrayLike
Input array to transform.
norm : Optional[str]
Type of normalization ("backward", "ortho", or "forward").
Returns
-------
Array
Transformed array after FFT or inverse FFT.
"""
transformed = _pfft(arr, fft_type)
transformed *= _fft_norm(
jnp.array(arr.shape, dtype=transformed.dtype), func_name, norm)
return transformed
def pfft3d(a: ArrayLike, norm: Optional[str] = "backward") -> Array:
"""
Perform 3D FFT on the input array.
Parameters
----------
a : ArrayLike
Input array to transform.
norm : Optional[str], optional
Type of normalization ("backward", "ortho", or "forward"), by default "backward".
Returns
-------
Array
Transformed array after 3D FFT.
"""
return _do_pfft("fft", xla_client.FftType.FFT, a, norm=norm)
def pifft3d(a: ArrayLike, norm: Optional[str] = "backward") -> Array:
"""
Perform inverse 3D FFT on the input array.
Parameters
----------
a : ArrayLike
Input array to transform.
norm : Optional[str], optional
Type of normalization ("backward", "ortho", or "forward"), by default "backward".
Returns
-------
Array
Transformed array after inverse 3D FFT.
"""
return _do_pfft("ifft", xla_client.FftType.IFFT, a, norm=norm)
Subproject commit 8b48ff878c168b51fe5ef7b8c728815b9e1a9857
[build-system]
requires = [ "scikit-build-core","pybind11"]
requires = ["scikit-build-core>=0.4.0", "pybind11>=2.9.0"]
build-backend = "scikit_build_core.build"
[project]
name = "jaxdecomp"
version = "0.1.0"
version = "0.2.0"
description = "JAX bindings for the cuDecomp library"
authors = [
{ name = "Wassim Kabalan" },
......@@ -15,13 +14,21 @@ readme = "README.md"
license = { file = "LICENSE" }
classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent"
"Operating System :: OS Independent",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
]
dependencies = [
"jaxtyping>=0.2.33",
"jax>=0.4.30",
]
dependencies = []
[project.optional-dependencies]
test = ["pytest"]
test = ["pytest>=8.0.0" , "jax[cpu]>=0.4.30"]
[tool.scikit-build]
minimum-version = "0.8"
......@@ -29,8 +36,11 @@ cmake.version = ">=3.25"
build-dir = "build/{wheel_tag}"
wheel.py-api = "py3"
cmake.build-type = "Release"
# Add any additional configurations for scikit-build if necessary
wheel.install-dir = "jaxdecomp/_src"
[tool.scikit-build.cmake.define]
CMAKE_LIBRARY_OUTPUT_DIRECTORY = ""
#[tool.cibuildwheel]
#test-extras = "test"
#test-command = "pytest {project}/tests"
This diff is collapsed.
......@@ -40,7 +40,7 @@ HRESULT GridDescriptorManager::createFFTExecutor(fftDescriptor& descriptor, size
}
if (hr == E_FAIL) {
hr = executor->Initialize(m_Handle, descriptor.config, work_size, descriptor);
hr = executor->Initialize(m_Handle, work_size, descriptor);
if (SUCCEEDED(hr)) { m_Descriptors64[descriptor] = executor; }
}
return hr;
......@@ -59,7 +59,7 @@ HRESULT GridDescriptorManager::createFFTExecutor(fftDescriptor& descriptor, size
}
if (hr == E_FAIL) {
hr = executor->Initialize(m_Handle, descriptor.config, work_size, descriptor);
hr = executor->Initialize(m_Handle, work_size, descriptor);
if (SUCCEEDED(hr)) { m_Descriptors32[descriptor] = executor; }
}
return hr;
......
File moved
File moved
......@@ -34,9 +34,7 @@ static Decomposition GetDecomposition(const int pdims[2]) {
} else if (pdims[0] > 1 && pdims[1] > 1) {
return Decomposition::pencil;
}
// Return pencils on one devices for testing
return Decomposition::pencil;
// return Decomposition::no_decomp;
return Decomposition::no_decomp;
}
// fftDescriptor hash should be triavially computable
......@@ -44,6 +42,7 @@ static Decomposition GetDecomposition(const int pdims[2]) {
class fftDescriptor {
public:
bool adjoint = false;
bool contiguous = true;
bool forward = true; ///< forward or backward pass
// fft_is_forward_pass and forwad are used for the Execution but not for the
// hash This way IFFT and FFT have the same plans when operating with the same
......@@ -63,7 +62,7 @@ public:
// created in the bottom of the file)
bool operator==(const fftDescriptor& other) const {
if (double_precision != other.double_precision || gdims[0] != other.gdims[0] || gdims[1] != other.gdims[1] ||
gdims[2] != other.gdims[2] || decomposition != other.decomposition) {
gdims[2] != other.gdims[2] || decomposition != other.decomposition || contiguous != other.contiguous) {
return false;
}
return true;
......@@ -74,14 +73,15 @@ public:
// this is used for subsequent ffts to find the Executor that was already
// defined
fftDescriptor(cudecompGridDescConfig_t& config, const bool& double_precision, const bool& iForward,
const bool& iAdjoint, const Decomposition& iDecomposition)
: double_precision(double_precision), config(config) {
const bool& iAdjoint, const bool& iContiguous, const Decomposition& iDecomposition)
: double_precision(double_precision), config(config), forward(iForward), contiguous(iContiguous),
adjoint(iAdjoint), decomposition(iDecomposition) {
gdims[0] = config.gdims[0];
gdims[1] = config.gdims[1];
gdims[2] = config.gdims[2];
forward = iForward;
adjoint = iAdjoint;
decomposition = iDecomposition;
this->config.transpose_axis_contiguous[0] = iContiguous;
this->config.transpose_axis_contiguous[1] = iContiguous;
this->config.transpose_axis_contiguous[2] = iContiguous;
}
};
......@@ -96,8 +96,7 @@ public:
FourierExecutor() : m_Tracer("JAXDECOMP") {}
~FourierExecutor();
HRESULT Initialize(cudecompHandle_t handle, cudecompGridDescConfig_t config, size_t& work_size,
fftDescriptor& fft_descriptor);
HRESULT Initialize(cudecompHandle_t handle, size_t& work_size, fftDescriptor& fft_descriptor);
HRESULT forward(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, void** buffers);
......@@ -114,6 +113,9 @@ private:
cudecompGridDesc_t m_GridConfig;
cudecompGridDescConfig_t m_GridDescConfig;
cudecompPencilInfo_t m_XPencilInfo;
cudecompPencilInfo_t m_YPencilInfo;
cudecompPencilInfo_t m_ZPencilInfo;
// For the sake of expressive code, plans have the name of their corresponding
// goal Instead of reusing pencils plans for slabs, or even ZY to YZ we store
// properly named plans
......@@ -126,21 +128,16 @@ private:
// so in the end it is FFT(X) FFT(YZ)
// For Slabs XZ FFT (X) FFT(YZ)
cufftHandle m_Plan_c2c_yz;
cufftHandle m_Plan_c2c_xy;
// work size
int64_t m_WorkSize;
// Internal functions
HRESULT InitializePencils(cudecompGridDescConfig_t& iGridConfig, cudecompPencilInfo_t& x_pencil_info,
cudecompPencilInfo_t& y_pencil_info, cudecompPencilInfo_t& z_pencil_info,
int64_t& work_size, const bool& is_contiguous);
HRESULT InitializePencils(int64_t& work_size, fftDescriptor& fft_descriptor);
HRESULT InitializeSlabXY(cudecompGridDescConfig_t& iGridConfig, cudecompPencilInfo_t& x_pencil_info,
cudecompPencilInfo_t& y_pencil_info, cudecompPencilInfo_t& z_pencil_info, int64_t& work_size,
const bool& is_contiguous);
HRESULT InitializeSlabXY(int64_t& work_size, fftDescriptor& fft_descriptor);
HRESULT InitializeSlabYZ(cudecompGridDescConfig_t& iGridConfig, cudecompPencilInfo_t& x_pencil_info,
cudecompPencilInfo_t& y_pencil_info, cudecompPencilInfo_t& z_pencil_info, int64_t& work_size,
const bool& is_contiguous);
HRESULT InitializeSlabYZ(int64_t& work_size, fftDescriptor& fft_descriptor);
HRESULT forwardXY(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, complex_t* input,
complex_t* output, complex_t* work_buffer);
......@@ -161,9 +158,6 @@ private:
complex_t* output, complex_t* work_buffer);
HRESULT clearPlans();
// DEBUG ONLY ... I WARN YOU
void inspect_device_array(complex_t* data, int size, cudaStream_t stream);
};
} // namespace jaxdecomp
......@@ -174,11 +168,9 @@ template <> struct hash<jaxdecomp::fftDescriptor> {
// Only hash The double precision and the gdims and pdims
// If adjoint is changed then the plan should be the same
// adjoint is to be used to execute the backward plan
static const size_t xy_hash = std::hash<int>()(jaxdecomp::Decomposition::slab_XY);
size_t hash = std::hash<double>()(descriptor.double_precision) ^ std::hash<int>()(descriptor.gdims[0]) ^
std::hash<int>()(descriptor.gdims[1]) ^ std::hash<int>()(descriptor.gdims[2]) ^
std::hash<int>()(descriptor.decomposition);
std::hash<int>()(descriptor.decomposition) ^ std::hash<bool>()(descriptor.contiguous);
return hash;
}
};
......
File moved
File moved