Compile the function ahead of time in the JAX example#6286
Compile the function ahead of time in the JAX example#6286rostan-t wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Greptile SummaryThis PR fixes a CUDA graph capture conflict in the JAX training example by adding a warmup cell that forces Confidence Score: 5/5Safe to merge — the fix is correct and all remaining findings are P2 or lower. The warmup approach is sound: No files require special attention.
|
| Filename | Overview |
|---|---|
| docs/examples/frameworks/jax/jax-basic_example.ipynb | Adds a pre-training warmup cell that JIT-compiles update with dummy inputs before DALI iterators produce real data, preventing CUDA graph capture from conflicting with DALI's background cudaDeviceSynchronize calls. |
Sequence Diagram
sequenceDiagram
participant NB as Notebook
participant DALI as DALI Background Thread
participant JAX as JAX/XLA
NB->>DALI: Create training_iterator & validation_iterator
Note over DALI: Prefetch thread starts (may issue cudaDeviceSynchronize)
rect rgb(200, 255, 200)
Note over NB,JAX: Warmup cell (NEW)
NB->>JAX: update(model, dummy_inputs) — first call
JAX->>JAX: JIT trace + XLA compile
Note over JAX: CUDA graph capture happens here<br/>(before DALI produces real batches)
JAX-->>NB: compiled & cached
end
rect rgb(220, 220, 255)
Note over NB,DALI: Training loop
loop each epoch
NB->>DALI: next(training_iterator)
DALI-->>NB: batch
NB->>JAX: update(model, batch)
Note over JAX: Already compiled — no graph capture
JAX-->>NB: updated model
end
end
Reviews (2): Last reviewed commit: "Compile the function ahead of time in th..." | Re-trigger Greptile
Signed-off-by: Rostan Tabet <rtabet@nvidia.com>
1f9dbc5 to
be34d7e
Compare
|
Tip: Greploop — Automatically fix all review issues by running Use the Greptile plugin for Claude Code to query reviews, search comments, and manage custom context directly from your terminal. |
|
!build |
|
CI MESSAGE: [48206370]: BUILD STARTED |
|
CI MESSAGE: [48206370]: BUILD PASSED |
This makes sure that DALI doesn't call cudaDeviceSynchronize during a graph capture
Category:
Bug fix (non-breaking change which fixes an issue)
Description:
Host-device synchronizing CUDA calls are forbidden when capturing a CUDA graph. This can cause JAX to raise exceptions because:
cudaDeviceSynchronize()jax.jitso it happens while DALI is runningThere is no easy fix for this from DALI because even if the error in
cudaDeviceSynchronize()is handled, this invalidates the capture.This PR changes the JAX training example to make sure that the function is compiled before DALI starts running.
Additional information:
Affected modules and functionalities:
JAX training example
Key points relevant for the review:
Tests:
Checklist
Documentation
DALI team only
Requirements
REQ IDs: N/A
JIRA TASK: N/A