diff --git a/lib/Dialect/P4HIR/P4HIR_Ops.cpp b/lib/Dialect/P4HIR/P4HIR_Ops.cpp index 1d678b3e..82baf5ef 100644 --- a/lib/Dialect/P4HIR/P4HIR_Ops.cpp +++ b/lib/Dialect/P4HIR/P4HIR_Ops.cpp @@ -250,12 +250,57 @@ OpFoldResult P4HIR::CastOp::fold(FoldAdaptor) { return {}; } +/// Returns true if folding a cast chain A -> B -> C into A -> C preserves +/// the cast semantics +/// +/// The key unsafe case is when the second cast widens (w_B < w_C), because +/// the extension type (zero vs sign) depends on the source's signedness. +/// If the intermediate type B differs from A in a way that changes how +/// the widening is performed, the fold would alter semantics. +/// +/// Fold is safe when: +/// - w_B >= w_C: second cast doesn't widen (truncation and reinterpretation +/// are sign-independent), OR +/// - w_A <= w_B AND s_A == s_B: first cast doesn't truncate and preserves +/// signedness, so the widening in the second cast uses the same extension +/// type as the direct A -> C cast would. +static bool isSafeCastComposition(mlir::Type srcType, mlir::Type midType, mlir::Type dstType) { + if (srcType == midType || midType == dstType) return true; + + auto srcBits = mlir::dyn_cast(srcType); + auto midBits = mlir::dyn_cast(midType); + auto dstBits = mlir::dyn_cast(dstType); + + if (srcBits && midBits && dstBits) { + unsigned wA = srcBits.getWidth(); + unsigned wB = midBits.getWidth(); + unsigned wC = dstBits.getWidth(); + + // Safe if the second cast doesn't widen. + if (wB >= wC) return true; + + // Second cast widens (wB < wC). Safe only if the first cast doesn't + // truncate and preserves signedness, so the composed extension matches + // the direct A -> C extension. + return wA <= wB && srcBits.isSigned() == midBits.isSigned(); + } + + // For non-BitsType chains, be conservative and don't fold. + return false; +} + LogicalResult P4HIR::CastOp::canonicalize(P4HIR::CastOp op, PatternRewriter &rewriter) { // Composition. // %b = cast(%a) : A -> B // cast(%b) : B -> C // ===> cast(%a) : A -> C if (auto inputCast = mlir::dyn_cast_if_present(op.getSrc().getDefiningOp())) { + mlir::Type srcType = inputCast.getSrc().getType(); + mlir::Type midType = inputCast.getType(); + mlir::Type dstType = op.getType(); + + if (!isSafeCastComposition(srcType, midType, dstType)) return failure(); + auto bitcast = rewriter.createOrFold(op.getLoc(), op.getType(), inputCast.getSrc()); rewriter.replaceOp(op, bitcast); diff --git a/test/Transforms/Folds/cast.mlir b/test/Transforms/Folds/cast.mlir index 0c764fb1..eb40fe9b 100644 --- a/test/Transforms/Folds/cast.mlir +++ b/test/Transforms/Folds/cast.mlir @@ -3,6 +3,8 @@ !i8i = !p4hir.int<8> !i16i = !p4hir.int<16> !b8i = !p4hir.bit<8> +!b16i = !p4hir.bit<16> +!b32i = !p4hir.bit<32> !infint = !p4hir.infint #false = #p4hir.bool : !p4hir.bool #true = #p4hir.bool : !p4hir.bool @@ -66,4 +68,76 @@ module { %cast6 = p4hir.cast(%c-128_i8i : !i8i) : !i16i p4hir.call @blackhole_i16i(%cast6) : (!i16i) -> () // CHECK: p4hir.call @blackhole_i16i (%[[cm128_i16i]]) : (!i16i) -> () + + // ---- Cast chain composition safety tests ---- + + p4hir.func @blackhole_b8i(!b8i) + p4hir.func @blackhole_b16i(!b16i) + p4hir.func @blackhole_b32i(!b32i) + + // Use non-constant arguments so chains are not constant-folded away. + p4hir.func @cast_chain_tests(%arg_b8 : !b8i, %arg_b16 : !b16i, %arg_i8 : !i8i, %arg_b32 : !b32i) { + // Safe: widen then reinterpret (w_B >= w_C). + // bit<8> -> bit<16> -> int<16> folds to bit<8> -> int<16>. + // CHECK-LABEL: @cast_chain_tests + // CHECK: %[[V0:.*]] = p4hir.cast(%arg0 : !b8i) : !i16i + // CHECK: p4hir.call @blackhole_i16i (%[[V0]]) + %c0 = p4hir.cast(%arg_b8 : !b8i) : !b16i + %c1 = p4hir.cast(%c0 : !b16i) : !i16i + p4hir.call @blackhole_i16i(%c1) : (!i16i) -> () + + // Unsafe: reinterpret then widen (sign change before widen). + // bit<8> -> int<8> -> int<16> must NOT fold: the chain sign-extends, + // but a direct bit<8> -> int<16> would zero-extend. + // CHECK: %[[V1:.*]] = p4hir.cast(%arg0 : !b8i) : !i8i + // CHECK: %[[V2:.*]] = p4hir.cast(%[[V1]] : !i8i) : !i16i + // CHECK: p4hir.call @blackhole_i16i (%[[V2]]) + %c2 = p4hir.cast(%arg_b8 : !b8i) : !i8i + %c3 = p4hir.cast(%c2 : !i8i) : !i16i + p4hir.call @blackhole_i16i(%c3) : (!i16i) -> () + + // Safe: same-sign widen then widen (w_A <= w_B, s_A == s_B). + // bit<8> -> bit<16> -> bit<32> folds to bit<8> -> bit<32>. + // CHECK: %[[V3:.*]] = p4hir.cast(%arg0 : !b8i) : !b32i + // CHECK: p4hir.call @blackhole_b32i (%[[V3]]) + %c4 = p4hir.cast(%arg_b8 : !b8i) : !b16i + %c5 = p4hir.cast(%c4 : !b16i) : !b32i + p4hir.call @blackhole_b32i(%c5) : (!b32i) -> () + + // Unsafe: truncate then widen (w_A > w_B, w_B < w_C). + // bit<16> -> bit<8> -> bit<32> must NOT fold: lossy truncation + // followed by widening differs from the direct widen. + // CHECK: %[[V4:.*]] = p4hir.cast(%arg1 : !b16i) : !b8i + // CHECK: %[[V5:.*]] = p4hir.cast(%[[V4]] : !b8i) : !b32i + // CHECK: p4hir.call @blackhole_b32i (%[[V5]]) + %c6 = p4hir.cast(%arg_b16 : !b16i) : !b8i + %c7 = p4hir.cast(%c6 : !b8i) : !b32i + p4hir.call @blackhole_b32i(%c7) : (!b32i) -> () + + // Safe: widen then truncate (w_B >= w_C). + // bit<8> -> bit<32> -> bit<16> folds to bit<8> -> bit<16>. + // CHECK: %[[V6:.*]] = p4hir.cast(%arg0 : !b8i) : !b16i + // CHECK: p4hir.call @blackhole_b16i (%[[V6]]) + %c8 = p4hir.cast(%arg_b8 : !b8i) : !b32i + %c9 = p4hir.cast(%c8 : !b32i) : !b16i + p4hir.call @blackhole_b16i(%c9) : (!b16i) -> () + + // Safe: truncate then reinterpret (w_B >= w_C). + // bit<16> -> bit<8> -> int<8> folds to bit<16> -> int<8>. + // CHECK: %[[V7:.*]] = p4hir.cast(%arg1 : !b16i) : !i8i + // CHECK: p4hir.call @blackhole_i8i (%[[V7]]) + %c10 = p4hir.cast(%arg_b16 : !b16i) : !b8i + %c11 = p4hir.cast(%c10 : !b8i) : !i8i + p4hir.call @blackhole_i8i(%c11) : (!i8i) -> () + + // Safe: truncate then truncate (w_B >= w_C). + // bit<32> -> bit<16> -> bit<8> folds to bit<32> -> bit<8>. + // CHECK: %[[V8:.*]] = p4hir.cast(%arg3 : !b32i) : !b8i + // CHECK: p4hir.call @blackhole_b8i (%[[V8]]) + %c12 = p4hir.cast(%arg_b32 : !b32i) : !b16i + %c13 = p4hir.cast(%c12 : !b16i) : !b8i + p4hir.call @blackhole_b8i(%c13) : (!b8i) -> () + + p4hir.return + } }