-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Make get_rdma_gbs RoCE-LAG aware #671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -242,17 +242,68 @@ def check_fast_rdma_atomic_support(nic_name: str = _DEFAULT_NIC_NAME) -> bool: | |
| return False | ||
|
|
||
|
|
||
| # MLX5DV_CONTEXT_MASK_NUM_LAG_PORTS from <infiniband/mlx5dv.h>; pyverbs does not | ||
| # re-export this constant, and ``query_mlx5_device()`` 's default comp_mask=-1 | ||
| # (which ORs the masks pyverbs knows about) skips it on at least some versions, | ||
| # leaving ``num_lag_ports`` at 0. Pass the bit explicitly. | ||
| _MLX5DV_CONTEXT_MASK_NUM_LAG_PORTS = 1 << 9 | ||
|
|
||
|
|
||
| @functools.lru_cache() | ||
| def _query_num_lag_ports(nic_name: str) -> int: | ||
| """ | ||
| Number of physical LAG ports underneath an mlx5 RoCE device. | ||
|
|
||
| Mellanox RoCE LAG presents N physical rails as a single logical port, | ||
| so ``ibstat`` reports the per-rail rate, not the aggregated bandwidth. | ||
| Querying the mlx5 direct-verbs interface for ``num_lag_ports`` is the | ||
| canonical way to recover the rail count without an `ib_write_bw` probe. | ||
|
|
||
| Returns 1 (legacy behaviour) when ``pyverbs`` is unavailable or the | ||
| query fails for any reason. | ||
| """ | ||
| # noinspection PyBroadException | ||
| try: | ||
| from pyverbs.providers.mlx5.mlx5dv import Mlx5Context, Mlx5DVContextAttr | ||
| except ImportError: | ||
| return 1 | ||
| try: | ||
| ctx = Mlx5Context(attr=Mlx5DVContextAttr(), name=nic_name) | ||
| try: | ||
| dv = ctx.query_mlx5_device(comp_mask=_MLX5DV_CONTEXT_MASK_NUM_LAG_PORTS) | ||
| num_lag_ports = int(dv.num_lag_ports or 0) | ||
| finally: | ||
| ctx.close() | ||
| return max(num_lag_ports, 1) | ||
| except Exception: | ||
| return 1 | ||
|
|
||
|
|
||
| @functools.lru_cache() | ||
| def get_rdma_gbs(nic_name: str = _DEFAULT_NIC_NAME) -> float: | ||
| """ | ||
| Get the RDMA bandwidth in GB/s, cached. | ||
|
|
||
| On RoCE LAG fabrics the value is automatically scaled by the number of | ||
| underlying physical ports reported by the mlx5 direct-verbs interface, | ||
| so a 2-rail bond delivers ``2 * ibstat_rate / 8`` instead of half of it. | ||
| Setting ``EP_RDMA_GBS=<gbps>`` skips detection and uses the supplied | ||
| value directly (handy when ``ibstat`` is missing or behaves oddly). | ||
|
|
||
| Arguments: | ||
| nic_name: the NIC device name. | ||
|
|
||
| Returns: | ||
| gbs: the RDMA bandwidth in GB/s (0 if detection fails). | ||
| """ | ||
| override = os.getenv('EP_RDMA_GBS') | ||
| if override: | ||
| # noinspection PyBroadException | ||
| try: | ||
| return float(override) / 8 | ||
| except ValueError: | ||
| print(f'Invalid EP_RDMA_GBS={override!r}, ignoring and falling back to ibstat') | ||
|
Comment on lines
+299
to
+305
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔵 suggestion:
Comment on lines
+299
to
+305
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔵 suggestion: EP_RDMA_GBS 覆盖路径未对非正数做校验:若用户误设 EP_RDMA_GBS=0 或负值,会直接返回 0 或负的带宽值。get_rdma_gbs 的返回值在 elastic.py:614 被用于 SM 数计算(通常作为分母或速率),0/负值可能导致除零或异常的 SM 估计。建议在 float 转换成功后校验 >0,否则也回退到 ibstat 探测。 |
||
|
|
||
| # noinspection PyBroadException | ||
| try: | ||
| result = subprocess.run(['ibstat'], capture_output=True, text=True, check=True) | ||
|
|
@@ -261,8 +312,9 @@ def get_rdma_gbs(nic_name: str = _DEFAULT_NIC_NAME) -> float: | |
| pattern = rf"CA '{nic_name}'.*?Port \d+:\s*.*?Rate:\s*(\d+)" | ||
| match = re.search(pattern, output, re.DOTALL) | ||
| assert match | ||
| rate = int(match.group(1)) | ||
| return rate / 8 | ||
| rate_per_port = int(match.group(1)) | ||
| except Exception as e: | ||
| print(f'Failed to get RDMA connection speed: {e}') | ||
| return 0 | ||
|
|
||
| return rate_per_port * _query_num_lag_ports(nic_name) / 8 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,130 @@ | ||
| """ | ||
| Tests for the RoCE-LAG aware path in ``deep_ep.utils.envs.get_rdma_gbs``. | ||
|
|
||
| Both functions under test (`_query_num_lag_ports`, `get_rdma_gbs`) are | ||
| purely local — they hit the kernel via pyverbs / sysfs / ``ibstat`` without | ||
| any inter-node traffic — so a single process on one node is enough. | ||
|
|
||
| The script is also safe to launch under ``torchrun --nnodes=2 --nproc-per-node=8``; | ||
| each rank simply runs the same local probes and prints from its own context, | ||
| which exercises that the pyverbs query holds up under concurrent opens. | ||
|
|
||
| On H800 + CX-7 (2x200G NDR RoCE LAG): | ||
|
|
||
| ibstat 'mlx5_bond_1' Rate : 200 (Gb/s, per rail) | ||
| _query_num_lag_ports('mlx5_bond_1') : 2 | ||
| get_rdma_gbs('mlx5_bond_1') : 50.0 (GB/s, = 400 Gb/s aggregated) | ||
| """ | ||
|
|
||
| import os | ||
| import subprocess | ||
|
|
||
| from deep_ep.utils.envs import _query_num_lag_ports, get_rdma_gbs | ||
|
|
||
|
|
||
| def _rank_prefix() -> str: | ||
| rank = os.environ.get("RANK") | ||
| return f"[rank {rank}] " if rank is not None else "" | ||
|
|
||
|
|
||
| def _log(msg: str) -> None: | ||
| print(_rank_prefix() + msg, flush=True) | ||
|
|
||
|
|
||
| def _has_pyverbs() -> bool: | ||
| try: | ||
| import pyverbs.providers.mlx5.mlx5dv # noqa: F401 | ||
| except ImportError: | ||
| return False | ||
| return True | ||
|
|
||
|
|
||
| def _has_ibstat() -> bool: | ||
| try: | ||
| subprocess.run(["ibstat"], capture_output=True, check=False, timeout=2) | ||
| except FileNotFoundError: | ||
| return False | ||
| return True | ||
|
|
||
|
|
||
| def _ibstat_per_port_gbps(nic: str): | ||
| try: | ||
| out = subprocess.check_output(["ibstat", nic], text=True, timeout=5) | ||
| except (FileNotFoundError, subprocess.CalledProcessError, subprocess.TimeoutExpired): | ||
| return None | ||
| for line in out.splitlines(): | ||
| line = line.strip() | ||
| if line.startswith("Rate:"): | ||
| try: | ||
| return int(line.split()[1]) | ||
| except (IndexError, ValueError): | ||
| return None | ||
| return None | ||
|
|
||
|
|
||
| def _clear_caches() -> None: | ||
| _query_num_lag_ports.cache_clear() | ||
| get_rdma_gbs.cache_clear() | ||
|
|
||
|
|
||
| def test_query_num_lag_ports(): | ||
| """LAG-aware path returns >=1; on a real RoCE LAG fabric it returns >=2.""" | ||
| nic = os.environ.get("EP_NIC_NAME", "mlx5_bond_1") | ||
| _clear_caches() | ||
| n = _query_num_lag_ports(nic) | ||
| _log(f"_query_num_lag_ports({nic!r}) = {n}") | ||
|
|
||
| if not _has_pyverbs(): | ||
| # No pyverbs -> intentional fallback to legacy single-rail behaviour. | ||
| assert n == 1 | ||
| return | ||
| assert n >= 1, f"expected >= 1 LAG port, got {n}" | ||
|
|
||
|
|
||
| def test_get_rdma_gbs_lag_aggregation(): | ||
| """get_rdma_gbs should report rate_per_port * num_lag_ports / 8 in GB/s.""" | ||
| nic = os.environ.get("EP_NIC_NAME", "mlx5_bond_1") | ||
| os.environ.pop("EP_RDMA_GBS", None) | ||
| _clear_caches() | ||
|
|
||
| if not _has_ibstat(): | ||
| _log("skipping: ibstat not installed") | ||
| return | ||
|
|
||
| rate_per_port = _ibstat_per_port_gbps(nic) | ||
| if rate_per_port is None: | ||
| _log(f"skipping: ibstat reported no Rate for {nic}") | ||
| return | ||
| n_lag = _query_num_lag_ports(nic) | ||
| expected_gbs = rate_per_port * n_lag / 8.0 | ||
| actual_gbs = get_rdma_gbs(nic) | ||
| _log( | ||
| f"get_rdma_gbs({nic!r}) = {actual_gbs} GB/s " | ||
| f"(per-port {rate_per_port} Gb/s x {n_lag} rails / 8)" | ||
| ) | ||
| assert actual_gbs == expected_gbs, f"expected {expected_gbs}, got {actual_gbs}" | ||
| # Headline assertion: on a 2-rail 200G LAG, we recover ~50 GB/s | ||
| # (= 400 Gb/s aggregated), not the previous 25 GB/s single-rail value. | ||
| if n_lag >= 2 and rate_per_port >= 100: | ||
| aggregated_gbps = rate_per_port * n_lag | ||
| _log(f"LAG aggregation detected: {aggregated_gbps} Gb/s -> {actual_gbs} GB/s") | ||
|
Comment on lines
+98
to
+110
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔵 suggestion: |
||
|
|
||
|
|
||
| def test_ep_rdma_gbs_env_override(): | ||
| """EP_RDMA_GBS overrides the probe; value is in Gb/s and divided by 8.""" | ||
| nic = os.environ.get("EP_NIC_NAME", "mlx5_bond_1") | ||
| _clear_caches() | ||
| os.environ["EP_RDMA_GBS"] = "400" | ||
| try: | ||
| assert get_rdma_gbs(nic) == 50.0 | ||
| finally: | ||
| del os.environ["EP_RDMA_GBS"] | ||
| _clear_caches() | ||
| _log("EP_RDMA_GBS=400 -> 50.0 GB/s") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_query_num_lag_ports() | ||
| test_get_rdma_gbs_lag_aggregation() | ||
| test_ep_rdma_gbs_env_override() | ||
| _log("ALL TESTS PASSED") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔵 suggestion: _query_num_lag_ports 使用 lru_cache 缓存结果,包含失败时返回的回退值 1。若首次调用因瞬时原因(如设备短暂不可用)失败而缓存为 1,后续即使 LAG 恢复也将一直沿用单轨值。鉴于该函数仅在启动期调用一次、与 get_rdma_gbs 的缓存语义一致,影响有限,但可在注释中说明此缓存语义以避免误用。