We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent e83c097 commit 88c6beaCopy full SHA for 88c6bea
build_tools/jax.py
@@ -3,6 +3,7 @@
3
# See LICENSE for license information.
4
5
"""JAX related extensions."""
6
+
7
import os
8
from pathlib import Path
9
from packaging import version
@@ -100,6 +101,14 @@ def setup_jax_extension(
100
101
else:
102
cxx_flags.append("-g0")
103
104
+ if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
105
+ assert os.getenv("MPI_HOME") is not None, (
106
+ "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
107
+ )
108
+ mpi_path = Path(os.getenv("MPI_HOME"))
109
+ include_dirs.append(mpi_path / "include")
110
+ cxx_flags.append("-DNVTE_UB_WITH_MPI")
111
112
# Define TE/JAX as a Pybind11Extension
113
from pybind11.setup_helpers import Pybind11Extension
114
0 commit comments