Skip to content

[rl] Remove call to get_model_state_dict in push_model_state_dict#3066

Merged
daniellepintz merged 1 commit intomainfrom
fix-weight-sync-regression
Apr 23, 2026
Merged

[rl] Remove call to get_model_state_dict in push_model_state_dict#3066
daniellepintz merged 1 commit intomainfrom
fix-weight-sync-regression

Conversation

@daniellepintz
Copy link
Copy Markdown
Contributor

The trainer refactor (#2985) changed push_model_state_dict to call get_model_state_dict() instead of model.state_dict(). While the end result of these calls are the same (returns sharded state_dict with DTensors), calling get_model_state_dict adds some overhead.

Average push time over 10 steps for Qwen3 0.6B:
with get_model_state_dict(): 0.0170s
with model.state_dict(): 0.008s

Revert to previous way of calling self.model.state_dict() directly.

…del_state_dict

The trainer refactor (#2985) changed push_model_state_dict to call
get_model_state_dict() which unshards FSDP state (all-gather) before
pushing to TorchStore. This added an expensive all-gather and increased
the size of tensors transferred. Revert to passing self.model.state_dict()
directly, which keeps the sharded representation.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 23, 2026
Copy link
Copy Markdown
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch thank you! Would you mind opening a GitHub Issue as well to add in timing tests for our RL work? That would've caught this immediately

@daniellepintz daniellepintz merged commit b21bf09 into main Apr 23, 2026
10 of 14 checks passed
means "skip StorageVolumes and let the destination read directly
from the source's GPU memory".

Uses get_model_state_dict() to unshard FSDP state before pushing.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now why we don't need unshard FSDP state?

Copy link
Copy Markdown
Contributor Author

@daniellepintz daniellepintz May 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, even before this PR, we weren't unsharding the FSDP state; the comment was wrong. by default get_model_state_dict() returns the sharded state

we don't need to unshard FSDP state because torchstore handles sharded DTensors. each trainer rank can put its own shard of a tensor to TorchStore, and when the generator pulls, it will handle the resharding if it needs to

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants