Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 229 additions & 0 deletions tests/cpp/operator/test_swizzle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,173 @@ TEST_P(UnswizzleTestSuite, TestUnswizzle) {
transa);
}

void performTestGroupedUnswizzleMXFP8(const int num_tensors, const size_t M, const size_t K) {
using namespace transformer_engine;
using namespace test;

std::vector<std::unique_ptr<Tensor>> input_tensors;
std::vector<std::unique_ptr<Tensor>> output_tensors;
std::vector<Tensor*> input_ptrs;
std::vector<Tensor*> output_ptrs;
input_tensors.reserve(num_tensors);
output_tensors.reserve(num_tensors);
input_ptrs.reserve(num_tensors);
output_ptrs.reserve(num_tensors);

const std::vector<size_t> shape{M, K};
for (int i = 0; i < num_tensors; ++i) {
auto input = std::make_unique<Tensor>("input_" + std::to_string(i), shape,
DType::kFloat8E4M3, true, true,
NVTE_MXFP8_1D_SCALING);
auto output = std::make_unique<Tensor>("output_" + std::to_string(i), shape,
DType::kFloat8E4M3, true, true,
NVTE_MXFP8_1D_SCALING);
fillUniform(input.get());
fillUniform(output.get());

input_ptrs.push_back(input.get());
output_ptrs.push_back(output.get());
input_tensors.emplace_back(std::move(input));
output_tensors.emplace_back(std::move(output));
}

GroupedBuffers grouped_input = build_grouped_tensor(input_ptrs, NVTE_MXFP8_1D_SCALING);
GroupedBuffers grouped_output = build_grouped_tensor(output_ptrs, NVTE_MXFP8_1D_SCALING);
const uint8_t input_swizzled = 1;
nvte_set_grouped_tensor_param(grouped_input.get_handle(),
kNVTEGroupedWithGEMMSwizzledScales,
&input_swizzled, sizeof(input_swizzled));
const uint8_t output_swizzled = 0;
nvte_set_grouped_tensor_param(grouped_output.get_handle(),
kNVTEGroupedWithGEMMSwizzledScales,
&output_swizzled, sizeof(output_swizzled));

const NVTEShape row_shape = input_tensors[0]->rowwise_scale_inv_shape();
const NVTEShape col_shape = input_tensors[0]->columnwise_scale_inv_shape();
const size_t row_numel = row_shape.data[0] * row_shape.data[1];
const size_t col_numel = col_shape.data[0] * col_shape.data[1];

NVTE_CHECK_CUDA(cudaMemset(grouped_output.scale_inv.get(), 0, num_tensors * row_numel));
NVTE_CHECK_CUDA(cudaMemset(grouped_output.columnwise_scale_inv.get(), 0, num_tensors * col_numel));

nvte_unswizzle_grouped_scaling_factors(grouped_input.get_handle(),
grouped_output.get_handle(), 0);

std::vector<uint8_t> output_row(num_tensors * row_numel);
std::vector<uint8_t> output_col(num_tensors * col_numel);
NVTE_CHECK_CUDA(cudaMemcpy(output_row.data(), grouped_output.scale_inv.get(),
output_row.size(), cudaMemcpyDeviceToHost));
NVTE_CHECK_CUDA(cudaMemcpy(output_col.data(), grouped_output.columnwise_scale_inv.get(),
output_col.size(), cudaMemcpyDeviceToHost));

std::vector<uint8_t> ref_row(num_tensors * row_numel);
std::vector<uint8_t> ref_col(num_tensors * col_numel);
for (int i = 0; i < num_tensors; ++i) {
compute_ref_unswizzle<128, 4, true>(input_tensors[i]->rowwise_cpu_scale_inv_ptr<uint8_t>(),
ref_row.data() + i * row_numel,
row_shape.data[0], row_shape.data[1]);
compute_ref_unswizzle<128, 4, false>(
input_tensors[i]->columnwise_cpu_scale_inv_ptr<uint8_t>(),
ref_col.data() + i * col_numel,
col_shape.data[1], col_shape.data[0]);
}

compareResults("grouped_unswizzle_rowwise", output_row.data(), ref_row.data(),
num_tensors * row_numel);
compareResults("grouped_unswizzle_colwise", output_col.data(), ref_col.data(),
num_tensors * col_numel);
}

void performTestGroupedSwizzleUnswizzleRoundtrip(const int num_tensors, const size_t M,
const size_t K) {
using namespace transformer_engine;
using namespace test;

constexpr size_t BLOCK_SIZE = 32;
const std::vector<size_t> shape{M, K};

std::vector<std::unique_ptr<Tensor>> orig_tensors, mid_tensors, final_tensors;
std::vector<Tensor*> orig_ptrs, mid_ptrs, final_ptrs;
orig_tensors.reserve(num_tensors);
mid_tensors.reserve(num_tensors);
final_tensors.reserve(num_tensors);

for (int i = 0; i < num_tensors; ++i) {
auto orig = std::make_unique<Tensor>("orig_" + std::to_string(i), shape,
DType::kFloat8E4M3, true, true, NVTE_MXFP8_1D_SCALING);
auto mid = std::make_unique<Tensor>("mid_" + std::to_string(i), shape,
DType::kFloat8E4M3, true, true, NVTE_MXFP8_1D_SCALING);
auto fin = std::make_unique<Tensor>("fin_" + std::to_string(i), shape,
DType::kFloat8E4M3, true, true, NVTE_MXFP8_1D_SCALING);
fillUniform(orig.get());

// Zero padding so the round-trip comparison is exact.
orig->to_cpu();
const NVTEShape rs = orig->rowwise_scale_inv_shape();
zero_scale_inv_padding(orig->rowwise_cpu_scale_inv_ptr<uint8_t>(),
rs.data[0], rs.data[1],
M, (K + BLOCK_SIZE - 1) / BLOCK_SIZE);
const NVTEShape cs = orig->columnwise_scale_inv_shape();
zero_scale_inv_padding(orig->columnwise_cpu_scale_inv_ptr<uint8_t>(),
cs.data[0], cs.data[1],
(M + BLOCK_SIZE - 1) / BLOCK_SIZE, K);
orig->from_cpu();

orig_ptrs.push_back(orig.get());
mid_ptrs.push_back(mid.get());
final_ptrs.push_back(fin.get());
orig_tensors.emplace_back(std::move(orig));
mid_tensors.emplace_back(std::move(mid));
final_tensors.emplace_back(std::move(fin));
}

GroupedBuffers grouped_orig = build_grouped_tensor(orig_ptrs, NVTE_MXFP8_1D_SCALING);
GroupedBuffers grouped_mid = build_grouped_tensor(mid_ptrs, NVTE_MXFP8_1D_SCALING);
GroupedBuffers grouped_fin = build_grouped_tensor(final_ptrs, NVTE_MXFP8_1D_SCALING);

const NVTEShape row_shape = orig_tensors[0]->rowwise_scale_inv_shape();
const NVTEShape col_shape = orig_tensors[0]->columnwise_scale_inv_shape();
const size_t row_numel = row_shape.data[0] * row_shape.data[1];
const size_t col_numel = col_shape.data[0] * col_shape.data[1];

const uint8_t no_swizzle = 0, has_swizzle = 1;
nvte_set_grouped_tensor_param(grouped_orig.get_handle(), kNVTEGroupedWithGEMMSwizzledScales,
&no_swizzle, sizeof(no_swizzle));
nvte_set_grouped_tensor_param(grouped_mid.get_handle(), kNVTEGroupedWithGEMMSwizzledScales,
&has_swizzle, sizeof(has_swizzle));
nvte_set_grouped_tensor_param(grouped_fin.get_handle(), kNVTEGroupedWithGEMMSwizzledScales,
&no_swizzle, sizeof(no_swizzle));

NVTE_CHECK_CUDA(cudaMemset(grouped_mid.scale_inv.get(), 0, num_tensors * row_numel));
NVTE_CHECK_CUDA(cudaMemset(grouped_mid.columnwise_scale_inv.get(), 0, num_tensors * col_numel));
NVTE_CHECK_CUDA(cudaMemset(grouped_fin.scale_inv.get(), 0, num_tensors * row_numel));
NVTE_CHECK_CUDA(cudaMemset(grouped_fin.columnwise_scale_inv.get(), 0, num_tensors * col_numel));

nvte_swizzle_grouped_scaling_factors(grouped_orig.get_handle(), grouped_mid.get_handle(), 0);
nvte_unswizzle_grouped_scaling_factors(grouped_mid.get_handle(), grouped_fin.get_handle(), 0);

std::vector<uint8_t> result_row(num_tensors * row_numel);
std::vector<uint8_t> result_col(num_tensors * col_numel);
NVTE_CHECK_CUDA(cudaMemcpy(result_row.data(), grouped_fin.scale_inv.get(),
result_row.size(), cudaMemcpyDeviceToHost));
NVTE_CHECK_CUDA(cudaMemcpy(result_col.data(), grouped_fin.columnwise_scale_inv.get(),
result_col.size(), cudaMemcpyDeviceToHost));

std::vector<uint8_t> ref_row(num_tensors * row_numel);
std::vector<uint8_t> ref_col(num_tensors * col_numel);
for (int i = 0; i < num_tensors; ++i) {
memcpy(ref_row.data() + i * row_numel,
orig_tensors[i]->rowwise_cpu_scale_inv_ptr<uint8_t>(), row_numel);
memcpy(ref_col.data() + i * col_numel,
orig_tensors[i]->columnwise_cpu_scale_inv_ptr<uint8_t>(), col_numel);
}

compareResults("grouped_roundtrip_rowwise", result_row.data(), ref_row.data(),
num_tensors * row_numel);
compareResults("grouped_roundtrip_colwise", result_col.data(), ref_col.data(),
num_tensors * col_numel);
}

class SwizzleGroupedTestSuite
: public ::testing::TestWithParam<std::tuple<int, size_t, size_t>> {};

Expand Down Expand Up @@ -374,6 +541,68 @@ INSTANTIATE_TEST_SUITE_P(
}
);

class UnswizzleGroupedTestSuite
: public ::testing::TestWithParam<std::tuple<int, size_t, size_t>> {};

TEST_P(UnswizzleGroupedTestSuite, TestGroupedUnswizzleMXFP8) {
const auto num_tensors = std::get<0>(GetParam());
const auto M = std::get<1>(GetParam());
const auto K = std::get<2>(GetParam());
performTestGroupedUnswizzleMXFP8(num_tensors, M, K);
}

INSTANTIATE_TEST_SUITE_P(
OperatorTest,
UnswizzleGroupedTestSuite,
::testing::Values(
std::make_tuple(3, 256, 256),
std::make_tuple(4, 128, 128),
std::make_tuple(3, 200, 256),
std::make_tuple(2, 65, 256),
std::make_tuple(3, 256, 160),
std::make_tuple(2, 256, 96),
std::make_tuple(3, 200, 160),
std::make_tuple(4, 33, 64),
std::make_tuple(2, 1, 32)
),
[](const testing::TestParamInfo<UnswizzleGroupedTestSuite::ParamType>& info) {
return "n" + std::to_string(std::get<0>(info.param)) +
"_M" + std::to_string(std::get<1>(info.param)) +
"_K" + std::to_string(std::get<2>(info.param));
}
);

class SwizzleUnswizzleGroupedRoundtripTestSuite
: public ::testing::TestWithParam<std::tuple<int, size_t, size_t>> {};

TEST_P(SwizzleUnswizzleGroupedRoundtripTestSuite, TestGroupedSwizzleUnswizzleRoundtrip) {
const auto num_tensors = std::get<0>(GetParam());
const auto M = std::get<1>(GetParam());
const auto K = std::get<2>(GetParam());
performTestGroupedSwizzleUnswizzleRoundtrip(num_tensors, M, K);
}

INSTANTIATE_TEST_SUITE_P(
OperatorTest,
SwizzleUnswizzleGroupedRoundtripTestSuite,
::testing::Values(
std::make_tuple(3, 256, 256),
std::make_tuple(4, 128, 128),
std::make_tuple(3, 200, 256),
std::make_tuple(2, 65, 256),
std::make_tuple(3, 256, 160),
std::make_tuple(2, 256, 96),
std::make_tuple(3, 200, 160),
std::make_tuple(4, 33, 64),
std::make_tuple(2, 1, 32)
),
[](const testing::TestParamInfo<SwizzleUnswizzleGroupedRoundtripTestSuite::ParamType>& info) {
return "n" + std::to_string(std::get<0>(info.param)) +
"_M" + std::to_string(std::get<1>(info.param)) +
"_K" + std::to_string(std::get<2>(info.param));
}
);

namespace {

std::vector<std::pair<int, int>> num_tiles = {
Expand Down
16 changes: 16 additions & 0 deletions transformer_engine/common/include/transformer_engine/swizzle.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,22 @@ void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input,
void nvte_swizzle_grouped_scaling_factors(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream);

/*! \brief Unswizzling scaling factors from the interleaved GEMM layout back to row-major (grouped)
*
* \param[in] input Input grouped tensor with swizzled scale_inv.
* \param[in,out] output Output grouped tensor which hosts non-swizzled scale_inv.
* \param[in] stream CUDA stream used for the operation.
*
* Requirements:
* - scaling mode must be MXFP8 1D scaling.
* - scale_inv is stored in row-major in output.
* - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale.
* - data is quantized along K-dimension, i.e. 1D-scaling block lies along the K-dimension.
* - all tensors in the grouped tensor must have the same shape.
*/
void nvte_unswizzle_grouped_scaling_factors(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream);

#ifdef __cplusplus
} // extern "C"
#endif
Expand Down
Loading
Loading