Fix batch invariant mode: using NCCL tree based all-reduce#2994
Conversation
Align torchtitan's batch-invariant NCCL settings with vLLM's (vllm/model_executor/layers/batch_invariant.py) to achieve bitwise identity between trainer and generator. Key changes: - NCCL_ALGO: Ring -> allreduce:tree - Add NCCL_LAUNCH_MODE=GROUP, NCCL_P2P_NET_DISABLE=1, NCCL_NTHREADS=1, NCCL_SOCKET_NTHREADS=1
tianyu-l
left a comment
There was a problem hiding this comment.
shall we error out if SP + batch invariance are used together?
SP uses reduce-scatter which only supports Ring algorithm in NCCL. Unlike allreduce (pinned to tree for cross-node determinism), reduce-scatter's Ring has not been validated for cross-node bitwise determinism. Disable SP in the batch-invariant config and error out if both are enabled together.
| os.environ["NCCL_MIN_NCHANNELS"] = "1" # Single channel to avoid split interleaving | ||
| os.environ["NCCL_MAX_NCHANNELS"] = "1" | ||
| os.environ["NCCL_PROTO"] = "Simple" # LL/LL128 may reorder reductions | ||
| # Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/batch_invariant.py |
There was a problem hiding this comment.
for batch inv, do we need to set s.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"? https://github.com/vllm-project/vllm/blob/219bb5b8c0dcc6a5d5f894e9168fa5b8c2f8255a/vllm/model_executor/layers/batch_invariant.py#L1031
There was a problem hiding this comment.
No, this is controlling vllm's all-reduce kernel, which we are not relying on
| When enabled, batch-invariant mode will: | ||
| - Replaces `mm`, `addmm`, `log_softmax`, and `mean.dim` with Triton kernels that use a fixed tile iteration order (via [batch_invariant_ops](https://github.com/thinking-machines-lab/batch_invariant_ops)) | ||
| - Forces NCCL to use Ring all-reduce with a single channel for deterministic inter-GPU collectives | ||
| - Forces deterministic NCCL collectives (single channel, simple protocol, tree allreduce) matching vLLM's settings |
There was a problem hiding this comment.
curious previously why we see Ring also deterministic?
There was a problem hiding this comment.
Because previously we only tested on TP=2 with smaller sequence length, within a single node. For larger size test, eg, cross node test, Tree algorithm is deterministic to best of my knowledge
| if self.trainer.parallelism.enable_sequence_parallel: | ||
| raise ValueError( | ||
| "batch_invariant mode doesn't support SP now. " | ||
| "SP uses reduce-scatter which only supports Ring in NCCL " |
There was a problem hiding this comment.
Doesn't FSDP also need reduce-scatter?
There was a problem hiding this comment.
For bit-wise identity check, we comparing forward results between trainer and generator. When FSDP is enabled, the RS is applied during backward to sync gradients, which won't affect forward results
There was a problem hiding this comment.
batch variant mode is about forward only, anything bwd is not required
- inference is fwd only
- there is no hope to achieve batch invariant for backward because the batch dim will be reduced too
As titled. Previously we used ring based algorithm which is not deterministic