Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • wassim/jaxDecomp
1 result
Show changes
Commits on Source (2)
Showing
with 2398 additions and 2164 deletions
name: Build and upload to PyPI
on:
workflow_dispatch:
pull_request:
push:
branches:
- main
release:
types:
- published
types: [published]
jobs:
build_wheels:
......@@ -25,7 +19,7 @@ jobs:
- name: Build wheels
uses: pypa/cibuildwheel@v2.21.3
env:
CIBW_BUILD: "cp310-* cp311-* cp312-*"
CIBW_BUILD: "cp310-manylinux_x86_64 cp311-manylinux_x86_64 cp312-manylinux_x86_64"
CIBW_BUILD_VERBOSITY: 2
- uses: actions/upload-artifact@v4
with:
......
......@@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.10.4]
python-version: ["3.10" , "3.11" , "3.12"]
steps:
- name: Checkout Source
......
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
rev: 'v5.0.0'
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/google/yapf
rev: v0.40.2
hooks:
- id: yapf
args: ['--parallel', '--in-place']
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.2
hooks:
- id: ruff-format
- id: ruff
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.4
hooks:
......
......@@ -94,4 +94,4 @@ else()
endif()
set_target_properties(_jaxdecomp PROPERTIES INSTALL_RPATH "$ORIGIN/lib")
install(TARGETS _jaxdecomp LIBRARY DESTINATION . PUBLIC_HEADER DESTINATION .)
install(TARGETS _jaxdecomp LIBRARY DESTINATION jaxdecomplib PUBLIC_HEADER DESTINATION jaxdecomplib)
# jaxDecomp: JAX Library for 3D Domain Decomposition and Parallel FFTs
[![Build](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/actions/workflows/github-deploy.yml/badge.svg)](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/actions/workflows/github-deploy.yml)
[![Code Formatting](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/actions/workflows/formatting.yml/badge.svg)](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/actions/workflows/formatting.yml)
[![Tests](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/actions/workflows/tests.yml/badge.svg)](https://github.com/DifferentiableUniverseInitiative/jaxDecomp/actions/workflows/tests.yml/badge.svg)
[![MIT License](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
......
import argparse
import os
from functools import partial
from typing import Any, Callable, Hashable
from typing import Any, Callable
from collections.abc import Hashable
Specs = Any
AxisName = Hashable
import jax
jax.config.update('jax_enable_x64', False)
jax.config.update("jax_enable_x64", False)
import jax.numpy as jnp
import jax_cosmo as jc
......@@ -23,26 +24,27 @@ from scatter import scatter
import jaxdecomp
def shmap(f: Callable,
in_specs: Specs,
out_specs: Specs,
check_rep: bool = True,
auto: frozenset[AxisName] = frozenset()):
"""Helper function to create a shard_map function that extracts the mesh from the
def shmap(
f: Callable,
in_specs: Specs,
out_specs: Specs,
check_rep: bool = True,
auto: frozenset[AxisName] = frozenset(),
):
"""Helper function to create a shard_map function that extracts the mesh from the
context."""
mesh = mesh_lib.thread_resources.env.physical_mesh
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
mesh = mesh_lib.thread_resources.env.physical_mesh
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
def _global_to_local_size(nc: int):
""" Helper function to get the local size of a mesh given the global size.
"""
pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape
return [nc // pdims[0], nc // pdims[1], nc]
"""Helper function to get the local size of a mesh given the global size."""
pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape
return [nc // pdims[0], nc // pdims[1], nc]
def fttk(nc: int):
"""
"""
Generate Fourier transform wave numbers for a given mesh.
Args:
......@@ -51,25 +53,27 @@ def fttk(nc: int):
Returns:
list: List of wave number arrays for each dimension in
the order [kx, ky, kz].
"""
kd = np.fft.fftfreq(nc) * 2 * np.pi
@partial(
shmap,
in_specs=(P('x'), P('y'), P(None)),
out_specs=(P('x'), P(None, 'y'), P(None)))
def get_kvec(ky, kz, kx):
return (ky.reshape([-1, 1, 1]),
kz.reshape([1, -1, 1]),
kx.reshape([1, 1, -1])) # yapf: disable
ky, kz, kx = get_kvec(kd, kd, kd) # The order of the output
# corresponds to the order of dimensions in the transposed FFT
# output
return kx, ky, kz
"""
kd = np.fft.fftfreq(nc) * 2 * np.pi
@partial(
shmap,
in_specs=(P("x"), P("y"), P(None)),
out_specs=(P("x"), P(None, "y"), P(None)),
)
def get_kvec(ky, kz, kx):
return (ky.reshape([-1, 1, 1]),
kz.reshape([1, -1, 1]),
kx.reshape([1, 1, -1])) # yapf: disable
ky, kz, kx = get_kvec(kd, kd, kd) # The order of the output
# corresponds to the order of dimensions in the transposed FFT
# output
return kx, ky, kz
def gravity_kernel(kx, ky, kz):
""" Computes a Fourier kernel combining laplace and derivative
"""Computes a Fourier kernel combining laplace and derivative
operators to compute gravitational forces.
Args:
......@@ -77,18 +81,18 @@ def gravity_kernel(kx, ky, kz):
Returns:
tuple of jnp.ndarray: kernels for each dimension.
"""
kk = kx**2 + ky**2 + kz**2
laplace_kernel = jnp.where(kk == 0, 1., 1. / kk)
"""
kk = kx**2 + ky**2 + kz**2
laplace_kernel = jnp.where(kk == 0, 1.0, 1.0 / kk)
grav_kernel = (laplace_kernel * 1j * kx,
laplace_kernel * 1j * ky,
laplace_kernel * 1j * kz) # yapf: disable
return grav_kernel
grav_kernel = (laplace_kernel * 1j * kx,
laplace_kernel * 1j * ky,
laplace_kernel * 1j * kz) # yapf: disable
return grav_kernel
def gaussian_field_and_forces(key, nc, box_size, power_spectrum):
"""
"""
Generate a Gaussian field with a given power spectrum, along with gravitational forces.
Args:
......@@ -99,40 +103,40 @@ def gaussian_field_and_forces(key, nc, box_size, power_spectrum):
Returns:
tuple of jnp.ndarray: The generated Gaussian field and the gravitational forces.
"""
local_mesh_shape = _global_to_local_size(nc)
"""
local_mesh_shape = _global_to_local_size(nc)
# Create a distributed field drawn from a Gaussian distribution in real space
delta = shmap(
partial(jax.random.normal, shape=local_mesh_shape, dtype='float32'),
in_specs=P(None),
out_specs=P('x', 'y'))(key) # yapf: disable
# Create a distributed field drawn from a Gaussian distribution in real space
delta = shmap(
partial(jax.random.normal, shape=local_mesh_shape, dtype='float32'),
in_specs=P(None),
out_specs=P('x', 'y'))(key) # yapf: disable
# Compute the Fourier transform of the field
delta_k = jaxdecomp.fft.pfft3d(delta.astype(jnp.complex64))
# Compute the Fourier transform of the field
delta_k = jaxdecomp.fft.pfft3d(delta.astype(jnp.complex64))
# Compute the Fourier wavenumbers of the field
kx, ky, kz = fttk(nc)
kk = jnp.sqrt(kx**2 + ky**2 + kz**2) * (nc / box_size)
# Compute the Fourier wavenumbers of the field
kx, ky, kz = fttk(nc)
kk = jnp.sqrt(kx**2 + ky**2 + kz**2) * (nc / box_size)
# Apply power spectrum to Fourier modes
delta_k *= (power_spectrum(kk) * (nc / box_size)**3)**0.5
# Apply power spectrum to Fourier modes
delta_k *= (power_spectrum(kk) * (nc / box_size) ** 3) ** 0.5
# Compute inverse Fourier transform to recover the initial conditions in real space
delta = jaxdecomp.fft.pifft3d(delta_k).real
# Compute inverse Fourier transform to recover the initial conditions in real space
delta = jaxdecomp.fft.pifft3d(delta_k).real
# Compute gravitational forces associated with this field
grav_kernel = gravity_kernel(kx, ky, kz)
forces_k = [g * delta_k for g in grav_kernel]
# Compute gravitational forces associated with this field
grav_kernel = gravity_kernel(kx, ky, kz)
forces_k = [g * delta_k for g in grav_kernel]
# Retrieve the forces in real space by inverse Fourier transforming
forces = jnp.stack([jaxdecomp.fft.pifft3d(f).real for f in forces_k], axis=-1)
# Retrieve the forces in real space by inverse Fourier transforming
forces = jnp.stack([jaxdecomp.fft.pifft3d(f).real for f in forces_k], axis=-1)
return delta, forces
return delta, forces
def cic_paint(displacement, halo_size):
""" Paints particles on a mesh using Cloud-In-Cell interpolation.
"""Paints particles on a mesh using Cloud-In-Cell interpolation.
Args:
displacement (jnp.ndarray): Displacement of each particle.
......@@ -140,60 +144,58 @@ def cic_paint(displacement, halo_size):
Returns:
jnp.ndarray: Density field.
"""
local_mesh_shape = _global_to_local_size(displacement.shape[0])
hs = halo_size
@partial(shmap, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))
def cic_op(disp):
""" CiC operation on each local slice of the mesh."""
# Create a mesh to paint the particles on for the local slice
mesh = jnp.zeros(disp.shape[:-1], dtype='float32')
# Padding the mesh along the two first dimensions
mesh = jnp.pad(mesh, [[hs, hs], [hs, hs], [0, 0]])
# Compute the position of the particles on a regular grid
pos_x, pos_y, pos_z = jnp.meshgrid(
jnp.arange(local_mesh_shape[0]),
jnp.arange(local_mesh_shape[1]),
jnp.arange(local_mesh_shape[2]),
indexing='ij')
# adding an offset of size halo size
pos = jnp.stack([pos_x + hs, pos_y + hs, pos_z], axis=-1)
# Apply scatter operation to paint the particles on the local mesh
field = scatter(pos.reshape([-1, 3]), disp.reshape([-1, 3]), mesh)
"""
local_mesh_shape = _global_to_local_size(displacement.shape[0])
hs = halo_size
@partial(shmap, in_specs=(P("x", "y"),), out_specs=P("x", "y"))
def cic_op(disp):
"""CiC operation on each local slice of the mesh."""
# Create a mesh to paint the particles on for the local slice
mesh = jnp.zeros(disp.shape[:-1], dtype="float32")
# Padding the mesh along the two first dimensions
mesh = jnp.pad(mesh, [[hs, hs], [hs, hs], [0, 0]])
# Compute the position of the particles on a regular grid
pos_x, pos_y, pos_z = jnp.meshgrid(
jnp.arange(local_mesh_shape[0]),
jnp.arange(local_mesh_shape[1]),
jnp.arange(local_mesh_shape[2]),
indexing="ij",
)
# adding an offset of size halo size
pos = jnp.stack([pos_x + hs, pos_y + hs, pos_z], axis=-1)
# Apply scatter operation to paint the particles on the local mesh
field = scatter(pos.reshape([-1, 3]), disp.reshape([-1, 3]), mesh)
return field
# Performs painting on a padded mesh, with halos on the two first dimensions
field = cic_op(displacement)
# Run halo exchange to get the correct values at the boundaries
field = jaxdecomp.halo_exchange(field, halo_extents=(hs // 2, hs // 2, 0), halo_periods=(True, True, True))
@partial(shmap, in_specs=(P("x", "y"),), out_specs=P("x", "y"))
def unpad(x):
"""Removes the padding and reduce the halo regions"""
x = x.at[hs : hs + hs // 2].add(x[: hs // 2])
x = x.at[-(hs + hs // 2) : -hs].add(x[-hs // 2 :])
x = x.at[:, hs : hs + hs // 2].add(x[:, : hs // 2])
x = x.at[:, -(hs + hs // 2) : -hs].add(x[:, -hs // 2 :])
return x[hs:-hs, hs:-hs, :]
# Unpad the output array
field = unpad(field)
return field
# Performs painting on a padded mesh, with halos on the two first dimensions
field = cic_op(displacement)
# Run halo exchange to get the correct values at the boundaries
field = jaxdecomp.halo_exchange(
field,
halo_extents=(hs // 2, hs // 2, 0),
halo_periods=(True, True, True))
@partial(shmap, in_specs=(P('x', 'y'),), out_specs=P('x', 'y'))
def unpad(x):
""" Removes the padding and reduce the halo regions"""
x = x.at[hs:hs + hs // 2].add(x[:hs // 2])
x = x.at[-(hs + hs // 2):-hs].add(x[-hs // 2:])
x = x.at[:, hs:hs + hs // 2].add(x[:, :hs // 2])
x = x.at[:, -(hs + hs // 2):-hs].add(x[:, -hs // 2:])
return x[hs:-hs, hs:-hs, :]
# Unpad the output array
field = unpad(field)
return field
@partial(jax.jit, static_argnames=('nc', 'box_size', 'halo_size'))
@partial(jax.jit, static_argnames=("nc", "box_size", "halo_size"))
def simulation_fn(key, nc, box_size, halo_size, a=1.0):
"""
"""
Run a simulation to generate initial conditions and density field using LPT.
Args:
......@@ -205,76 +207,67 @@ def simulation_fn(key, nc, box_size, halo_size, a=1.0):
Returns:
tuple of jnp.ndarray: Initial conditions and final density field.
"""
# Build a default cosmology
cosmo = jc.Planck15()
"""
# Build a default cosmology
cosmo = jc.Planck15()
# Create a small function to generate the linear matter power spectrum at arbitrary k
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(cosmo, k)
pk_fn = jax.jit(lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).
reshape(x.shape))
# Create a small function to generate the linear matter power spectrum at arbitrary k
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(cosmo, k)
pk_fn = jax.jit(lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).reshape(x.shape))
# Generate a Gaussian field and gravitational forces from a power spectrum
intial_conditions, initial_forces = gaussian_field_and_forces(
key=key, nc=nc, box_size=box_size, power_spectrum=pk_fn)
# Generate a Gaussian field and gravitational forces from a power spectrum
intial_conditions, initial_forces = gaussian_field_and_forces(key=key, nc=nc, box_size=box_size, power_spectrum=pk_fn)
# Compute the LPT displacement that particles initialy placed on a regular grid
# would experience at scale factor a, by simple Zeldovich approximation
initial_displacement = jc.background.growth_factor(
cosmo, jnp.atleast_1d(a)) * initial_forces
# Compute the LPT displacement that particles initialy placed on a regular grid
# would experience at scale factor a, by simple Zeldovich approximation
initial_displacement = jc.background.growth_factor(cosmo, jnp.atleast_1d(a)) * initial_forces
# Paints the displaced particles on a mesh to obtain the density field
final_field = cic_paint(initial_displacement, halo_size)
# Paints the displaced particles on a mesh to obtain the density field
final_field = cic_paint(initial_displacement, halo_size)
return intial_conditions, final_field
return intial_conditions, final_field
def main(args):
print(f"Running with arguments {args}")
# Setting up distributed jax
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()
# Setting up distributed random numbers
master_key = jax.random.PRNGKey(42)
key = jax.random.split(master_key, size)[rank]
# Create computing mesh and sharding information
pdims = tuple(map(int, args.pdims.split('x')))
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y'))
# Run the simulation on the compute mesh
with mesh:
initial_conds, final_field = simulation_fn(
key=key, nc=args.nc, box_size=args.box_size, halo_size=args.halo_size)
# Create output directory to save the results
output_dir = args.output
os.makedirs(output_dir, exist_ok=True)
np.save(f'{output_dir}/initial_conditions_{rank}.npy',
initial_conds.addressable_data(0))
np.save(f'{output_dir}/field_{rank}.npy', final_field.addressable_data(0))
print(f"Finished saved to {output_dir}")
# Closing distributed jax
jax.distributed.shutdown()
if __name__ == '__main__':
parser = argparse.ArgumentParser("Distributed LPT N-body simulation.")
parser.add_argument(
'--pdims', type=str, default='1x1', help="Processor grid dimensions")
parser.add_argument(
'--nc', type=int, default=256, help="Number of cells in the mesh")
parser.add_argument(
'--box_size', type=float, default=512., help="Box size in Mpc/h")
parser.add_argument(
'--halo_size', type=int, default=32, help="Halo size for painting")
parser.add_argument('--output', type=str, default='out')
args = parser.parse_args()
main(args)
print(f"Running with arguments {args}")
# Setting up distributed jax
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()
# Setting up distributed random numbers
master_key = jax.random.PRNGKey(42)
key = jax.random.split(master_key, size)[rank]
# Create computing mesh and sharding information
pdims = tuple(map(int, args.pdims.split("x")))
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=("x", "y"))
# Run the simulation on the compute mesh
with mesh:
initial_conds, final_field = simulation_fn(key=key, nc=args.nc, box_size=args.box_size, halo_size=args.halo_size)
# Create output directory to save the results
output_dir = args.output
os.makedirs(output_dir, exist_ok=True)
np.save(f"{output_dir}/initial_conditions_{rank}.npy", initial_conds.addressable_data(0))
np.save(f"{output_dir}/field_{rank}.npy", final_field.addressable_data(0))
print(f"Finished saved to {output_dir}")
# Closing distributed jax
jax.distributed.shutdown()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Distributed LPT N-body simulation.")
parser.add_argument("--pdims", type=str, default="1x1", help="Processor grid dimensions")
parser.add_argument("--nc", type=int, default=256, help="Number of cells in the mesh")
parser.add_argument("--box_size", type=float, default=512.0, help="Box size in Mpc/h")
parser.add_argument("--halo_size", type=int, default=32, help="Halo size for painting")
parser.add_argument("--output", type=str, default="out")
args = parser.parse_args()
main(args)
......@@ -38,121 +38,116 @@ from jax.lax import scan
def _chunk_split(ptcl_num, chunk_size, *arrays):
"""Split and reshape particle arrays into chunks and remainders, with the remainders
"""Split and reshape particle arrays into chunks and remainders, with the remainders
preceding the chunks. 0D ones are duplicated as full arrays in the chunks."""
chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num)
remainder_size = ptcl_num % chunk_size
chunk_num = ptcl_num // chunk_size
chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num)
remainder_size = ptcl_num % chunk_size
chunk_num = ptcl_num // chunk_size
remainder = None
chunks = arrays
if remainder_size:
remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays]
chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays]
remainder = None
chunks = arrays
if remainder_size:
remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays]
chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays]
# `scan` triggers errors in scatter and gather without the `full`
chunks = [
x.reshape(chunk_num, chunk_size, *x.shape[1:])
if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks
]
# `scan` triggers errors in scatter and gather without the `full`
chunks = [x.reshape(chunk_num, chunk_size, *x.shape[1:]) if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks]
return remainder, chunks
return remainder, chunks
def enmesh(i1, d1, a1, s1, b12, a2, s2):
"""Multilinear enmeshing."""
i1 = jnp.asarray(i1)
d1 = jnp.asarray(d1)
a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1, dtype=d1.dtype)
if s1 is not None:
s1 = jnp.array(s1, dtype=i1.dtype)
b12 = jnp.float64(b12)
if a2 is not None:
a2 = jnp.float64(a2)
if s2 is not None:
s2 = jnp.array(s2, dtype=i1.dtype)
dim = i1.shape[1]
neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] >> jnp.arange(
dim, dtype=i1.dtype)) & 1
if a2 is not None:
P = i1 * a1 + d1 - b12
P = P[:, jnp.newaxis] # insert neighbor axis
i2 = P + neighbors * a2 # multilinear
"""Multilinear enmeshing."""
i1 = jnp.asarray(i1)
d1 = jnp.asarray(d1)
a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1, dtype=d1.dtype)
if s1 is not None:
L = s1 * a1
i2 %= L
s1 = jnp.array(s1, dtype=i1.dtype)
b12 = jnp.float64(b12)
if a2 is not None:
a2 = jnp.float64(a2)
if s2 is not None:
s2 = jnp.array(s2, dtype=i1.dtype)
i2 //= a2
d2 = P - i2 * a2
dim = i1.shape[1]
neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] >> jnp.arange(dim, dtype=i1.dtype)) & 1
if s1 is not None:
d2 -= jnp.rint(d2 / L) * L # also abs(d2) < a2 is expected
if a2 is not None:
P = i1 * a1 + d1 - b12
P = P[:, jnp.newaxis] # insert neighbor axis
i2 = P + neighbors * a2 # multilinear
i2 = i2.astype(i1.dtype)
d2 = d2.astype(d1.dtype)
a2 = a2.astype(d1.dtype)
if s1 is not None:
L = s1 * a1
i2 %= L
d2 /= a2
else:
i12, d12 = jnp.divmod(b12, a1)
i1 -= i12.astype(i1.dtype)
d1 -= d12.astype(d1.dtype)
i2 //= a2
d2 = P - i2 * a2
# insert neighbor axis
i1 = i1[:, jnp.newaxis]
d1 = d1[:, jnp.newaxis]
if s1 is not None:
d2 -= jnp.rint(d2 / L) * L # also abs(d2) < a2 is expected
# multilinear
d1 /= a1
i2 = jnp.floor(d1).astype(i1.dtype)
i2 += neighbors
d2 = d1 - i2
i2 += i1
i2 = i2.astype(i1.dtype)
d2 = d2.astype(d1.dtype)
a2 = a2.astype(d1.dtype)
if s1 is not None:
i2 %= s1
d2 /= a2
else:
i12, d12 = jnp.divmod(b12, a1)
i1 -= i12.astype(i1.dtype)
d1 -= d12.astype(d1.dtype)
# insert neighbor axis
i1 = i1[:, jnp.newaxis]
d1 = d1[:, jnp.newaxis]
# multilinear
d1 /= a1
i2 = jnp.floor(d1).astype(i1.dtype)
i2 += neighbors
d2 = d1 - i2
i2 += i1
if s1 is not None:
i2 %= s1
f2 = 1 - jnp.abs(d2)
f2 = 1 - jnp.abs(d2)
if s1 is None and s2 is not None: # all i2 >= 0 if s1 is not None
i2 = jnp.where(i2 < 0, s2, i2)
if s1 is None and s2 is not None: # all i2 >= 0 if s1 is not None
i2 = jnp.where(i2 < 0, s2, i2)
f2 = f2.prod(axis=-1)
f2 = f2.prod(axis=-1)
return i2, f2
return i2, f2
def scatter(pmid, disp, mesh, chunk_size=2**24, val=1., offset=0, cell_size=1.):
ptcl_num, spatial_ndim = pmid.shape
val = jnp.asarray(val)
mesh = jnp.asarray(mesh)
def scatter(pmid, disp, mesh, chunk_size=2**24, val=1.0, offset=0, cell_size=1.0):
ptcl_num, spatial_ndim = pmid.shape
val = jnp.asarray(val)
mesh = jnp.asarray(mesh)
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
carry = mesh, offset, cell_size
if remainder is not None:
carry = _scatter_chunk(carry, remainder)[0]
carry = scan(_scatter_chunk, carry, chunks)[0]
mesh = carry[0]
carry = mesh, offset, cell_size
if remainder is not None:
carry = _scatter_chunk(carry, remainder)[0]
carry = scan(_scatter_chunk, carry, chunks)[0]
mesh = carry[0]
return mesh
return mesh
def _scatter_chunk(carry, chunk):
mesh, offset, cell_size = carry
pmid, disp, val = chunk
spatial_ndim = pmid.shape[1]
spatial_shape = mesh.shape
# multilinear mesh indices and fractions
ind, frac = enmesh(pmid, disp, cell_size, spatial_shape, offset, cell_size,
spatial_shape)
# scatter
ind = tuple(ind[..., i] for i in range(spatial_ndim))
mesh = mesh.at[ind].add(val * frac)
carry = mesh, offset, cell_size
return carry, None
mesh, offset, cell_size = carry
pmid, disp, val = chunk
spatial_ndim = pmid.shape[1]
spatial_shape = mesh.shape
# multilinear mesh indices and fractions
ind, frac = enmesh(pmid, disp, cell_size, spatial_shape, offset, cell_size, spatial_shape)
# scatter
ind = tuple(ind[..., i] for i in range(spatial_ndim))
mesh = mesh.at[ind].add(val * frac)
carry = mesh, offset, cell_size
return carry, None
......@@ -70,8 +70,8 @@
}
],
"source": [
"folder = '../out'\n",
"pdims=(4,4)\n",
"folder = \"../out\"\n",
"pdims = (4, 4)\n",
"\n",
"init_field_slices = []\n",
"field_slices = []\n",
......@@ -79,11 +79,11 @@
"for i in range(pdims[0]):\n",
" row_init_field = []\n",
" row_field = []\n",
" \n",
"\n",
" for j in range(pdims[1]):\n",
" slice_index = i * pdims[1] + j \n",
" row_field.append(np.load(f'{folder}/field_{slice_index}.npy'))\n",
" row_init_field.append(np.load(f'{folder}/initial_conditions_{slice_index}.npy'))\n",
" slice_index = i * pdims[1] + j\n",
" row_field.append(np.load(f\"{folder}/field_{slice_index}.npy\"))\n",
" row_init_field.append(np.load(f\"{folder}/initial_conditions_{slice_index}.npy\"))\n",
"\n",
" field_slices.append(np.vstack(row_field))\n",
" init_field_slices.append(np.vstack(row_init_field))\n",
......@@ -148,23 +148,31 @@
" slicing[proj_axis] = slice(None, sum_over)\n",
" slicing = tuple(slicing)\n",
"\n",
"\n",
" fig, axes = plt.subplots(1, 2, figsize=(12, 6))\n",
"\n",
" # Flatten axes for easy indexing\n",
" axes = axes.flatten()\n",
"\n",
" # Plot initial conditions\n",
" axes[0].imshow(initial_conditions[slicing].sum(axis=proj_axis), cmap='magma', extent=[0, box_size + 5, 0, box_size + 5])\n",
" axes[0].set_xlabel('Mpc/h')\n",
" axes[0].set_ylabel('Mpc/h')\n",
" axes[0].set_title('Initial conditions')\n",
" axes[0].imshow(\n",
" initial_conditions[slicing].sum(axis=proj_axis),\n",
" cmap=\"magma\",\n",
" extent=[0, box_size + 5, 0, box_size + 5],\n",
" )\n",
" axes[0].set_xlabel(\"Mpc/h\")\n",
" axes[0].set_ylabel(\"Mpc/h\")\n",
" axes[0].set_title(\"Initial conditions\")\n",
"\n",
" # Plot LPT density field at z=0\n",
" axes[1].imshow(field[slicing].sum(axis=proj_axis), cmap='magma', extent=[0, box_size + 5, 0, box_size + 5])\n",
" axes[1].set_xlabel('Mpc/h')\n",
" axes[1].set_ylabel('Mpc/h')\n",
" axes[1].set_title('LPT density field at z=0')\n",
" axes[1].imshow(\n",
" field[slicing].sum(axis=proj_axis),\n",
" cmap=\"magma\",\n",
" extent=[0, box_size + 5, 0, box_size + 5],\n",
" )\n",
" axes[1].set_xlabel(\"Mpc/h\")\n",
" axes[1].set_ylabel(\"Mpc/h\")\n",
" axes[1].set_title(\"LPT density field at z=0\")\n",
"\n",
"\n",
"for i in range(3):\n",
" plot(i)"
......@@ -201,10 +209,14 @@
"box_size = initial_conditions.shape[0] + 5\n",
"plt.figure(figsize=(20, 20))\n",
"# Generate the plot\n",
"plt.imshow(np.log10(field[:16].sum(axis=proj_axis) + 1), cmap='magma', extent=[0, box_size, 0, box_size])\n",
"plt.xlabel('Mpc/h')\n",
"plt.ylabel('Mpc/h')\n",
"plt.title('LPT density field at z=0')\n",
"plt.imshow(\n",
" np.log10(field[:16].sum(axis=proj_axis) + 1),\n",
" cmap=\"magma\",\n",
" extent=[0, box_size, 0, box_size],\n",
")\n",
"plt.xlabel(\"Mpc/h\")\n",
"plt.ylabel(\"Mpc/h\")\n",
"plt.title(\"LPT density field at z=0\")\n",
"\n",
"# Display the plot\n",
"plt.show()"
......@@ -3,7 +3,7 @@ requires = ["scikit-build-core>=0.4.0", "pybind11>=2.9.0"]
build-backend = "scikit_build_core.build"
[project]
name = "jaxdecomp"
version = "0.2.0"
version = "0.2.4"
description = "JAX bindings for the cuDecomp library"
authors = [
{ name = "Wassim Kabalan" },
......@@ -23,24 +23,49 @@ classifiers = [
"Intended Audience :: Science/Research",
]
dependencies = [
"jaxtyping>=0.2.33",
"jax>=0.4.30",
"jaxtyping>=0.2.0",
"jax>=0.4.35",
"typing-extensions; python_version < '3.11'",
]
[project.optional-dependencies]
test = ["pytest>=8.0.0" , "jax[cpu]>=0.4.30"]
test = ["pytest>=8.0.0"]
[tool.scikit-build]
minimum-version = "0.8"
cmake.version = ">=3.25"
build-dir = "build/{wheel_tag}"
wheel.py-api = "py3"
cmake.build-type = "Release"
wheel.install-dir = "jaxdecomp/_src"
[tool.scikit-build.cmake.define]
CMAKE_LIBRARY_OUTPUT_DIRECTORY = ""
#[tool.cibuildwheel]
#test-extras = "test"
#test-command = "pytest {project}/tests"
[tool.cibuildwheel]
test-extras = "test"
test-command = "pytest {project}/tests"
[tool.ruff]
line-length = 150
src = ["src"]
exclude = ["third_party"]
[tool.ruff.lint]
select = [
# pycodestyle
'E',
# pyflakes
'F',
# pyupgrade
'UP',
# flake8-debugger
'T10',
]
ignore = [
'E402', # module level import not at top of file
'E203',
'E731',
'E701',
'E741',
'E722',
'UP037', # conflicts with jaxtyping Array annotations
]
......@@ -17,11 +17,10 @@ config.halo_comm_backend = jaxdecomp.config.halo_comm_backend
config.transpose_comm_backend = jaxdecomp.config.transpose_comm_backend
# Run autotune
tuned_config = jaxdecomp.get_autotuned_config(config, False, False, True, True,
(32, 32, 32), (True, True, True))
tuned_config = jaxdecomp.get_autotuned_config(config, False, False, True, True, (32, 32, 32), (True, True, True))
if rank == 0:
print(rank, "*** Results of optimization ***")
print(rank, "pdims", tuned_config.pdims)
print(rank, "halo_comm_backend", tuned_config.halo_comm_backend)
print(rank, "transpose_comm_backend", tuned_config.transpose_comm_backend)
print(rank, "*** Results of optimization ***")
print(rank, "pdims", tuned_config.pdims)
print(rank, "halo_comm_backend", tuned_config.halo_comm_backend)
print(rank, "transpose_comm_backend", tuned_config.transpose_comm_backend)
import time
import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
......@@ -13,45 +12,42 @@ jaxdecomp.init()
jax.distributed.initialize()
rank = jax.process_index()
#jaxdecomp.config.update('transpose_comm_backend', jaxdecomp.TRANSPOSE_COMM_MPI_P2P_PL)
# jaxdecomp.config.update('transpose_comm_backend', jaxdecomp.TRANSPOSE_COMM_MPI_P2P_PL)
pdims = (2, 2)
global_shape = (1024, 1024, 1024)
# Initialize a local slice of the global array
array = jax.random.normal(
shape=[
global_shape[0] // pdims[1], global_shape[1] // pdims[0],
global_shape[2]
],
key=jax.random.PRNGKey(0))
shape=[global_shape[0] // pdims[1], global_shape[1] // pdims[0], global_shape[2]],
key=jax.random.PRNGKey(0),
)
# Remap to the global array from the local slice
devices = mesh_utils.create_device_mesh(pdims[::-1])
mesh = Mesh(devices, axis_names=('z', 'y'))
global_array = multihost_utils.host_local_array_to_global_array(
array, mesh, P('z', 'y'))
mesh = Mesh(devices, axis_names=("z", "y"))
global_array = multihost_utils.host_local_array_to_global_array(array, mesh, P("z", "y"))
@jax.jit
def do_fft(x):
return jaxdecomp.fft.pfft3d(x)
return jaxdecomp.fft.pfft3d(x)
with mesh:
do_fft(global_array)
before = time.time()
karray = do_fft(global_array).block_until_ready()
after = time.time()
print(rank, 'took', after - before, 's')
do_fft(global_array)
before = time.time()
karray = do_fft(global_array).block_until_ready()
after = time.time()
print(rank, "took", after - before, "s")
# And now, let's do the inverse FFT
rec_array = jaxdecomp.fft.pifft3d(karray)
# And now, let's do the inverse FFT
rec_array = jaxdecomp.fft.pifft3d(karray)
diff = jax.jit(lambda x, y: abs(x - y).max())(rec_array, global_array)
diff = jax.jit(lambda x, y: abs(x - y).max())(rec_array, global_array)
# Let's test if things are like we expect
if rank == 0:
print('maximum reconstruction difference', diff)
print("maximum reconstruction difference", diff)
jaxdecomp.finalize()
jax.distributed.shutdown()
from dataclasses import dataclass
from importlib.metadata import PackageNotFoundError, version
from typing import Tuple
from jaxdecomp._src.pencil_utils import get_output_specs
from jaxdecomp.fft import fftfreq3d, pfft3d, pifft3d, rfftfreq3d
from jaxdecomp.halo import halo_exchange
from jaxdecomp.transpose import (transposeXtoY, transposeXtoZ, transposeYtoX,
transposeYtoZ, transposeZtoX, transposeZtoY)
from jaxdecomp.transpose import (
transposeXtoY,
transposeXtoZ,
transposeYtoX,
transposeYtoZ,
transposeZtoX,
transposeZtoY,
)
from ._src import (HALO_COMM_MPI, HALO_COMM_MPI_BLOCKING, HALO_COMM_NCCL,
HALO_COMM_NVSHMEM, HALO_COMM_NVSHMEM_BLOCKING, NO_DECOMP,
PENCILS, SLAB_XY, SLAB_YZ, TRANSPOSE_COMM_MPI_A2A,
TRANSPOSE_COMM_MPI_P2P, TRANSPOSE_COMM_MPI_P2P_PL,
TRANSPOSE_COMM_NCCL, TRANSPOSE_COMM_NCCL_PL,
TRANSPOSE_COMM_NVSHMEM, TRANSPOSE_COMM_NVSHMEM_PL,
TRANSPOSE_XY, TRANSPOSE_YX, TRANSPOSE_YZ, TRANSPOSE_ZY,
HaloCommBackend, TransposeCommBackend, finalize,
get_autotuned_config, get_pencil_info, init, make_config)
from ._src import (
HALO_COMM_MPI,
HALO_COMM_MPI_BLOCKING,
HALO_COMM_NCCL,
HALO_COMM_NVSHMEM,
HALO_COMM_NVSHMEM_BLOCKING,
NO_DECOMP,
PENCILS,
SLAB_XY,
SLAB_YZ,
TRANSPOSE_COMM_MPI_A2A,
TRANSPOSE_COMM_MPI_P2P,
TRANSPOSE_COMM_MPI_P2P_PL,
TRANSPOSE_COMM_NCCL,
TRANSPOSE_COMM_NCCL_PL,
TRANSPOSE_COMM_NVSHMEM,
TRANSPOSE_COMM_NVSHMEM_PL,
TRANSPOSE_XY,
TRANSPOSE_YX,
TRANSPOSE_YZ,
TRANSPOSE_ZY,
HaloCommBackend,
TransposeCommBackend,
finalize,
get_autotuned_config,
get_pencil_info,
init,
make_config,
)
try:
__version__ = version("jaxDecomp")
__version__ = version("jaxDecomp")
except PackageNotFoundError:
# package is not installed
pass
# package is not installed
pass
__all__ = [
"config",
......@@ -51,22 +76,33 @@ __all__ = [
"get_output_specs",
"fftfreq3d",
"rfftfreq3d",
"HALO_COMM_MPI",
"HALO_COMM_MPI_BLOCKING",
"HALO_COMM_NVSHMEM",
"HALO_COMM_NVSHMEM_BLOCKING",
"TRANSPOSE_COMM_MPI_A2A",
"TRANSPOSE_COMM_MPI_P2P",
"TRANSPOSE_COMM_MPI_P2P_PL",
"TRANSPOSE_COMM_NCCL_PL",
"TRANSPOSE_COMM_NVSHMEM",
"TRANSPOSE_COMM_NVSHMEM_PL",
]
@dataclass
class JAXDecompConfig:
"""Class for storing the configuration state of the library."""
halo_comm_backend: HaloCommBackend = HALO_COMM_NCCL
transpose_comm_backend: TransposeCommBackend = TRANSPOSE_COMM_NCCL
transpose_axis_contiguous: bool = True
transpose_axis_contiguous_2: bool = True
"""Class for storing the configuration state of the library."""
def update(self, key, value):
if hasattr(self, key):
setattr(self, key, value)
else:
raise ValueError("key %s is not a valid configuration key" % key)
halo_comm_backend: HaloCommBackend = HALO_COMM_NCCL
transpose_comm_backend: TransposeCommBackend = TRANSPOSE_COMM_NCCL
transpose_axis_contiguous: bool = True
transpose_axis_contiguous_2: bool = True
def update(self, key, value):
if hasattr(self, key):
setattr(self, key, value)
else:
raise ValueError(f"key {key} is not a valid configuration key")
# Declare the global configuration object
......
from jax.lib import xla_client
from . import _jaxdecomp
from jaxdecomplib import _jaxdecomp
init = _jaxdecomp.init
finalize = _jaxdecomp.finalize
......@@ -9,19 +8,34 @@ get_autotuned_config = _jaxdecomp.get_autotuned_config
make_config = _jaxdecomp.GridConfig
# Loading the comm configuration flags at the top level
from ._jaxdecomp import TransposeCommBackend # yapf: disable
from ._jaxdecomp import (HALO_COMM_MPI, HALO_COMM_MPI_BLOCKING, HALO_COMM_NCCL,
HALO_COMM_NVSHMEM, HALO_COMM_NVSHMEM_BLOCKING,
NO_DECOMP, PENCILS, SLAB_XY, SLAB_YZ,
TRANSPOSE_COMM_MPI_A2A, TRANSPOSE_COMM_MPI_P2P,
TRANSPOSE_COMM_MPI_P2P_PL, TRANSPOSE_COMM_NCCL,
TRANSPOSE_COMM_NCCL_PL, TRANSPOSE_COMM_NVSHMEM,
TRANSPOSE_COMM_NVSHMEM_PL, TRANSPOSE_XY, TRANSPOSE_YX,
TRANSPOSE_YZ, TRANSPOSE_ZY, HaloCommBackend)
from jaxdecomplib._jaxdecomp import HaloCommBackend # yapf: disable
from jaxdecomplib._jaxdecomp import TransposeCommBackend # yapf: disable
from jaxdecomplib._jaxdecomp import ( # dummy line to avoid yapf reformatting
HALO_COMM_MPI,
HALO_COMM_MPI_BLOCKING,
HALO_COMM_NCCL,
HALO_COMM_NVSHMEM,
HALO_COMM_NVSHMEM_BLOCKING,
NO_DECOMP,
PENCILS,
SLAB_XY,
SLAB_YZ,
TRANSPOSE_COMM_MPI_A2A,
TRANSPOSE_COMM_MPI_P2P,
TRANSPOSE_COMM_MPI_P2P_PL,
TRANSPOSE_COMM_NCCL,
TRANSPOSE_COMM_NCCL_PL,
TRANSPOSE_COMM_NVSHMEM,
TRANSPOSE_COMM_NVSHMEM_PL,
TRANSPOSE_XY,
TRANSPOSE_YX,
TRANSPOSE_YZ,
TRANSPOSE_ZY,
)
# Registering ops for XLA
for name, fn in _jaxdecomp.registrations().items():
xla_client.register_custom_call_target(name, fn, platform="gpu")
xla_client.register_custom_call_target(name, fn, platform="gpu")
__all__ = [
"init",
......@@ -37,4 +51,18 @@ __all__ = [
"SLAB_YZ",
"PENCILS",
"NO_DECOMP",
"HALO_COMM_MPI",
"HALO_COMM_MPI_BLOCKING",
"HALO_COMM_NCCL",
"HALO_COMM_NVSHMEM",
"HALO_COMM_NVSHMEM_BLOCKING",
"TRANSPOSE_COMM_MPI_A2A",
"TRANSPOSE_COMM_MPI_P2P",
"TRANSPOSE_COMM_MPI_P2P_PL",
"TRANSPOSE_COMM_NCCL",
"TRANSPOSE_COMM_NCCL_PL",
"TRANSPOSE_COMM_NVSHMEM",
"TRANSPOSE_COMM_NVSHMEM_PL",
"HaloCommBackend",
"TransposeCommBackend",
]
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from math import prod
from typing import Tuple, TypeAlias
from typing import TypeAlias
from jax import lax
from jax import numpy as jnp
from jax.lib import xla_client
from jaxtyping import Array
FftType: TypeAlias = xla_client.FftType
FftType: TypeAlias = lax.FftType
FORWARD_FFTs = {FftType.FFT, FftType.RFFT}
INVERSE_FFTs = {FftType.IFFT, FftType.IRFFT}
def ADJOINT(fft_type: FftType) -> FftType:
"""Returns the adjoint (inverse) of the given FFT type.
"""Returns the adjoint (inverse) of the given FFT type.
Args:
fft_type: The type of FFT (FftType).
......@@ -23,21 +23,21 @@ def ADJOINT(fft_type: FftType) -> FftType:
Raises:
ValueError: If an unknown FFT type is provided.
"""
match fft_type:
case FftType.FFT:
return FftType.IFFT
case FftType.IFFT:
return FftType.FFT
case FftType.RFFT:
return FftType.IRFFT
case FftType.IRFFT:
return FftType.RFFT
case _:
raise ValueError(f"Unknown FFT type '{fft_type}'")
match fft_type:
case FftType.FFT:
return FftType.IFFT
case FftType.IFFT:
return FftType.FFT
case FftType.RFFT:
return FftType.IRFFT
case FftType.IRFFT:
return FftType.RFFT
case _:
raise ValueError(f"Unknown FFT type '{fft_type}'")
def COMPLEX(fft_type: FftType) -> FftType:
"""Returns the complex equivalent of the given FFT type.
"""Returns the complex equivalent of the given FFT type.
Args:
fft_type: The type of FFT (FftType).
......@@ -48,17 +48,17 @@ def COMPLEX(fft_type: FftType) -> FftType:
Raises:
ValueError: If an unknown FFT type is provided.
"""
match fft_type:
case FftType.RFFT | FftType.FFT:
return FftType.FFT
case FftType.IRFFT | FftType.IFFT:
return FftType.IFFT
case _:
raise ValueError(f"Unknown FFT type '{fft_type}'")
match fft_type:
case FftType.RFFT | FftType.FFT:
return FftType.FFT
case FftType.IRFFT | FftType.IFFT:
return FftType.IFFT
case _:
raise ValueError(f"Unknown FFT type '{fft_type}'")
def _un_normalize_fft(s: Tuple[int, ...], fft_type: FftType) -> Array:
"""Computes the un-normalization factor for the FFT.
def _un_normalize_fft(s: tuple[int, ...], fft_type: FftType) -> Array:
"""Computes the un-normalization factor for the FFT.
Args:
s: Shape of the array (Tuple[int, ...]).
......@@ -67,14 +67,14 @@ def _un_normalize_fft(s: Tuple[int, ...], fft_type: FftType) -> Array:
Returns:
The un-normalization factor (Array).
"""
if fft_type in FORWARD_FFTs:
return jnp.array(1)
else:
return jnp.array(prod(s))
if fft_type in FORWARD_FFTs:
return jnp.array(1)
else:
return jnp.array(prod(s))
def fftn(a: Array, fft_type: FftType, adjoint: bool) -> Array:
"""Performs an n-dimensional FFT on the input array.
"""Performs an n-dimensional FFT on the input array.
Args:
a: Input array (Array).
......@@ -87,33 +87,33 @@ def fftn(a: Array, fft_type: FftType, adjoint: bool) -> Array:
Raises:
ValueError: If an unknown FFT type is provided.
"""
if fft_type in FORWARD_FFTs:
axes = tuple(range(0, 3))
else:
axes = tuple(range(2, -1, -1))
if fft_type in FORWARD_FFTs:
axes = tuple(range(0, 3))
else:
axes = tuple(range(2, -1, -1))
if adjoint:
fft_type = ADJOINT(fft_type)
if adjoint:
fft_type = ADJOINT(fft_type)
if fft_type == FftType.FFT:
a = jnp.fft.fftn(a, axes=axes)
elif fft_type == FftType.IFFT:
a = jnp.fft.ifftn(a, axes=axes)
elif fft_type == FftType.RFFT:
a = jnp.fft.rfftn(a, axes=axes)
elif fft_type == FftType.IRFFT:
a = jnp.fft.irfftn(a, axes=axes)
else:
raise ValueError(f"Unknown FFT type '{fft_type}'")
if fft_type == FftType.FFT:
a = jnp.fft.fftn(a, axes=axes)
elif fft_type == FftType.IFFT:
a = jnp.fft.ifftn(a, axes=axes)
elif fft_type == FftType.RFFT:
a = jnp.fft.rfftn(a, axes=axes)
elif fft_type == FftType.IRFFT:
a = jnp.fft.irfftn(a, axes=axes)
else:
raise ValueError(f"Unknown FFT type '{fft_type}'")
s = a.shape
a *= _un_normalize_fft(s, fft_type)
s = a.shape
a *= _un_normalize_fft(s, fft_type)
return a
return a
def fft(a: Array, fft_type: FftType, axis: int, adjoint: bool) -> Array:
"""Performs a 1-dimensional FFT along the specified axis of the input array.
"""Performs a 1-dimensional FFT along the specified axis of the input array.
Args:
a: Input array (Array).
......@@ -127,29 +127,28 @@ def fft(a: Array, fft_type: FftType, axis: int, adjoint: bool) -> Array:
Raises:
ValueError: If an unknown FFT type is provided.
"""
if adjoint:
fft_type = ADJOINT(fft_type)
if adjoint:
fft_type = ADJOINT(fft_type)
if fft_type == FftType.FFT:
a = jnp.fft.fft(a, axis=axis)
elif fft_type == FftType.IFFT:
a = jnp.fft.ifft(a, axis=axis)
elif fft_type == FftType.RFFT:
a = jnp.fft.rfft(a, axis=axis)
elif fft_type == FftType.IRFFT:
a = jnp.fft.irfft(a, axis=axis)
else:
raise ValueError(f"Unknown FFT type '{fft_type}'")
if fft_type == FftType.FFT:
a = jnp.fft.fft(a, axis=axis)
elif fft_type == FftType.IFFT:
a = jnp.fft.ifft(a, axis=axis)
elif fft_type == FftType.RFFT:
a = jnp.fft.rfft(a, axis=axis)
elif fft_type == FftType.IRFFT:
a = jnp.fft.irfft(a, axis=axis)
else:
raise ValueError(f"Unknown FFT type '{fft_type}'")
s = (a.shape[axis],)
a *= _un_normalize_fft(s, fft_type)
s = (a.shape[axis],)
a *= _un_normalize_fft(s, fft_type)
return a
return a
def fft2(a: Array, fft_type: FftType, axes: Tuple[int, int],
adjoint: bool) -> Array:
"""Performs a 2-dimensional FFT along the specified axes of the input array.
def fft2(a: Array, fft_type: FftType, axes: tuple[int, int], adjoint: bool) -> Array:
"""Performs a 2-dimensional FFT along the specified axes of the input array.
Args:
a: Input array (Array).
......@@ -163,21 +162,21 @@ def fft2(a: Array, fft_type: FftType, axes: Tuple[int, int],
Raises:
ValueError: If an unknown FFT type is provided.
"""
if adjoint:
fft_type = ADJOINT(fft_type)
if fft_type == FftType.FFT:
a = jnp.fft.fft2(a, axes=axes)
elif fft_type == FftType.IFFT:
a = jnp.fft.ifft2(a, axes=axes)
elif fft_type == FftType.RFFT:
a = jnp.fft.rfft2(a, axes=axes)
elif fft_type == FftType.IRFFT:
a = jnp.fft.irfft2(a, axes=axes)
else:
raise ValueError(f"Unknown FFT type '{fft_type}'")
s = tuple(a.shape[i] for i in axes)
a *= _un_normalize_fft(s, fft_type)
return a
if adjoint:
fft_type = ADJOINT(fft_type)
if fft_type == FftType.FFT:
a = jnp.fft.fft2(a, axes=axes)
elif fft_type == FftType.IFFT:
a = jnp.fft.ifft2(a, axes=axes)
elif fft_type == FftType.RFFT:
a = jnp.fft.rfft2(a, axes=axes)
elif fft_type == FftType.IRFFT:
a = jnp.fft.irfft2(a, axes=axes)
else:
raise ValueError(f"Unknown FFT type '{fft_type}'")
s = tuple(a.shape[i] for i in axes)
a *= _un_normalize_fft(s, fft_type)
return a
This diff is collapsed.
from functools import partial
from typing import Tuple
import jax
from jax import lax
from jax import numpy as jnp
from jax.experimental.shard_alike import shard_alike
from jax.lib import xla_client
from jaxtyping import Array
FftType = xla_client.FftType
FftType = lax.FftType
@partial(jax.jit, static_argnums=(1,))
def fftfreq3d(k_array: Array, d: float = 1.0) -> Tuple[Array, Array, Array]:
"""
def fftfreq3d(k_array: Array, d: float = 1.0) -> tuple[Array, Array, Array]:
"""
Computes the 3D FFT frequencies for a given array, assuming a Z-pencil configuration.
Parameters
......@@ -27,32 +26,37 @@ def fftfreq3d(k_array: Array, d: float = 1.0) -> Tuple[Array, Array, Array]:
Tuple[Array, Array, Array]
The frequencies corresponding to the Z, Y, and X axes, respectively.
"""
if jnp.iscomplexobj(k_array):
dtype = jnp.float32 if k_array.dtype == jnp.complex64 else jnp.float64
else:
dtype = k_array.dtype
if jnp.iscomplexobj(k_array):
dtype = jnp.float32 if k_array.dtype == jnp.complex64 else jnp.float64
else:
dtype = k_array.dtype
# Compute the FFT frequencies for each axis
ky = jnp.fft.fftfreq(k_array.shape[0], d=d, dtype=dtype) * 2 * jnp.pi
kx = jnp.fft.fftfreq(k_array.shape[1], d=d, dtype=dtype) * 2 * jnp.pi
kz = jnp.fft.fftfreq(k_array.shape[2], d=d, dtype=dtype) * 2 * jnp.pi
# Compute the FFT frequencies for each axis
ky = jnp.fft.fftfreq(k_array.shape[0], d=d, dtype=dtype) * 2 * jnp.pi
kx = jnp.fft.fftfreq(k_array.shape[1], d=d, dtype=dtype) * 2 * jnp.pi
kz = jnp.fft.fftfreq(k_array.shape[2], d=d, dtype=dtype) * 2 * jnp.pi
k_array_structure = jax.tree.structure(k_array)
kx = jax.tree.unflatten(k_array_structure, (kx,))
ky = jax.tree.unflatten(k_array_structure, (ky,))
kz = jax.tree.unflatten(k_array_structure, (kz,))
# Ensure frequencies are sharded similarly to the input array
ky, _ = shard_alike(ky, k_array[:, 0, 0])
kx, _ = shard_alike(kx, k_array[0, :, 0])
kz, _ = shard_alike(kz, k_array[0, 0, :])
# Ensure frequencies are sharded similarly to the input array
ky, _ = shard_alike(ky, k_array[:, 0, 0])
kx, _ = shard_alike(kx, k_array[0, :, 0])
kz, _ = shard_alike(kz, k_array[0, 0, :])
# Reshape the frequencies to match the input array's dimensionality
ky = ky.reshape([-1, 1, 1])
kx = kx.reshape([1, -1, 1])
kz = kz.reshape([1, 1, -1])
# Reshape the frequencies to match the input array's dimensionality
ky = ky.reshape([-1, 1, 1])
kx = kx.reshape([1, -1, 1])
kz = kz.reshape([1, 1, -1])
return kz, ky, kx
return kz, ky, kx
@partial(jax.jit, static_argnums=(1,))
def rfftfreq3d(k_array: Array, d: float = 1.0) -> Tuple[Array, Array, Array]:
"""
def rfftfreq3d(k_array: Array, d: float = 1.0) -> tuple[Array, Array, Array]:
"""
Computes the 3D FFT frequencies for a real input array, assuming a Z-pencil configuration.
The FFT is computed for the real input on the X axis using rfft.
......@@ -68,24 +72,29 @@ def rfftfreq3d(k_array: Array, d: float = 1.0) -> Tuple[Array, Array, Array]:
Tuple[Array, Array, Array]
The frequencies corresponding to the Z, Y, and X axes, respectively.
"""
if jnp.iscomplexobj(k_array):
dtype = jnp.float32 if k_array.dtype == jnp.complex64 else jnp.float64
else:
dtype = k_array.dtype
# Compute the FFT frequencies for each axis
ky = jnp.fft.fftfreq(k_array.shape[0], d=d, dtype=dtype) * 2 * jnp.pi
kx = jnp.fft.rfftfreq(k_array.shape[1], d=d, dtype=dtype) * 2 * jnp.pi
kz = jnp.fft.fftfreq(k_array.shape[2], d=d, dtype=dtype) * 2 * jnp.pi
# Ensure frequencies are sharded similarly to the input array
ky, _ = shard_alike(ky, k_array[:, 0, 0])
kx, _ = shard_alike(kx, k_array[0, :, 0])
kz, _ = shard_alike(kz, k_array[0, 0, :])
# Reshape the frequencies to match the input array's dimensionality
ky = ky.reshape([-1, 1, 1])
kx = kx.reshape([1, -1, 1])
kz = kz.reshape([1, 1, -1])
return kz, ky, kx
if jnp.iscomplexobj(k_array):
dtype = jnp.float32 if k_array.dtype == jnp.complex64 else jnp.float64
else:
dtype = k_array.dtype
# Compute the FFT frequencies for each axis
ky = jnp.fft.fftfreq(k_array.shape[0], d=d, dtype=dtype) * 2 * jnp.pi
kx = jnp.fft.rfftfreq(k_array.shape[1], d=d, dtype=dtype) * 2 * jnp.pi
kz = jnp.fft.fftfreq(k_array.shape[2], d=d, dtype=dtype) * 2 * jnp.pi
k_array_structure = jax.tree.structure(k_array)
kx = jax.tree.unflatten(k_array_structure, (kx,))
ky = jax.tree.unflatten(k_array_structure, (ky,))
kz = jax.tree.unflatten(k_array_structure, (kz,))
# Ensure frequencies are sharded similarly to the input array
ky, _ = shard_alike(ky, k_array[:, 0, 0])
kx, _ = shard_alike(kx, k_array[0, :, 0])
kz, _ = shard_alike(kz, k_array[0, 0, :])
# Reshape the frequencies to match the input array's dimensionality
ky = ky.reshape([-1, 1, 1])
kx = kx.reshape([1, -1, 1])
kz = kz.reshape([1, 1, -1])
return kz, ky, kx
This diff is collapsed.