Make get_rdma_gbs RoCE-LAG aware#671
Conversation
Mellanox RoCE LAG presents N physical rails (e.g. 2x200G NDR) as a single logical port, so the existing ibstat-based probe in get_rdma_gbs() returns the per-rail rate, not the aggregated bandwidth. On a 2-rail bond that's half of what's actually achievable, and the downstream get_theoretical_num_sms calculation then picks the wrong bottleneck and recommends fewer SMs than it should. Probe num_lag_ports via the mlx5 direct-verbs interface (pyverbs) and multiply the ibstat rate by it. Falls back cleanly to the old behaviour when pyverbs is not installed. Also add an EP_RDMA_GBS escape hatch (in Gbps, divided by 8 here, same units as ibstat's Rate field) so deployments with broken probes or exotic fabrics can override without patching the package. This matches the request in deepseek-ai#614. Verified on H800 + CX-7: ibstat reports 200 Gb/s per rail, ib_write_bw with 8 QPs reaches ~372 Gb/s, and Mlx5Context.query_mlx5_device() returns num_lag_ports=2, recovering ~46 GB/s = 372/8 (vs the old 25). Fixes deepseek-ai#614 (credit @michaelchen1996 for the original report and patch).
Three local checks that fall back gracefully when pyverbs / ibstat /
mlx5_bond_* is absent (so single-rail and non-RDMA CI hosts skip):
- _query_num_lag_ports returns >=1
- get_rdma_gbs matches rate_per_port * num_lag_ports / 8
- EP_RDMA_GBS env override is honored
Pure-local probes, so single-process is enough; the file also runs
cleanly under torchrun --nnodes=2 --nproc-per-node=8 (each rank just
prints from its own context, exercising concurrent pyverbs opens).
On H800 + CX-7 2x200G RoCE LAG the test prints:
get_rdma_gbs('mlx5_bond_1') = 50.0 GB/s
LAG aggregation detected: 400 Gb/s -> 50.0 GB/s
pyverbs's default comp_mask=-1 ORs only the mask bits it knows about,
and at least some versions do not include MLX5DV_CONTEXT_MASK_NUM_LAG_PORTS
in that list. The result is that num_lag_ports stays at 0 and we fall back
to the legacy single-rail behaviour even on real LAG fabrics.
Passing the bit explicitly (1 << 9, taken from rdma-core's mlx5dv.h)
restores the intended behaviour.
Verified on H800 + CX-7 2-rail LAG:
comp_mask=-1 (default) -> num_lag_ports = 0
comp_mask=1<<9 (this patch) -> num_lag_ports = 2
get_rdma_gbs('mlx5_bond_1') = 50.0 GB/s (was 25.0)
ds-review-bot
left a comment
There was a problem hiding this comment.
🤖 ds-review-bot Code Review
Model 1
本 MR 让 get_rdma_gbs 在 Mellanox RoCE LAG(多条物理 rail 绑定成一个逻辑端口)下感知聚合带宽:通过 pyverbs 的 mlx5dv_query_device 查询 num_lag_ports,把 ibstat 报告的每 rail 速率乘以 rail 数,从而恢复真实单卡带宽(H800+CX-7 2×200G 下由 25 GB/s 修正为 50 GB/s);同时新增 EP_RDMA_GBS 环境变量直接覆盖检测结果,并在缺失 pyverbs / 非 mlx5 / 非 LAG 时优雅回退到旧行为。整体实现方向与 Mooncake 的 C++ 实现一致,正则解析 × num_lag_ports / 8 的核心计算正确,对下游 get_theoretical_num_sms 的带宽建模语义也吻合(rdma_gbs 本就是按整卡聚合带宽参与 traffic/gbs 计算)。主要不足是 LAG 查询失败时静默降级到半带宽、不输出任何告警,以及新增的测试基本是“函数自洽”式断言、对真实带宽缺乏校验力。
Model 2
本 MR 让 get_rdma_gbs 支持 Mellanox RoCE LAG 场景:在 CX-7 双轨 bond(mlx5_bond_*)上,ibstat 仅报告单轨速率,导致下游 SM 数计算把真实带宽低估一半。修复方式是通过 pyverbs 的 mlx5 direct-verbs 查询 num_lag_ports 并将 ibstat 速率乘以该值,同时新增 EP_RDMA_GBS 环境变量覆盖项;当 pyverbs 缺失、非 mlx5 或未启用 LAG 时干净地回退到原单轨行为。整体实现逻辑正确、异常处理完备、回退路径清晰,测试也能在缺少依赖时自动跳过。未发现明显 bug,下方仅给出两点提示性建议。
Files reviewed: 2
Issues found: 🟡 1 warning | 🔵 4 suggestion
Inline comments posted: 5
| try: | ||
| ctx = Mlx5Context(attr=Mlx5DVContextAttr(), name=nic_name) | ||
| try: | ||
| dv = ctx.query_mlx5_device(comp_mask=_MLX5DV_CONTEXT_MASK_NUM_LAG_PORTS) | ||
| num_lag_ports = int(dv.num_lag_ports or 0) | ||
| finally: | ||
| ctx.close() | ||
| return max(num_lag_ports, 1) | ||
| except Exception: | ||
| return 1 |
There was a problem hiding this comment.
🟡 warning: _query_num_lag_ports 用 try/except Exception 吞掉所有异常并静默返回 1,而 get_rdma_gbs 在调用它时也不输出任何“是否应用了 LAG 聚合”的日志。这意味着当 pyverbs 已安装但查询失败(设备打不开、pyverbs 版本不匹配、num_lag_ports 字段不存在等)时,函数会无任何提示地退回到原先的半带宽结果——而这正是本 MR 要修复的问题,用户根本无法察觉修复没有生效(且错误只发生在一台 node 上,跨节点更难排查)。注意无效 EP_RDMA_GBS 都已加了 print(第 303-304 行),而真正核心的 LAG 查询失败反而无声无息。建议在该 except 分支至少 print 一条告警,或仅当确实是“pyverbs 缺失/非 mlx5”等预期情况才无声回退,其余异常应暴露出来。
| override = os.getenv('EP_RDMA_GBS') | ||
| if override: | ||
| # noinspection PyBroadException | ||
| try: | ||
| return float(override) / 8 | ||
| except ValueError: | ||
| print(f'Invalid EP_RDMA_GBS={override!r}, ignoring and falling back to ibstat') |
There was a problem hiding this comment.
🔵 suggestion: EP_RDMA_GBS 在 get_rdma_gbs 函数体内读取,但该函数是 @lru_cache(以 nic_name 为键)。因此仅在“首次调用”时才会读取该环境变量并缓存,之后即便修改或新设置 EP_RDMA_GBS 也不会再生效。若期望支持运行时动态覆盖,需要把该环境变量也纳入 lru_cache 的 key(或不在被缓存的函数内部读取 env);若仅希望启动前设置一次,建议在 docstring 里明确说明“需在首次调用前设置,之后修改无效”。
| n_lag = _query_num_lag_ports(nic) | ||
| expected_gbs = rate_per_port * n_lag / 8.0 | ||
| actual_gbs = get_rdma_gbs(nic) | ||
| _log( | ||
| f"get_rdma_gbs({nic!r}) = {actual_gbs} GB/s " | ||
| f"(per-port {rate_per_port} Gb/s x {n_lag} rails / 8)" | ||
| ) | ||
| assert actual_gbs == expected_gbs, f"expected {expected_gbs}, got {actual_gbs}" | ||
| # Headline assertion: on a 2-rail 200G LAG, we recover ~50 GB/s | ||
| # (= 400 Gb/s aggregated), not the previous 25 GB/s single-rail value. | ||
| if n_lag >= 2 and rate_per_port >= 100: | ||
| aggregated_gbps = rate_per_port * n_lag | ||
| _log(f"LAG aggregation detected: {aggregated_gbps} Gb/s -> {actual_gbs} GB/s") |
There was a problem hiding this comment.
🔵 suggestion: test_get_rdma_gbs_lag_aggregation 中 expected_gbs 是用 ibstat 重解析 + 同一个 _query_num_lag_ports(nic) 算出来的,而 get_rdma_gbs 内部恰好就是 rate_per_port * _query_num_lag_ports(nic) / 8,因此 actual_gbs == expected_gbs 本质上只是在验证“函数是否与自身一致”(仅能捕获重构导致的不一致)。真正的结论(LAG 下恢复约 50 GB/s、而非 25 GB/s)只是打印出来、并未断言;遇到非 LAG 环境又直接 skip。建议至少在检测到 LAG(n_lag >= 2)时断言聚合带宽大于单 rail 值(或约为 n_lag 倍),否则这个测试对“修复是否真的起效”几乎没有校验力,存在“测试通过但实际仍拿不到聚合带宽”的风险。
| override = os.getenv('EP_RDMA_GBS') | ||
| if override: | ||
| # noinspection PyBroadException | ||
| try: | ||
| return float(override) / 8 | ||
| except ValueError: | ||
| print(f'Invalid EP_RDMA_GBS={override!r}, ignoring and falling back to ibstat') |
There was a problem hiding this comment.
🔵 suggestion: EP_RDMA_GBS 覆盖路径未对非正数做校验:若用户误设 EP_RDMA_GBS=0 或负值,会直接返回 0 或负的带宽值。get_rdma_gbs 的返回值在 elastic.py:614 被用于 SM 数计算(通常作为分母或速率),0/负值可能导致除零或异常的 SM 估计。建议在 float 转换成功后校验 >0,否则也回退到 ibstat 探测。
| @functools.lru_cache() | ||
| def _query_num_lag_ports(nic_name: str) -> int: |
There was a problem hiding this comment.
🔵 suggestion: _query_num_lag_ports 使用 lru_cache 缓存结果,包含失败时返回的回退值 1。若首次调用因瞬时原因(如设备短暂不可用)失败而缓存为 1,后续即使 LAG 恢复也将一直沿用单轨值。鉴于该函数仅在启动期调用一次、与 get_rdma_gbs 的缓存语义一致,影响有限,但可在注释中说明此缓存语义以避免误用。
4426f40 to
2c5745a
Compare
Background
get_rdma_gbsreadsibstat'sRate:to feedElasticBuffer.get_theoretical_num_sms. On Mellanox RoCE LAG (e.g., CX-7 with two 200 Gb/s NDR rails bonded into onemlx5_bond_*),ibstatreports the per-rail rate — the mlx5 driver hashes traffic across both rails, so a single QP only ever uses one, but multiple QPs reach the bond's full aggregate. The probe therefore returns half of the real per-NIC capacity and the downstream SM-count calculation picks the wrong bottleneck.Same root cause as #614.
Measurement (H800 + CX-7,
mlx5_bond_*2-rail LAG)ibstat 'mlx5_bond_1' .. Rate:ib_write_bw -d mlx5_bond_1 -F(1 QP)ib_write_bw -d mlx5_bond_1 -F -q 8Mlx5Context.query_mlx5_device(comp_mask=MASK_NUM_LAG_PORTS).num_lag_portsReal per-NIC bandwidth is ~50 GB/s, not 25 — a 2x error.
Fix
Query
num_lag_portsvia pyverbs and multiply the ibstat rate by it. Same approach as Mooncake's transfer engine, seemooncake-transfer-engine/src/transport/rdma_transport/rdma_context.cpp:266for the C++ equivalent (mlx5dv_query_device+MLX5DV_CONTEXT_MASK_NUM_LAG_PORTS).Gotcha:
query_mlx5_device()'s defaultcomp_mask=-1ORs only the masks pyverbs knows about; at least the version I tested against didn't includeNUM_LAG_PORTS, sonum_lag_portscame back as 0. The bit must be passed explicitly. Hardcoded frominfiniband/mlx5dv.hsince pyverbs doesn't re-export the constant.Also adds an
EP_RDMA_GBS=<gbps>env override (also/8to GB/s) so users with broken probes or non-Mellanox fabrics can short-circuit detection — addresses the request in #614.Falls back cleanly to the original single-rail behaviour when
pyverbsis missing, when the NIC isn't mlx5, or when LAG isn't enabled.Tests
tests/utils/test_envs.py— three local checks (no distributed setup needed) that auto-skip when pyverbs / ibstat /mlx5_bond_*aren't present. On H800 + CX-7 LAG the test prints:Compatibility
No new hard dependencies;
pyverbsis optional (import istry/except). No API changes. Only the returned value changes, and only when LAG is actually in effect.Fixes #614. Credit @michaelchen1996 env-override sketch.