Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils import TransformersKwargs, agnostic, auto_docstring, can_return_tuple, logging
from ...utils.generic import maybe_autocast, merge_with_config_defaults
from ...utils.import_utils import is_flash_linear_attention_available
from ...utils.output_capturing import capture_outputs
Expand Down Expand Up @@ -695,7 +695,7 @@ def __init__(self, config: OlmoHybridConfig, layer_idx: int):
else FusedRMSNormGated(
self.head_v_dim,
eps=1e-5,
device=torch.cuda.current_device(),
device=agnostic.gpu.current_device(),
dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
)
)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ...modeling_rope_utils import dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, logging
from ...utils import TransformersKwargs, agnostic, auto_docstring, logging
from ...utils.generic import maybe_autocast, merge_with_config_defaults
from ...utils.import_utils import is_flash_linear_attention_available
from ...utils.output_capturing import capture_outputs
Expand Down Expand Up @@ -513,7 +513,7 @@ def __init__(self, config: OlmoHybridConfig, layer_idx: int):
else FusedRMSNormGated(
self.head_v_dim,
eps=1e-5,
device=torch.cuda.current_device(),
device=agnostic.gpu.current_device(),
dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
)
)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/qwen3_5/modeling_qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check
from ...utils import TransformersKwargs, agnostic, auto_docstring, can_return_tuple, logging, torch_compilable_check
from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults
from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available
from ...utils.output_capturing import capture_outputs
Expand Down Expand Up @@ -395,7 +395,7 @@ def __init__(self, config: Qwen3_5Config, layer_idx: int):
self.head_v_dim,
eps=self.layer_norm_epsilon,
activation=self.activation,
device=torch.cuda.current_device(),
device=agnostic.gpu.current_device(),
dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
)
)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check
from ...utils import TransformersKwargs, agnostic, auto_docstring, can_return_tuple, logging, torch_compilable_check
from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults
from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available
from ...utils.output_capturing import OutputRecorder, capture_outputs
Expand Down Expand Up @@ -396,7 +396,7 @@ def __init__(self, config: Qwen3_5MoeConfig, layer_idx: int):
self.head_v_dim,
eps=self.layer_norm_epsilon,
activation=self.activation,
device=torch.cuda.current_device(),
device=agnostic.gpu.current_device(),
dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
)
)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/qwen3_next/modeling_qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils import TransformersKwargs, agnostic, auto_docstring, can_return_tuple, logging
from ...utils.generic import maybe_autocast, merge_with_config_defaults
from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available
from ...utils.output_capturing import OutputRecorder, capture_outputs
Expand Down Expand Up @@ -541,7 +541,7 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int):
self.head_v_dim,
eps=self.layer_norm_epsilon,
activation=self.activation,
device=torch.cuda.current_device(),
device=agnostic.gpu.current_device(),
dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
)
)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/qwen3_next/modular_qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, logging
from ...utils import TransformersKwargs, agnostic, auto_docstring, logging
from ...utils.generic import merge_with_config_defaults
from ...utils.import_utils import (
is_causal_conv1d_available,
Expand Down Expand Up @@ -380,7 +380,7 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int):
self.head_v_dim,
eps=self.layer_norm_epsilon,
activation=self.activation,
device=torch.cuda.current_device(),
device=agnostic.gpu.current_device(),
dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
)
)
Expand Down
83 changes: 83 additions & 0 deletions src/transformers/utils/agnostic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
GPU calls that are device-agnostic.
"""

try:
import torch
except Exception:
torch = None


class AgnosticGPU:
@staticmethod
def configure() -> "AgnosticGPU":
return (
NoGPU()
if torch is None
else CUDAGPU()
if torch.cuda.is_available()
else XPUGPU()
if (hasattr(torch, "xpu") and torch.xpu.is_available())
else MPSGPU()
if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available())
else NoGPU()
)

name: str

def is_accelerator_available(self) -> bool:
return False

def current_device(self) -> int:
return 0

def device_count(self) -> int:
return 0


class CUDAGPU(AgnosticGPU):
def __init__(self):
assert torch is not None
self.name = "cuda"
self.is_accelerator_available = torch.cuda.is_available
self.current_device = torch.cuda.current_device
self.device_count = torch.cuda.device_count


class XPUGPU(AgnosticGPU):
def __init__(self):
assert torch is not None
self.name = "xpu"
self.is_accelerator_available = torch.xpu.is_available
self.current_device = torch.xpu.current_device
self.device_count = torch.xpu.device_count


class MPSGPU(AgnosticGPU):
def __init__(self):
assert torch is not None
self.name = "mps"
self.is_accelerator_available = torch.mps.is_available
# self.current_device = torch.mps.current_device
self.device_count = torch.mps.device_count


class NoGPU(AgnosticGPU):
def __init__(self) -> None:
self.name = "cpu"


gpu = AgnosticGPU.configure()
8 changes: 7 additions & 1 deletion src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,8 +794,14 @@ def is_mamba_2_ssm_available() -> bool:

@lru_cache
def is_flash_linear_attention_available():
from . import agnostic

is_available, fla_version = _is_package_available("fla", return_version=True)
return is_torch_cuda_available() and is_available and version.parse(fla_version) >= version.parse("0.2.2")
return (
agnostic.gpu.is_accelerator_available()
and is_available
and version.parse(fla_version) >= version.parse("0.2.2")
)


@lru_cache
Expand Down
Loading