Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions lib/Dialect/P4HIR/P4HIR_Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,57 @@ OpFoldResult P4HIR::CastOp::fold(FoldAdaptor) {
return {};
}

/// Returns true if folding a cast chain A -> B -> C into A -> C preserves
/// the cast semantics
///
/// The key unsafe case is when the second cast widens (w_B < w_C), because
/// the extension type (zero vs sign) depends on the source's signedness.
/// If the intermediate type B differs from A in a way that changes how
/// the widening is performed, the fold would alter semantics.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this middle section of comments is overly verbose.
The safe cases are explained below and naming this the "key unsafe case" is somewhat misleading (I would say truncation is also key).

///
/// Fold is safe when:
/// - w_B >= w_C: second cast doesn't widen (truncation and reinterpretation
/// are sign-independent), OR
/// - w_A <= w_B AND s_A == s_B: first cast doesn't truncate and preserves
/// signedness, so the widening in the second cast uses the same extension
/// type as the direct A -> C cast would.
static bool isSafeCastComposition(mlir::Type srcType, mlir::Type midType, mlir::Type dstType) {
if (srcType == midType || midType == dstType) return true;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be needed due to what we have in CastOp::fold


auto srcBits = mlir::dyn_cast<P4HIR::BitsType>(srcType);
auto midBits = mlir::dyn_cast<P4HIR::BitsType>(midType);
auto dstBits = mlir::dyn_cast<P4HIR::BitsType>(dstType);

if (srcBits && midBits && dstBits) {
unsigned wA = srcBits.getWidth();
unsigned wB = midBits.getWidth();
unsigned wC = dstBits.getWidth();

// Safe if the second cast doesn't widen.
if (wB >= wC) return true;

// Second cast widens (wB < wC). Safe only if the first cast doesn't
// truncate and preserves signedness, so the composed extension matches
// the direct A -> C extension.
return wA <= wB && srcBits.isSigned() == midBits.isSigned();
}

// For non-BitsType chains, be conservative and don't fold.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please make this more complete by also handling booleans? They should be fairly easy to add and then you can also remove this comment.

return false;
}

LogicalResult P4HIR::CastOp::canonicalize(P4HIR::CastOp op, PatternRewriter &rewriter) {
// Composition.
// %b = cast(%a) : A -> B
// cast(%b) : B -> C
// ===> cast(%a) : A -> C
if (auto inputCast = mlir::dyn_cast_if_present<CastOp>(op.getSrc().getDefiningOp())) {
mlir::Type srcType = inputCast.getSrc().getType();
mlir::Type midType = inputCast.getType();
mlir::Type dstType = op.getType();

if (!isSafeCastComposition(srcType, midType, dstType)) return failure();

auto bitcast =
rewriter.createOrFold<P4HIR::CastOp>(op.getLoc(), op.getType(), inputCast.getSrc());
rewriter.replaceOp(op, bitcast);
Expand Down
74 changes: 74 additions & 0 deletions test/Transforms/Folds/cast.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
!i8i = !p4hir.int<8>
!i16i = !p4hir.int<16>
!b8i = !p4hir.bit<8>
!b16i = !p4hir.bit<16>
!b32i = !p4hir.bit<32>
!infint = !p4hir.infint
#false = #p4hir.bool<false> : !p4hir.bool
#true = #p4hir.bool<true> : !p4hir.bool
Expand Down Expand Up @@ -66,4 +68,76 @@ module {
%cast6 = p4hir.cast(%c-128_i8i : !i8i) : !i16i
p4hir.call @blackhole_i16i(%cast6) : (!i16i) -> ()
// CHECK: p4hir.call @blackhole_i16i (%[[cm128_i16i]]) : (!i16i) -> ()

// ---- Cast chain composition safety tests ----

p4hir.func @blackhole_b8i(!b8i)
p4hir.func @blackhole_b16i(!b16i)
p4hir.func @blackhole_b32i(!b32i)

// Use non-constant arguments so chains are not constant-folded away.
p4hir.func @cast_chain_tests(%arg_b8 : !b8i, %arg_b16 : !b16i, %arg_i8 : !i8i, %arg_b32 : !b32i) {
// Safe: widen then reinterpret (w_B >= w_C).
// bit<8> -> bit<16> -> int<16> folds to bit<8> -> int<16>.
// CHECK-LABEL: @cast_chain_tests
// CHECK: %[[V0:.*]] = p4hir.cast(%arg0 : !b8i) : !i16i
// CHECK: p4hir.call @blackhole_i16i (%[[V0]])
%c0 = p4hir.cast(%arg_b8 : !b8i) : !b16i
%c1 = p4hir.cast(%c0 : !b16i) : !i16i
p4hir.call @blackhole_i16i(%c1) : (!i16i) -> ()

// Unsafe: reinterpret then widen (sign change before widen).
// bit<8> -> int<8> -> int<16> must NOT fold: the chain sign-extends,
// but a direct bit<8> -> int<16> would zero-extend.
// CHECK: %[[V1:.*]] = p4hir.cast(%arg0 : !b8i) : !i8i
// CHECK: %[[V2:.*]] = p4hir.cast(%[[V1]] : !i8i) : !i16i
// CHECK: p4hir.call @blackhole_i16i (%[[V2]])
%c2 = p4hir.cast(%arg_b8 : !b8i) : !i8i
%c3 = p4hir.cast(%c2 : !i8i) : !i16i
p4hir.call @blackhole_i16i(%c3) : (!i16i) -> ()

// Safe: same-sign widen then widen (w_A <= w_B, s_A == s_B).
// bit<8> -> bit<16> -> bit<32> folds to bit<8> -> bit<32>.
// CHECK: %[[V3:.*]] = p4hir.cast(%arg0 : !b8i) : !b32i
// CHECK: p4hir.call @blackhole_b32i (%[[V3]])
%c4 = p4hir.cast(%arg_b8 : !b8i) : !b16i
%c5 = p4hir.cast(%c4 : !b16i) : !b32i
p4hir.call @blackhole_b32i(%c5) : (!b32i) -> ()

// Unsafe: truncate then widen (w_A > w_B, w_B < w_C).
// bit<16> -> bit<8> -> bit<32> must NOT fold: lossy truncation
// followed by widening differs from the direct widen.
// CHECK: %[[V4:.*]] = p4hir.cast(%arg1 : !b16i) : !b8i
// CHECK: %[[V5:.*]] = p4hir.cast(%[[V4]] : !b8i) : !b32i
// CHECK: p4hir.call @blackhole_b32i (%[[V5]])
%c6 = p4hir.cast(%arg_b16 : !b16i) : !b8i
%c7 = p4hir.cast(%c6 : !b8i) : !b32i
p4hir.call @blackhole_b32i(%c7) : (!b32i) -> ()

// Safe: widen then truncate (w_B >= w_C).
// bit<8> -> bit<32> -> bit<16> folds to bit<8> -> bit<16>.
// CHECK: %[[V6:.*]] = p4hir.cast(%arg0 : !b8i) : !b16i
// CHECK: p4hir.call @blackhole_b16i (%[[V6]])
%c8 = p4hir.cast(%arg_b8 : !b8i) : !b32i
%c9 = p4hir.cast(%c8 : !b32i) : !b16i
p4hir.call @blackhole_b16i(%c9) : (!b16i) -> ()

// Safe: truncate then reinterpret (w_B >= w_C).
// bit<16> -> bit<8> -> int<8> folds to bit<16> -> int<8>.
// CHECK: %[[V7:.*]] = p4hir.cast(%arg1 : !b16i) : !i8i
// CHECK: p4hir.call @blackhole_i8i (%[[V7]])
%c10 = p4hir.cast(%arg_b16 : !b16i) : !b8i
%c11 = p4hir.cast(%c10 : !b8i) : !i8i
p4hir.call @blackhole_i8i(%c11) : (!i8i) -> ()

// Safe: truncate then truncate (w_B >= w_C).
// bit<32> -> bit<16> -> bit<8> folds to bit<32> -> bit<8>.
// CHECK: %[[V8:.*]] = p4hir.cast(%arg3 : !b32i) : !b8i
// CHECK: p4hir.call @blackhole_b8i (%[[V8]])
%c12 = p4hir.cast(%arg_b32 : !b32i) : !b16i
%c13 = p4hir.cast(%c12 : !b16i) : !b8i
p4hir.call @blackhole_b8i(%c13) : (!b8i) -> ()

p4hir.return
}
}
Loading