diff --git a/apex/contrib/fmha/fmha.py b/apex/contrib/fmha/fmha.py index af067a345..e2da2d79c 100644 --- a/apex/contrib/fmha/fmha.py +++ b/apex/contrib/fmha/fmha.py @@ -34,9 +34,13 @@ class FMHAFun(torch.autograd.Function): @staticmethod def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training, zero_tensors): batch_size = cu_seqlens.numel() - 1 - if batch_size < 4: + # Note: bwd_nl (no-loop backward) only works correctly for batch_size <= 2. + # For batch_size == 3, bwd_nl produces incorrect results due to num_chunks + # configuration issue in the CUDA kernel, so we use the regular path instead. + use_nl = batch_size <= 2 + if use_nl: max_s = 512 - context, S_dmask = mha.fwd_nl( + context, S_dmask = mha.fwd( qkv, cu_seqlens, p_dropout, max_s, is_training, True, zero_tensors, None ) else: @@ -55,13 +59,13 @@ def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training, zero_tensors): ctx.p_dropout = p_dropout ctx.max_s = max_s ctx.zero_tensors = zero_tensors + ctx.use_nl = use_nl return context @staticmethod def backward(ctx, dout): qkv, S_dmask = ctx.saved_tensors - batch_size = ctx.cu_seqlens.numel() - 1 - if batch_size < 4: + if ctx.use_nl: dqkv, dp, _ = mha.bwd_nl( dout, qkv, diff --git a/apex/contrib/test/fmha/test_fmha.py b/apex/contrib/test/fmha/test_fmha.py index 4d2e76caf..84e337be5 100644 --- a/apex/contrib/test/fmha/test_fmha.py +++ b/apex/contrib/test/fmha/test_fmha.py @@ -91,7 +91,11 @@ def run_test(self, s: int, b: int, zero_tensors: bool): qkv.requires_grad = True - if b < 4: + # Note: bwd_nl (no-loop backward) only works correctly for batch_size <= 2. + # For batch_size == 3, bwd_nl produces incorrect results due to num_chunks + # configuration issue in the CUDA kernel, so we use the regular bwd instead. + use_nl = b <= 2 + if use_nl: ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, True, zero_tensors, None) else: ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, False, zero_tensors, None) @@ -109,7 +113,7 @@ def run_test(self, s: int, b: int, zero_tensors: bool): dw2 = dw.permute(0, 2, 1, 3).clone().detach().contiguous() - if b < 4: + if use_nl: dqkv2, _, _ = mha.bwd_nl(dw2, qkv_vs, S_, cu_seqlens, 0.0, s, zero_tensors) else: dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, 0.0, s, zero_tensors)