Merge pull request #30 from DifferentiableUniverseInitiative/implement-jax-versions
Implement jax versions
No related branches found
No related tags found
Showing
- .github/workflows/tests.yml 39 additions, 0 deletions.github/workflows/tests.yml
- .gitmodules 0 additions, 3 deletions.gitmodules
- CMakeLists.txt 74 additions, 62 deletionsCMakeLists.txt
- include/fft.h 18 additions, 26 deletionsinclude/fft.h
- include/jaxdecomp.h 50 additions, 0 deletionsinclude/jaxdecomp.h
- include/transpose.h 19 additions, 7 deletionsinclude/transpose.h
- jaxdecomp/__init__.py 24 additions, 13 deletionsjaxdecomp/__init__.py
- jaxdecomp/_src/__init__.py 18 additions, 4 deletionsjaxdecomp/_src/__init__.py
- jaxdecomp/_src/cudecomp/__init__.py 0 additions, 0 deletionsjaxdecomp/_src/cudecomp/__init__.py
- jaxdecomp/_src/cudecomp/fft.py 69 additions, 204 deletionsjaxdecomp/_src/cudecomp/fft.py
- jaxdecomp/_src/cudecomp/halo.py 83 additions, 87 deletionsjaxdecomp/_src/cudecomp/halo.py
- jaxdecomp/_src/cudecomp/transpose.py 474 additions, 0 deletionsjaxdecomp/_src/cudecomp/transpose.py
- jaxdecomp/_src/fft_utils.py 113 additions, 0 deletionsjaxdecomp/_src/fft_utils.py
- jaxdecomp/_src/jax/__init__.py 0 additions, 0 deletionsjaxdecomp/_src/jax/__init__.py
- jaxdecomp/_src/jax/fft.py 346 additions, 0 deletionsjaxdecomp/_src/jax/fft.py
- jaxdecomp/_src/jax/fftfreq.py 65 additions, 0 deletionsjaxdecomp/_src/jax/fftfreq.py
- jaxdecomp/_src/jax/halo.py 359 additions, 0 deletionsjaxdecomp/_src/jax/halo.py
- jaxdecomp/_src/jax/transpose.py 368 additions, 0 deletionsjaxdecomp/_src/jax/transpose.py
- jaxdecomp/_src/pencil_utils.py 230 additions, 0 deletionsjaxdecomp/_src/pencil_utils.py
- jaxdecomp/_src/spmd_ops.py 57 additions, 5 deletionsjaxdecomp/_src/spmd_ops.py
Loading
Please register or sign in to comment