Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions cuda_core/cuda/core/_linker.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,23 @@ cdef class Linker:
else:
return as_py(self._culink_handle)

@property
def backend(self) -> CompilerBackendType:
"""Return this Linker instance's underlying :class:`CompilerBackendType`."""
return CompilerBackendType.NVJITLINK if self._use_nvjitlink else CompilerBackendType.DRIVER
@classmethod
def which_backend(cls) -> CompilerBackendType:
"""Return which linking backend will be used.

Returns :attr:`~CompilerBackendType.NVJITLINK` when the nvJitLink
library is available and meets the minimum version requirement,
otherwise :attr:`~CompilerBackendType.DRIVER`.

.. note::

Prefer letting :class:`Linker` decide. Query ``which_backend()``
only when you need to dispatch based on input format (for
example: choose PTX vs. LTOIR before constructing a
``Linker``). The returned value names an implementation
detail whose support matrix may shift across CTK releases.
"""
return CompilerBackendType.DRIVER if _decide_nvjitlink_or_driver() else CompilerBackendType.NVJITLINK


# =============================================================================
Expand Down
2 changes: 1 addition & 1 deletion cuda_core/cuda/core/_program.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ cdef inline int Program_init(Program self, object code, str code_type, object op
self._linker = Linker(
ObjectCode._init(code_bytes, code_type), options=_translate_program_options(options)
)
self._backend = str(self._linker.backend)
self._backend = str(Linker.which_backend())

elif code_type == "nvvm":
_get_nvvm_module() # Validate NVVM availability
Expand Down
6 changes: 6 additions & 0 deletions cuda_core/docs/source/release/1.0.0-notes.rst
Comment thread
leofang marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ Breaking changes
``CUgraphConditionalHandle`` value. Previously, ``.handle`` had to be
extracted explicitly.

- :meth:`Linker.which_backend` is now a classmethod instead of the former
``backend`` instance property. Call sites must use ``Linker.which_backend()``
(with parentheses) instead of ``linker.backend``. This allows querying the
linking backend without constructing a ``Linker`` instance — for example, to
choose between PTX and LTOIR input before linking.

- :attr:`DeviceMemoryResource.peer_accessible_by` now returns a
:class:`collections.abc.MutableSet` of :obj:`~_device.Device` objects instead
of a sorted ``tuple[int, ...]``. The property setter is unchanged.
Expand Down
40 changes: 39 additions & 1 deletion cuda_core/tests/test_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# SPDX-License-Identifier: Apache-2.0

import inspect

import pytest

from cuda.core import Device, Linker, LinkerOptions, Program, ProgramOptions, _linker
Expand Down Expand Up @@ -92,7 +94,7 @@ def test_linker_init(compile_ptx_functions, options):
linker = Linker(*compile_ptx_functions, options=options)
object_code = linker.link("cubin")
assert isinstance(object_code, ObjectCode)
assert linker.backend == ("driver" if is_culink_backend else "nvJitLink")
assert Linker.which_backend() == ("driver" if is_culink_backend else "nvJitLink")


def test_linker_init_invalid_arch(compile_ptx_functions):
Expand Down Expand Up @@ -242,3 +244,39 @@ def test_linker_options_nvjitlink_options_as_str():
assert f"-arch={ARCH}" in options
assert "-g" in options
assert "-lineinfo" in options


class TestWhichBackendClassmethod:
def test_which_backend_returns_nvjitlink(self, monkeypatch):
monkeypatch.setattr(_linker, "_use_nvjitlink_backend", True)
assert Linker.which_backend() == "nvJitLink"

def test_which_backend_returns_driver(self, monkeypatch):
monkeypatch.setattr(_linker, "_use_nvjitlink_backend", False)
assert Linker.which_backend() == "driver"

def test_which_backend_invokes_probe_when_not_memoised(self, monkeypatch):
monkeypatch.setattr(_linker, "_use_nvjitlink_backend", None)
called = []

def fake_decide():
called.append(True)
return False # False = not falling back to driver = nvJitLink

monkeypatch.setattr(_linker, "_decide_nvjitlink_or_driver", fake_decide)
result = Linker.which_backend()
assert result == "nvJitLink"
assert called, "_decide_nvjitlink_or_driver was not called"

def test_which_backend_is_classmethod(self):
attr = inspect.getattr_static(Linker, "which_backend")
assert isinstance(attr, classmethod)

def test_which_backend_is_not_property(self):
"""which_backend is a classmethod, not a property.

This is an intentional breaking change from the prior ``backend`` property API.
All call sites must use parens: ``Linker.which_backend()``.
"""
attr = inspect.getattr_static(Linker, "which_backend")
assert not isinstance(attr, property)
Loading