diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index 7dfb34201d..806a2482ab 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -339,6 +339,173 @@ TEST_P(UnswizzleTestSuite, TestUnswizzle) { transa); } +void performTestGroupedUnswizzleMXFP8(const int num_tensors, const size_t M, const size_t K) { + using namespace transformer_engine; + using namespace test; + + std::vector> input_tensors; + std::vector> output_tensors; + std::vector input_ptrs; + std::vector output_ptrs; + input_tensors.reserve(num_tensors); + output_tensors.reserve(num_tensors); + input_ptrs.reserve(num_tensors); + output_ptrs.reserve(num_tensors); + + const std::vector shape{M, K}; + for (int i = 0; i < num_tensors; ++i) { + auto input = std::make_unique("input_" + std::to_string(i), shape, + DType::kFloat8E4M3, true, true, + NVTE_MXFP8_1D_SCALING); + auto output = std::make_unique("output_" + std::to_string(i), shape, + DType::kFloat8E4M3, true, true, + NVTE_MXFP8_1D_SCALING); + fillUniform(input.get()); + fillUniform(output.get()); + + input_ptrs.push_back(input.get()); + output_ptrs.push_back(output.get()); + input_tensors.emplace_back(std::move(input)); + output_tensors.emplace_back(std::move(output)); + } + + GroupedBuffers grouped_input = build_grouped_tensor(input_ptrs, NVTE_MXFP8_1D_SCALING); + GroupedBuffers grouped_output = build_grouped_tensor(output_ptrs, NVTE_MXFP8_1D_SCALING); + const uint8_t input_swizzled = 1; + nvte_set_grouped_tensor_param(grouped_input.get_handle(), + kNVTEGroupedWithGEMMSwizzledScales, + &input_swizzled, sizeof(input_swizzled)); + const uint8_t output_swizzled = 0; + nvte_set_grouped_tensor_param(grouped_output.get_handle(), + kNVTEGroupedWithGEMMSwizzledScales, + &output_swizzled, sizeof(output_swizzled)); + + const NVTEShape row_shape = input_tensors[0]->rowwise_scale_inv_shape(); + const NVTEShape col_shape = input_tensors[0]->columnwise_scale_inv_shape(); + const size_t row_numel = row_shape.data[0] * row_shape.data[1]; + const size_t col_numel = col_shape.data[0] * col_shape.data[1]; + + NVTE_CHECK_CUDA(cudaMemset(grouped_output.scale_inv.get(), 0, num_tensors * row_numel)); + NVTE_CHECK_CUDA(cudaMemset(grouped_output.columnwise_scale_inv.get(), 0, num_tensors * col_numel)); + + nvte_unswizzle_grouped_scaling_factors(grouped_input.get_handle(), + grouped_output.get_handle(), 0); + + std::vector output_row(num_tensors * row_numel); + std::vector output_col(num_tensors * col_numel); + NVTE_CHECK_CUDA(cudaMemcpy(output_row.data(), grouped_output.scale_inv.get(), + output_row.size(), cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(output_col.data(), grouped_output.columnwise_scale_inv.get(), + output_col.size(), cudaMemcpyDeviceToHost)); + + std::vector ref_row(num_tensors * row_numel); + std::vector ref_col(num_tensors * col_numel); + for (int i = 0; i < num_tensors; ++i) { + compute_ref_unswizzle<128, 4, true>(input_tensors[i]->rowwise_cpu_scale_inv_ptr(), + ref_row.data() + i * row_numel, + row_shape.data[0], row_shape.data[1]); + compute_ref_unswizzle<128, 4, false>( + input_tensors[i]->columnwise_cpu_scale_inv_ptr(), + ref_col.data() + i * col_numel, + col_shape.data[1], col_shape.data[0]); + } + + compareResults("grouped_unswizzle_rowwise", output_row.data(), ref_row.data(), + num_tensors * row_numel); + compareResults("grouped_unswizzle_colwise", output_col.data(), ref_col.data(), + num_tensors * col_numel); +} + +void performTestGroupedSwizzleUnswizzleRoundtrip(const int num_tensors, const size_t M, + const size_t K) { + using namespace transformer_engine; + using namespace test; + + constexpr size_t BLOCK_SIZE = 32; + const std::vector shape{M, K}; + + std::vector> orig_tensors, mid_tensors, final_tensors; + std::vector orig_ptrs, mid_ptrs, final_ptrs; + orig_tensors.reserve(num_tensors); + mid_tensors.reserve(num_tensors); + final_tensors.reserve(num_tensors); + + for (int i = 0; i < num_tensors; ++i) { + auto orig = std::make_unique("orig_" + std::to_string(i), shape, + DType::kFloat8E4M3, true, true, NVTE_MXFP8_1D_SCALING); + auto mid = std::make_unique("mid_" + std::to_string(i), shape, + DType::kFloat8E4M3, true, true, NVTE_MXFP8_1D_SCALING); + auto fin = std::make_unique("fin_" + std::to_string(i), shape, + DType::kFloat8E4M3, true, true, NVTE_MXFP8_1D_SCALING); + fillUniform(orig.get()); + + // Zero padding so the round-trip comparison is exact. + orig->to_cpu(); + const NVTEShape rs = orig->rowwise_scale_inv_shape(); + zero_scale_inv_padding(orig->rowwise_cpu_scale_inv_ptr(), + rs.data[0], rs.data[1], + M, (K + BLOCK_SIZE - 1) / BLOCK_SIZE); + const NVTEShape cs = orig->columnwise_scale_inv_shape(); + zero_scale_inv_padding(orig->columnwise_cpu_scale_inv_ptr(), + cs.data[0], cs.data[1], + (M + BLOCK_SIZE - 1) / BLOCK_SIZE, K); + orig->from_cpu(); + + orig_ptrs.push_back(orig.get()); + mid_ptrs.push_back(mid.get()); + final_ptrs.push_back(fin.get()); + orig_tensors.emplace_back(std::move(orig)); + mid_tensors.emplace_back(std::move(mid)); + final_tensors.emplace_back(std::move(fin)); + } + + GroupedBuffers grouped_orig = build_grouped_tensor(orig_ptrs, NVTE_MXFP8_1D_SCALING); + GroupedBuffers grouped_mid = build_grouped_tensor(mid_ptrs, NVTE_MXFP8_1D_SCALING); + GroupedBuffers grouped_fin = build_grouped_tensor(final_ptrs, NVTE_MXFP8_1D_SCALING); + + const NVTEShape row_shape = orig_tensors[0]->rowwise_scale_inv_shape(); + const NVTEShape col_shape = orig_tensors[0]->columnwise_scale_inv_shape(); + const size_t row_numel = row_shape.data[0] * row_shape.data[1]; + const size_t col_numel = col_shape.data[0] * col_shape.data[1]; + + const uint8_t no_swizzle = 0, has_swizzle = 1; + nvte_set_grouped_tensor_param(grouped_orig.get_handle(), kNVTEGroupedWithGEMMSwizzledScales, + &no_swizzle, sizeof(no_swizzle)); + nvte_set_grouped_tensor_param(grouped_mid.get_handle(), kNVTEGroupedWithGEMMSwizzledScales, + &has_swizzle, sizeof(has_swizzle)); + nvte_set_grouped_tensor_param(grouped_fin.get_handle(), kNVTEGroupedWithGEMMSwizzledScales, + &no_swizzle, sizeof(no_swizzle)); + + NVTE_CHECK_CUDA(cudaMemset(grouped_mid.scale_inv.get(), 0, num_tensors * row_numel)); + NVTE_CHECK_CUDA(cudaMemset(grouped_mid.columnwise_scale_inv.get(), 0, num_tensors * col_numel)); + NVTE_CHECK_CUDA(cudaMemset(grouped_fin.scale_inv.get(), 0, num_tensors * row_numel)); + NVTE_CHECK_CUDA(cudaMemset(grouped_fin.columnwise_scale_inv.get(), 0, num_tensors * col_numel)); + + nvte_swizzle_grouped_scaling_factors(grouped_orig.get_handle(), grouped_mid.get_handle(), 0); + nvte_unswizzle_grouped_scaling_factors(grouped_mid.get_handle(), grouped_fin.get_handle(), 0); + + std::vector result_row(num_tensors * row_numel); + std::vector result_col(num_tensors * col_numel); + NVTE_CHECK_CUDA(cudaMemcpy(result_row.data(), grouped_fin.scale_inv.get(), + result_row.size(), cudaMemcpyDeviceToHost)); + NVTE_CHECK_CUDA(cudaMemcpy(result_col.data(), grouped_fin.columnwise_scale_inv.get(), + result_col.size(), cudaMemcpyDeviceToHost)); + + std::vector ref_row(num_tensors * row_numel); + std::vector ref_col(num_tensors * col_numel); + for (int i = 0; i < num_tensors; ++i) { + memcpy(ref_row.data() + i * row_numel, + orig_tensors[i]->rowwise_cpu_scale_inv_ptr(), row_numel); + memcpy(ref_col.data() + i * col_numel, + orig_tensors[i]->columnwise_cpu_scale_inv_ptr(), col_numel); + } + + compareResults("grouped_roundtrip_rowwise", result_row.data(), ref_row.data(), + num_tensors * row_numel); + compareResults("grouped_roundtrip_colwise", result_col.data(), ref_col.data(), + num_tensors * col_numel); +} + class SwizzleGroupedTestSuite : public ::testing::TestWithParam> {}; @@ -374,6 +541,68 @@ INSTANTIATE_TEST_SUITE_P( } ); +class UnswizzleGroupedTestSuite + : public ::testing::TestWithParam> {}; + +TEST_P(UnswizzleGroupedTestSuite, TestGroupedUnswizzleMXFP8) { + const auto num_tensors = std::get<0>(GetParam()); + const auto M = std::get<1>(GetParam()); + const auto K = std::get<2>(GetParam()); + performTestGroupedUnswizzleMXFP8(num_tensors, M, K); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + UnswizzleGroupedTestSuite, + ::testing::Values( + std::make_tuple(3, 256, 256), + std::make_tuple(4, 128, 128), + std::make_tuple(3, 200, 256), + std::make_tuple(2, 65, 256), + std::make_tuple(3, 256, 160), + std::make_tuple(2, 256, 96), + std::make_tuple(3, 200, 160), + std::make_tuple(4, 33, 64), + std::make_tuple(2, 1, 32) + ), + [](const testing::TestParamInfo& info) { + return "n" + std::to_string(std::get<0>(info.param)) + + "_M" + std::to_string(std::get<1>(info.param)) + + "_K" + std::to_string(std::get<2>(info.param)); + } +); + +class SwizzleUnswizzleGroupedRoundtripTestSuite + : public ::testing::TestWithParam> {}; + +TEST_P(SwizzleUnswizzleGroupedRoundtripTestSuite, TestGroupedSwizzleUnswizzleRoundtrip) { + const auto num_tensors = std::get<0>(GetParam()); + const auto M = std::get<1>(GetParam()); + const auto K = std::get<2>(GetParam()); + performTestGroupedSwizzleUnswizzleRoundtrip(num_tensors, M, K); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + SwizzleUnswizzleGroupedRoundtripTestSuite, + ::testing::Values( + std::make_tuple(3, 256, 256), + std::make_tuple(4, 128, 128), + std::make_tuple(3, 200, 256), + std::make_tuple(2, 65, 256), + std::make_tuple(3, 256, 160), + std::make_tuple(2, 256, 96), + std::make_tuple(3, 200, 160), + std::make_tuple(4, 33, 64), + std::make_tuple(2, 1, 32) + ), + [](const testing::TestParamInfo& info) { + return "n" + std::to_string(std::get<0>(info.param)) + + "_M" + std::to_string(std::get<1>(info.param)) + + "_K" + std::to_string(std::get<2>(info.param)); + } +); + namespace { std::vector> num_tiles = { diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h index aa697aafe1..4e28de3beb 100644 --- a/transformer_engine/common/include/transformer_engine/swizzle.h +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -107,6 +107,22 @@ void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input, void nvte_swizzle_grouped_scaling_factors(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream); +/*! \brief Unswizzling scaling factors from the interleaved GEMM layout back to row-major (grouped) + * + * \param[in] input Input grouped tensor with swizzled scale_inv. + * \param[in,out] output Output grouped tensor which hosts non-swizzled scale_inv. + * \param[in] stream CUDA stream used for the operation. + * + * Requirements: + * - scaling mode must be MXFP8 1D scaling. + * - scale_inv is stored in row-major in output. + * - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale. + * - data is quantized along K-dimension, i.e. 1D-scaling block lies along the K-dimension. + * - all tensors in the grouped tensor must have the same shape. + */ +void nvte_unswizzle_grouped_scaling_factors(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 28a879a376..d17d0b6d3d 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -485,6 +485,24 @@ __global__ void __launch_bounds__(TB_DIM* TB_DIM) gridDim.y); } +template +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + grouped_unswizzle_scaling_uniform_shape_kernel(const void* input, void* output, const int M, + const int K, const size_t scale_stride_bytes, + const bool row_scaling) { + const int tensor_id = blockIdx.z; + const uint8_t* input_base = + reinterpret_cast(input) + tensor_id * scale_stride_bytes; + uint8_t* output_base = reinterpret_cast(output) + tensor_id * scale_stride_bytes; + if (row_scaling) { + unswizzle_row_scaling_kernel_impl( + input_base, output_base, M, K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); + } else { + unswizzle_col_scaling_kernel_impl( + input_base, output_base, M, K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); + } +} + template __global__ void multi_tensor_unswizzle_row_scaling_kernel(MultiSwizzleArgs kernel_args) { const int bid = blockIdx.x; @@ -1692,6 +1710,110 @@ void swizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* } } +void unswizzle_grouped_scaling_factors(const GroupedTensor* input, GroupedTensor* output, + cudaStream_t stream) { + NVTE_CHECK(output->scaling_mode == NVTE_MXFP8_1D_SCALING, + "Grouped unswizzle supports only MXFP8 scaling."); + + CheckInputGroupedTensor(*input, "input"); + CheckOutputGroupedTensor(*output, "output", false); + NVTE_CHECK(input->with_gemm_swizzled_scales, + "Expected input grouped tensor with scales in GEMM swizzled format."); + NVTE_CHECK(!output->with_gemm_swizzled_scales, + "Expected output grouped tensor with scales in compact format."); + NVTE_CHECK(input->scaling_mode == output->scaling_mode, + "Input and output grouped tensors must have matching scaling modes."); + + const bool has_rowwise_scale_inv = output->scale_inv.has_data(); + const bool has_columnwise_scale_inv = output->columnwise_scale_inv.has_data(); + if (!has_rowwise_scale_inv && !has_columnwise_scale_inv) { + return; + } + + NVTE_CHECK(output->all_same_shape(), "Grouped unswizzle requires uniform tensor shapes."); + + const size_t first_dim = output->get_common_first_dim(); + const size_t last_dim = output->get_common_last_dim(); + + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; + const dim3 block_size(TB_DIM, TB_DIM); + + auto launch_grouped_unswizzle = [&](bool rowwise) { + const size_t m = rowwise ? first_dim : last_dim; + const size_t k = rowwise ? last_dim : first_dim; + const size_t padded_m = round_up_to_multiple(m, 128); + const size_t padded_k = + round_up_to_multiple(DIVUP(k, static_cast(MXFP8_BLOCK_SIZE)), 4); + const size_t scale_elems = padded_m * padded_k; + + const size_t scale_elem_size = rowwise ? typeToSize(output->scale_inv.dtype) + : typeToSize(output->columnwise_scale_inv.dtype); + const size_t scale_stride_bytes = scale_elems * scale_elem_size; + + if (rowwise) { + NVTE_CHECK(input->scale_inv.numel() == input->num_tensors * scale_elems, + "Grouped input scale_inv size does not match expected packed size."); + NVTE_CHECK(output->scale_inv.numel() == output->num_tensors * scale_elems, + "Grouped output scale_inv size does not match expected packed size."); + } else { + NVTE_CHECK(input->columnwise_scale_inv.numel() == input->num_tensors * scale_elems, + "Grouped input columnwise_scale_inv size does not match expected packed size."); + NVTE_CHECK(output->columnwise_scale_inv.numel() == output->num_tensors * scale_elems, + "Grouped output columnwise_scale_inv size does not match expected packed size."); + } + + const int num_tiles_m = padded_m / SF_TILE_DIM_M; + const int num_tiles_k = padded_k / SF_TILE_DIM_K; + int vec_load_size = (rowwise ? ((num_tiles_k - 1) % 4 + 1) : ((num_tiles_m - 1) % 4 + 1)); + if (vec_load_size == 3) vec_load_size = 1; + const int n_tiles_in_tb = TB_DIM * vec_load_size; + + dim3 num_blocks; + if (rowwise) { + num_blocks = dim3(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m, output->num_tensors); + } else { + num_blocks = + dim3(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size), output->num_tensors); + } + const int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + + const void* input_ptr = rowwise ? input->scale_inv.dptr : input->columnwise_scale_inv.dptr; + void* output_ptr = rowwise ? output->scale_inv.dptr : output->columnwise_scale_inv.dptr; + + using kernel_t = void (*)(const void*, void*, const int, const int, const size_t, const bool); + kernel_t kernel_fn = nullptr; + switch (vec_load_size) { + case 4: + kernel_fn = + grouped_unswizzle_scaling_uniform_shape_kernel; + break; + case 2: + kernel_fn = + grouped_unswizzle_scaling_uniform_shape_kernel; + break; + case 1: + kernel_fn = + grouped_unswizzle_scaling_uniform_shape_kernel; + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + } + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + kernel_fn<<>>(input_ptr, output_ptr, padded_m, + padded_k, scale_stride_bytes, rowwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + }; + + if (has_rowwise_scale_inv) { + launch_grouped_unswizzle(true); + } + if (has_columnwise_scale_inv) { + launch_grouped_unswizzle(false); + } +} + } // namespace transformer_engine void nvte_swizzle_grouped_scaling_factors(const NVTEGroupedTensor input, NVTEGroupedTensor output, @@ -1701,3 +1823,11 @@ void nvte_swizzle_grouped_scaling_factors(const NVTEGroupedTensor input, NVTEGro swizzle_grouped_scaling_factors(convertNVTEGroupedTensorCheck(input), convertNVTEGroupedTensorCheck(output), stream); } + +void nvte_unswizzle_grouped_scaling_factors(const NVTEGroupedTensor input, NVTEGroupedTensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_unswizzle_grouped_scaling_factors); + using namespace transformer_engine; + unswizzle_grouped_scaling_factors(convertNVTEGroupedTensorCheck(input), + convertNVTEGroupedTensorCheck(output), stream); +}