Skip to content

Fix batch invariant mode: using NCCL tree based all-reduce#2994

Merged
wwwjn merged 6 commits intomainfrom
fix-batch-invariant
Apr 21, 2026
Merged

Fix batch invariant mode: using NCCL tree based all-reduce#2994
wwwjn merged 6 commits intomainfrom
fix-batch-invariant

Conversation

@wwwjn
Copy link
Copy Markdown
Contributor

@wwwjn wwwjn commented Apr 16, 2026

As titled. Previously we used ring based algorithm which is not deterministic

wwwjn added 2 commits April 15, 2026 15:13
Align torchtitan's batch-invariant NCCL settings with vLLM's
(vllm/model_executor/layers/batch_invariant.py) to achieve bitwise
identity between trainer and generator. Key changes:
- NCCL_ALGO: Ring -> allreduce:tree
- Add NCCL_LAUNCH_MODE=GROUP, NCCL_P2P_NET_DISABLE=1,
  NCCL_NTHREADS=1, NCCL_SOCKET_NTHREADS=1
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 16, 2026
Copy link
Copy Markdown
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we error out if SP + batch invariance are used together?

wwwjn added 3 commits April 16, 2026 14:58
SP uses reduce-scatter which only supports Ring algorithm in NCCL.
Unlike allreduce (pinned to tree for cross-node determinism),
reduce-scatter's Ring has not been validated for cross-node
bitwise determinism. Disable SP in the batch-invariant config
and error out if both are enabled together.
os.environ["NCCL_MIN_NCHANNELS"] = "1" # Single channel to avoid split interleaving
os.environ["NCCL_MAX_NCHANNELS"] = "1"
os.environ["NCCL_PROTO"] = "Simple" # LL/LL128 may reorder reductions
# Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is controlling vllm's all-reduce kernel, which we are not relying on

When enabled, batch-invariant mode will:
- Replaces `mm`, `addmm`, `log_softmax`, and `mean.dim` with Triton kernels that use a fixed tile iteration order (via [batch_invariant_ops](https://github.com/thinking-machines-lab/batch_invariant_ops))
- Forces NCCL to use Ring all-reduce with a single channel for deterministic inter-GPU collectives
- Forces deterministic NCCL collectives (single channel, simple protocol, tree allreduce) matching vLLM's settings
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious previously why we see Ring also deterministic?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because previously we only tested on TP=2 with smaller sequence length, within a single node. For larger size test, eg, cross node test, Tree algorithm is deterministic to best of my knowledge

if self.trainer.parallelism.enable_sequence_parallel:
raise ValueError(
"batch_invariant mode doesn't support SP now. "
"SP uses reduce-scatter which only supports Ring in NCCL "
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't FSDP also need reduce-scatter?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For bit-wise identity check, we comparing forward results between trainer and generator. When FSDP is enabled, the RS is applied during backward to sync gradients, which won't affect forward results

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch variant mode is about forward only, anything bwd is not required

  • inference is fwd only
  • there is no hope to achieve batch invariant for backward because the batch dim will be reduced too

@wwwjn wwwjn merged commit 4984c14 into main Apr 21, 2026
13 of 20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants