Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion invokeai/app/invocations/anima_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def _run_llm_adapter_for_regions(

def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
device = TorchDevice.choose_torch_device()
inference_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
inference_dtype = TorchDevice.choose_anima_inference_dtype(device)

if self.denoising_start >= self.denoising_end:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion invokeai/app/invocations/anima_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _encode_prompt(
device = text_encoder.device

# Apply LoRA models to the text encoder
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
lora_dtype = TorchDevice.choose_anima_inference_dtype(device)
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
model=text_encoder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _load_from_singlefile(

# Determine safe dtype
target_device = TorchDevice.choose_torch_device()
model_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
model_dtype = TorchDevice.choose_anima_inference_dtype(target_device)

# Handle memory management
new_sd_size = sum(ten.nelement() * model_dtype.itemsize for ten in sd.values())
Expand Down
14 changes: 14 additions & 0 deletions invokeai/backend/util/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,17 @@ def choose_bfloat16_safe_dtype(cls, device: Optional[torch.device] = None) -> to
if device.type == "cuda":
return torch.float16
return torch.float32

@classmethod
def choose_anima_inference_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
"""Choose the inference dtype for Anima models, honoring config.precision.

When precision is 'auto', delegates to choose_bfloat16_safe_dtype (current
behavior). When precision is set to a specific value (float16, bfloat16,
float32), returns that dtype directly without hardware probing.
"""
device = device or cls.choose_torch_device()
config = get_config()
if config.precision == "auto":
return cls.choose_bfloat16_safe_dtype(device)
return NAME_TO_PRECISION[config.precision]
39 changes: 39 additions & 0 deletions tests/backend/util/test_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,42 @@ def test_legacy_precision_name():
assert "float16" == choose_precision(torch.device("cuda"))
assert "float16" == choose_precision(torch.device("mps"))
assert "float32" == choose_precision(torch.device("cpu"))


# ===== choose_anima_inference_dtype (config.precision honoring) ============


def test_choose_anima_inference_dtype_float16():
"""precision='float16' returns torch.float16 without touching hardware."""
config = get_config()
config.precision = "float16"
result = TorchDevice.choose_anima_inference_dtype(torch.device("cpu"))
assert result is torch.float16


def test_choose_anima_inference_dtype_bfloat16():
"""precision='bfloat16' returns torch.bfloat16 without touching hardware."""
config = get_config()
config.precision = "bfloat16"
result = TorchDevice.choose_anima_inference_dtype(torch.device("cpu"))
assert result is torch.bfloat16


def test_choose_anima_inference_dtype_float32():
"""precision='float32' returns torch.float32 without touching hardware."""
config = get_config()
config.precision = "float32"
result = TorchDevice.choose_anima_inference_dtype(torch.device("cpu"))
assert result is torch.float32


def test_choose_anima_inference_dtype_auto_delegates_to_safe_dtype():
"""precision='auto' delegates to choose_bfloat16_safe_dtype (current behavior)."""
config = get_config()
config.precision = "auto"
device = torch.device("cpu")
sentinel = torch.bfloat16
with patch.object(TorchDevice, "choose_bfloat16_safe_dtype", return_value=sentinel) as mock_safe:
result = TorchDevice.choose_anima_inference_dtype(device)
assert result is sentinel
mock_safe.assert_called_once_with(device)
Loading