Skip to content

Improve compilation time (reduce from ~50 seconds to ~15s for vLLM)#3145

Merged
Lucaskabela merged 1 commit intopytorch:mainfrom
Lucaskabela:lucaskabela/improve_vllm_comp
May 1, 2026
Merged

Improve compilation time (reduce from ~50 seconds to ~15s for vLLM)#3145
Lucaskabela merged 1 commit intopytorch:mainfrom
Lucaskabela:lucaskabela/improve_vllm_comp

Conversation

@Lucaskabela
Copy link
Copy Markdown
Contributor

@Lucaskabela Lucaskabela commented Apr 28, 2026

Fixes #3119 and #3071

Summary

We make significant improvements to the vLLM compilation, saving ~40s (20s from cudagraph, 1s per step, and ~13s from Dynamo) from the following changes:

  1. Since we are using FA based attention which is compatible with traceability, we can use FullGraph cudagraphs (which yields -1s per step)
  2. We adjust the max cudagraph capture size based on default from configs - this cuts cudagraph capture time from 30s to 11s
  3. We move compilation to use the same compile pipeline that trainer model does. While this misses out on some of vLLM's custom passes, we have more observability and controllability to ensure unified definition. This has the added benefit of leveraging the regional compile - reducing compile from O(n) to O(1) (as we compile one transformer layer then reuse it) cutting Dynamo compile time from 17s to 4s

Test plan

python torchtitan/experiments/rl/grpo.py --module rl --config rl_grpo_qwen3_0_6b

Test results:

Before

INFO 04-28 16:20:20 [backends.py:1128] [actor=<root>.<torchtitan.experiments.rl.actors.generator.VLLMGenerator generator{'gpus': 0/4}>] Dynamo bytecode transform time: 15.26 s
Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|███████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:31<00:00,  1.13it/s]
...
[actor=<root>] Step  9 | Loss: +0.0020 | Reward: +0.743 (correctness=+0.450, format=+0.293) | Avg tokens: 100 | Logprob diff: mean=-6.4715e-05, max=2.2314e-01 | Time: 2.7s
[actor=<root>] Post-training validation
[actor=<root>] Summary:
  Pre:  mean_reward=+0.365 (correctness=+0.200, format=+0.165)
  Post: mean_reward=+0.700 (correctness=+0.400, format=+0.300)

After (this PR)

Capturing CUDA graphs (mixed prefill-decode, FULL): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:09<00:00,  1.34s/it]
...
[actor=<root>] Step  9 | Loss: +0.0020 | Reward: +0.743 (correctness=+0.450, format=+0.293) | Avg tokens: 100 | Logprob diff: mean=-6.4715e-05, max=2.2314e-01 | Time: 1.8s
[actor=<root>] Post-training validation
[actor=<root>] Summary:
  Pre:  mean_reward=+0.365 (correctness=+0.200, format=+0.165)
  Post: mean_reward=+0.700 (correctness=+0.400, format=+0.300)

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 28, 2026
@Lucaskabela Lucaskabela force-pushed the lucaskabela/improve_vllm_comp branch 3 times, most recently from fadcce5 to e81ea28 Compare April 28, 2026 23:01
@Lucaskabela Lucaskabela linked an issue Apr 28, 2026 that may be closed by this pull request
@Lucaskabela Lucaskabela force-pushed the lucaskabela/improve_vllm_comp branch 2 times, most recently from 8bac4cf to 671cc0a Compare April 29, 2026 00:04
@Lucaskabela Lucaskabela requested review from acisseJZhong, daniellepintz, tianyu-l and wwwjn and removed request for tianyu-l April 29, 2026 00:11
@Lucaskabela Lucaskabela marked this pull request as ready for review April 29, 2026 00:11
Comment thread torchtitan/experiments/rl/actors/generator.py Outdated
Comment thread torchtitan/experiments/rl/actors/generator.py Outdated
Comment thread torchtitan/experiments/rl/actors/generator.py Outdated
Comment thread torchtitan/experiments/rl/actors/generator.py Outdated
Comment thread torchtitan/experiments/rl/models/vllm_wrapper.py Outdated
disable_log_stats=True,
)
if config.max_num_seqs is not None:
engine_kwargs["max_num_seqs"] = config.max_num_seqs
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this field do?
what if we don't set it here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented above - this controls cudagraphs we capture, not setting it defaults to the behavior today on main

I can go ahead and make this set by default to avoid any sort of silent slowdown

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait these are two different kwargs -- the other is for vllm cudagraph behavior, what's this additional kwarg for?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_num_seqs controls other things like the padding for max size (used in kv cache)

kwargs: dict = dict(cudagraph_mode=self.cudagraph_mode, mode=0)

if max_num_seqs is not None and self.cudagraph_mode != "none":
kwargs["cudagraph_capture_sizes"] = self._compute_cudagraph_capture_sizes(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if we don't set it when cudagraph is enabled?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defaults to the 256 which captures ~35 different sizes (ranging from 1 to 256) so no incorrectness, just more memory and startup time used

Comment on lines +49 to +51
require vLLM's whole-model torch.compile to split the graph around
non-capturable ops, which conflicts with per-layer compile.
See https://docs.vllm.ai/en/latest/design/cuda_graphs/#cudagraphmodes"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when we enable cudagraph for per-layer compile?

We can save compile time, but what's the impact on run time, e.g. when going to GB200 with significant CPU overhead.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test plan shows the time impact - we observe speedup over piecewise in this particular setup

Copy link
Copy Markdown
Contributor

@tianyu-l tianyu-l Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also would like to check if it works with EP, being enabled in #3142

MoE has dynamic shape, despite being full graph torch-compilable

Comment on lines +49 to +51
require vLLM's whole-model torch.compile to split the graph around
non-capturable ops, which conflicts with per-layer compile.
See https://docs.vllm.ai/en/latest/design/cuda_graphs/#cudagraphmodes"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it seems we move compile back to torchtitan, but cudagraph application still in vllm. How far are we from moving cudagraph application also in torchtitan?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would estimate 2-3weeks but I can leave a TODO here that we should unify cudagraph config once we have it on trainer side

@Lucaskabela Lucaskabela force-pushed the lucaskabela/improve_vllm_comp branch from 671cc0a to f587e8a Compare April 29, 2026 18:00
Comment thread torchtitan/experiments/rl/actors/trainer.py Outdated
Comment thread torchtitan/experiments/rl/actors/generator.py Outdated
Comment thread torchtitan/experiments/rl/actors/generator.py
Comment thread torchtitan/experiments/rl/actors/generator.py Outdated
Comment thread torchtitan/experiments/rl/models/vllm_wrapper.py Outdated
Comment thread torchtitan/experiments/rl/actors/generator.py Outdated
@Lucaskabela Lucaskabela marked this pull request as draft April 29, 2026 23:59
@Lucaskabela Lucaskabela force-pushed the lucaskabela/improve_vllm_comp branch 2 times, most recently from 8509698 to 386d02e Compare April 30, 2026 15:52
@Lucaskabela Lucaskabela marked this pull request as ready for review April 30, 2026 16:12
@Lucaskabela Lucaskabela requested a review from tianyu-l April 30, 2026 16:12
@Lucaskabela Lucaskabela force-pushed the lucaskabela/improve_vllm_comp branch 2 times, most recently from 411d944 to e671a2e Compare April 30, 2026 23:05
Copy link
Copy Markdown
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! one nit

Comment thread torchtitan/experiments/rl/grpo.py Outdated
@Lucaskabela Lucaskabela force-pushed the lucaskabela/improve_vllm_comp branch from e671a2e to 5a362db Compare May 1, 2026 18:24
@Lucaskabela Lucaskabela merged commit 2e5f137 into pytorch:main May 1, 2026
7 of 11 checks passed
@Lucaskabela Lucaskabela deleted the lucaskabela/improve_vllm_comp branch May 6, 2026 19:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

2 participants