diff --git a/src/accelerate/utils/fsdp_utils.py b/src/accelerate/utils/fsdp_utils.py index 6897888c0e4..233aabf0fc7 100644 --- a/src/accelerate/utils/fsdp_utils.py +++ b/src/accelerate/utils/fsdp_utils.py @@ -658,13 +658,28 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: ) model_has_params4bit = False + incompatible_params4bit = set() for name, param in model.named_parameters(): # this is a temporary fix whereby loading models with bnb params cannot be moved from # GPU to a meta device due with FSDP2 because torch operations don't return the original class type # bypassing the move to meta will still cause the VRAM spike, but at least it still will load if param.__class__.__name__ == "Params4bit": model_has_params4bit = True - break + # Exclude non-floating frozen Params4bit from FSDP sharding. + # Default uint8 quant_storage cannot survive fully_shard's DTensor conversion. + if (not param.requires_grad) and (not param.is_floating_point()) and (not param.is_complex()): + incompatible_params4bit.add(param) + + if incompatible_params4bit and is_torch_version(">=", "2.7.0"): + ignored = set(fsdp2_kwargs.get("ignored_params", set())) + fsdp2_kwargs["ignored_params"] = ignored | incompatible_params4bit + if accelerator.is_main_process: + warnings.warn( + f"Found {len(incompatible_params4bit)} non-floating frozen Params4bit. " + "Excluding from FSDP2 sharding to prevent quant_state corruption." + "To enable memory-efficient sharding of 4-bit weights, set" + "bnb_4bit_quant_storage to a floating dtype (e.g. bf16)." + ) if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit: # Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard`