Skip to content

Fix JAX extension build with NVTE_UB_WITH_MPI=1#2835

Merged
timmoon10 merged 4 commits intoNVIDIA:mainfrom
GaetanLepage:main
Apr 11, 2026
Merged

Fix JAX extension build with NVTE_UB_WITH_MPI=1#2835
timmoon10 merged 4 commits intoNVIDIA:mainfrom
GaetanLepage:main

Conversation

@GaetanLepage
Copy link
Copy Markdown
Contributor

Description

When building Transformer Engine with NVTE_UB_WITH_MPI=1 and NVTE_FRAMEWORK=pytorch,jax, the JAX extension (transformer_engine_jax) fails to load at runtime with an undefined symbol error, while the PyTorch extension works fine.

In userbuffers.h, the ExtComm type is conditionally defined based on NVTE_UB_WITH_MPI:

#ifdef NVTE_UB_WITH_MPI
#define ExtComm MPI_Comm
#else
#define ExtComm const char *
#endif

This type flows into ExtAllgatherOp and ExtBarrierOp, which are parameters of the CommOverlapP2PBase constructor.
This means the constructor has a different mangled symbol name depending on whether NVTE_UB_WITH_MPI is defined.
The core library (libtransformer_engine.so) is built via CMake, which correctly sets -DNVTE_UB_WITH_MPI.
The PyTorch extension also adds this flag.
However, the JAX extension is missing this flag entirely.
As a result, transformer_engine_jax.so is compiled expecting the const char * variant of the constructor, while libtransformer_engine.so only exports the MPI_Comm variant, causing an undefined symbol error at import time.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

This PR adds the MPI include path and -DNVTE_UB_WITH_MPI compile definition to the JAX extension build, mirroring the existing handling in build_tools/pytorch.py.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 4, 2026

Greptile Summary

This PR fixes a runtime undefined symbol error when building the JAX extension with NVTE_UB_WITH_MPI=1. The root cause is that transformer_engine_jax.so was compiled without -DNVTE_UB_WITH_MPI, causing a C++ symbol mangling mismatch against libtransformer_engine.so. The fix extracts the MPI flag setup into a shared setup_mpi_flags helper in utils.py and calls it from both the JAX and PyTorch extension builders.

Confidence Score: 5/5

Safe to merge — targeted bug fix with correct implementation and no regressions introduced.

The fix is minimal and correct: it adds the missing -DNVTE_UB_WITH_MPI compile flag to the JAX extension builder, directly resolving the symbol mangling mismatch. The refactoring into setup_mpi_flags is clean and functionally identical to the original inline code in pytorch.py. The only previously noted concern (empty MPI_HOME string) was already flagged in a prior review thread. No new P0/P1 issues exist.

No files require special attention.

Important Files Changed

Filename Overview
build_tools/jax.py Core bug fix: adds setup_mpi_flags(include_dirs, cxx_flags) call so the JAX extension is compiled with -DNVTE_UB_WITH_MPI when the flag is set, matching the PyTorch extension behavior.
build_tools/utils.py New setup_mpi_flags helper extracts MPI include/flag setup from pytorch.py into a shared utility; logic is unchanged from the original inline version.
build_tools/pytorch.py Replaces the inline MPI setup block with a call to the new shared setup_mpi_flags helper; behavior is functionally identical.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Build with NVTE_UB_WITH_MPI=1] --> B{Extension type}
    B -->|JAX| C[jax.py: setup_jax_extension]
    B -->|PyTorch| D[pytorch.py: setup_pytorch_extension]
    C --> E[setup_mpi_flags - utils.py]
    D --> E
    E --> F{NVTE_UB_WITH_MPI set?}
    F -->|No| G[No MPI flags added]
    F -->|Yes| H{MPI_HOME set?}
    H -->|No| I[Assert error: MPI_HOME required]
    H -->|Yes| J[Add MPI include dir]
    J --> K[Add -DNVTE_UB_WITH_MPI flag]
    K --> L[Extension compiled with matching symbol mangling]
    L --> M[libtransformer_engine.so symbols resolved at runtime]
Loading

Reviews (4): Last reviewed commit: "Merge branch 'main' into main" | Re-trigger Greptile

Comment on lines +105 to +106
assert (
os.getenv("MPI_HOME") is not None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Empty MPI_HOME string bypasses the guard

os.getenv("MPI_HOME") returns None only when the variable is unset. If a user exports MPI_HOME="" (empty string), the assert passes (empty string is not None), and Path("") silently resolves to the current working directory — not a valid MPI installation — causing confusing compile errors downstream.

Consider checking for a non-empty value:

Suggested change
assert (
os.getenv("MPI_HOME") is not None
mpi_home = os.getenv("MPI_HOME")
assert mpi_home, (
"MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
)
mpi_path = Path(mpi_home)

This also avoids calling os.getenv("MPI_HOME") twice (once in the assert, once for Path(...)). Note: the same pattern exists in build_tools/pytorch.py line 71–74.

Signed-off-by: Gaetan Lepage <gaetan@glepage.com>
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI

@timmoon10
Copy link
Copy Markdown
Collaborator

/te-ci L1

@timmoon10 timmoon10 merged commit 2dd31bb into NVIDIA:main Apr 11, 2026
10 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants