fix(FSDP2): auto-wrap policy ignoring _no_split_modules fallback#3999
Open
JohnGiorgi wants to merge 2 commits intohuggingface:mainfrom
Open
fix(FSDP2): auto-wrap policy ignoring _no_split_modules fallback#3999JohnGiorgi wants to merge 2 commits intohuggingface:mainfrom
JohnGiorgi wants to merge 2 commits intohuggingface:mainfrom
Conversation
The `policy()` closure inside `fsdp2_prepare_auto_wrap_policy` checks `fsdp2_plugin.transformer_cls_names_to_wrap is None` to decide whether to wrap a module. When `transformer_cls_names_to_wrap` is not explicitly set (the common case — relying on the model's `_no_split_modules`), this check makes the policy return `False` for every module, even though the local `transformer_cls_to_wrap` set was correctly populated from `_no_split_modules`. The fix changes the check to use the local `transformer_cls_to_wrap` variable instead of the plugin attribute, so that the `_no_split_modules` fallback is respected. Without this fix, the entire model becomes a single FSDP2 unit (only `fully_shard(model)` on the root), causing every GPU to all-gather the full model parameters during forward — leading to OOM on large models. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
SunMarc
approved these changes
Apr 7, 2026
Author
Thanks for taking a look! I think its ready to go |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes a variable scoping bug in the
policy()closure returned byfsdp2_prepare_auto_wrap_policy.The bug
The closure checks
fsdp2_plugin.transformer_cls_names_to_wrap is None(the plugin attribute) instead ofnot transformer_cls_to_wrap(the local set that was just populated from_no_split_modules). When the plugin attribute isNone(which it is whenever the user relies on the model's_no_split_modulesrather than explicitly settingfsdp_transformer_layer_cls_to_wrap) the policy always returnsFalse, regardless of what was resolved into the local set.Why it matters
In the current codebase,
set_auto_wrap_policyhappens to mutatefsdp_plugin.transformer_cls_names_to_wrapbeforefsdp2_prepare_auto_wrap_policyruns, masking this bug through the standardAccelerator.prepare()path. But the function's contract is still broken: it builds a local set, then ignores it. Anyone calling the function directly (e.g. tests, custom training loops) gets wrong results, and relying on mutation from a different method is fragile.Reproduction (no GPU required)
Likely related to #3474.
Before submitting
Pull Request section?
to it if that's the case. Related: FSDP2 - High memory usage with LORA #3474
Who can review?
@S1ro1 @SunMarc