Skip to content

Commit 88c6bea

Browse files
committed
Fix JAX extension build with NVTE_UB_WITH_MPI=1
Signed-off-by: Gaetan Lepage <gaetan@glepage.com>
1 parent e83c097 commit 88c6bea

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

build_tools/jax.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# See LICENSE for license information.
44

55
"""JAX related extensions."""
6+
67
import os
78
from pathlib import Path
89
from packaging import version
@@ -100,6 +101,14 @@ def setup_jax_extension(
100101
else:
101102
cxx_flags.append("-g0")
102103

104+
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
105+
assert os.getenv("MPI_HOME") is not None, (
106+
"MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
107+
)
108+
mpi_path = Path(os.getenv("MPI_HOME"))
109+
include_dirs.append(mpi_path / "include")
110+
cxx_flags.append("-DNVTE_UB_WITH_MPI")
111+
103112
# Define TE/JAX as a Pybind11Extension
104113
from pybind11.setup_helpers import Pybind11Extension
105114

0 commit comments

Comments
 (0)