diff --git a/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py b/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py index bf72e06d94eb..dc61f5fccaa3 100644 --- a/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py +++ b/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py @@ -38,7 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils import TransformersKwargs, agnostic, auto_docstring, can_return_tuple, logging from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.import_utils import is_flash_linear_attention_available from ...utils.output_capturing import capture_outputs @@ -695,7 +695,7 @@ def __init__(self, config: OlmoHybridConfig, layer_idx: int): else FusedRMSNormGated( self.head_v_dim, eps=1e-5, - device=torch.cuda.current_device(), + device=agnostic.gpu.current_device(), dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(), ) ) diff --git a/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py b/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py index 089f29309007..2e0c043737bd 100644 --- a/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py +++ b/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py @@ -32,7 +32,7 @@ from ...modeling_rope_utils import dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, logging +from ...utils import TransformersKwargs, agnostic, auto_docstring, logging from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.import_utils import is_flash_linear_attention_available from ...utils.output_capturing import capture_outputs @@ -513,7 +513,7 @@ def __init__(self, config: OlmoHybridConfig, layer_idx: int): else FusedRMSNormGated( self.head_v_dim, eps=1e-5, - device=torch.cuda.current_device(), + device=agnostic.gpu.current_device(), dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(), ) ) diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index eba3eec02fdd..e3145b9f1b5c 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -44,7 +44,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check +from ...utils import TransformersKwargs, agnostic, auto_docstring, can_return_tuple, logging, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available from ...utils.output_capturing import capture_outputs @@ -395,7 +395,7 @@ def __init__(self, config: Qwen3_5Config, layer_idx: int): self.head_v_dim, eps=self.layer_norm_epsilon, activation=self.activation, - device=torch.cuda.current_device(), + device=agnostic.gpu.current_device(), dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(), ) ) diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index a8a46ecf508b..7627684d1f0a 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -45,7 +45,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check +from ...utils import TransformersKwargs, agnostic, auto_docstring, can_return_tuple, logging, torch_compilable_check from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available from ...utils.output_capturing import OutputRecorder, capture_outputs @@ -396,7 +396,7 @@ def __init__(self, config: Qwen3_5MoeConfig, layer_idx: int): self.head_v_dim, eps=self.layer_norm_epsilon, activation=self.activation, - device=torch.cuda.current_device(), + device=agnostic.gpu.current_device(), dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(), ) ) diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 9e7fa7e01c69..a13f71af9a6e 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -42,7 +42,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils import TransformersKwargs, agnostic, auto_docstring, can_return_tuple, logging from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available from ...utils.output_capturing import OutputRecorder, capture_outputs @@ -541,7 +541,7 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int): self.head_v_dim, eps=self.layer_norm_epsilon, activation=self.activation, - device=torch.cuda.current_device(), + device=agnostic.gpu.current_device(), dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(), ) ) diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index 417a9a59cf8b..a3f74b0c3e9b 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -28,7 +28,7 @@ from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, logging +from ...utils import TransformersKwargs, agnostic, auto_docstring, logging from ...utils.generic import merge_with_config_defaults from ...utils.import_utils import ( is_causal_conv1d_available, @@ -380,7 +380,7 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int): self.head_v_dim, eps=self.layer_norm_epsilon, activation=self.activation, - device=torch.cuda.current_device(), + device=agnostic.gpu.current_device(), dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(), ) ) diff --git a/src/transformers/utils/agnostic.py b/src/transformers/utils/agnostic.py new file mode 100644 index 000000000000..11ce7bdd09e8 --- /dev/null +++ b/src/transformers/utils/agnostic.py @@ -0,0 +1,83 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +GPU calls that are device-agnostic. +""" + +try: + import torch +except Exception: + torch = None + + +class AgnosticGPU: + @staticmethod + def configure() -> "AgnosticGPU": + return ( + NoGPU() + if torch is None + else CUDAGPU() + if torch.cuda.is_available() + else XPUGPU() + if (hasattr(torch, "xpu") and torch.xpu.is_available()) + else MPSGPU() + if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) + else NoGPU() + ) + + name: str + + def is_accelerator_available(self) -> bool: + return False + + def current_device(self) -> int: + return 0 + + def device_count(self) -> int: + return 0 + + +class CUDAGPU(AgnosticGPU): + def __init__(self): + assert torch is not None + self.name = "cuda" + self.is_accelerator_available = torch.cuda.is_available + self.current_device = torch.cuda.current_device + self.device_count = torch.cuda.device_count + + +class XPUGPU(AgnosticGPU): + def __init__(self): + assert torch is not None + self.name = "xpu" + self.is_accelerator_available = torch.xpu.is_available + self.current_device = torch.xpu.current_device + self.device_count = torch.xpu.device_count + + +class MPSGPU(AgnosticGPU): + def __init__(self): + assert torch is not None + self.name = "mps" + self.is_accelerator_available = torch.mps.is_available + # self.current_device = torch.mps.current_device + self.device_count = torch.mps.device_count + + +class NoGPU(AgnosticGPU): + def __init__(self) -> None: + self.name = "cpu" + + +gpu = AgnosticGPU.configure() diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 1e1ac2545f05..a97a2d37a570 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -794,8 +794,14 @@ def is_mamba_2_ssm_available() -> bool: @lru_cache def is_flash_linear_attention_available(): + from . import agnostic + is_available, fla_version = _is_package_available("fla", return_version=True) - return is_torch_cuda_available() and is_available and version.parse(fla_version) >= version.parse("0.2.2") + return ( + agnostic.gpu.is_accelerator_available() + and is_available + and version.parse(fla_version) >= version.parse("0.2.2") + ) @lru_cache