diff --git a/src/accelerate/utils/fsdp_utils.py b/src/accelerate/utils/fsdp_utils.py index 75e54c7e51e..d23bf9bd5ae 100644 --- a/src/accelerate/utils/fsdp_utils.py +++ b/src/accelerate/utils/fsdp_utils.py @@ -790,7 +790,7 @@ def fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model: torch.nn.Module) -> Call transformer_cls_to_wrap.add(transformer_cls) def policy(module: torch.nn.Module) -> bool: - if fsdp2_plugin.transformer_cls_names_to_wrap is None: + if not transformer_cls_to_wrap: return False return isinstance(module, tuple(transformer_cls_to_wrap))