Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions deep_ep/utils/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,17 +242,68 @@ def check_fast_rdma_atomic_support(nic_name: str = _DEFAULT_NIC_NAME) -> bool:
return False


# MLX5DV_CONTEXT_MASK_NUM_LAG_PORTS from <infiniband/mlx5dv.h>; pyverbs does not
# re-export this constant, and ``query_mlx5_device()`` 's default comp_mask=-1
# (which ORs the masks pyverbs knows about) skips it on at least some versions,
# leaving ``num_lag_ports`` at 0. Pass the bit explicitly.
_MLX5DV_CONTEXT_MASK_NUM_LAG_PORTS = 1 << 9


@functools.lru_cache()
def _query_num_lag_ports(nic_name: str) -> int:
Comment on lines +252 to +253

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔵 suggestion: _query_num_lag_ports 使用 lru_cache 缓存结果,包含失败时返回的回退值 1。若首次调用因瞬时原因(如设备短暂不可用)失败而缓存为 1,后续即使 LAG 恢复也将一直沿用单轨值。鉴于该函数仅在启动期调用一次、与 get_rdma_gbs 的缓存语义一致,影响有限,但可在注释中说明此缓存语义以避免误用。

"""
Number of physical LAG ports underneath an mlx5 RoCE device.

Mellanox RoCE LAG presents N physical rails as a single logical port,
so ``ibstat`` reports the per-rail rate, not the aggregated bandwidth.
Querying the mlx5 direct-verbs interface for ``num_lag_ports`` is the
canonical way to recover the rail count without an `ib_write_bw` probe.

Returns 1 (legacy behaviour) when ``pyverbs`` is unavailable or the
query fails for any reason.
"""
# noinspection PyBroadException
try:
from pyverbs.providers.mlx5.mlx5dv import Mlx5Context, Mlx5DVContextAttr
except ImportError:
return 1
try:
ctx = Mlx5Context(attr=Mlx5DVContextAttr(), name=nic_name)
try:
dv = ctx.query_mlx5_device(comp_mask=_MLX5DV_CONTEXT_MASK_NUM_LAG_PORTS)
num_lag_ports = int(dv.num_lag_ports or 0)
finally:
ctx.close()
return max(num_lag_ports, 1)
except Exception:
return 1


@functools.lru_cache()
def get_rdma_gbs(nic_name: str = _DEFAULT_NIC_NAME) -> float:
"""
Get the RDMA bandwidth in GB/s, cached.

On RoCE LAG fabrics the value is automatically scaled by the number of
underlying physical ports reported by the mlx5 direct-verbs interface,
so a 2-rail bond delivers ``2 * ibstat_rate / 8`` instead of half of it.
Setting ``EP_RDMA_GBS=<gbps>`` skips detection and uses the supplied
value directly (handy when ``ibstat`` is missing or behaves oddly).

Arguments:
nic_name: the NIC device name.

Returns:
gbs: the RDMA bandwidth in GB/s (0 if detection fails).
"""
override = os.getenv('EP_RDMA_GBS')
if override:
# noinspection PyBroadException
try:
return float(override) / 8
except ValueError:
print(f'Invalid EP_RDMA_GBS={override!r}, ignoring and falling back to ibstat')
Comment on lines +299 to +305

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔵 suggestion: EP_RDMA_GBSget_rdma_gbs 函数体内读取,但该函数是 @lru_cache(以 nic_name 为键)。因此仅在“首次调用”时才会读取该环境变量并缓存,之后即便修改或新设置 EP_RDMA_GBS 也不会再生效。若期望支持运行时动态覆盖,需要把该环境变量也纳入 lru_cache 的 key(或不在被缓存的函数内部读取 env);若仅希望启动前设置一次,建议在 docstring 里明确说明“需在首次调用前设置,之后修改无效”。

Comment on lines +299 to +305

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔵 suggestion: EP_RDMA_GBS 覆盖路径未对非正数做校验:若用户误设 EP_RDMA_GBS=0 或负值,会直接返回 0 或负的带宽值。get_rdma_gbs 的返回值在 elastic.py:614 被用于 SM 数计算(通常作为分母或速率),0/负值可能导致除零或异常的 SM 估计。建议在 float 转换成功后校验 >0,否则也回退到 ibstat 探测。


# noinspection PyBroadException
try:
result = subprocess.run(['ibstat'], capture_output=True, text=True, check=True)
Expand All @@ -261,8 +312,9 @@ def get_rdma_gbs(nic_name: str = _DEFAULT_NIC_NAME) -> float:
pattern = rf"CA '{nic_name}'.*?Port \d+:\s*.*?Rate:\s*(\d+)"
match = re.search(pattern, output, re.DOTALL)
assert match
rate = int(match.group(1))
return rate / 8
rate_per_port = int(match.group(1))
except Exception as e:
print(f'Failed to get RDMA connection speed: {e}')
return 0

return rate_per_port * _query_num_lag_ports(nic_name) / 8
130 changes: 130 additions & 0 deletions tests/utils/test_envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""
Tests for the RoCE-LAG aware path in ``deep_ep.utils.envs.get_rdma_gbs``.

Both functions under test (`_query_num_lag_ports`, `get_rdma_gbs`) are
purely local — they hit the kernel via pyverbs / sysfs / ``ibstat`` without
any inter-node traffic — so a single process on one node is enough.

The script is also safe to launch under ``torchrun --nnodes=2 --nproc-per-node=8``;
each rank simply runs the same local probes and prints from its own context,
which exercises that the pyverbs query holds up under concurrent opens.

On H800 + CX-7 (2x200G NDR RoCE LAG):

ibstat 'mlx5_bond_1' Rate : 200 (Gb/s, per rail)
_query_num_lag_ports('mlx5_bond_1') : 2
get_rdma_gbs('mlx5_bond_1') : 50.0 (GB/s, = 400 Gb/s aggregated)
"""

import os
import subprocess

from deep_ep.utils.envs import _query_num_lag_ports, get_rdma_gbs


def _rank_prefix() -> str:
rank = os.environ.get("RANK")
return f"[rank {rank}] " if rank is not None else ""


def _log(msg: str) -> None:
print(_rank_prefix() + msg, flush=True)


def _has_pyverbs() -> bool:
try:
import pyverbs.providers.mlx5.mlx5dv # noqa: F401
except ImportError:
return False
return True


def _has_ibstat() -> bool:
try:
subprocess.run(["ibstat"], capture_output=True, check=False, timeout=2)
except FileNotFoundError:
return False
return True


def _ibstat_per_port_gbps(nic: str):
try:
out = subprocess.check_output(["ibstat", nic], text=True, timeout=5)
except (FileNotFoundError, subprocess.CalledProcessError, subprocess.TimeoutExpired):
return None
for line in out.splitlines():
line = line.strip()
if line.startswith("Rate:"):
try:
return int(line.split()[1])
except (IndexError, ValueError):
return None
return None


def _clear_caches() -> None:
_query_num_lag_ports.cache_clear()
get_rdma_gbs.cache_clear()


def test_query_num_lag_ports():
"""LAG-aware path returns >=1; on a real RoCE LAG fabric it returns >=2."""
nic = os.environ.get("EP_NIC_NAME", "mlx5_bond_1")
_clear_caches()
n = _query_num_lag_ports(nic)
_log(f"_query_num_lag_ports({nic!r}) = {n}")

if not _has_pyverbs():
# No pyverbs -> intentional fallback to legacy single-rail behaviour.
assert n == 1
return
assert n >= 1, f"expected >= 1 LAG port, got {n}"


def test_get_rdma_gbs_lag_aggregation():
"""get_rdma_gbs should report rate_per_port * num_lag_ports / 8 in GB/s."""
nic = os.environ.get("EP_NIC_NAME", "mlx5_bond_1")
os.environ.pop("EP_RDMA_GBS", None)
_clear_caches()

if not _has_ibstat():
_log("skipping: ibstat not installed")
return

rate_per_port = _ibstat_per_port_gbps(nic)
if rate_per_port is None:
_log(f"skipping: ibstat reported no Rate for {nic}")
return
n_lag = _query_num_lag_ports(nic)
expected_gbs = rate_per_port * n_lag / 8.0
actual_gbs = get_rdma_gbs(nic)
_log(
f"get_rdma_gbs({nic!r}) = {actual_gbs} GB/s "
f"(per-port {rate_per_port} Gb/s x {n_lag} rails / 8)"
)
assert actual_gbs == expected_gbs, f"expected {expected_gbs}, got {actual_gbs}"
# Headline assertion: on a 2-rail 200G LAG, we recover ~50 GB/s
# (= 400 Gb/s aggregated), not the previous 25 GB/s single-rail value.
if n_lag >= 2 and rate_per_port >= 100:
aggregated_gbps = rate_per_port * n_lag
_log(f"LAG aggregation detected: {aggregated_gbps} Gb/s -> {actual_gbs} GB/s")
Comment on lines +98 to +110

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔵 suggestion: test_get_rdma_gbs_lag_aggregationexpected_gbs 是用 ibstat 重解析 + 同一个 _query_num_lag_ports(nic) 算出来的,而 get_rdma_gbs 内部恰好就是 rate_per_port * _query_num_lag_ports(nic) / 8,因此 actual_gbs == expected_gbs 本质上只是在验证“函数是否与自身一致”(仅能捕获重构导致的不一致)。真正的结论(LAG 下恢复约 50 GB/s、而非 25 GB/s)只是打印出来、并未断言;遇到非 LAG 环境又直接 skip。建议至少在检测到 LAG(n_lag >= 2)时断言聚合带宽大于单 rail 值(或约为 n_lag 倍),否则这个测试对“修复是否真的起效”几乎没有校验力,存在“测试通过但实际仍拿不到聚合带宽”的风险。



def test_ep_rdma_gbs_env_override():
"""EP_RDMA_GBS overrides the probe; value is in Gb/s and divided by 8."""
nic = os.environ.get("EP_NIC_NAME", "mlx5_bond_1")
_clear_caches()
os.environ["EP_RDMA_GBS"] = "400"
try:
assert get_rdma_gbs(nic) == 50.0
finally:
del os.environ["EP_RDMA_GBS"]
_clear_caches()
_log("EP_RDMA_GBS=400 -> 50.0 GB/s")


if __name__ == "__main__":
test_query_num_lag_ports()
test_get_rdma_gbs_lag_aggregation()
test_ep_rdma_gbs_env_override()
_log("ALL TESTS PASSED")