Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion invokeai/app/invocations/anima_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def _run_llm_adapter_for_regions(

def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
device = TorchDevice.choose_torch_device()
inference_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
inference_dtype = TorchDevice.choose_anima_inference_dtype(device)

if self.denoising_start >= self.denoising_end:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion invokeai/app/invocations/anima_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _encode_prompt(
device = text_encoder.device

# Apply LoRA models to the text encoder
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
lora_dtype = TorchDevice.choose_anima_inference_dtype(device)
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
model=text_encoder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _load_from_singlefile(

# Determine safe dtype
target_device = TorchDevice.choose_torch_device()
model_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
model_dtype = TorchDevice.choose_anima_inference_dtype(target_device)

# Handle memory management
new_sd_size = sum(ten.nelement() * model_dtype.itemsize for ten in sd.values())
Expand Down
14 changes: 14 additions & 0 deletions invokeai/backend/util/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,17 @@ def choose_bfloat16_safe_dtype(cls, device: Optional[torch.device] = None) -> to
if device.type == "cuda":
return torch.float16
return torch.float32

@classmethod
def choose_anima_inference_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
"""Choose the inference dtype for Anima models, honoring config.precision.

When precision is 'auto', delegates to choose_bfloat16_safe_dtype (current
behavior). When precision is set to a specific value (float16, bfloat16,
float32), returns that dtype directly without hardware probing.
"""
device = device or cls.choose_torch_device()
config = get_config()
if config.precision == "auto":
return cls.choose_bfloat16_safe_dtype(device)
return NAME_TO_PRECISION[config.precision]
39 changes: 39 additions & 0 deletions tests/backend/util/test_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,42 @@ def test_legacy_precision_name():
assert "float16" == choose_precision(torch.device("cuda"))
assert "float16" == choose_precision(torch.device("mps"))
assert "float32" == choose_precision(torch.device("cpu"))


# ===== choose_anima_inference_dtype (config.precision honoring) ============


def test_choose_anima_inference_dtype_float16():
"""precision='float16' returns torch.float16 without touching hardware."""
config = get_config()
config.precision = "float16"
result = TorchDevice.choose_anima_inference_dtype(torch.device("cpu"))
assert result is torch.float16


def test_choose_anima_inference_dtype_bfloat16():
"""precision='bfloat16' returns torch.bfloat16 without touching hardware."""
config = get_config()
config.precision = "bfloat16"
result = TorchDevice.choose_anima_inference_dtype(torch.device("cpu"))
assert result is torch.bfloat16


def test_choose_anima_inference_dtype_float32():
"""precision='float32' returns torch.float32 without touching hardware."""
config = get_config()
config.precision = "float32"
result = TorchDevice.choose_anima_inference_dtype(torch.device("cpu"))
assert result is torch.float32


def test_choose_anima_inference_dtype_auto_delegates_to_safe_dtype():
"""precision='auto' delegates to choose_bfloat16_safe_dtype (current behavior)."""
config = get_config()
config.precision = "auto"
device = torch.device("cpu")
sentinel = torch.bfloat16
with patch.object(TorchDevice, "choose_bfloat16_safe_dtype", return_value=sentinel) as mock_safe:
result = TorchDevice.choose_anima_inference_dtype(device)
assert result is sentinel
mock_safe.assert_called_once_with(device)
Loading