diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh index 245184045..b16fd6c27 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh @@ -197,6 +197,15 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, ptx::tensor_map_replace_global_addr_in_smem(smem_tensor_map_b, gmem_b_ptr + current_k_offset * shape_n); ptx::tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a, scheduler.current_shape_k, scheduler.current_shape_k); ptx::tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b, scheduler.current_shape_k, scheduler.current_shape_k); + + // Make sure tensormaps are not used by TMA before updating GMEM. + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + // Only used to prevent `ptxas` from moving the following GMEM stores before `cute::tma_desc_wait_group()`. + // Shouldn't be needed otherwise, since we only use one thread. + + __syncwarp(1U << lane_idx); + *(gmem_tensor_map_a) = *(smem_tensor_map_a); *(gmem_tensor_map_b) = *(smem_tensor_map_b); ptx::tensor_map_release_gpu(); diff --git a/tests/generators.py b/tests/generators.py index 343c77214..723648669 100644 --- a/tests/generators.py +++ b/tests/generators.py @@ -198,7 +198,7 @@ def enumerate_k_grouped_contiguous(dtype: torch.dtype): psum_list = (False, True) if get_arch_major() == 10 else (False, ) # Must with FP32 accumulation and 1D1D kernels for num_groups, m, n, expected_k_per_group in (( 4, 4096, 7168, 8192), ( 4, 7168, 2048, 8192), # EP64 - ( 8, 4096, 7168, 4096), ( 8, 7168, 2048, 4096), # EP32 + ( 8, 768, 2048, 128), ( 8, 4096, 7168, 4096), ( 8, 7168, 2048, 4096), # EP32 (16, 4096, 7168, 2048), (16, 7168, 2048, 2048)): # EP16 real_ks_cpu = [max(1, int(expected_k_per_group * random.uniform(0.7, 1.3))) for _ in range(num_groups)] for use_psum_layout in psum_list: