Skip to content

Compile the function ahead of time in the JAX example#6286

Open
rostan-t wants to merge 1 commit intoNVIDIA:mainfrom
rostan-t:jax-example-fix-graph-capture
Open

Compile the function ahead of time in the JAX example#6286
rostan-t wants to merge 1 commit intoNVIDIA:mainfrom
rostan-t:jax-example-fix-graph-capture

Conversation

@rostan-t
Copy link
Copy Markdown
Collaborator

@rostan-t rostan-t commented Apr 9, 2026

This makes sure that DALI doesn't call cudaDeviceSynchronize during a graph capture

Category:

Bug fix (non-breaking change which fixes an issue)

Description:

Host-device synchronizing CUDA calls are forbidden when capturing a CUDA graph. This can cause JAX to raise exceptions because:

  1. DALI has background threads calling cudaDeviceSynchronize()
  2. XLA captures CUDA graphs when compiling
  3. Compilation is lazy with jax.jit so it happens while DALI is running

There is no easy fix for this from DALI because even if the error in cudaDeviceSynchronize() is handled, this invalidates the capture.

This PR changes the JAX training example to make sure that the function is compiled before DALI starts running.

Additional information:

Affected modules and functionalities:

JAX training example

Key points relevant for the review:

Tests:

  • Existing tests apply
  • New tests added
    • Python tests
    • GTests
    • Benchmark
    • Other
  • N/A

Checklist

Documentation

  • Existing documentation applies
  • Documentation updated
    • Docstring
    • Doxygen
    • RST
    • Jupyter
    • Other
  • N/A

DALI team only

Requirements

  • Implements new requirements
  • Affects existing requirements
  • N/A

REQ IDs: N/A

JIRA TASK: N/A

@review-notebook-app
Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 9, 2026

Greptile Summary

This PR fixes a CUDA graph capture conflict in the JAX training example by adding a warmup cell that forces update (decorated with @jit) to compile before any DALI workload begins. The root cause is that XLA captures CUDA graphs during the first @jit invocation, and this previously happened lazily during training while DALI's background thread was already issuing cudaDeviceSynchronize, which is forbidden during graph capture. The fix is correct: update is the only @jit-decorated function in the training path, accuracy/predict are not @jit, so no CUDA graph capture occurs during validation.

Confidence Score: 5/5

Safe to merge — the fix is correct and all remaining findings are P2 or lower.

The warmup approach is sound: update is the only @jit-decorated function in the training path, so warming it up with dummy inputs of the correct shape/dtype is sufficient to prevent CUDA graph capture from overlapping with DALI. No P0/P1 issues remain; prior review thread concerns have already been raised with the author.

No files require special attention.

Vulnerabilities

No security concerns identified.

Important Files Changed

Filename Overview
docs/examples/frameworks/jax/jax-basic_example.ipynb Adds a pre-training warmup cell that JIT-compiles update with dummy inputs before DALI iterators produce real data, preventing CUDA graph capture from conflicting with DALI's background cudaDeviceSynchronize calls.

Sequence Diagram

sequenceDiagram
    participant NB as Notebook
    participant DALI as DALI Background Thread
    participant JAX as JAX/XLA

    NB->>DALI: Create training_iterator & validation_iterator
    Note over DALI: Prefetch thread starts (may issue cudaDeviceSynchronize)

    rect rgb(200, 255, 200)
        Note over NB,JAX: Warmup cell (NEW)
        NB->>JAX: update(model, dummy_inputs) — first call
        JAX->>JAX: JIT trace + XLA compile
        Note over JAX: CUDA graph capture happens here<br/>(before DALI produces real batches)
        JAX-->>NB: compiled & cached
    end

    rect rgb(220, 220, 255)
        Note over NB,DALI: Training loop
        loop each epoch
            NB->>DALI: next(training_iterator)
            DALI-->>NB: batch
            NB->>JAX: update(model, batch)
            Note over JAX: Already compiled — no graph capture
            JAX-->>NB: updated model
        end
    end
Loading

Reviews (2): Last reviewed commit: "Compile the function ahead of time in th..." | Re-trigger Greptile

Signed-off-by: Rostan Tabet <rtabet@nvidia.com>
@rostan-t rostan-t force-pushed the jax-example-fix-graph-capture branch from 1f9dbc5 to be34d7e Compare April 9, 2026 12:35
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 9, 2026

Tip:

Greploop — Automatically fix all review issues by running /greploops in Claude Code. It iterates: fix, push, re-review, repeat until 5/5 confidence.

Use the Greptile plugin for Claude Code to query reviews, search comments, and manage custom context directly from your terminal.

@rostan-t
Copy link
Copy Markdown
Collaborator Author

!build

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [48206370]: BUILD STARTED

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [48206370]: BUILD PASSED

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants