Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/accelerate/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,13 +658,28 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
)

model_has_params4bit = False
incompatible_params4bit = set()
for name, param in model.named_parameters():
# this is a temporary fix whereby loading models with bnb params cannot be moved from
# GPU to a meta device due with FSDP2 because torch operations don't return the original class type
# bypassing the move to meta will still cause the VRAM spike, but at least it still will load
if param.__class__.__name__ == "Params4bit":
model_has_params4bit = True
break
# Exclude non-floating frozen Params4bit from FSDP sharding.
# Default uint8 quant_storage cannot survive fully_shard's DTensor conversion.
if (not param.requires_grad) and (not param.is_floating_point()) and (not param.is_complex()):
incompatible_params4bit.add(param)

if incompatible_params4bit and is_torch_version(">=", "2.7.0"):
ignored = set(fsdp2_kwargs.get("ignored_params", set()))
fsdp2_kwargs["ignored_params"] = ignored | incompatible_params4bit
if accelerator.is_main_process:
warnings.warn(
f"Found {len(incompatible_params4bit)} non-floating frozen Params4bit. "
"Excluding from FSDP2 sharding to prevent quant_state corruption."
"To enable memory-efficient sharding of 4-bit weights, set"
"bnb_4bit_quant_storage to a floating dtype (e.g. bf16)."
)

if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
# Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard`
Expand Down
Loading