diff --git a/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td b/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td index 0fa59674d918..6300ac68ac97 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td +++ b/include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td @@ -645,6 +645,37 @@ def Mux4PadSel : Pat< $a, $b, $c, $d)), [(IntTypeWidthLTX<2> $cond)]>; + +// mux2cell(cond, 0, 1) -> not(cond) +def Mux2Not : Pat< + (Mux2CellIntrinsicOp:$old $cond, (ConstantOp:$zcst $_), + (ConstantOp:$ocst $_)), + (MoveNameHint $old, (NotPrimOp $cond)), [ + (EqualTypes $cond, $zcst), (EqualTypes $cond, $ocst), + (ZeroConstantOp $zcst), (OneConstantOp $ocst) + ]>; + +// mux4cell(cond, 0, 1, 0, 1) -> not(cond) +def Mux4Not : Pat< + (Mux4CellIntrinsicOp:$old $cond, (ConstantOp:$cst0 $_), + (ConstantOp:$cst1 $_), (ConstantOp:$cst0b $_), (ConstantOp:$cst1b $_)), + (MoveNameHint $old, (NotPrimOp $cond)), [ + (EqualTypes $cond, $cst0), (EqualTypes $cond, $cst1), + (ZeroConstantOp $cst0), (OneConstantOp $cst1), + (ZeroConstantOp $cst0b), (OneConstantOp $cst1b) + ]>; + + +// mux4cell(sel, 1, 0, 1, 0) -> not(sel) +def Mux4SelNot : Pat< + (Mux4CellIntrinsicOp:$old $cond, (ConstantOp:$cst0 $_), + (ConstantOp:$cst1 $_), (ConstantOp:$cst0b $_), (ConstantOp:$cst1b $_)), + (MoveNameHint $old, (NotPrimOp $cond)), [ + (EqualTypes $cond, $cst0), (EqualTypes $cond, $cst1), + (OneConstantOp $cst0), (ZeroConstantOp $cst1), + (OneConstantOp $cst0b), (ZeroConstantOp $cst1b) + ]>; + def CatDoubleConst : Pat < (CatPrimOp:$old (variadic $cst1, (CatPrimOp (variadic $cst2, $v)))), (MoveNameHint $old, (CatPrimOp (variadic (CatPrimOp (variadic $cst1, (AsUIntPrimOp $cst2))), (AsUIntPrimOp $v)))), diff --git a/lib/Dialect/FIRRTL/FIRRTLFolds.cpp b/lib/Dialect/FIRRTL/FIRRTLFolds.cpp index fa769d910a1d..705d22aacda2 100644 --- a/lib/Dialect/FIRRTL/FIRRTLFolds.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLFolds.cpp @@ -1871,7 +1871,60 @@ OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) { return foldMux(*this, adaptor); } -OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) { return {}; } +OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) { + // mux4 : UInt<0> -> 0 + if (getType().getBitWidthOrSentinel() == 0) + return getIntAttr(getType(), APInt(0, 0, getType().isSignedInteger())); + + // mux4(sel, x, x, x, x) -> x + auto operands = {getV0(), getV1(), getV2(), getV3()}; + if (llvm::all_of(operands, [&](auto v) { return v == getV0(); }) && + getV0().getType() == getType()) + return getV0(); + + // The following folds require that the result has a known width. + if (getType().getBitWidthOrSentinel() < 0) + return {}; + + // mux4 with constant selector + if (auto cond = getConstant(adaptor.getSel())) { + switch (cond->getZExtValue()) { + case 0: + return getV0().getType() == getType() ? getV0() : Value{}; + case 1: + return getV1().getType() == getType() ? getV1() : Value{}; + case 2: + return getV2().getType() == getType() ? getV2() : Value{}; + case 3: + return getV3().getType() == getType() ? getV3() : Value{}; + default: + return {}; + } + } + + // mux4 with all constant operands + auto foldedVals = {adaptor.getV0(), adaptor.getV1(), adaptor.getV2(), + adaptor.getV3()}; + SmallVector constants; + for (auto v : foldedVals) { + if (auto c = getConstant(v)) { + constants.push_back(*c); + } else { + return {}; + } + } + + // All operands must have the same bit width + if (!llvm::all_of(constants, + [&](auto &c) { return c.getBitWidth() == constants[0].getBitWidth(); })) + return {}; + + // mux4(sel, c, c, c, c) -> c + if (llvm::all_of(constants, [&](auto &c) { return c == constants[0]; })) + return getIntAttr(getType(), constants[0]); + + return {}; +} namespace { @@ -2014,12 +2067,12 @@ void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results, void Mux2CellIntrinsicOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } void Mux4CellIntrinsicOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) { diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index ed368e5cb7ce..4017537caf59 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -1295,6 +1295,19 @@ LogicalResult ModuleState::processOp(Operation *op) { processDomainDefinition(createAnon); return success(); } + // DomainSubfieldOp has a DomainType input which is not processed by + // unifyAssociations. Check here that the domain is driven before access. + if (auto subfield = dyn_cast(op)) { + auto inputDomain = cast(subfield.getInput()); + auto *term = getOptTermForDomain(inputDomain); + // If the domain has no term, it was never driven. + if (!term) { + return subfield.emitOpError() + << "accesses an undriven domain; domain wires must be " + "driven before their subfields can be accessed"; + } + return success(); + } return unifyAssociations(op); } diff --git a/test/Dialect/FIRRTL/canonicalization.mlir b/test/Dialect/FIRRTL/canonicalization.mlir index af67120afc3e..bf28d5d399a6 100644 --- a/test/Dialect/FIRRTL/canonicalization.mlir +++ b/test/Dialect/FIRRTL/canonicalization.mlir @@ -4168,3 +4168,102 @@ firrtl.module @BoolBinaryFold(in %b: !firrtl.bool, } } + +// CHECK-LABEL: firrtl.circuit "Mux2CellMux4CellCanon" +firrtl.circuit "Mux2CellMux4CellCanon" { + +// CHECK-LABEL: firrtl.module @Mux2CellMux4CellCanon +firrtl.module @Mux2CellMux4CellCanon( + in %sel: !firrtl.uint<1>, + in %sel2: !firrtl.uint<2>, + in %a: !firrtl.uint<4>, + in %b: !firrtl.uint<4>, + in %c: !firrtl.uint<4>, + in %d: !firrtl.uint<4>, + out %out1: !firrtl.uint<4>, + out %out2: !firrtl.uint<4>, + out %out3: !firrtl.uint<4>, + out %out4: !firrtl.uint<1>, + out %out5: !firrtl.uint<4>, + out %out6: !firrtl.uint<4>, + out %out7: !firrtl.uint<4>, + out %out8: !firrtl.uint<4>, + out %out9: !firrtl.uint<4>, + out %out10: !firrtl.uint<4>, + out %out11: !firrtl.uint<1>, + out %out12: !firrtl.uint<1> +) { + // mux2cell(sel, a, a) -> a (fold) + // CHECK: firrtl.matchingconnect %out1, %a + %1 = firrtl.int.mux2cell (%sel, %a, %a) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4> + firrtl.matchingconnect %out1, %1 : !firrtl.uint<4> + + // mux2cell(0, a, b) -> b (fold) + // CHECK: firrtl.matchingconnect %out2, %b + %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> + %2 = firrtl.int.mux2cell (%c0_ui1, %a, %b) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4> + firrtl.matchingconnect %out2, %2 : !firrtl.uint<4> + + // mux2cell(1, a, b) -> a (fold) + // CHECK: firrtl.matchingconnect %out3, %a + %c1_ui1 = firrtl.constant 1 : !firrtl.uint<1> + %3 = firrtl.int.mux2cell (%c1_ui1, %a, %b) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4> + firrtl.matchingconnect %out3, %3 : !firrtl.uint<4> + + // mux2cell(sel, 0, 1) -> not(sel) (canonicalization - all types uint<1>) + %c0_ui1_2 = firrtl.constant 0 : !firrtl.uint<1> + %c1_ui1_2 = firrtl.constant 1 : !firrtl.uint<1> + %4 = firrtl.int.mux2cell (%sel, %c0_ui1_2, %c1_ui1_2) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1> + // CHECK: [[NOT:%.+]] = firrtl.not %sel + // CHECK: firrtl.matchingconnect %out4, [[NOT]] + firrtl.matchingconnect %out4, %4 : !firrtl.uint<1> + + // mux2cell with all constant operands -> mux2cell + %c10_ui4 = firrtl.constant 10 : !firrtl.uint<4> + %c12_ui4 = firrtl.constant 12 : !firrtl.uint<4> + %5 = firrtl.int.mux2cell (%sel, %c10_ui4, %c12_ui4) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4> + // CHECK: [[MUX:%.+]] = firrtl.int.mux2cell + // CHECK: firrtl.matchingconnect %out5, [[MUX]] + firrtl.matchingconnect %out5, %5 : !firrtl.uint<4> + + // mux4cell(sel, a, a, a, a) -> a (fold) + // CHECK: firrtl.matchingconnect %out6, %a + %6 = firrtl.int.mux4cell (%sel, %a, %a, %a, %a) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4> + firrtl.matchingconnect %out6, %6 : !firrtl.uint<4> + + // mux4cell(0, d, c, b, a) -> a (fold) + // CHECK: firrtl.matchingconnect %out7, %a + %7 = firrtl.int.mux4cell (%c0_ui1, %d, %c, %b, %a) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4> + firrtl.matchingconnect %out7, %7 : !firrtl.uint<4> + + // mux4cell(1, d, c, b, a) -> b (fold) + // CHECK: firrtl.matchingconnect %out8, %b + %8 = firrtl.int.mux4cell (%c1_ui1, %d, %c, %b, %a) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4> + firrtl.matchingconnect %out8, %8 : !firrtl.uint<4> + + // mux4cell(2, d, c, b, a) -> c (fold) + %c2_ui2 = firrtl.constant 2 : !firrtl.uint<2> + // CHECK: firrtl.matchingconnect %out9, %c + %9 = firrtl.int.mux4cell (%c2_ui2, %d, %c, %b, %a) : (!firrtl.uint<2>, !firrtl.uint<4>, !firrtl.uint<4>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4> + firrtl.matchingconnect %out9, %9 : !firrtl.uint<4> + + // mux4cell(3, d, c, b, a) -> d (fold) + %c3_ui2 = firrtl.constant 3 : !firrtl.uint<2> + // CHECK: firrtl.matchingconnect %out10, %d + %10 = firrtl.int.mux4cell (%c3_ui2, %d, %c, %b, %a) : (!firrtl.uint<2>, !firrtl.uint<4>, !firrtl.uint<4>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4> + firrtl.matchingconnect %out10, %10 : !firrtl.uint<4> + + // mux4cell(sel, 0, 1, 0, 1) -> not(sel) (canonicalization - all types uint<1>) + %11 = firrtl.int.mux4cell (%sel, %c0_ui1_2, %c1_ui1_2, %c0_ui1_2, %c1_ui1_2) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1> + // CHECK: [[NOT2:%.+]] = firrtl.not %sel + // CHECK: firrtl.matchingconnect %out11, [[NOT2]] + firrtl.matchingconnect %out11, %11 : !firrtl.uint<1> + + // mux4cell(sel, 1, 0, 1, 0) -> not(sel) (all types uint<1>) + %12 = firrtl.int.mux4cell (%sel, %c1_ui1_2, %c0_ui1_2, %c1_ui1_2, %c0_ui1_2) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1> + // CHECK: [[NOT:%.+]] = firrtl.not %sel + // CHECK: firrtl.matchingconnect %out12, [[NOT]] + firrtl.matchingconnect %out12, %12 : !firrtl.uint<1> +} + +} diff --git a/test/Dialect/FIRRTL/infer-domains-infer-errors.mlir b/test/Dialect/FIRRTL/infer-domains-infer-errors.mlir index 96923d92b2f3..21851d8e7db3 100644 --- a/test/Dialect/FIRRTL/infer-domains-infer-errors.mlir +++ b/test/Dialect/FIRRTL/infer-domains-infer-errors.mlir @@ -52,3 +52,16 @@ firrtl.circuit "IllegalDomainCrossing" { firrtl.matchingconnect %b, %a : !firrtl.uint<1> } } + +// Test that accessing a subfield of an undriven domain wire produces an error. +// This catches the issue early in InferDomains instead of in LowerClasses. + +// CHECK-LABEL: UndrivenDomainWireSubfield +firrtl.circuit "UndrivenDomainWireSubfield" { + firrtl.domain @ClockDomain [#firrtl.domain.field<"source", !firrtl.string>] + firrtl.module @UndrivenDomainWireSubfield() { + %wire = firrtl.wire : !firrtl.domain<@ClockDomain(source: !firrtl.string)> + // expected-error @below {{accesses an undriven domain; domain wires must be driven before their subfields can be accessed}} + %1 = firrtl.domain.subfield %wire[source] : !firrtl.domain<@ClockDomain(source: !firrtl.string)> + } +} diff --git a/test/Dialect/FIRRTL/muxcell-canonicalization.mlir b/test/Dialect/FIRRTL/muxcell-canonicalization.mlir new file mode 100644 index 000000000000..7734415a214b --- /dev/null +++ b/test/Dialect/FIRRTL/muxcell-canonicalization.mlir @@ -0,0 +1,46 @@ +// Test for mux2cell/mux4cell canonicalization (Issue #5448) +// RUN: circt-opt -canonicalize %s | FileCheck %s + +// These tests verify that: +// 1. WITHOUT the feature: mux2cell/mux4cell operations remain unsimplified +// 2. WITH the feature: mux2cell/mux4cell are simplified to not(sel) or sel + +firrtl.circuit "MuxCellCanonicalization" { + // CHECK-LABEL: firrtl.module @MuxCellCanonicalization + firrtl.module @MuxCellCanonicalization( + in %sel: !firrtl.uint<1>, + out %out_mux2_not: !firrtl.uint<1>, + out %out_mux4_not: !firrtl.uint<1>, + out %out_mux4_identity: !firrtl.uint<1> + ) { + %c0 = firrtl.constant 0 : !firrtl.uint<1> + %c1 = firrtl.constant 1 : !firrtl.uint<1> + + // Test: mux2cell(sel, 0, 1) -> not(sel) + // This simplifies to a NOT operation when selector is 1-bit. + // Without the canonicalization pattern, this would remain as mux2cell. + // CHECK: [[NOT1:%.+]] = firrtl.not %sel + // CHECK: firrtl.matchingconnect %out_mux2_not, [[NOT1]] + %mux2 = firrtl.int.mux2cell(%sel, %c0, %c1) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1> + firrtl.matchingconnect %out_mux2_not, %mux2 : !firrtl.uint<1> + + // Test: mux4cell(sel, 0, 1, 0, 1) -> not(sel) + // This simplifies to a NOT operation. The selector is padded to uint<2>. + // Pattern matches only when selector and operand types match (both uint<1>). + // Without the canonicalization pattern, this would remain as mux4cell with pad. + // CHECK: [[NOT2:%.+]] = firrtl.not %sel + // CHECK: firrtl.matchingconnect %out_mux4_not, [[NOT2]] + %mux4 = firrtl.int.mux4cell(%sel, %c0, %c1, %c0, %c1) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1> + firrtl.matchingconnect %out_mux4_not, %mux4 : !firrtl.uint<1> + + // Test: mux4cell(sel, 1, 0, 1, 0) -> not(sel) + // This simplifies to just the selector. Operands are (1, 0, 1, 0) which maps to: + // sel=0 -> v0=1, sel=1 -> v1=0, sel=2 -> v2=1, sel=3 -> v3=0 + // So for 1-bit selector: sel=0 returns 1, sel=1 returns 0, which is not(sel). + // Without the fold, this would remain as mux4cell. + // CHECK: [[NOT3:%.+]] = firrtl.not %sel + // CHECK: firrtl.matchingconnect %out_mux4_identity, [[NOT3]] + %mux4id = firrtl.int.mux4cell(%sel, %c1, %c0, %c1, %c0) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1> + firrtl.matchingconnect %out_mux4_identity, %mux4id : !firrtl.uint<1> + } +}