Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
J
jaxDecomp
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Admin message
Gitlab has been updated. More info
here
.
Show more breadcrumbs
Wassim KABALAN
jaxDecomp
Commits
320201ce
Commit
320201ce
authored
1 month ago
by
Wassim KABALAN
Browse files
Options
Downloads
Patches
Plain Diff
toto
parent
96d9e924
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
benchmarks/jaxfft.py
+203
-0
203 additions, 0 deletions
benchmarks/jaxfft.py
benchmarks/mpi4jaxfft.py
+251
-0
251 additions, 0 deletions
benchmarks/mpi4jaxfft.py
with
454 additions
and
0 deletions
benchmarks/jaxfft.py
0 → 100644
+
203
−
0
View file @
320201ce
import
jax
jax
.
distributed
.
initialize
()
rank
=
jax
.
process_index
()
size
=
jax
.
process_count
()
import
argparse
import
re
import
time
from
functools
import
partial
import
jax.numpy
as
jnp
import
jaxdecomp
import
numpy
as
np
from
cupy.cuda.nvtx
import
RangePop
,
RangePush
from
jax.experimental
import
mesh_utils
,
multihost_utils
from
jax.experimental.multihost_utils
import
sync_global_devices
from
jax.sharding
import
Mesh
from
jax.sharding
import
PartitionSpec
as
P
from
jax_hpc_profiler
import
Timer
def
run_benchmark
(
pdims
,
global_shape
,
backend
,
nb_nodes
,
precision
,
iterations
,
trace
,
contiguous
,
output_path
):
pfft_backend
=
"
cudecomp
"
if
backend
==
"
NCCL
"
:
jaxdecomp
.
config
.
update
(
'
transpose_comm_backend
'
,
jaxdecomp
.
TRANSPOSE_COMM_NCCL
)
elif
backend
==
"
NCCL_PL
"
:
jaxdecomp
.
config
.
update
(
'
transpose_comm_backend
'
,
jaxdecomp
.
TRANSPOSE_COMM_NCCL_PL
)
elif
backend
==
"
MPI_P2P
"
:
jaxdecomp
.
config
.
update
(
'
transpose_comm_backend
'
,
jaxdecomp
.
TRANSPOSE_COMM_MPI_P2P
)
elif
backend
==
"
MPI
"
:
jaxdecomp
.
config
.
update
(
'
transpose_comm_backend
'
,
jaxdecomp
.
TRANSPOSE_COMM_MPI_A2A
)
else
:
pfft_backend
=
"
jax
"
# Initialize the local slice with the local slice shape
array
=
jax
.
random
.
normal
(
shape
=
[
global_shape
[
0
]
//
pdims
[
0
],
global_shape
[
1
]
//
pdims
[
1
],
global_shape
[
2
]],
key
=
jax
.
random
.
PRNGKey
(
rank
))
if
contiguous
:
jaxdecomp
.
config
.
update
(
'
transpose_axis_contiguous
'
,
contiguous
)
# Remap to the global array from the local slice
devices
=
mesh_utils
.
create_device_mesh
(
pdims
)
mesh
=
Mesh
(
devices
.
T
,
axis_names
=
(
'
z
'
,
'
y
'
))
global_array
=
multihost_utils
.
host_local_array_to_global_array
(
array
,
mesh
,
P
(
'
z
'
,
'
y
'
))
if
jax
.
process_index
()
==
0
:
print
(
f
"
Devices
{
jax
.
devices
()
}
"
)
print
(
f
"
Global dims
{
global_shape
}
, pdims
{
pdims
}
, Backend
{
backend
}
local array shape
{
array
.
shape
}
global array shape
{
global_array
.
shape
}
"
)
print
(
"
Sharding :
"
)
print
(
global_array
.
sharding
)
@jax.jit
def
do_fft
(
x
):
return
jaxdecomp
.
fft
.
pfft3d
(
x
,
backend
=
pfft_backend
)
@jax.jit
def
do_ifft
(
x
):
return
jaxdecomp
.
fft
.
pifft3d
(
x
,
backend
=
pfft_backend
)
@jax.jit
def
get_diff
(
arr1
,
arr2
):
return
jnp
.
abs
(
arr1
-
arr2
).
max
()
fft_chrono
=
Timer
(
save_jaxpr
=
True
)
ifft_chrono
=
Timer
(
save_jaxpr
=
True
)
with
mesh
:
if
trace
:
jit_fft_output
=
f
"
{
output_path
}
/jit_fft_trace
"
first_run_fft_output
=
f
"
{
output_path
}
/first_run_fft_trace
"
second_run_fft_output
=
f
"
{
output_path
}
/second_run_ifft_trace
"
jit_ifft_output
=
f
"
{
output_path
}
/jit_ifft_trace
"
first_run_ifft_output
=
f
"
{
output_path
}
/first_run_ifft_trace
"
second_run_ifft_output
=
f
"
{
output_path
}
/second_run_ifft_trace
"
with
jax
.
profiler
.
trace
(
jit_fft_output
,
create_perfetto_trace
=
True
):
global_array
=
do_fft
(
global_array
).
block_until_ready
()
with
jax
.
profiler
.
trace
(
jit_ifft_output
,
create_perfetto_trace
=
True
):
global_array
=
do_ifft
(
global_array
).
block_until_ready
()
with
jax
.
profiler
.
trace
(
first_run_fft_output
,
create_perfetto_trace
=
True
):
global_array
=
do_fft
(
global_array
).
block_until_ready
()
with
jax
.
profiler
.
trace
(
first_run_ifft_output
,
create_perfetto_trace
=
True
):
global_array
=
do_ifft
(
global_array
).
block_until_ready
()
with
jax
.
profiler
.
trace
(
second_run_fft_output
,
create_perfetto_trace
=
True
):
global_array
=
do_ifft
(
global_array
).
block_until_ready
()
with
jax
.
profiler
.
trace
(
second_run_ifft_output
,
create_perfetto_trace
=
True
):
global_array
=
do_ifft
(
global_array
).
block_until_ready
()
else
:
# Warm start
RangePush
(
"
warmup
"
)
global_array
=
fft_chrono
.
chrono_jit
(
do_fft
,
global_array
)
global_array
=
ifft_chrono
.
chrono_jit
(
do_ifft
,
global_array
)
RangePop
()
sync_global_devices
(
"
warmup
"
)
for
i
in
range
(
iterations
):
RangePush
(
f
"
fft iter
{
i
}
"
)
global_array
=
fft_chrono
.
chrono_fun
(
do_fft
,
global_array
)
RangePop
()
RangePush
(
f
"
ifft iter
{
i
}
"
)
global_array
=
ifft_chrono
.
chrono_fun
(
do_ifft
,
global_array
)
RangePop
()
cont_str
=
"
-cont
"
if
contiguous
else
"
-noncont
"
if
not
trace
:
fft_metadata
=
{
'
function
'
:
f
"
FFT
{
cont_str
}
"
,
'
precision
'
:
precision
,
'
x
'
:
str
(
global_shape
[
0
]),
'
px
'
:
str
(
pdims
[
0
]),
'
py
'
:
str
(
pdims
[
1
]),
'
backend
'
:
backend
,
'
nodes
'
:
str
(
nb_nodes
),
}
ifft_metadata
=
{
'
function
'
:
f
"
IFFT
{
cont_str
}
"
,
'
precision
'
:
precision
,
'
x
'
:
str
(
global_shape
[
0
]),
'
px
'
:
str
(
pdims
[
0
]),
'
py
'
:
str
(
pdims
[
1
]),
'
backend
'
:
backend
,
'
nodes
'
:
str
(
nb_nodes
),
}
fft_chrono
.
report
(
f
"
{
output_path
}
/jaxfft.csv
"
,
**
fft_metadata
)
ifft_chrono
.
report
(
f
"
{
output_path
}
/jaxfft.csv
"
,
**
ifft_metadata
)
print
(
f
"
Done
"
)
if
__name__
==
"
__main__
"
:
parser
=
argparse
.
ArgumentParser
(
description
=
'
NBody Benchmark
'
)
parser
.
add_argument
(
'
-p
'
,
'
--pdims
'
,
type
=
str
,
help
=
'
GPU grid
'
,
required
=
True
)
parser
.
add_argument
(
'
-l
'
,
'
--local_shape
'
,
type
=
int
,
help
=
'
Local shape
'
,
default
=
None
)
parser
.
add_argument
(
'
-g
'
,
'
--global_shape
'
,
type
=
int
,
help
=
'
Global shape of the array
'
,
default
=
None
)
parser
.
add_argument
(
'
-b
'
,
'
--backend
'
,
type
=
str
,
help
=
'
Backend to use for transpose comm
'
,
choices
=
[
"
NCCL
"
,
"
NCCL_PL
"
,
"
MPI_P2P
"
,
"
MPI
"
],
default
=
"
NCCL
"
)
parser
.
add_argument
(
'
-n
'
,
'
--nb_nodes
'
,
type
=
int
,
help
=
'
Number of nodes
'
,
default
=
1
)
parser
.
add_argument
(
'
-o
'
,
'
--output_path
'
,
type
=
str
,
help
=
'
Output path
'
,
default
=
"
.
"
)
parser
.
add_argument
(
'
-pr
'
,
'
--precision
'
,
type
=
str
,
help
=
'
Precision
'
,
default
=
"
float32
"
)
parser
.
add_argument
(
'
-i
'
,
'
--iterations
'
,
type
=
int
,
help
=
'
Number of iterations
'
,
default
=
10
)
parser
.
add_argument
(
'
-c
'
,
'
--contiguous
'
,
type
=
bool
,
help
=
'
Contiguous
'
,
default
=
True
)
parser
.
add_argument
(
'
-t
'
,
'
--trace
'
,
action
=
'
store_true
'
,
help
=
'
Profile using tensorboard
'
)
args
=
parser
.
parse_args
()
if
args
.
local_shape
is
not
None
and
args
.
global_shape
is
not
None
:
print
(
"
Please provide either local_shape or global_shape
"
)
parser
.
print_help
()
exit
(
0
)
if
args
.
local_shape
is
not
None
:
global_shape
=
(
args
.
local_shape
*
jax
.
device_count
(),
args
.
local_shape
*
jax
.
device_count
(),
args
.
local_shape
*
jax
.
device_count
())
elif
args
.
global_shape
is
not
None
:
global_shape
=
(
args
.
global_shape
,
args
.
global_shape
,
args
.
global_shape
)
else
:
print
(
"
Please provide either local_shape or global_shape
"
)
parser
.
print_help
()
exit
(
0
)
if
args
.
precision
==
"
float32
"
:
jax
.
config
.
update
(
"
jax_enable_x64
"
,
False
)
elif
args
.
precision
==
"
float64
"
:
jax
.
config
.
update
(
"
jax_enable_x64
"
,
True
)
else
:
print
(
"
Precision should be either float32 or float64
"
)
parser
.
print_help
()
exit
(
0
)
pdims
=
tuple
(
map
(
int
,
args
.
pdims
.
split
(
"
x
"
)))
backend
=
args
.
backend
nb_nodes
=
args
.
nb_nodes
output_path
=
args
.
output_path
import
os
os
.
makedirs
(
output_path
,
exist_ok
=
True
)
for
dim
in
global_shape
:
for
pdim
in
pdims
:
if
dim
%
pdim
!=
0
:
print
(
f
"
Global shape
{
global_shape
}
is not divisible by pdims
{
pdims
}
"
)
exit
(
0
)
# Do not raise error for slurm jobs
# raise ValueError(f"Global shape {global_shape} is not divisible by pdims {pdims}")
run_benchmark
(
pdims
,
global_shape
,
backend
,
nb_nodes
,
args
.
precision
,
args
.
iterations
,
args
.
trace
,
args
.
contiguous
,
output_path
)
jaxdecomp
.
finalize
()
jax
.
distributed
.
shutdown
()
\ No newline at end of file
This diff is collapsed.
Click to expand it.
benchmarks/mpi4jaxfft.py
0 → 100644
+
251
−
0
View file @
320201ce
import
jax
jax
.
distributed
.
initialize
()
import
argparse
import
os
import
time
import
jax.numpy
as
jnp
import
mpi4jax
import
numpy
as
np
from
cupy.cuda.nvtx
import
RangePop
,
RangePush
from
mpi4py
import
MPI
# Create communicators
world
=
MPI
.
COMM_WORLD
rank
=
world
.
Get_rank
()
size
=
world
.
Get_size
()
if
rank
==
0
:
print
(
"
Communication setup done!
"
)
def
chrono_fun
(
fun
,
*
args
):
start
=
time
.
perf_counter
()
out
=
fun
(
*
args
).
block_until_ready
()
end
=
time
.
perf_counter
()
return
out
,
end
-
start
def
fft3d
(
arr
,
comms
=
None
):
"""
Computes forward FFT, note that the output is transposed
"""
if
comms
is
not
None
:
shape
=
list
(
arr
.
shape
)
nx
=
comms
[
0
].
Get_size
()
ny
=
comms
[
1
].
Get_size
()
# First FFT along z
arr
=
jnp
.
fft
.
fft
(
arr
)
# [x, y, z]
# Perform single gpu or distributed transpose
if
comms
==
None
:
arr
=
arr
.
transpose
([
1
,
2
,
0
])
else
:
arr
=
arr
.
reshape
(
shape
[:
-
1
]
+
[
nx
,
shape
[
-
1
]
//
nx
])
#arr = arr.transpose([2, 1, 3, 0]) # [y, z, x]
arr
=
jnp
.
einsum
(
'
ij,xyjz->iyzx
'
,
jnp
.
eye
(
nx
),
arr
)
# TODO: remove this hack when we understand why transpose before alltoall doenst work
arr
,
token
=
mpi4jax
.
alltoall
(
arr
,
comm
=
comms
[
0
])
arr
=
arr
.
transpose
([
1
,
2
,
0
,
3
]).
reshape
(
shape
)
# Now [y, z, x]
# Second FFT along x
arr
=
jnp
.
fft
.
fft
(
arr
)
# Perform single gpu or distributed transpose
if
comms
==
None
:
arr
=
arr
.
transpose
([
1
,
2
,
0
])
else
:
arr
=
arr
.
reshape
(
shape
[:
-
1
]
+
[
ny
,
shape
[
-
1
]
//
ny
])
#arr = arr.transpose([2, 1, 3, 0]) # [z, x, y]
arr
=
jnp
.
einsum
(
'
ij,yzjx->izxy
'
,
jnp
.
eye
(
ny
),
arr
)
# TODO: remove this hack when we understand why transpose before alltoall doenst work
arr
,
token
=
mpi4jax
.
alltoall
(
arr
,
comm
=
comms
[
1
],
token
=
token
)
arr
=
arr
.
transpose
([
1
,
2
,
0
,
3
]).
reshape
(
shape
)
# Now [z, x, y]
# Third FFT along y
return
jnp
.
fft
.
fft
(
arr
)
def
ifft3d
(
arr
,
comms
=
None
):
"""
Let
'
s assume that the data is distributed accross x
"""
if
comms
is
not
None
:
shape
=
list
(
arr
.
shape
)
nx
=
comms
[
0
].
Get_size
()
ny
=
comms
[
1
].
Get_size
()
# First FFT along y
arr
=
jnp
.
fft
.
ifft
(
arr
)
# Now [z, x, y]
# Perform single gpu or distributed transpose
if
comms
==
None
:
arr
=
arr
.
transpose
([
0
,
2
,
1
])
else
:
arr
=
arr
.
reshape
(
shape
[:
-
1
]
+
[
ny
,
shape
[
-
1
]
//
ny
])
# arr = arr.transpose([2, 0, 3, 1]) # Now [z, y, x]
arr
=
jnp
.
einsum
(
'
ij,zxjy->izyx
'
,
jnp
.
eye
(
ny
),
arr
)
# TODO: remove this hack when we understand why transpose before alltoall doenst work
arr
,
token
=
mpi4jax
.
alltoall
(
arr
,
comm
=
comms
[
1
])
arr
=
arr
.
transpose
([
1
,
2
,
0
,
3
]).
reshape
(
shape
)
# Now [z,y,x]
# Second FFT along x
arr
=
jnp
.
fft
.
ifft
(
arr
)
# Perform single gpu or distributed transpose
if
comms
==
None
:
arr
=
arr
.
transpose
([
2
,
1
,
0
])
else
:
arr
=
arr
.
reshape
(
shape
[:
-
1
]
+
[
nx
,
shape
[
-
1
]
//
nx
])
# arr = arr.transpose([2, 3, 1, 0]) # now [x, y, z]
arr
=
jnp
.
einsum
(
'
ij,zyjx->ixyz
'
,
jnp
.
eye
(
nx
),
arr
)
# TODO: remove this hack when we understand why transpose before alltoall doenst work
arr
,
token
=
mpi4jax
.
alltoall
(
arr
,
comm
=
comms
[
0
],
token
=
token
)
arr
=
arr
.
transpose
([
1
,
2
,
0
,
3
]).
reshape
(
shape
)
# Now [x,y,z]
# Third FFT along z
return
jnp
.
fft
.
ifft
(
arr
)
def
normal
(
key
,
shape
,
comms
=
None
):
"""
Generates a normal variable for the given
global shape.
"""
if
comms
is
None
:
return
jax
.
random
.
normal
(
key
,
shape
)
nx
=
comms
[
0
].
Get_size
()
ny
=
comms
[
1
].
Get_size
()
print
(
shape
)
return
jax
.
random
.
normal
(
key
,
[
shape
[
0
]
//
nx
,
shape
[
1
]
//
ny
]
+
list
(
shape
[
2
:]))
def
run_benchmark
(
global_shape
,
nb_nodes
,
pdims
,
precision
,
iterations
,
output_path
):
"""
Run the benchmark
"""
cart_comm
=
MPI
.
COMM_WORLD
.
Create_cart
(
dims
=
list
(
pdims
),
periods
=
[
True
,
True
])
comms
=
[
cart_comm
.
Sub
([
True
,
False
]),
cart_comm
.
Sub
([
False
,
True
])]
backend
=
"
MPI4JAX
"
# Setup random keys
master_key
=
jax
.
random
.
PRNGKey
(
42
)
key
=
jax
.
random
.
split
(
master_key
,
size
)[
rank
]
# Size of the FFT
N
=
256
mesh_shape
=
list
(
global_shape
)
# Generate a random gaussian variable for the global
# mesh shape
global_array
=
normal
(
key
,
mesh_shape
,
comms
=
comms
)
if
jax
.
process_index
()
==
0
:
print
(
f
"
Devices
{
jax
.
devices
()
}
"
)
print
(
f
"
Global dims
{
global_shape
}
, pdims (
{
comms
[
0
].
Get_size
()
}
,
{
comms
[
1
].
Get_size
()
}
) , Bachend
{
backend
}
original_array shape
{
global_array
.
shape
}
"
)
@jax.jit
def
do_fft
(
x
):
return
fft3d
(
x
,
comms
=
comms
)
@jax.jit
def
do_ifft
(
x
):
return
ifft3d
(
x
,
comms
=
comms
)
jit_fft_time
=
0
jit_ifft_time
=
0
ffts_times
=
[]
iffts_times
=
[]
# Warm start
RangePush
(
"
warmup
"
)
global_array
,
jit_fft_time
=
chrono_fun
(
do_fft
,
global_array
)
global_array
,
jit_ifft_time
=
chrono_fun
(
do_ifft
,
global_array
)
RangePop
()
MPI
.
COMM_WORLD
.
Barrier
()
for
i
in
range
(
iterations
):
RangePush
(
f
"
fft iter
{
i
}
"
)
global_array
,
fft_time
=
chrono_fun
(
do_fft
,
global_array
)
RangePop
()
ffts_times
.
append
(
fft_time
)
RangePush
(
f
"
ifft iter
{
i
}
"
)
global_array
,
ifft_time
=
chrono_fun
(
do_ifft
,
global_array
)
RangePop
()
iffts_times
.
append
(
ifft_time
)
ffts_times
=
np
.
array
(
ffts_times
)
*
1e3
iffts_times
=
np
.
array
(
iffts_times
)
*
1e3
# FFT
jit_fft_time
*=
1e3
fft_min_time
=
np
.
min
(
ffts_times
)
fft_max_time
=
np
.
max
(
ffts_times
)
fft_mean_time
=
jnp
.
mean
(
ffts_times
)
fft_std_time
=
jnp
.
std
(
ffts_times
)
last_fft_time
=
ffts_times
[
-
1
]
# IFFT
jit_ifft_time
*=
1e3
ifft_min_time
=
np
.
min
(
iffts_times
)
ifft_max_time
=
np
.
max
(
iffts_times
)
ifft_mean_time
=
jnp
.
mean
(
iffts_times
)
ifft_std_time
=
jnp
.
std
(
iffts_times
)
last_ifft_time
=
iffts_times
[
-
1
]
# RANK TYPE PRECISION SIZE PDIMS BACKEND NB_NODES JIT_TIME MIN MAX MEAN STD
with
open
(
f
"
{
output_path
}
/mpi4jaxfft.csv
"
,
'
a
'
)
as
f
:
f
.
write
(
f
"
{
jax
.
process_index
()
}
,FFT,
{
precision
}
,
{
global_shape
[
0
]
}
,
{
global_shape
[
1
]
}
,
{
global_shape
[
2
]
}
,
{
pdims
[
0
]
}
,
{
pdims
[
1
]
}
,
{
backend
}
,
{
nb_nodes
}
,
{
jit_fft_time
:
.
4
f
}
,
{
fft_min_time
:
.
4
f
}
,
{
fft_max_time
:
.
4
f
}
,
{
fft_mean_time
:
.
4
f
}
,
{
fft_std_time
:
.
4
f
}
,
{
last_fft_time
:
.
4
f
}
\n
"
)
f
.
write
(
f
"
{
jax
.
process_index
()
}
,IFFT,
{
precision
}
,
{
global_shape
[
0
]
}
,
{
global_shape
[
1
]
}
,
{
global_shape
[
2
]
}
,
{
pdims
[
0
]
}
,
{
pdims
[
1
]
}
,
{
backend
}
,
{
nb_nodes
}
,
{
jit_ifft_time
:
.
4
f
}
,
{
ifft_min_time
:
.
4
f
}
,
{
ifft_max_time
:
.
4
f
}
,
{
ifft_mean_time
:
.
4
f
}
,
{
ifft_std_time
:
.
4
f
}
,
{
last_ifft_time
:
.
4
f
}
\n
"
)
print
(
f
"
Done
"
)
print
(
f
"
FFT times
"
)
print
(
f
"
JIT time
{
jit_fft_time
:
.
4
f
}
ms
"
)
for
i
in
range
(
iterations
):
print
(
f
"
FFT
{
i
}
time
{
ffts_times
[
i
]
:
.
4
f
}
ms
"
)
print
(
f
"
IFFT times
"
)
print
(
f
"
JIT time
{
jit_ifft_time
:
.
4
f
}
ms
"
)
for
i
in
range
(
iterations
):
print
(
f
"
IFFT
{
i
}
time
{
iffts_times
[
i
]
:
.
4
f
}
ms
"
)
if
__name__
==
"
__main__
"
:
parser
=
argparse
.
ArgumentParser
(
description
=
'
NBody MPI4JAX Benchmark
'
)
parser
.
add_argument
(
'
-g
'
,
'
--global_shape
'
,
type
=
int
,
help
=
'
Global shape
'
,
default
=
None
)
parser
.
add_argument
(
'
-l
'
,
'
--local_shape
'
,
type
=
int
,
help
=
'
Local shape
'
,
default
=
None
)
parser
.
add_argument
(
'
-n
'
,
'
--nb_nodes
'
,
type
=
int
,
help
=
'
Number of nodes
'
,
default
=
1
)
parser
.
add_argument
(
'
-p
'
,
'
--pdims
'
,
type
=
str
,
help
=
'
GPU grid
'
,
required
=
True
)
parser
.
add_argument
(
'
-o
'
,
'
--output_path
'
,
type
=
str
,
help
=
'
Output path
'
,
default
=
"
.
"
)
parser
.
add_argument
(
'
-pr
'
,
'
--precision
'
,
type
=
str
,
help
=
'
Precision
'
,
default
=
"
float32
"
)
parser
.
add_argument
(
'
-i
'
,
'
--iterations
'
,
type
=
int
,
help
=
'
Number of iterations
'
,
default
=
10
)
args
=
parser
.
parse_args
()
if
args
.
local_shape
is
not
None
:
global_shape
=
(
args
.
global_shape
*
jax
.
device_count
(),
args
.
global_shape
*
jax
.
device_count
(),
args
.
global_shape
*
jax
.
devices
())
elif
args
.
global_shape
is
not
None
:
global_shape
=
(
args
.
global_shape
,
args
.
global_shape
,
args
.
global_shape
)
else
:
print
(
"
Please provide either local_shape or global_shape
"
)
parser
.
print_help
()
exit
(
0
)
if
args
.
precision
==
"
float32
"
:
jax
.
config
.
update
(
"
jax_enable_x64
"
,
False
)
elif
args
.
precision
==
"
float64
"
:
jax
.
config
.
update
(
"
jax_enable_x64
"
,
True
)
else
:
print
(
"
Precision should be either float32 or float64
"
)
parser
.
print_help
()
exit
(
0
)
nb_nodes
=
args
.
nb_nodes
output_path
=
args
.
output_path
os
.
makedirs
(
output_path
,
exist_ok
=
True
)
pdims
=
[
int
(
x
)
for
x
in
args
.
pdims
.
split
(
"
x
"
)]
run_benchmark
(
global_shape
,
nb_nodes
,
pdims
,
args
.
precision
,
args
.
iterations
,
output_path
)
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment