Skip to content

[demo] verify fully_shard([norm, head]) and fully_shard([tok_embedding, norm, head]) works with chunked loss#2976

Closed
weifengpy wants to merge 6 commits intopytorch:mainfrom
weifengpy:chunked-ce-loss
Closed

[demo] verify fully_shard([norm, head]) and fully_shard([tok_embedding, norm, head]) works with chunked loss#2976
weifengpy wants to merge 6 commits intopytorch:mainfrom
weifengpy:chunked-ce-loss

Conversation

@weifengpy
Copy link
Copy Markdown
Contributor

@weifengpy weifengpy commented Apr 15, 2026

remove if-else base on chunked loss in apply_fsdp: #2937

wwwjn added 6 commits April 10, 2026 15:40
Implements chunked cross-entropy loss that splits the sequence dimension
into N chunks, computing lm_head projection and CE loss per-chunk to avoid
materializing the full [B, L, V] logits tensor at once.

Key components:
- ChunkedCELoss: wraps lm_head + ce_loss with chunked forward/backward
- GradAccumulator: pre-allocated buffer for assembling chunk gradients
- _no_reshard_after_backward: FSDP2 context to avoid N all-gathers
- skip_lm_head kwarg on Decoder.forward() for the detach boundary
- ChunkedCELossFactory: deferred initialization (model not available at build time)
- Trainer integration with dedicated forward_backward_step branch
…CELoss

- Add loss_num_chunks to TrainingConfig (default 1, no-op)
- Trainer auto-wraps loss_fn in ChunkedCELossFactory when loss_num_chunks > 1
- Integration tests for FSDP, FSDP+TP(SP), FSDP+CP, FSDP+TP+CP, FSDP+compile
FSDP2's backward hooks are one-shot per forward pass. The previous approach
of calling self.lm_head(h_chunk) triggered FSDP2's backward hooks during
chunk backward, leaving no hooks for the decoder backward (h.backward(grad)),
causing zero gradients on model parameters.

Fix: Use F.linear(h_chunk, lm_weight) to bypass FSDP2 module hooks during
chunk computation. Use (h * accumulated_grad).sum().backward() instead of
h.backward(grad) to properly trigger FSDP2's hooks in a single backward pass.
Replace bare function + build_fn pattern with proper loss classes.
CrossEntropyLoss and MSELoss encapsulate compilation logic internally.
The old function names (cross_entropy_loss, mse_loss) remain as public API
for backward compatibility. build_cross_entropy_loss and build_mse_loss
now return class instances.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 15, 2026
@weifengpy weifengpy changed the title verify fully_shard([norm, head]) and fully_shard([tok_embedding, norm, head]) works with chunked loss [demo] verify fully_shard([norm, head]) and fully_shard([tok_embedding, norm, head]) works with chunked loss Apr 15, 2026
@weifengpy weifengpy closed this May 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants