Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • wassim/jaxDecomp
1 result
Show changes
Commits on Source (255)
Showing with 1432 additions and 0 deletions
---
BasedOnStyle: LLVM
ColumnLimit: 120
CommentPragmas: '^\\.+'
DerivePointerAlignment: false
Language: Cpp
PointerAlignment: Left
UseTab: Never
AlignAfterOpenBracket: Align
AlignTrailingComments: true
AllowShortBlocksOnASingleLine: true
AllowShortCaseLabelsOnASingleLine : true
AllowShortIfStatementsOnASingleLine: true
AllowShortLoopsOnASingleLine: true
...
name: Code Formatting
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
- name: Install dependencies
run: |
python -m pip install --upgrade pip isort
python -m pip install pre-commit
- name: Run pre-commit
run: python -m pre_commit run --all-files
name: Draft PDF
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
jobs:
paper:
runs-on: ubuntu-latest
name: Paper Draft
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Build draft PDF
uses: openjournals/openjournals-draft-action@master
with:
journal: joss
# This should be the path to the paper within your repo.
paper-path: joss-paper/paper.md
- name: Upload
uses: actions/upload-artifact@v3
with:
name: paper
# This is the output path where Pandoc will write the compiled
# PDF. Note, this should be the same directory as the input
# paper.md
path: joss-paper/paper.pdf
name: Tests
on:
push:
branches:
- main
pull_request:
branches:
- main
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.10.4]
steps:
- name: Checkout Source
uses: actions/checkout@v2.3.1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[tests]
- name: Run tests
run: |
cd tests
export JAX_PLATFORM_NAME=cpu
export XLA_FLAGS='--xla_force_host_platform_device_count=8'
pytest test_fft.py
pytest test_halo.py
pytest test_transpose.py
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
.vscode/
scripts/experimental
compile_commands.json
CMakeFiles
traces*
notes.txt
[submodule "third_party/cuDecomp"]
path = third_party/cuDecomp
url = https://github.com/NVIDIA/cuDecomp.git
[submodule "pybind11"]
path = pybind11
url = https://github.com/pybind/pybind11.git
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/google/yapf
rev: v0.40.2
hooks:
- id: yapf
args: ['--parallel', '--in-place']
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.4
hooks:
- id: clang-format
files: '\.(c|cc|cpp|h|hpp|cxx|hh|cu|cuh)$'
exclude: '^third_party/|/pybind11/'
name: clang-format
[style]
based_on_style = yapf
# Change log
## jaxdecomp 0.1.0
* Changes
* Fixed bug with Halo
* Added joss paper
## jaxdecomp 0.0.1
* Changes
* New version compatible with JAX 0.4.30
* jaxDecomp now works in a multi-host environment
* Added custom partitioning for FFTs
* Added custom partitioning for halo exchange
* Added custom partitioning for slice_pad and slice_unpad
* Add example for multi-host FFTs in `examples/jaxdecomp_lpt.py`
## jaxdecomp 0.0.1rc2
* Changes
* Added utility to run autotuning
## jaxdecomp 0.0.1rc1 (Nov. 25th 2022)
Initial pre-release, include support for parallel ffts and halo exchange
cmake_minimum_required(VERSION 3.19...3.25)
project(jaxdecomp LANGUAGES CXX)
# NVCC 12 does not support C++20
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
option(JD_CUDECOMP_BACKEND "Use cuDecomp backend" OFF)
# Set default build type to Release
if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." FORCE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
endif()
# Check for CUDA
include(CheckLanguage)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER AND JD_CUDECOMP_BACKEND)
enable_language(CUDA)
# Latest JAX v0.4.26 no longer supports cuda 11.8
find_package(CUDAToolkit REQUIRED VERSION 12)
set(NVHPC_CUDA_VERSION ${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR})
message(STATUS "Using CUDA ${NVHPC_CUDA_VERSION}")
add_subdirectory(third_party/cuDecomp)
option(CUDECOMP_BUILD_FORTRAN "Build Fortran bindings" OFF)
option(CUDECOMP_ENABLE_NVSHMEM "Enable NVSHMEM" OFF)
option(CUDECOMP_BUILD_EXTRAS "Build benchmark, examples, and tests" OFF)
# 70: Volta, 80: Ampere, 89: RTX 4060
set(CUDECOMP_CUDA_CC_LIST "70;80;89" CACHE STRING "List of CUDA compute capabilities to build cuDecomp for.")
# Add pybind11 and cuDecomp subdirectories
add_subdirectory(pybind11)
find_package(NVHPC REQUIRED COMPONENTS MATH MPI NCCL)
string(REPLACE "/lib64" "/include" NVHPC_MATH_INCLUDE_DIR ${NVHPC_MATH_LIBRARY_DIR})
string(REPLACE "/lib64" "/include" NVHPC_CUDA_INCLUDE_DIR ${NVHPC_CUDA_LIBRARY_DIR})
find_library(NCCL_LIBRARY
NAMES nccl
HINTS ${NVHPC_NCCL_LIBRARY_DIR}
)
string(REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR})
message(STATUS "Using NCCL library: ${NCCL_LIBRARY}")
message(STATUS "NVHPC NCCL lib dir: ${NVHPC_NCCL_LIBRARY_DIR}")
message(STATUS "NCCL include dir: ${NCCL_INCLUDE_DIR}")
# Add _jaxdecomp modulei
pybind11_add_module(_jaxdecomp
src/halo.cu
src/jaxdecomp.cc
src/grid_descriptor_mgr.cc
src/fft.cu
src/transpose.cu
)
set_target_properties(_jaxdecomp PROPERTIES CUDA_ARCHITECTURES "${CUDECOMP_CUDA_CC_LIST}")
target_include_directories(_jaxdecomp
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/third_party/cuDecomp/include
${NVHPC_CUDA_INCLUDE_DIR}
${MPI_CXX_INCLUDE_DIRS}
${NVHPC_MATH_INCLUDE_DIR}
${NCCL_INCLUDE_DIR}
)
target_link_libraries(_jaxdecomp PRIVATE MPI::MPI_CXX)
target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUFFT)
target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUTENSOR)
target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUDA)
target_link_libraries(_jaxdecomp PRIVATE ${NCCL_LIBRARY})
target_link_libraries(_jaxdecomp PRIVATE cudecomp)
target_link_libraries(_jaxdecomp PRIVATE stdc++fs)
set_target_properties(_jaxdecomp PROPERTIES LINKER_LANGUAGE CXX)
target_compile_definitions(_jaxdecomp PRIVATE JD_CUDECOMP_BACKEND)
else()
add_subdirectory(pybind11)
pybind11_add_module(_jaxdecomp src/jaxdecomp.cc)
target_include_directories(_jaxdecomp PRIVATE ${CMAKE_CURRENT_LIST_DIR}/include)
target_compile_definitions(_jaxdecomp PRIVATE JD_JAX_BACKEND)
endif()
set_target_properties(_jaxdecomp PROPERTIES INSTALL_RPATH "$ORIGIN/lib")
install(TARGETS _jaxdecomp LIBRARY DESTINATION . PUBLIC_HEADER DESTINATION .)
# Code and Contribution guidelines
## Code formatting
Formatting is enforced using [yapf](https://github.com/google/yapf) and automatically applied using pre-commit hooks. To manually format the code, run the following command:
```shell
yapf -i -r .
```
but we highly recommend using the pre-commit hooks to ensure consistent formatting across the codebase, see below.
### Pre-commit
We use pre-commit to enforce code formatting and quality standards. Follow these steps to install and use pre-commit:
1. Make sure you have Python installed on your system.
2. Install pre-commit by running the following command in your terminal:
```shell
pip install pre-commit
```
3. Navigate to the root directory of your project.
4. Run the following command to initialize pre-commit:
```shell
pre-commit install
```
5. Now, whenever you make a commit, pre-commit will automatically run the configured hooks on the files you modified. If any issues are found, pre-commit will prevent the commit from being made.
You can also manually run pre-commit on all files by running the following command:
```shell
pre-commit run --all-files
```
This is useful if you want to check all files before making a commit.
6. Customize the pre-commit configuration by creating a `.pre-commit-config.yaml` file in the root directory of your project. This file allows you to specify which hooks should be run and how they should be configured.
For more information on configuring pre-commit, refer to the [pre-commit documentation](https://pre-commit.com/#configuration).
That's it! You now have pre-commit set up to automatically enforce code formatting and quality standards in your project.
MIT License
Copyright (c) 2022 Differentiable Universe Initiative
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# jaxDecomp: JAX Library for 3D Domain Decomposition and Parallel FFTs
[![Code Formatting](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/actions/workflows/formatting.yml/badge.svg)](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/actions/workflows/formatting.yml)
JAX bindings for NVIDIA's [cuDecomp](https://nvidia.github.io/cuDecomp/index.html) library [(Romero et al. 2022)](https://dl.acm.org/doi/abs/10.1145/3539781.3539797), allowing for efficient **multi-node parallel FFTs and halo exchanges** directly in low level NCCL/CUDA-Aware MPI from your JAX code :tada:
## Usage
Here is an example of how to use `jaxDecomp` to perform a 3D FFT on a 3D array distributed across multiple GPUs. This example also includes a halo exchange operation, which is a common operation in many scientific computing applications.
```python
import jax
from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import Mesh, PartitionSpec as P
import jaxdecomp
# Initialise the library, and optionally selects a communication backend (defaults to NCCL)
jaxdecomp.config.update('halo_comm_backend', jaxdecomp.HALO_COMM_MPI)
jaxdecomp.config.update('transpose_comm_backend', jaxdecomp.TRANSPOSE_COMM_MPI_A2A)
# Initialize jax distributed to instruct jax local process which GPU to use
jax.distributed.initialize()
rank = jax.process_index()
# Setup a processor mesh (should be same size as "size")
pdims= (1,4)
global_shape=[1024,1024,1024]
# Initialize an array with the expected gobal size
local_array = jax.random.normal(
shape=[
global_shape[0] // pdims[1], global_shape[1] // pdims[0],
global_shape[2]
],
key=jax.random.PRNGKey(rank))
# Remap to the global array from the local slice
devices = mesh_utils.create_device_mesh(pdims[::-1])
mesh = Mesh(devices, axis_names=('z', 'y'))
global_array = multihost_utils.host_local_array_to_global_array(
local_array, mesh, P('z', 'y'))
# Forward FFT, note that the output FFT is transposed
@jax.jit
def modify_array(array):
return 2 * array + 1
with mesh:
# Forward FFT
karray = jaxdecomp.fft.pfft3d(global_array)
# Do some operation on your array
karray = modify_array(karray)
# Reverse FFT
recarray = jaxdecomp.fft.pifft3d(karray).astype('float32')
# Add halo regions to our array
padding_width = ((32,32),(32,32),(32,32)) # Has to a tuple of tuples
padded_array = jaxdecomp.slice_pad(recarray, padding_width , pdims)
# Perform a halo exchange
exchanged_array = jaxdecomp.halo_exchange(padded_array,
halo_extents=(32,32,32),
halo_periods=(True,True,True))
# Remove the halo regions
recarray = jaxdecomp.slice_unpad(exchanged_array, padding_width, pdims)
# Gather the results (only if it fits on CPU memory)
gathered_array = multihost_utils.process_allgather(recarray, tiled=True)
# Finalize the library
jax.distributed.shutdown()
```
**Note**: All these functions are jittable and have well defined derivatives!
This script would have to be run on 8 GPUs in total with something like
```bash
$ mpirun -n 8 python demo.py
```
On an HPC cluster like Jean Zay you should do this
```bash
$ srun python demo.py
```
Check the slurm [README](slurms/README.md) and [template](slurms/template.slurm) for more information on how to run on a Jean Zay.
### Caveats
The code presented above should work, but there are a few caveats mentioned in [this issue](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/issues/1). If you need a functionality that is not currently implemented, feel free to mention it on that issue.
## Install
Start by cloning this repository locally on your cluster:
```bash
$ git clone --recurse-submodules https://github.com/DifferentiableUniverseInitiative/jaxDecomp
```
#### Requirements
This install procedure assumes that the [NVIDIA HPC SDK](https://developer.nvidia.com/hpc-sdk) is available in your environment. You can either install it from the NVIDIA website, or better yet, it may be available as a module on your cluster.
Make sure all environment variables relative to the SDK are properly set.
### Building jaxDecomp
From this directory, install & build jaxDecomp via pip
```bash
$ pip install --user .
```
If CMake complains of not finding the NVHPC SDK, you can manually specify the location
of the sdk's cmake files like so:
```
$ export CMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH:$NVCOMPILERS/$NVARCH/22.9/cmake
$ pip install --user .
```
### Specific Install Notes for Specific Machines
#### IDRIS [Jean Zay](http://www.idris.fr/eng/jean-zay/cpu/jean-zay-cpu-hw-eng.html) HPE SGI 8600 supercomputer
As of September. 2024, the following works:
You need to load modules in that order exactly.
```bash
# Load NVHPC 23.9 because it has cuda 12.2
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
# Installing jax
pip install --upgrade "jax[cuda12]"
# Installing jaxdecomp
export CMAKE_PREFIX_PATH=$NVHPC_ROOT/cmake # Not always needed
pip install .
```
#### NERSC [Perlmutter](https://docs.nersc.gov/systems/perlmutter/architecture/) HPE Cray EX supercomputer
As of Nov. 2022, the following works:
```bash
module load PrgEnv-nvhpc python
export CRAY_ACCEL_TARGET=nvidia80
# Installing mpi4py
MPICC="cc -target-accel=nvidia80 -shared" CC=nvc CFLAGS="-noswitcherror" pip install --force --no-cache-dir --no-binary=mpi4py mpi4py
# Installing jax
pip install --upgrade "jax[cuda12]"
# Installing jaxdecomp
export CMAKE_PREFIX_PATH=/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cmake
pip install .
```
## Design
Here is what works now :
```python
from jaxdecomp.fft import pfft3, ipfft3
import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils, multihost_utils
# Initialize jax distributed to instruct jax local process which GPU to use
jax.distributed.initialize()
pdims = (2 , 4)
global_shape = (512 , 512 , 512 )
local_array = jax.random.normal(shape=[global_shape[0]//pdims[0],
global_shape[1]//pdims[1],
global_shape[2]], key=jax.random.PRNGKey(0))
# remap to global array (this is a free call no communications are happening)
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('z', 'y'))
global_array = multihost_utils.host_local_array_to_global_array(
array, mesh, P('z', 'y'))
with mesh
z = pfft3(global_array)
# If we could inspect the distribution of y, we would see that it is sliced in 2 along x, and 4 along y
# This could also be part of a jitted function, no problem
z_rec = ipfft3(z)
# And z remains at all times distributed.
jaxdecomp.finalize()
jax.distributed.shutdown()
```
#### Backend configuration
We can set the default communication backend to use for cuDecomp operations either through a `config` module, or environment variables. This will allow the users to choose at startup (although can be changed afterwards) the communication backend, making it possible to use CUDA-aware MPI or NVSHMEM as preferred.
Here is how it would like:
```python
jaxdecomp.config.update('transpose_comm_backend', 'NCCL')
# We could for instance time how long it takes to execute in this mode
%timeit pfft3(y)
# And then update the backend
jaxdecomp.config.update('transpose_comm_backend', 'MPI')
# And measure again
%timeit pfft3(y)
```
#### Autotune computational mesh
We can also make things fancier, since cuDecomp is able to autotune, we could use it to tell us what is the best way to partition the data given the available GPUs, something like this:
```python
automesh = jaxdecomp.autotune(shape=[512,512,512])
# This is a JAX Sharding spec object, optimized for the given GPUs
# and shape of the tensor
sharding = PositionalSharding(automesh)
```
# Use-Case Examples
This directory contains examples of how to use the jaxDecomp library on a few use cases.
## Distributed LPT Cosmological Simulation
This example demonstrates the use of the 3D distributed FFT and halo exchange functions in the `jaxDecomp` library to implement a distributed LPT cosmological simulation. We provide a notebook to visualize the results of the simulation in [visualizer.ipynb](visualizer.ipynb).
To run the demo, some additional dependencies are required. You can install them by running:
```bash
pip install jax-cosmo
```
Then, you can run the example by executing the following command:
```bash
mpirun -n 4 python lpt_nbody_demo.py --nc 256 --box_size 256 --pdims 4x4 --halo_size 32 --output out
```
We also include an example of a slurm script in [submit_rusty.sbatch](submit_rusty.sbatch) that can be used to run the example on a slurm cluster with:
```bash
sbatch submit_rusty.sbatch
```
import argparse
import os
from functools import partial
from typing import Any, Callable, Hashable
Specs = Any
AxisName = Hashable
import jax
jax.config.update('jax_enable_x64', False)
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from jax._src import mesh as mesh_lib
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from scatter import scatter
import jaxdecomp
def shmap(f: Callable,
in_specs: Specs,
out_specs: Specs,
check_rep: bool = True,
auto: frozenset[AxisName] = frozenset()):
"""Helper function to create a shard_map function that extracts the mesh from the
context."""
mesh = mesh_lib.thread_resources.env.physical_mesh
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
def _global_to_local_size(nc: int):
""" Helper function to get the local size of a mesh given the global size.
"""
pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape
return [nc // pdims[0], nc // pdims[1], nc]
def fttk(nc: int) -> list:
"""
Generate Fourier transform wave numbers for a given mesh.
Args:
nc (int): Shape of the mesh grid.
Returns:
list: List of wave number arrays for each dimension in
the order [kx, ky, kz].
"""
kd = np.fft.fftfreq(nc) * 2 * np.pi
@partial(
shmap,
in_specs=(P('x'), P('y'), P(None)),
out_specs=(P('x'), P(None, 'y'), P(None)))
def get_kvec(ky, kz, kx):
return (ky.reshape([-1, 1, 1]),
kz.reshape([1, -1, 1]),
kx.reshape([1, 1, -1])) # yapf: disable
ky, kz, kx = get_kvec(kd, kd, kd) # The order of the output
# corresponds to the order of dimensions in the transposed FFT
# output
return kx, ky, kz
def gravity_kernel(kx, ky, kz):
""" Computes a Fourier kernel combining laplace and derivative
operators to compute gravitational forces.
Args:
kvec (tuple of float): Wave numbers in Fourier space.
Returns:
tuple of jnp.ndarray: kernels for each dimension.
"""
kk = kx**2 + ky**2 + kz**2
laplace_kernel = jnp.where(kk == 0, 1., 1. / kk)
grav_kernel = (laplace_kernel * 1j * kx,
laplace_kernel * 1j * ky,
laplace_kernel * 1j * kz) # yapf: disable
return grav_kernel
def gaussian_field_and_forces(key, nc, box_size, power_spectrum):
"""
Generate a Gaussian field with a given power spectrum, along with gravitational forces.
Args:
key (int): Key for the random number generator.
nc (int): Number of cells in the mesh.
box_size (float): Size of the box.
power_spectrum (callable): Power spectrum function.
Returns:
tuple of jnp.ndarray: The generated Gaussian field and the gravitational forces.
"""
local_mesh_shape = _global_to_local_size(nc)
# Create a distributed field drawn from a Gaussian distribution in real space
delta = shmap(
partial(jax.random.normal, shape=local_mesh_shape, dtype='float32'),
in_specs=P(None),
out_specs=P('x', 'y'))(key) # yapf: disable
# Compute the Fourier transform of the field
delta_k = jaxdecomp.fft.pfft3d(delta.astype(jnp.complex64))
# Compute the Fourier wavenumbers of the field
kx, ky, kz = fttk(nc)
kk = jnp.sqrt(kx**2 + ky**2 + kz**2) * (nc / box_size)
# Apply power spectrum to Fourier modes
delta_k *= (power_spectrum(kk) * (nc / box_size)**3)**0.5
# Compute inverse Fourier transform to recover the initial conditions in real space
delta = jaxdecomp.fft.pifft3d(delta_k).real
# Compute gravitational forces associated with this field
grav_kernel = gravity_kernel(kx, ky, kz)
forces_k = [g * delta_k for g in grav_kernel]
# Retrieve the forces in real space by inverse Fourier transforming
forces = jnp.stack([jaxdecomp.fft.pifft3d(f).real for f in forces_k], axis=-1)
return delta, forces
def cic_paint(displacement, halo_size):
""" Paints particles on a mesh using Cloud-In-Cell interpolation.
Args:
displacement (jnp.ndarray): Displacement of each particle.
halo_size (int): Halo size for painting.
Returns:
jnp.ndarray: Density field.
"""
local_mesh_shape = _global_to_local_size(displacement.shape[0])
hs = halo_size
@partial(shmap, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))
def cic_op(disp):
""" CiC operation on each local slice of the mesh."""
# Create a mesh to paint the particles on for the local slice
mesh = jnp.zeros(disp.shape[:-1], dtype='float32')
# Padding the mesh along the two first dimensions
mesh = jnp.pad(mesh, [[hs, hs], [hs, hs], [0, 0]])
# Compute the position of the particles on a regular grid
pos_x, pos_y, pos_z = jnp.meshgrid(
jnp.arange(local_mesh_shape[0]),
jnp.arange(local_mesh_shape[1]),
jnp.arange(local_mesh_shape[2]),
indexing='ij')
# adding an offset of size halo size
pos = jnp.stack([pos_x + hs, pos_y + hs, pos_z], axis=-1)
# Apply scatter operation to paint the particles on the local mesh
field = scatter(pos.reshape([-1, 3]), disp.reshape([-1, 3]), mesh)
return field
# Performs painting on a padded mesh, with halos on the two first dimensions
field = cic_op(displacement)
# Run halo exchange to get the correct values at the boundaries
field = jaxdecomp.halo_exchange(
field,
halo_extents=(hs // 2, hs // 2, 0),
halo_periods=(True, True, True))
@partial(shmap, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))
def unpad(x):
""" Removes the padding and reduce the halo regions"""
x = x.at[hs:hs + hs // 2].add(x[:hs // 2])
x = x.at[-(hs + hs // 2):-hs].add(x[-hs // 2:])
x = x.at[:, hs:hs + hs // 2].add(x[:, :hs // 2])
x = x.at[:, -(hs + hs // 2):-hs].add(x[:, -hs // 2:])
return x[hs:-hs, hs:-hs, :]
# Unpad the output array
field = unpad(field)
return field
@partial(jax.jit, static_argnames=('nc', 'box_size', 'halo_size'))
def simulation_fn(key, nc, box_size, halo_size, a=1.0):
"""
Run a simulation to generate initial conditions and density field using LPT.
Args:
key (list of int): Jax random key for the random number generator.
nc (int): Size of the mesh grid.
box_size (float): Size of the box.
halo_size (int): Halo size for painting.
a (float): Scale factor of final field.
Returns:
tuple of jnp.ndarray: Initial conditions and final density field.
"""
# Build a default cosmology
cosmo = jc.Planck15()
# Create a small function to generate the linear matter power spectrum at arbitrary k
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(cosmo, k)
pk_fn = jax.jit(lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).
reshape(x.shape))
# Generate a Gaussian field and gravitational forces from a power spectrum
intial_conditions, initial_forces = gaussian_field_and_forces(
key=key, nc=nc, box_size=box_size, power_spectrum=pk_fn)
# Compute the LPT displacement that particles initialy placed on a regular grid
# would experience at scale factor a, by simple Zeldovich approximation
initial_displacement = jc.background.growth_factor(
cosmo, jnp.atleast_1d(a)) * initial_forces
# Paints the displaced particles on a mesh to obtain the density field
final_field = cic_paint(initial_displacement, halo_size)
return intial_conditions, final_field
def main(args):
print(f"Running with arguments {args}")
# Setting up distributed jax
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()
# Setting up distributed random numbers
master_key = jax.random.PRNGKey(42)
key = jax.random.split(master_key, size)[rank]
# Create computing mesh and sharding information
pdims = tuple(map(int, args.pdims.split('x')))
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y'))
# Run the simulation on the compute mesh
with mesh:
initial_conds, final_field = simulation_fn(
key=key, nc=args.nc, box_size=args.box_size, halo_size=args.halo_size)
# Create output directory to save the results
output_dir = args.output
os.makedirs(output_dir, exist_ok=True)
np.save(f'{output_dir}/initial_conditions_{rank}.npy',
initial_conds.addressable_data(0))
np.save(f'{output_dir}/field_{rank}.npy', final_field.addressable_data(0))
print(f"Finished saved to {output_dir}")
# Closing distributed jax
jax.distributed.shutdown()
if __name__ == '__main__':
parser = argparse.ArgumentParser("Distributed LPT N-body simulation.")
parser.add_argument(
'--pdims', type=str, default='1x1', help="Processor grid dimensions")
parser.add_argument(
'--nc', type=int, default=256, help="Number of cells in the mesh")
parser.add_argument(
'--box_size', type=float, default=512., help="Box size in Mpc/h")
parser.add_argument(
'--halo_size', type=int, default=32, help="Halo size for painting")
parser.add_argument('--output', type=str, default='out')
args = parser.parse_args()
main(args)
# This file is adapted from the scatter implementation of the pmwd library
# https://github.com/eelregit/pmwd/blob/master/pmwd/scatter.py
# It provides a simple way to perform a scatter operation by chunks and saves
# memory compared to a native jax.lax.scatter.
# Below is the orginal license of the pmwd library:
###############################################################################
# BSD 3-Clause License
#
# Copyright (c) 2021, the pmwd developers
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import jax.numpy as jnp
from jax.lax import scan
def _chunk_split(ptcl_num, chunk_size, *arrays):
"""Split and reshape particle arrays into chunks and remainders, with the remainders
preceding the chunks. 0D ones are duplicated as full arrays in the chunks."""
chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num)
remainder_size = ptcl_num % chunk_size
chunk_num = ptcl_num // chunk_size
remainder = None
chunks = arrays
if remainder_size:
remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays]
chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays]
# `scan` triggers errors in scatter and gather without the `full`
chunks = [
x.reshape(chunk_num, chunk_size, *x.shape[1:])
if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks
]
return remainder, chunks
def enmesh(i1, d1, a1, s1, b12, a2, s2):
"""Multilinear enmeshing."""
i1 = jnp.asarray(i1)
d1 = jnp.asarray(d1)
a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1, dtype=d1.dtype)
if s1 is not None:
s1 = jnp.array(s1, dtype=i1.dtype)
b12 = jnp.float64(b12)
if a2 is not None:
a2 = jnp.float64(a2)
if s2 is not None:
s2 = jnp.array(s2, dtype=i1.dtype)
dim = i1.shape[1]
neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] >> jnp.arange(
dim, dtype=i1.dtype)) & 1
if a2 is not None:
P = i1 * a1 + d1 - b12
P = P[:, jnp.newaxis] # insert neighbor axis
i2 = P + neighbors * a2 # multilinear
if s1 is not None:
L = s1 * a1
i2 %= L
i2 //= a2
d2 = P - i2 * a2
if s1 is not None:
d2 -= jnp.rint(d2 / L) * L # also abs(d2) < a2 is expected
i2 = i2.astype(i1.dtype)
d2 = d2.astype(d1.dtype)
a2 = a2.astype(d1.dtype)
d2 /= a2
else:
i12, d12 = jnp.divmod(b12, a1)
i1 -= i12.astype(i1.dtype)
d1 -= d12.astype(d1.dtype)
# insert neighbor axis
i1 = i1[:, jnp.newaxis]
d1 = d1[:, jnp.newaxis]
# multilinear
d1 /= a1
i2 = jnp.floor(d1).astype(i1.dtype)
i2 += neighbors
d2 = d1 - i2
i2 += i1
if s1 is not None:
i2 %= s1
f2 = 1 - jnp.abs(d2)
if s1 is None and s2 is not None: # all i2 >= 0 if s1 is not None
i2 = jnp.where(i2 < 0, s2, i2)
f2 = f2.prod(axis=-1)
return i2, f2
def scatter(pmid, disp, mesh, chunk_size=2**24, val=1., offset=0, cell_size=1.):
ptcl_num, spatial_ndim = pmid.shape
val = jnp.asarray(val)
mesh = jnp.asarray(mesh)
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
carry = mesh, offset, cell_size
if remainder is not None:
carry = _scatter_chunk(carry, remainder)[0]
carry = scan(_scatter_chunk, carry, chunks)[0]
mesh = carry[0]
return mesh
def _scatter_chunk(carry, chunk):
mesh, offset, cell_size = carry
pmid, disp, val = chunk
spatial_ndim = pmid.shape[1]
spatial_shape = mesh.shape
# multilinear mesh indices and fractions
ind, frac = enmesh(pmid, disp, cell_size, spatial_shape, offset, cell_size,
spatial_shape)
# scatter
ind = tuple(ind[..., i] for i in range(spatial_ndim))
mesh = mesh.at[ind].add(val * frac)
carry = mesh, offset, cell_size
return carry, None
#!/bin/bash -l
#SBATCH -p gpu
#SBATCH -t 0:10:00
#SBATCH -C a100
#SBATCH -N 1
#SBATCH --ntasks-per-node=4
#SBATCH --cpus-per-task=16
#SBATCH --gpus-per-task=1
module load modules/2.3
module load gcc nvhpc python
source ~/venvs/jaxdecomp724/bin/activate
export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK
export LD_LIBRARY_PATH=$NVHPC_ROOT/Linux_x86_64/24.3/cuda/12.3/extras/CUPTI/lib64:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$NVHPC_ROOT/Linux_x86_64/24.3/cuda/12.3/lib64:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$NVHPC_ROOT/Linux_x86_64/24.3/comm_libs/nccl/lib:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$NVHPC_ROOT/Linux_x86_64/24.3/math_libs/lib64:$LD_LIBRARY_PATH
mpirun python3 lpt_nbody_demo.py --pdims 2x2
This diff is collapsed.
#ifndef _JAX_DECOMP_CHECKS_H_
#define _JAX_DECOMP_CHECKS_H_
#include <cstdio>
typedef long HRESULT;
using namespace std;
#define SUCCEEDED(hr) (((HRESULT)(hr)) >= 0)
#define FAILED(hr) (((HRESULT)(hr)) < 0)
#define S_OK ((HRESULT)0x00000000L)
#define S_FALSE ((HRESULT)1L)
#define E_ABORT ((HRESULT)0x80004004L)
#define E_ACCESSDENIED ((HRESULT)0x80070005L)
#define E_FAIL ((HRESULT)0x80004005L)
#define E_HANDLE ((HRESULT)0x80070006L)
#define E_INVALIDARG ((HRESULT)0x80070057L)
#define E_NOINTERFACE ((HRESULT)0x80004002L)
#define E_NOTIMPL ((HRESULT)0x80004001L)
#define E_OUTOFMEMORY ((HRESULT)0x8007000EL)
#define E_POINTER ((HRESULT)0x80004003L)
#define E_UNEXPECTED ((HRESULT)0x8000FFFFL)
#define E_OUTOFMEMORY ((HRESULT)0x8007000EL)
#define E_NOTIMPL ((HRESULT)0x80004001L)
// Macro to check for CUDA errors
#define CHECK_CUDA_EXIT(call) \
do { \
cudaError_t err = call; \
if (err != cudaSuccess) { \
printf("CUDA error at %s %d: %s\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
} while (0)
// Error checking macros
#define CHECK_CUDECOMP_EXIT(call) \
do { \
cudecompResult_t err = call; \
if (CUDECOMP_RESULT_SUCCESS != err) { \
fprintf(stderr, "%s:%d CUDECOMP error. (error code %d)\n", __FILE__, __LINE__, err); \
throw exception(); \
} \
} while (false)
#define CHECK_CUFFT_EXIT(call) \
do { \
cufftResult_t err = call; \
if (CUFFT_SUCCESS != err) { \
fprintf(stderr, "%s:%d CUFFT error. (error code %d)\n", __FILE__, __LINE__, err); \
throw exception(); \
} \
} while (false)
#define CHECK_MPI_EXIT(call) \
{ \
int err = call; \
if (0 != err) { \
char error_str[MPI_MAX_ERROR_STRING]; \
int len; \
MPI_Error_string(err, error_str, &len); \
if (error_str) { \
fprintf(stderr, "%s:%d MPI error. (%s)\n", __FILE__, __LINE__, error_str); \
} else { \
fprintf(stderr, "%s:%d MPI error. (error code %d)\n", __FILE__, __LINE__, err); \
} \
exit(EXIT_FAILURE); \
} \
} \
while (false)
#define HR2STR(hr) \
((hr == S_OK) ? "S_OK" \
: (hr == S_FALSE) ? "S_FALSE" \
: (hr == E_ABORT) ? "E_ABORT" \
: (hr == E_ACCESSDENIED) ? "E_ACCESSDENIED" \
: (hr == E_FAIL) ? "E_FAIL" \
: (hr == E_HANDLE) ? "E_HANDLE" \
: (hr == E_INVALIDARG) ? "E_INVALIDARG" \
: (hr == E_NOINTERFACE) ? "E_NOINTERFACE" \
: (hr == E_NOTIMPL) ? "E_NOTIMPL" \
: (hr == E_OUTOFMEMORY) ? "E_OUTOFMEMORY" \
: (hr == E_POINTER) ? "E_POINTER" \
: (hr == E_UNEXPECTED) ? "E_UNEXPECTED" \
: "Unknown HRESULT")
#endif // _JAX_DECOMP_CHECKS_H_
#ifndef _JAX_DECOMP_FFT_H_
#define _JAX_DECOMP_FFT_H_
#include "checks.h"
#include "logger.hpp"
#include <array>
#include <cmath> // has to be included before cuda/std/complex
#include <cstddef>
#include <cstdio>
#include <cuda/std/complex>
#include <cudecomp.h>
#include <cufftXt.h>
#include <mpi.h>
static bool get_double_precision(float) { return false; }
static bool get_double_precision(double) { return true; }
static cufftType get_cufft_type_r2c(double) { return CUFFT_D2Z; }
static cufftType get_cufft_type_r2c(float) { return CUFFT_R2C; }
static cufftType get_cufft_type_c2r(double) { return CUFFT_Z2D; }
static cufftType get_cufft_type_c2r(float) { return CUFFT_C2R; }
static cufftType get_cufft_type_c2c(double) { return CUFFT_Z2Z; }
static cufftType get_cufft_type_c2c(float) { return CUFFT_C2C; }
static cudecompDataType_t get_cudecomp_datatype(cuda::std::complex<float>) { return CUDECOMP_FLOAT_COMPLEX; }
static cudecompDataType_t get_cudecomp_datatype(cuda::std::complex<double>) { return CUDECOMP_DOUBLE_COMPLEX; }
namespace jaxdecomp {
enum Decomposition { slab_XY = 0, slab_YZ = 1, pencil = 2, no_decomp = 3 };
static Decomposition GetDecomposition(const int pdims[2]) {
if (pdims[0] == 1 && pdims[1] > 1) {
return Decomposition::slab_XY;
} else if (pdims[0] > 1 && pdims[1] == 1) {
return Decomposition::slab_YZ;
} else if (pdims[0] > 1 && pdims[1] > 1) {
return Decomposition::pencil;
}
return Decomposition::no_decomp;
}
// fftDescriptor hash should be triavially computable
// because it contains only bools and integers
class fftDescriptor {
public:
bool adjoint = false;
bool contiguous = true;
bool forward = true; ///< forward or backward pass
// fft_is_forward_pass and forwad are used for the Execution but not for the
// hash This way IFFT and FFT have the same plans when operating with the same
// grid and pdims
int32_t gdims[3]; ///< dimensions of global data grid
// Decomposition type is used in order to allow reusing plans
// from the XY and XZ forward pass for ZY and YZ backward pass respectively
Decomposition decomposition = Decomposition::no_decomp; ///< decomposition type
bool double_precision = false;
cudecompGridDescConfig_t config; // Descriptor for the grid
// To make it trivially copyable
fftDescriptor() = default;
fftDescriptor(const fftDescriptor& other) = default;
fftDescriptor& operator=(const fftDescriptor& other) = default;
// Create a compare operator to be used in the unordered_map (a hash is also
// created in the bottom of the file)
bool operator==(const fftDescriptor& other) const {
if (double_precision != other.double_precision || gdims[0] != other.gdims[0] || gdims[1] != other.gdims[1] ||
gdims[2] != other.gdims[2] || decomposition != other.decomposition || contiguous != other.contiguous) {
return false;
}
return true;
}
~fftDescriptor() = default;
// Initialize the descriptor from the grid configuration
// this is used for subsequent ffts to find the Executor that was already
// defined
fftDescriptor(cudecompGridDescConfig_t& config, const bool& double_precision, const bool& iForward,
const bool& iAdjoint, const bool& iContiguous, const Decomposition& iDecomposition)
: double_precision(double_precision), config(config), forward(iForward), contiguous(iContiguous),
adjoint(iAdjoint), decomposition(iDecomposition) {
gdims[0] = config.gdims[0];
gdims[1] = config.gdims[1];
gdims[2] = config.gdims[2];
this->config.transpose_axis_contiguous[0] = iContiguous;
this->config.transpose_axis_contiguous[1] = iContiguous;
this->config.transpose_axis_contiguous[2] = iContiguous;
}
};
template <typename real_t> class FourierExecutor {
using complex_t = cuda::std::complex<real_t>;
// Allow the Manager to access the private members in order to destroy the
// GridDesc
friend class GridDescriptorManager;
public:
FourierExecutor() : m_Tracer("JAXDECOMP") {}
~FourierExecutor();
HRESULT Initialize(cudecompHandle_t handle, size_t& work_size, fftDescriptor& fft_descriptor);
HRESULT forward(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, void** buffers);
HRESULT backward(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, void** buffers);
private:
AsyncLogger m_Tracer;
// GridDesc exists in GridConfig, but I rather store this way because of the C
// struct definition that cuDecomp team chose to do
// typedef struct cudecompGridDesc* cudecompGridDesc_t;
// This produces a warning of incomplete type in the forward declaration and
// To avoid this warning and having to include internal/common.h (which is not
// C++20 compliant) I chose to store the cudecompGridDescConfig_t
cudecompGridDesc_t m_GridConfig;
cudecompGridDescConfig_t m_GridDescConfig;
cudecompPencilInfo_t m_XPencilInfo;
cudecompPencilInfo_t m_YPencilInfo;
cudecompPencilInfo_t m_ZPencilInfo;
// For the sake of expressive code, plans have the name of their corresponding
// goal Instead of reusing pencils plans for slabs, or even ZY to YZ we store
// properly named plans
// For Pencils
cufftHandle m_Plan_c2c_x;
cufftHandle m_Plan_c2c_y;
cufftHandle m_Plan_c2c_z;
// For Slabs XY FFT (Y) FFT(XZ) but JAX redifines the axes to YZX as X pencil for cudecomp
// so in the end it is FFT(X) FFT(YZ)
// For Slabs XZ FFT (X) FFT(YZ)
cufftHandle m_Plan_c2c_yz;
cufftHandle m_Plan_c2c_xy;
// work size
int64_t m_WorkSize;
// Internal functions
HRESULT InitializePencils(int64_t& work_size, fftDescriptor& fft_descriptor);
HRESULT InitializeSlabXY(int64_t& work_size, fftDescriptor& fft_descriptor);
HRESULT InitializeSlabYZ(int64_t& work_size, fftDescriptor& fft_descriptor);
HRESULT forwardXY(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, complex_t* input,
complex_t* output, complex_t* work_buffer);
HRESULT backwardXY(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, complex_t* input,
complex_t* output, complex_t* work_buffer);
HRESULT forwardYZ(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, complex_t* input,
complex_t* output, complex_t* work_buffer);
HRESULT backwardYZ(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, complex_t* input,
complex_t* output, complex_t* work_buffer);
HRESULT forwardPencil(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, complex_t* input,
complex_t* output, complex_t* work_buffer);
HRESULT backwardPencil(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, complex_t* input,
complex_t* output, complex_t* work_buffer);
HRESULT clearPlans();
};
} // namespace jaxdecomp
namespace std {
template <> struct hash<jaxdecomp::fftDescriptor> {
std::size_t operator()(const jaxdecomp::fftDescriptor& descriptor) const {
// Only hash The double precision and the gdims and pdims
// If adjoint is changed then the plan should be the same
// adjoint is to be used to execute the backward plan
size_t hash = std::hash<double>()(descriptor.double_precision) ^ std::hash<int>()(descriptor.gdims[0]) ^
std::hash<int>()(descriptor.gdims[1]) ^ std::hash<int>()(descriptor.gdims[2]) ^
std::hash<int>()(descriptor.decomposition) ^ std::hash<bool>()(descriptor.contiguous);
return hash;
}
};
} // namespace std
#endif