Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,37 @@ def Mux4PadSel : Pat<
$a, $b, $c, $d)),
[(IntTypeWidthLTX<2> $cond)]>;


// mux2cell(cond, 0, 1) -> not(cond)
def Mux2Not : Pat<
(Mux2CellIntrinsicOp:$old $cond, (ConstantOp:$zcst $_),
(ConstantOp:$ocst $_)),
(MoveNameHint $old, (NotPrimOp $cond)), [
(EqualTypes $cond, $zcst), (EqualTypes $cond, $ocst),
(ZeroConstantOp $zcst), (OneConstantOp $ocst)
]>;

// mux4cell(cond, 0, 1, 0, 1) -> not(cond)
def Mux4Not : Pat<
(Mux4CellIntrinsicOp:$old $cond, (ConstantOp:$cst0 $_),
(ConstantOp:$cst1 $_), (ConstantOp:$cst0b $_), (ConstantOp:$cst1b $_)),
(MoveNameHint $old, (NotPrimOp $cond)), [
(EqualTypes $cond, $cst0), (EqualTypes $cond, $cst1),
(ZeroConstantOp $cst0), (OneConstantOp $cst1),
(ZeroConstantOp $cst0b), (OneConstantOp $cst1b)
]>;


// mux4cell(sel, 1, 0, 1, 0) -> not(sel)
def Mux4SelNot : Pat<
(Mux4CellIntrinsicOp:$old $cond, (ConstantOp:$cst0 $_),
(ConstantOp:$cst1 $_), (ConstantOp:$cst0b $_), (ConstantOp:$cst1b $_)),
(MoveNameHint $old, (NotPrimOp $cond)), [
(EqualTypes $cond, $cst0), (EqualTypes $cond, $cst1),
(OneConstantOp $cst0), (ZeroConstantOp $cst1),
(OneConstantOp $cst0b), (ZeroConstantOp $cst1b)
]>;

def CatDoubleConst : Pat <
(CatPrimOp:$old (variadic $cst1, (CatPrimOp (variadic $cst2, $v)))),
(MoveNameHint $old, (CatPrimOp (variadic (CatPrimOp (variadic $cst1, (AsUIntPrimOp $cst2))), (AsUIntPrimOp $v)))),
Expand Down
59 changes: 56 additions & 3 deletions lib/Dialect/FIRRTL/FIRRTLFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1871,7 +1871,60 @@ OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
return foldMux(*this, adaptor);
}

OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) { return {}; }
OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) {
// mux4 : UInt<0> -> 0
if (getType().getBitWidthOrSentinel() == 0)
return getIntAttr(getType(), APInt(0, 0, getType().isSignedInteger()));

// mux4(sel, x, x, x, x) -> x
auto operands = {getV0(), getV1(), getV2(), getV3()};
if (llvm::all_of(operands, [&](auto v) { return v == getV0(); }) &&
getV0().getType() == getType())
return getV0();

// The following folds require that the result has a known width.
if (getType().getBitWidthOrSentinel() < 0)
return {};

// mux4 with constant selector
if (auto cond = getConstant(adaptor.getSel())) {
switch (cond->getZExtValue()) {
case 0:
return getV0().getType() == getType() ? getV0() : Value{};
case 1:
return getV1().getType() == getType() ? getV1() : Value{};
case 2:
return getV2().getType() == getType() ? getV2() : Value{};
case 3:
return getV3().getType() == getType() ? getV3() : Value{};
default:
return {};
}
}

// mux4 with all constant operands
auto foldedVals = {adaptor.getV0(), adaptor.getV1(), adaptor.getV2(),
adaptor.getV3()};
SmallVector<APInt> constants;
for (auto v : foldedVals) {
if (auto c = getConstant(v)) {
constants.push_back(*c);
} else {
return {};
}
}

// All operands must have the same bit width
if (!llvm::all_of(constants,
[&](auto &c) { return c.getBitWidth() == constants[0].getBitWidth(); }))
return {};

// mux4(sel, c, c, c, c) -> c
if (llvm::all_of(constants, [&](auto &c) { return c == constants[0]; }))
return getIntAttr(getType(), constants[0]);

return {};
}

namespace {

Expand Down Expand Up @@ -2014,12 +2067,12 @@ void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,

void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<patterns::Mux2PadSel>(context);
results.add<patterns::Mux2PadSel, patterns::Mux2Not>(context);
}

void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<patterns::Mux4PadSel>(context);
results.add<patterns::Mux4PadSel, patterns::Mux4Not, patterns::Mux4SelNot>(context);
}

OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
Expand Down
13 changes: 13 additions & 0 deletions lib/Dialect/FIRRTL/Transforms/InferDomains.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1295,6 +1295,19 @@ LogicalResult ModuleState::processOp(Operation *op) {
processDomainDefinition(createAnon);
return success();
}
// DomainSubfieldOp has a DomainType input which is not processed by
// unifyAssociations. Check here that the domain is driven before access.
if (auto subfield = dyn_cast<DomainSubfieldOp>(op)) {
auto inputDomain = cast<DomainValue>(subfield.getInput());
auto *term = getOptTermForDomain(inputDomain);
// If the domain has no term, it was never driven.
if (!term) {
return subfield.emitOpError()
<< "accesses an undriven domain; domain wires must be "
"driven before their subfields can be accessed";
}
return success();
}

return unifyAssociations(op);
}
Expand Down
99 changes: 99 additions & 0 deletions test/Dialect/FIRRTL/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4168,3 +4168,102 @@ firrtl.module @BoolBinaryFold(in %b: !firrtl.bool,
}

}

// CHECK-LABEL: firrtl.circuit "Mux2CellMux4CellCanon"
firrtl.circuit "Mux2CellMux4CellCanon" {

// CHECK-LABEL: firrtl.module @Mux2CellMux4CellCanon
firrtl.module @Mux2CellMux4CellCanon(
in %sel: !firrtl.uint<1>,
in %sel2: !firrtl.uint<2>,
in %a: !firrtl.uint<4>,
in %b: !firrtl.uint<4>,
in %c: !firrtl.uint<4>,
in %d: !firrtl.uint<4>,
out %out1: !firrtl.uint<4>,
out %out2: !firrtl.uint<4>,
out %out3: !firrtl.uint<4>,
out %out4: !firrtl.uint<1>,
out %out5: !firrtl.uint<4>,
out %out6: !firrtl.uint<4>,
out %out7: !firrtl.uint<4>,
out %out8: !firrtl.uint<4>,
out %out9: !firrtl.uint<4>,
out %out10: !firrtl.uint<4>,
out %out11: !firrtl.uint<1>,
out %out12: !firrtl.uint<1>
) {
// mux2cell(sel, a, a) -> a (fold)
// CHECK: firrtl.matchingconnect %out1, %a
%1 = firrtl.int.mux2cell (%sel, %a, %a) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4>
firrtl.matchingconnect %out1, %1 : !firrtl.uint<4>

// mux2cell(0, a, b) -> b (fold)
// CHECK: firrtl.matchingconnect %out2, %b
%c0_ui1 = firrtl.constant 0 : !firrtl.uint<1>
%2 = firrtl.int.mux2cell (%c0_ui1, %a, %b) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4>
firrtl.matchingconnect %out2, %2 : !firrtl.uint<4>

// mux2cell(1, a, b) -> a (fold)
// CHECK: firrtl.matchingconnect %out3, %a
%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>
%3 = firrtl.int.mux2cell (%c1_ui1, %a, %b) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4>
firrtl.matchingconnect %out3, %3 : !firrtl.uint<4>

// mux2cell(sel, 0, 1) -> not(sel) (canonicalization - all types uint<1>)
%c0_ui1_2 = firrtl.constant 0 : !firrtl.uint<1>
%c1_ui1_2 = firrtl.constant 1 : !firrtl.uint<1>
%4 = firrtl.int.mux2cell (%sel, %c0_ui1_2, %c1_ui1_2) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK: [[NOT:%.+]] = firrtl.not %sel
// CHECK: firrtl.matchingconnect %out4, [[NOT]]
firrtl.matchingconnect %out4, %4 : !firrtl.uint<1>

// mux2cell with all constant operands -> mux2cell
%c10_ui4 = firrtl.constant 10 : !firrtl.uint<4>
%c12_ui4 = firrtl.constant 12 : !firrtl.uint<4>
%5 = firrtl.int.mux2cell (%sel, %c10_ui4, %c12_ui4) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4>
// CHECK: [[MUX:%.+]] = firrtl.int.mux2cell
// CHECK: firrtl.matchingconnect %out5, [[MUX]]
firrtl.matchingconnect %out5, %5 : !firrtl.uint<4>

// mux4cell(sel, a, a, a, a) -> a (fold)
// CHECK: firrtl.matchingconnect %out6, %a
%6 = firrtl.int.mux4cell (%sel, %a, %a, %a, %a) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4>
firrtl.matchingconnect %out6, %6 : !firrtl.uint<4>

// mux4cell(0, d, c, b, a) -> a (fold)
// CHECK: firrtl.matchingconnect %out7, %a
%7 = firrtl.int.mux4cell (%c0_ui1, %d, %c, %b, %a) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4>
firrtl.matchingconnect %out7, %7 : !firrtl.uint<4>

// mux4cell(1, d, c, b, a) -> b (fold)
// CHECK: firrtl.matchingconnect %out8, %b
%8 = firrtl.int.mux4cell (%c1_ui1, %d, %c, %b, %a) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4>
firrtl.matchingconnect %out8, %8 : !firrtl.uint<4>

// mux4cell(2, d, c, b, a) -> c (fold)
%c2_ui2 = firrtl.constant 2 : !firrtl.uint<2>
// CHECK: firrtl.matchingconnect %out9, %c
%9 = firrtl.int.mux4cell (%c2_ui2, %d, %c, %b, %a) : (!firrtl.uint<2>, !firrtl.uint<4>, !firrtl.uint<4>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4>
firrtl.matchingconnect %out9, %9 : !firrtl.uint<4>

// mux4cell(3, d, c, b, a) -> d (fold)
%c3_ui2 = firrtl.constant 3 : !firrtl.uint<2>
// CHECK: firrtl.matchingconnect %out10, %d
%10 = firrtl.int.mux4cell (%c3_ui2, %d, %c, %b, %a) : (!firrtl.uint<2>, !firrtl.uint<4>, !firrtl.uint<4>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4>
firrtl.matchingconnect %out10, %10 : !firrtl.uint<4>

// mux4cell(sel, 0, 1, 0, 1) -> not(sel) (canonicalization - all types uint<1>)
%11 = firrtl.int.mux4cell (%sel, %c0_ui1_2, %c1_ui1_2, %c0_ui1_2, %c1_ui1_2) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK: [[NOT2:%.+]] = firrtl.not %sel
// CHECK: firrtl.matchingconnect %out11, [[NOT2]]
firrtl.matchingconnect %out11, %11 : !firrtl.uint<1>

// mux4cell(sel, 1, 0, 1, 0) -> not(sel) (all types uint<1>)
%12 = firrtl.int.mux4cell (%sel, %c1_ui1_2, %c0_ui1_2, %c1_ui1_2, %c0_ui1_2) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK: [[NOT:%.+]] = firrtl.not %sel
// CHECK: firrtl.matchingconnect %out12, [[NOT]]
firrtl.matchingconnect %out12, %12 : !firrtl.uint<1>
}

}
13 changes: 13 additions & 0 deletions test/Dialect/FIRRTL/infer-domains-infer-errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,16 @@ firrtl.circuit "IllegalDomainCrossing" {
firrtl.matchingconnect %b, %a : !firrtl.uint<1>
}
}

// Test that accessing a subfield of an undriven domain wire produces an error.
// This catches the issue early in InferDomains instead of in LowerClasses.

// CHECK-LABEL: UndrivenDomainWireSubfield
firrtl.circuit "UndrivenDomainWireSubfield" {
firrtl.domain @ClockDomain [#firrtl.domain.field<"source", !firrtl.string>]
firrtl.module @UndrivenDomainWireSubfield() {
%wire = firrtl.wire : !firrtl.domain<@ClockDomain(source: !firrtl.string)>
// expected-error @below {{accesses an undriven domain; domain wires must be driven before their subfields can be accessed}}
%1 = firrtl.domain.subfield %wire[source] : !firrtl.domain<@ClockDomain(source: !firrtl.string)>
}
}
46 changes: 46 additions & 0 deletions test/Dialect/FIRRTL/muxcell-canonicalization.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Test for mux2cell/mux4cell canonicalization (Issue #5448)
// RUN: circt-opt -canonicalize %s | FileCheck %s

// These tests verify that:
// 1. WITHOUT the feature: mux2cell/mux4cell operations remain unsimplified
// 2. WITH the feature: mux2cell/mux4cell are simplified to not(sel) or sel

firrtl.circuit "MuxCellCanonicalization" {
// CHECK-LABEL: firrtl.module @MuxCellCanonicalization
firrtl.module @MuxCellCanonicalization(
in %sel: !firrtl.uint<1>,
out %out_mux2_not: !firrtl.uint<1>,
out %out_mux4_not: !firrtl.uint<1>,
out %out_mux4_identity: !firrtl.uint<1>
) {
%c0 = firrtl.constant 0 : !firrtl.uint<1>
%c1 = firrtl.constant 1 : !firrtl.uint<1>

// Test: mux2cell(sel, 0, 1) -> not(sel)
// This simplifies to a NOT operation when selector is 1-bit.
// Without the canonicalization pattern, this would remain as mux2cell.
// CHECK: [[NOT1:%.+]] = firrtl.not %sel
// CHECK: firrtl.matchingconnect %out_mux2_not, [[NOT1]]
%mux2 = firrtl.int.mux2cell(%sel, %c0, %c1) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.matchingconnect %out_mux2_not, %mux2 : !firrtl.uint<1>

// Test: mux4cell(sel, 0, 1, 0, 1) -> not(sel)
// This simplifies to a NOT operation. The selector is padded to uint<2>.
// Pattern matches only when selector and operand types match (both uint<1>).
// Without the canonicalization pattern, this would remain as mux4cell with pad.
// CHECK: [[NOT2:%.+]] = firrtl.not %sel
// CHECK: firrtl.matchingconnect %out_mux4_not, [[NOT2]]
%mux4 = firrtl.int.mux4cell(%sel, %c0, %c1, %c0, %c1) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.matchingconnect %out_mux4_not, %mux4 : !firrtl.uint<1>

// Test: mux4cell(sel, 1, 0, 1, 0) -> not(sel)
// This simplifies to just the selector. Operands are (1, 0, 1, 0) which maps to:
// sel=0 -> v0=1, sel=1 -> v1=0, sel=2 -> v2=1, sel=3 -> v3=0
// So for 1-bit selector: sel=0 returns 1, sel=1 returns 0, which is not(sel).
// Without the fold, this would remain as mux4cell.
// CHECK: [[NOT3:%.+]] = firrtl.not %sel
// CHECK: firrtl.matchingconnect %out_mux4_identity, [[NOT3]]
%mux4id = firrtl.int.mux4cell(%sel, %c1, %c0, %c1, %c0) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.matchingconnect %out_mux4_identity, %mux4id : !firrtl.uint<1>
}
}
Loading