diff --git a/tests/kernels/utils/fp4_utils.py b/tests/kernels/utils/fp4_utils.py index c02a694a2..90ec47811 100644 --- a/tests/kernels/utils/fp4_utils.py +++ b/tests/kernels/utils/fp4_utils.py @@ -60,6 +60,63 @@ def f32_to_e8m0(x): return exponent.view(fp8_e8m0) +def f32_to_e8m0_even(amax: Tensor, *, emax: int, mbits: int) -> Tensor: + """OCP MX v1.0 §6.3 ``even_round`` E8M0 scale assignment. + + Given a per-block ``amax = max |x_i|`` (non-negative float32), returns the + E8M0 byte (uint8) that encodes the per-block scale + ``X_scale = 2^(floor(log2(amax_rounded)) - emax)``, where + ``amax_rounded`` is ``amax`` rounded to ``mbits``-of-mantissa precision via + round-half-to-even (i.e. add half-ULP at f32 mantissa bit + ``23 - mbits - 1`` and mask off the mantissa). + + Args: + amax: float32 tensor of per-block absolute maxima. Must be + non-negative; NaN entries propagate to 0xFF. + emax: largest representable element-format exponent. For E2M1 use 2, + E2M3 use 2, E3M2 use 4, E4M3 use 8, etc. (= 2^(ebits-1) for + symmetric exponent-bias formats). + mbits: element-format mantissa width in bits. Used only to size the + half-ULP added to ``amax`` before truncation; controls + tie-breaking. E2M1: 1, E2M3: 3, E3M2: 2, E4M3: 3. + + Returns: + uint8 tensor (viewed as fp8_e8m0 when supported) of the same shape as + ``amax`` holding the E8M0 codepoint ``k + 127`` where + ``k = floor(log2(amax_rounded)) - emax``. + + Notes: + ``f32_to_e8m0`` above expects pre-divided input ``amax / max_normal`` + and is biased one exponent step low on ~68% of blocks vs. the OCP + spec when ``max_normal != 2^emax`` (e.g. E2M3 max_normal = 7.5, so + ``1/max_normal`` lies in (0.5, 1) of a binade, dragging the rounded + exponent down by one). Use this function instead when producing + scales that need to interoperate with Quark / OCP-spec checkpoints. + See ``quark.torch.quantization.utils.even_round`` for the reference + implementation this matches bit-for-bit. + """ + assert amax.dtype == torch.float32, f"amax must be float32, got {amax.dtype}" + amax = amax.contiguous() + amax_i32 = amax.view(torch.int32) + # Half-ULP at mantissa bit (23 - mbits - 1). On uint32 add this would + # carry cleanly into the exponent; we go through int64 because HIP + # doesn't implement uint32 ops. + val_to_add = 1 << (23 - mbits - 1) + sign_exp_mask = 0x7F800000 # sign bit is 0 since amax >= 0 + rounded = ((amax_i32.to(torch.int64) & 0xFFFFFFFF) + val_to_add) & sign_exp_mask + # rounded_exp = f32 exponent field of amax_rounded = floor(log2(.)) + 127 + rounded_exp = (rounded >> 23) & 0xFF + # E8M0 byte = (rounded_exp - 127) - emax + 127 = rounded_exp - emax + byte = (rounded_exp - emax).clamp_(min=0, max=254) + # Edge cases (per Quark): zero amax → smallest normal scale (byte=1); + # NaN amax → 0xFF. + nan_mask = ((amax_i32.to(torch.int64) >> 23) & 0xFF) == 0xFF + zero_mask = amax == 0 + byte = torch.where(zero_mask, torch.full_like(byte, 1), byte) + byte = torch.where(nan_mask, torch.full_like(byte, 0xFF), byte) + return byte.to(torch.uint8).view(fp8_e8m0) + + def e8m0_to_f32(scale_e8m0_biased): scale_e8m0_biased = scale_e8m0_biased.view(torch.uint8) zero_case = scale_e8m0_biased == 0 diff --git a/tests/kernels/utils/test_fp4_utils.py b/tests/kernels/utils/test_fp4_utils.py new file mode 100644 index 000000000..61547740d --- /dev/null +++ b/tests/kernels/utils/test_fp4_utils.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""Unit tests for ``tests.kernels.utils.fp4_utils``. + +Today these are CPU-only (no MFMA exercised). The most important coverage +is ``f32_to_e8m0_even`` vs. the legacy ``f32_to_e8m0(amax / max_normal)`` +path: the two recipes diverge by one E8M0 step on the majority of inputs +when ``max_normal`` is not exactly ``2^emax`` (i.e. for E2M3/E3M2/E4M3). +""" + +from __future__ import annotations + +import math +import os +import sys + +import pytest +import torch + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +from tests.kernels.utils import fp4_utils # noqa: E402 + +# (dtype_name, emax, mbits, max_normal) -- straight from the OCP MX v1.0 spec. +_OCP_MX_TYPES = [ + ("E2M1", 2, 1, 6.0), # MXFP4 max = 4 * 1.5 + ("E2M3", 2, 3, 7.5), # MXFP6 max = 4 * 1.875 + ("E3M2", 4, 2, 28.0), # MXFP6 max = 16 * 1.75 + ("E4M3", 8, 3, 448.0), # MXFP8 max = 256 * 1.75 + ("E5M2", 15, 2, 57344.0), +] + + +def _e8m0_to_f32(byte_u8: torch.Tensor) -> torch.Tensor: + """uint8 E8M0 -> float32, OCP convention. byte=0 reserved (0+), byte=255 NaN.""" + bits = byte_u8.to(torch.int64) << 23 + f = bits.to(torch.int32).view(torch.float32) + f = torch.where(byte_u8 == 0, torch.full_like(f, 2.0**-126), f) + f = torch.where(byte_u8 == 0xFF, torch.full_like(f, float("nan")), f) + return f + + +def _reference_even_round_byte(amax: float, emax: int, mbits: int) -> int: + """Single-element reference: log2-floor-with-half-ULP-round-to-even. + + Independent re-derivation of the OCP §6.3 formula in plain python; used + to cross-check the vectorized ``f32_to_e8m0_even`` against arithmetic + that doesn't share its bit-manipulation tricks. + """ + if amax == 0.0: + return 1 # smallest normal scale, matches Quark even_round + if math.isnan(amax): + return 0xFF + # Add half-ULP at f32 mantissa bit (23 - mbits - 1), mask off mantissa. + import struct + + bits = struct.unpack("= amax" (no saturation) is NOT guaranteed + by RTNE because round-down can win at tie points. The earlier + ``f32_to_e8m0(amax / max_normal)`` path biases low by roughly one step + on most inputs and would fail the upper bound for many blocks. + """ + g = torch.Generator(device="cpu").manual_seed(1) + amax = torch.rand(4096, generator=g, dtype=torch.float32) * max_normal * 3 + 1e-6 + byte = fp4_utils.f32_to_e8m0_even(amax, emax=emax, mbits=mbits).view(torch.uint8) + scale = _e8m0_to_f32(byte) + ratio = (scale * max_normal) / amax + # Round-to-nearest in log2 space ⇒ ratio in [1/(1+max_mant_frac), 2*max_mant_frac]. + # Tight numerical bound is 0.5 < ratio < 2; the tie may land exactly at + # the edges, so use a small slack. + n_lo = (ratio < 0.5 - 1e-5).sum().item() + n_hi = (ratio > 2.0 + 1e-5).sum().item() + assert n_lo == 0 and n_hi == 0, ( + f"{dtype_name}: even_round produced scale outside [0.5, 2.0] of " + f"amax/max_normal: {n_lo} below, {n_hi} above. " + f"Range observed: [{ratio.min().item():.4g}, {ratio.max().item():.4g}]" + ) + + +def test_f32_to_e8m0_legacy_off_by_one_on_e2m3(): + """Document the existing ``f32_to_e8m0(amax / max_normal)`` bias. + + The legacy path is consistently one E8M0 step low for ~60-70% of random + amax inputs in E2M3 because 1/7.5 lies in (0.5, 1) of a binade. This + test fails-loudly only if the legacy function's behavior changes; the + intent is to pin down *why* ``f32_to_e8m0_even`` exists. + """ + emax, mbits, max_normal = 2, 3, 7.5 + g = torch.Generator(device="cpu").manual_seed(2) + amax = torch.rand(4096, generator=g, dtype=torch.float32) * max_normal + even = fp4_utils.f32_to_e8m0_even(amax, emax=emax, mbits=mbits).view(torch.uint8) + legacy = fp4_utils.f32_to_e8m0((amax / max_normal).clamp_(min=2**-126)) + legacy = legacy.view(torch.uint8) + diff = even.int() - legacy.int() + # Every diverging byte should be even == legacy + 1 (legacy is biased low). + bad_dirs = ((diff != 0) & (diff != 1)).sum().item() + assert bad_dirs == 0, ( + f"Unexpected disagreement direction: {bad_dirs} bytes diverged in " + f"directions other than +1. Distribution: {torch.unique(diff, return_counts=True)}" + ) + n_mism = (diff != 0).sum().item() + # Loose lower bound: on uniform amax in [0, max_normal] we see ~60%+. + assert n_mism > amax.numel() // 3, ( + f"Expected ≥1/3 of bytes to disagree (legacy off-by-one on E2M3); " + f"got {n_mism}/{amax.numel()}. Has f32_to_e8m0 been changed?" + )