Skip to content

fix(FSDP2): auto-wrap policy ignoring _no_split_modules fallback#3999

Open
JohnGiorgi wants to merge 2 commits intohuggingface:mainfrom
JohnGiorgi:fix/fsdp2-auto-wrap-policy
Open

fix(FSDP2): auto-wrap policy ignoring _no_split_modules fallback#3999
JohnGiorgi wants to merge 2 commits intohuggingface:mainfrom
JohnGiorgi:fix/fsdp2-auto-wrap-policy

Conversation

@JohnGiorgi
Copy link
Copy Markdown

@JohnGiorgi JohnGiorgi commented Apr 7, 2026

What does this PR do?

Fixes a variable scoping bug in the policy() closure returned by fsdp2_prepare_auto_wrap_policy.

The bug

The closure checks fsdp2_plugin.transformer_cls_names_to_wrap is None (the plugin attribute) instead of not transformer_cls_to_wrap (the local set that was just populated from _no_split_modules). When the plugin attribute is None (which it is whenever the user relies on the model's _no_split_modules rather than explicitly setting fsdp_transformer_layer_cls_to_wrap) the policy always returns False, regardless of what was resolved into the local set.

         def policy(module: torch.nn.Module) -> bool:
-            if fsdp2_plugin.transformer_cls_names_to_wrap is None:
+            if not transformer_cls_to_wrap:
                 return False
             return isinstance(module, tuple(transformer_cls_to_wrap))

Why it matters

In the current codebase, set_auto_wrap_policy happens to mutate fsdp_plugin.transformer_cls_names_to_wrap before fsdp2_prepare_auto_wrap_policy runs, masking this bug through the standard Accelerator.prepare() path. But the function's contract is still broken: it builds a local set, then ignores it. Anyone calling the function directly (e.g. tests, custom training loops) gets wrong results, and relying on mutation from a different method is fragile.

Reproduction (no GPU required)
import torch
from transformers import AutoConfig, AutoModelForCausalLM
from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin
from accelerate.utils.fsdp_utils import fsdp2_prepare_auto_wrap_policy

config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
with torch.device("meta"):
    model = AutoModelForCausalLM.from_config(config)

plugin = FullyShardedDataParallelPlugin(
    auto_wrap_policy="TRANSFORMER_BASED_WRAP",
    fsdp_version=2,
)

policy_func = fsdp2_prepare_auto_wrap_policy(plugin, model)
matched = sum(1 for m in model.modules() if policy_func(m))
layers = sum(1 for m in model.modules() if m.__class__.__name__ == "LlamaDecoderLayer")
print(f"Matched: {matched}, Expected: {layers}")
# Before fix: Matched: 0, Expected: 30
# After fix:  Matched: 30, Expected: 30

Likely related to #3474.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case. Related: FSDP2 - High memory usage with LORA #3474
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

@S1ro1 @SunMarc

JohnGiorgi and others added 2 commits April 7, 2026 00:04
The `policy()` closure inside `fsdp2_prepare_auto_wrap_policy` checks
`fsdp2_plugin.transformer_cls_names_to_wrap is None` to decide whether
to wrap a module. When `transformer_cls_names_to_wrap` is not explicitly
set (the common case — relying on the model's `_no_split_modules`), this
check makes the policy return `False` for every module, even though the
local `transformer_cls_to_wrap` set was correctly populated from
`_no_split_modules`.

The fix changes the check to use the local `transformer_cls_to_wrap`
variable instead of the plugin attribute, so that the `_no_split_modules`
fallback is respected.

Without this fix, the entire model becomes a single FSDP2 unit (only
`fully_shard(model)` on the root), causing every GPU to all-gather the
full model parameters during forward — leading to OOM on large models.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@JohnGiorgi JohnGiorgi changed the title Fix FSDP2 auto-wrap policy ignoring _no_split_modules fallback fix(FSDP2): auto-wrap policy ignoring _no_split_modules fallback Apr 7, 2026
@JohnGiorgi JohnGiorgi marked this pull request as ready for review April 7, 2026 00:21
@JohnGiorgi JohnGiorgi marked this pull request as draft April 7, 2026 12:59
Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, thanks for fixing this ! Lmk when this is good to be merged !

@JohnGiorgi JohnGiorgi marked this pull request as ready for review April 7, 2026 16:28
@JohnGiorgi
Copy link
Copy Markdown
Author

Indeed, thanks for fixing this ! Lmk when this is good to be merged !

Thanks for taking a look! I think its ready to go

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants