Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deep_ep/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def check_nvlink_connections(group: dist.ProcessGroup):
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_indices]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i >= j:
if i >= j or physical_device_indices[i] == physical_device_indices[j]:
continue
status = pynvml.nvmlDeviceGetP2PStatus(handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
assert status == pynvml.NVML_P2P_STATUS_OK,\
Expand Down
36 changes: 36 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os
import pytest
from unittest import mock
import torch
import torch.distributed as dist

from deep_ep.utils import check_nvlink_connections


@pytest.fixture
def mock_pcie_gpu():
"""Mock PCIE GPU to trigger NVLink connection check."""
with mock.patch("torch.cuda.get_device_name", return_value="NVIDIA A100-PCIE-40GB"):
yield


def test_nvlink_check_duplicate_physical_gpu(mock_pcie_gpu):
"""Test NVLink check works when multiple ranks share same physical GPU.

Fixes #582: AssertionError 'No NVLink connection between GPU X and GPU X'
"""
group = mock.Mock(spec=dist.ProcessGroup)
group.size.return_value = 2

# Mock CUDA_VISIBLE_DEVICES with duplicate ID
with mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,0"}):
with mock.patch("torch.cuda.current_device", return_value=0):
# Mock all_gather returns two same physical GPU IDs
def mock_all_gather(output_list, _, __):
output_list[:] = [0, 0]
with mock.patch("torch.distributed.all_gather_object", side_effect=mock_all_gather):
with mock.patch("pynvml.nvmlInit"):
with mock.patch("pynvml.nvmlDeviceGetHandleByIndex"):
with mock.patch("pynvml.nvmlShutdown"):
# Should not raise assertion for same GPU
check_nvlink_connections(group)