From 2e61c5444b4f13ab6d3a6f6d02ccd7719467e092 Mon Sep 17 00:00:00 2001 From: jianzhu Date: Thu, 2 Apr 2026 21:42:39 +0800 Subject: [PATCH] fix(CP, MLA): CP works fine with MLA in a2a cp_comm_type --- .../pytorch/attention/dot_product_attention/context_parallel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 64cccaac6e..6d5dc887bd 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4072,6 +4072,7 @@ def attn_forward_func_with_cp( enable_mla = k.shape[-1] != v.shape[-1] assert not enable_mla or cp_comm_type in [ "p2p", + "a2a", "a2a+p2p", ], f"Context parallelism does not support MLA with {cp_comm_type=}!"