Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions tests/cpp/operator/test_act.cu
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ void performTestGLU(const size_t N, const size_t H) {
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
if ((otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2)
&& N * H > 0) {
auto [atol, rtol] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol, rtol);
if (output.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
Expand Down Expand Up @@ -392,7 +393,9 @@ std::vector<std::pair<size_t, size_t>> act_test_cases = {{2048, 12288},
{65536, 128},
{256, 256},
{257, 259},
{128, 128+1}};
{128, 128+1},
{0, 128},
{128, 0}};

} // namespace

Expand Down
4 changes: 3 additions & 1 deletion tests/cpp/operator/test_cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void performTest(const std::vector<size_t>& shape) {
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
if (isFp8Type(otype) && full_size > 0) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this problem shows up not only for activations, but also for a regular cast?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, many of our kernels are not robust to empty tensors. I still expect to see problems in the FP8 block-scale quantization kernels and transpose kernels.

auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
Expand All @@ -91,6 +91,8 @@ std::vector<std::vector<size_t>> test_cases = {
{5, 160},
{5, 4, 3, 160},
{217, 256},
{0, 128},
{128, 0},
};
} // namespace

Expand Down
4 changes: 3 additions & 1 deletion tests/cpp/operator/test_cast_gated_swiglu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void performTest(const std::vector<size_t>& shape) {
rows,
cols);

if (isFp8Type(otype)) {
if (isFp8Type(otype) && input_size > 0) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
Expand All @@ -118,6 +118,8 @@ std::vector<std::vector<size_t>> test_cases = {
{217, 256},
{1296},
{5, 4, 3, 160},
{0, 128},
{128, 0},
};

} // namespace
Expand Down
2 changes: 2 additions & 0 deletions tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,8 @@ std::vector<std::pair<size_t, size_t>> matrix_sizes = {
{768, 1024},
{8192, 128},
{577, 1632},
{0, 128},
{128, 0},
};

std::vector<std::pair<size_t, size_t>> block_sizes = {
Expand Down
5 changes: 5 additions & 0 deletions transformer_engine/common/cast/fp8/gated_fp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,14 @@ void cast_gated_tma(const Tensor &gated_input, const Tensor &grad, Tensor *outpu
checkCuDriverContext(stream);

NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function.");

// Tensor dimensions
const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2;
const size_t output_cols = (IS_BWD ? 2 : 1) * cols;
if (rows == 0 || cols == 0) {
return;
}

const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
Expand Down
15 changes: 15 additions & 0 deletions transformer_engine/common/cast/fp8/quantize_fp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void quantize_1D(const Tensor &input, Tensor *output, cudaStream_t stream) {
using namespace quantize_1D_kernel;

// Tensor size
const size_t N = product(input.data.shape);
if (N == 0) {
return;
}

const bool isFullTile = (N % ELEMS_PER_BLOCK == 0);
NVTE_CHECK(isFullTile, "Only full tiles are supported.");
Expand Down Expand Up @@ -391,8 +396,18 @@ void quantize_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T
using namespace quantize_2D_kernel;
checkCuDriverContext(stream);

// Tensor dimensions
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();

// Skip kernel if tensor size is zero
if (rows == 0 || cols == 0) {
if constexpr (IS_DBIAS) {
NVTE_ERROR("Invalid tensor shape for DBias computation (shape=", input.shape(), ").");
}
return;
}

const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y);
const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X);
const size_t blocks_Y = chunks_Y;
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,10 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream)

const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
if (rows == 0 || cols == 0) {
return;
}

const size_t chunks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t chunks_X = DIVUP(cols, CHUNK_DIM_X);

Expand Down
13 changes: 9 additions & 4 deletions transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,13 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu
using namespace gated_kernel;
checkCuDriverContext(stream);

const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2;
const size_t output_cols = (IS_BWD ? 2 : 1) * cols;
if (rows == 0 || cols == 0) {
return;
}

const bool USE_ROWWISE_SCALING = output->has_data();
const bool USE_COLWISE_SCALING = output->has_columnwise_data();
const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales;
Expand All @@ -725,12 +732,10 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu
scaling_type = ScalingType::COLWISE;
} else if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) {
scaling_type = ScalingType::BIDIMENSIONAL;
} else {
NVTE_ERROR("Missing both row-wise and column-wise data.");
}

const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2;
const size_t output_cols = (IS_BWD ? 2 : 1) * cols;

const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);

Expand Down
9 changes: 9 additions & 0 deletions transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,15 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations
}
}

// Skip kernel if tensor size is zero
if (elts_total == 0) {
if constexpr (IS_DBIAS) {
NVTE_ERROR("Invalid grouped tensor shape for DBias computation (first_logical_dim=",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case we can output dbias = zero tensor also right instead of throwing an error?

first_logical_dim, ", last_logical_dim=", last_logical_dim, ")");
}
return;
}

TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input->dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
Expand Down
8 changes: 8 additions & 0 deletions transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,14 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop,
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();

// Skip kernel if tensor size is zero
if (rows == 0 || cols == 0) {
if constexpr (IS_DBIAS) {
NVTE_ERROR("Invalid tensor shape for DBias computation (shape=", input.shape(), ").");
}
return;
}

// Tensor chunk handled by each CUDA block
constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64;
constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64;
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream)
constexpr int FP4_BLOCK_SIZE = 16;
const size_t N = input.flat_first_dim();
const size_t M = input.flat_last_dim();
if (N == 0 || M == 0) {
return;
}

NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ",
FP4_BLOCK_SIZE, ", but got ", input.data.shape, ".");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,9 @@ void group_quantize_transpose(const Tensor &input, const Tensor *noop,

const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
if (rows == 0 || cols == 0) {
return;
}

NVTE_CHECK(rows % 32 == 0,
"Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,9 @@ inline void quantize(const Tensor &input, const Tensor *noop, Tensor *output, cu

const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
if (rows == 0 || cols == 0) {
return;
}

constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1197,6 +1197,9 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,

const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
if (rows == 0 || cols == 0) {
return;
}

NVTE_CHECK(rows % 32 == 0,
"Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,9 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop,

const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
if (rows == 0 || cols == 0) {
return;
}

NVTE_CHECK(rows % 32 == 0,
"Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA
Expand Down
Loading