Skip to content

Commit f6c8a4c

Browse files
author
arashed
committed
[P4HIR] Fix unsound CastOp chain folding (#317)
Cast chains A→B→C were unconditionally folded to A→C, but this changes semantics when B alters signedness before a widening cast (e.g. bit<8>→int<8>→int<16> sign-extends, but bit<8>→int<16> zero-extends). Only fold when the second cast doesn't widen, or the first cast preserves signedness without truncating. Fixes #317
1 parent 8dc43ac commit f6c8a4c

2 files changed

Lines changed: 119 additions & 0 deletions

File tree

lib/Dialect/P4HIR/P4HIR_Ops.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,12 +250,57 @@ OpFoldResult P4HIR::CastOp::fold(FoldAdaptor) {
250250
return {};
251251
}
252252

253+
/// Returns true if folding a cast chain A -> B -> C into A -> C preserves
254+
/// the cast semantics per the P4 specification (Section 8.11.1).
255+
///
256+
/// The key unsafe case is when the second cast widens (w_B < w_C), because
257+
/// the extension type (zero vs sign) depends on the source's signedness.
258+
/// If the intermediate type B differs from A in a way that changes how
259+
/// the widening is performed, the fold would alter semantics.
260+
///
261+
/// Fold is safe when:
262+
/// - w_B >= w_C: second cast doesn't widen (truncation and reinterpretation
263+
/// are sign-independent), OR
264+
/// - w_A <= w_B AND s_A == s_B: first cast doesn't truncate and preserves
265+
/// signedness, so the widening in the second cast uses the same extension
266+
/// type as the direct A -> C cast would.
267+
static bool isSafeCastComposition(mlir::Type srcType, mlir::Type midType, mlir::Type dstType) {
268+
if (srcType == midType || midType == dstType) return true;
269+
270+
auto srcBits = mlir::dyn_cast<P4HIR::BitsType>(srcType);
271+
auto midBits = mlir::dyn_cast<P4HIR::BitsType>(midType);
272+
auto dstBits = mlir::dyn_cast<P4HIR::BitsType>(dstType);
273+
274+
if (srcBits && midBits && dstBits) {
275+
unsigned wA = srcBits.getWidth();
276+
unsigned wB = midBits.getWidth();
277+
unsigned wC = dstBits.getWidth();
278+
279+
// Safe if the second cast doesn't widen.
280+
if (wB >= wC) return true;
281+
282+
// Second cast widens (wB < wC). Safe only if the first cast doesn't
283+
// truncate and preserves signedness, so the composed extension matches
284+
// the direct A -> C extension.
285+
return wA <= wB && srcBits.isSigned() == midBits.isSigned();
286+
}
287+
288+
// For non-BitsType chains, be conservative and don't fold.
289+
return false;
290+
}
291+
253292
LogicalResult P4HIR::CastOp::canonicalize(P4HIR::CastOp op, PatternRewriter &rewriter) {
254293
// Composition.
255294
// %b = cast(%a) : A -> B
256295
// cast(%b) : B -> C
257296
// ===> cast(%a) : A -> C
258297
if (auto inputCast = mlir::dyn_cast_if_present<CastOp>(op.getSrc().getDefiningOp())) {
298+
mlir::Type srcType = inputCast.getSrc().getType();
299+
mlir::Type midType = inputCast.getType();
300+
mlir::Type dstType = op.getType();
301+
302+
if (!isSafeCastComposition(srcType, midType, dstType)) return failure();
303+
259304
auto bitcast =
260305
rewriter.createOrFold<P4HIR::CastOp>(op.getLoc(), op.getType(), inputCast.getSrc());
261306
rewriter.replaceOp(op, bitcast);

test/Transforms/Folds/cast.mlir

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
!i8i = !p4hir.int<8>
44
!i16i = !p4hir.int<16>
55
!b8i = !p4hir.bit<8>
6+
!b16i = !p4hir.bit<16>
7+
!b32i = !p4hir.bit<32>
68
!infint = !p4hir.infint
79
#false = #p4hir.bool<false> : !p4hir.bool
810
#true = #p4hir.bool<true> : !p4hir.bool
@@ -66,4 +68,76 @@ module {
6668
%cast6 = p4hir.cast(%c-128_i8i : !i8i) : !i16i
6769
p4hir.call @blackhole_i16i(%cast6) : (!i16i) -> ()
6870
// CHECK: p4hir.call @blackhole_i16i (%[[cm128_i16i]]) : (!i16i) -> ()
71+
72+
// ---- Cast chain composition safety tests ----
73+
74+
p4hir.func @blackhole_b8i(!b8i)
75+
p4hir.func @blackhole_b16i(!b16i)
76+
p4hir.func @blackhole_b32i(!b32i)
77+
78+
// Use non-constant arguments so chains are not constant-folded away.
79+
p4hir.func @cast_chain_tests(%arg_b8 : !b8i, %arg_b16 : !b16i, %arg_i8 : !i8i, %arg_b32 : !b32i) {
80+
// Safe: widen then reinterpret (w_B >= w_C).
81+
// bit<8> -> bit<16> -> int<16> folds to bit<8> -> int<16>.
82+
// CHECK-LABEL: @cast_chain_tests
83+
// CHECK: %[[V0:.*]] = p4hir.cast(%arg0 : !b8i) : !i16i
84+
// CHECK: p4hir.call @blackhole_i16i (%[[V0]])
85+
%c0 = p4hir.cast(%arg_b8 : !b8i) : !b16i
86+
%c1 = p4hir.cast(%c0 : !b16i) : !i16i
87+
p4hir.call @blackhole_i16i(%c1) : (!i16i) -> ()
88+
89+
// Unsafe: reinterpret then widen (sign change before widen).
90+
// bit<8> -> int<8> -> int<16> must NOT fold: the chain sign-extends,
91+
// but a direct bit<8> -> int<16> would zero-extend.
92+
// CHECK: %[[V1:.*]] = p4hir.cast(%arg0 : !b8i) : !i8i
93+
// CHECK: %[[V2:.*]] = p4hir.cast(%[[V1]] : !i8i) : !i16i
94+
// CHECK: p4hir.call @blackhole_i16i (%[[V2]])
95+
%c2 = p4hir.cast(%arg_b8 : !b8i) : !i8i
96+
%c3 = p4hir.cast(%c2 : !i8i) : !i16i
97+
p4hir.call @blackhole_i16i(%c3) : (!i16i) -> ()
98+
99+
// Safe: same-sign widen then widen (w_A <= w_B, s_A == s_B).
100+
// bit<8> -> bit<16> -> bit<32> folds to bit<8> -> bit<32>.
101+
// CHECK: %[[V3:.*]] = p4hir.cast(%arg0 : !b8i) : !b32i
102+
// CHECK: p4hir.call @blackhole_b32i (%[[V3]])
103+
%c4 = p4hir.cast(%arg_b8 : !b8i) : !b16i
104+
%c5 = p4hir.cast(%c4 : !b16i) : !b32i
105+
p4hir.call @blackhole_b32i(%c5) : (!b32i) -> ()
106+
107+
// Unsafe: truncate then widen (w_A > w_B, w_B < w_C).
108+
// bit<16> -> bit<8> -> bit<32> must NOT fold: lossy truncation
109+
// followed by widening differs from the direct widen.
110+
// CHECK: %[[V4:.*]] = p4hir.cast(%arg1 : !b16i) : !b8i
111+
// CHECK: %[[V5:.*]] = p4hir.cast(%[[V4]] : !b8i) : !b32i
112+
// CHECK: p4hir.call @blackhole_b32i (%[[V5]])
113+
%c6 = p4hir.cast(%arg_b16 : !b16i) : !b8i
114+
%c7 = p4hir.cast(%c6 : !b8i) : !b32i
115+
p4hir.call @blackhole_b32i(%c7) : (!b32i) -> ()
116+
117+
// Safe: widen then truncate (w_B >= w_C).
118+
// bit<8> -> bit<32> -> bit<16> folds to bit<8> -> bit<16>.
119+
// CHECK: %[[V6:.*]] = p4hir.cast(%arg0 : !b8i) : !b16i
120+
// CHECK: p4hir.call @blackhole_b16i (%[[V6]])
121+
%c8 = p4hir.cast(%arg_b8 : !b8i) : !b32i
122+
%c9 = p4hir.cast(%c8 : !b32i) : !b16i
123+
p4hir.call @blackhole_b16i(%c9) : (!b16i) -> ()
124+
125+
// Safe: truncate then reinterpret (w_B >= w_C).
126+
// bit<16> -> bit<8> -> int<8> folds to bit<16> -> int<8>.
127+
// CHECK: %[[V7:.*]] = p4hir.cast(%arg1 : !b16i) : !i8i
128+
// CHECK: p4hir.call @blackhole_i8i (%[[V7]])
129+
%c10 = p4hir.cast(%arg_b16 : !b16i) : !b8i
130+
%c11 = p4hir.cast(%c10 : !b8i) : !i8i
131+
p4hir.call @blackhole_i8i(%c11) : (!i8i) -> ()
132+
133+
// Safe: truncate then truncate (w_B >= w_C).
134+
// bit<32> -> bit<16> -> bit<8> folds to bit<32> -> bit<8>.
135+
// CHECK: %[[V8:.*]] = p4hir.cast(%arg3 : !b32i) : !b8i
136+
// CHECK: p4hir.call @blackhole_b8i (%[[V8]])
137+
%c12 = p4hir.cast(%arg_b32 : !b32i) : !b16i
138+
%c13 = p4hir.cast(%c12 : !b16i) : !b8i
139+
p4hir.call @blackhole_b8i(%c13) : (!b8i) -> ()
140+
141+
p4hir.return
142+
}
69143
}

0 commit comments

Comments
 (0)