diff --git a/invokeai/app/invocations/anima_denoise.py b/invokeai/app/invocations/anima_denoise.py index 0016a4fd261..a1e78aeccb7 100644 --- a/invokeai/app/invocations/anima_denoise.py +++ b/invokeai/app/invocations/anima_denoise.py @@ -396,7 +396,7 @@ def _run_llm_adapter_for_regions( def _run_diffusion(self, context: InvocationContext) -> torch.Tensor: device = TorchDevice.choose_torch_device() - inference_dtype = TorchDevice.choose_bfloat16_safe_dtype(device) + inference_dtype = TorchDevice.choose_anima_inference_dtype(device) if self.denoising_start >= self.denoising_end: raise ValueError( diff --git a/invokeai/app/invocations/anima_text_encoder.py b/invokeai/app/invocations/anima_text_encoder.py index 1856a69ae76..a8098889f79 100644 --- a/invokeai/app/invocations/anima_text_encoder.py +++ b/invokeai/app/invocations/anima_text_encoder.py @@ -132,7 +132,7 @@ def _encode_prompt( device = text_encoder.device # Apply LoRA models to the text encoder - lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device) + lora_dtype = TorchDevice.choose_anima_inference_dtype(device) exit_stack.enter_context( LayerPatcher.apply_smart_model_patches( model=text_encoder, diff --git a/invokeai/backend/model_manager/load/model_loaders/anima.py b/invokeai/backend/model_manager/load/model_loaders/anima.py index d63393fb7bc..94dc2dcc5a6 100644 --- a/invokeai/backend/model_manager/load/model_loaders/anima.py +++ b/invokeai/backend/model_manager/load/model_loaders/anima.py @@ -118,7 +118,7 @@ def _load_from_singlefile( # Determine safe dtype target_device = TorchDevice.choose_torch_device() - model_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device) + model_dtype = TorchDevice.choose_anima_inference_dtype(target_device) # Handle memory management new_sd_size = sum(ten.nelement() * model_dtype.itemsize for ten in sd.values()) diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index 01b4684af1a..359ce45dc4f 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -137,3 +137,17 @@ def choose_bfloat16_safe_dtype(cls, device: Optional[torch.device] = None) -> to if device.type == "cuda": return torch.float16 return torch.float32 + + @classmethod + def choose_anima_inference_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype: + """Choose the inference dtype for Anima models, honoring config.precision. + + When precision is 'auto', delegates to choose_bfloat16_safe_dtype (current + behavior). When precision is set to a specific value (float16, bfloat16, + float32), returns that dtype directly without hardware probing. + """ + device = device or cls.choose_torch_device() + config = get_config() + if config.precision == "auto": + return cls.choose_bfloat16_safe_dtype(device) + return NAME_TO_PRECISION[config.precision] diff --git a/tests/backend/util/test_devices.py b/tests/backend/util/test_devices.py index b65137c08dc..3f134e3c3da 100644 --- a/tests/backend/util/test_devices.py +++ b/tests/backend/util/test_devices.py @@ -130,3 +130,42 @@ def test_legacy_precision_name(): assert "float16" == choose_precision(torch.device("cuda")) assert "float16" == choose_precision(torch.device("mps")) assert "float32" == choose_precision(torch.device("cpu")) + + +# ===== choose_anima_inference_dtype (config.precision honoring) ============ + + +def test_choose_anima_inference_dtype_float16(): + """precision='float16' returns torch.float16 without touching hardware.""" + config = get_config() + config.precision = "float16" + result = TorchDevice.choose_anima_inference_dtype(torch.device("cpu")) + assert result is torch.float16 + + +def test_choose_anima_inference_dtype_bfloat16(): + """precision='bfloat16' returns torch.bfloat16 without touching hardware.""" + config = get_config() + config.precision = "bfloat16" + result = TorchDevice.choose_anima_inference_dtype(torch.device("cpu")) + assert result is torch.bfloat16 + + +def test_choose_anima_inference_dtype_float32(): + """precision='float32' returns torch.float32 without touching hardware.""" + config = get_config() + config.precision = "float32" + result = TorchDevice.choose_anima_inference_dtype(torch.device("cpu")) + assert result is torch.float32 + + +def test_choose_anima_inference_dtype_auto_delegates_to_safe_dtype(): + """precision='auto' delegates to choose_bfloat16_safe_dtype (current behavior).""" + config = get_config() + config.precision = "auto" + device = torch.device("cpu") + sentinel = torch.bfloat16 + with patch.object(TorchDevice, "choose_bfloat16_safe_dtype", return_value=sentinel) as mock_safe: + result = TorchDevice.choose_anima_inference_dtype(device) + assert result is sentinel + mock_safe.assert_called_once_with(device)