From 8e7d48351229fd2bba9d3ede06ef5080df837dc5 Mon Sep 17 00:00:00 2001 From: okekayode Date: Sat, 30 May 2026 18:27:23 -0700 Subject: [PATCH] [Synth] Add an operation for declarative Cut rewrite pattern --- .../circt/Dialect/Synth/SynthAttributes.td | 10 +- include/circt/Dialect/Synth/SynthOps.h | 1 + include/circt/Dialect/Synth/SynthOps.td | 23 +++++ lib/Dialect/Synth/CMakeLists.txt | 1 + lib/Dialect/Synth/SynthAttributes.cpp | 39 ++++++++ lib/Dialect/Synth/SynthOps.cpp | 98 +++++++++++++++++++ lib/Dialect/Synth/Transforms/TechMapper.cpp | 17 ++-- test/Dialect/Synth/errors.mlir | 38 +++++++ test/Dialect/Synth/round-trip.mlir | 14 +++ 9 files changed, 229 insertions(+), 12 deletions(-) create mode 100644 lib/Dialect/Synth/SynthAttributes.cpp diff --git a/include/circt/Dialect/Synth/SynthAttributes.td b/include/circt/Dialect/Synth/SynthAttributes.td index 9844a3e554f2..e6d1093a71de 100644 --- a/include/circt/Dialect/Synth/SynthAttributes.td +++ b/include/circt/Dialect/Synth/SynthAttributes.td @@ -76,12 +76,14 @@ def MappingCostAttr : AttrDef { let summary = "Simplified timing and area cost for tech mapping"; let parameters = (ins "::mlir::FloatAttr":$area, - "::mlir::ArrayAttr":$arcs, - "::mlir::DictionaryAttr":$inputCaps + OptionalParameter<"::mlir::ArrayAttr">:$arcs, + OptionalParameter<"::mlir::DictionaryAttr">:$inputCaps ); + let genVerifyDecl = 1; let assemblyFormat = - "`<` `area` `=` $area `,` `arcs` `=` $arcs `,` " - "`input_caps` `=` $inputCaps `>`"; + "`<` `area` `=` $area " + "(`,` `arcs` `=` $arcs^)? " + "(`,` `input_caps` `=` $inputCaps^)? `>`"; } #endif // CIRCT_DIALECT_SYNTH_SYNTHATTRIBUTES_TD diff --git a/include/circt/Dialect/Synth/SynthOps.h b/include/circt/Dialect/Synth/SynthOps.h index aae1b0bee7f1..51f7b24c8134 100644 --- a/include/circt/Dialect/Synth/SynthOps.h +++ b/include/circt/Dialect/Synth/SynthOps.h @@ -13,6 +13,7 @@ #ifndef CIRCT_DIALECT_SYNTH_SYNTHOPS_H #define CIRCT_DIALECT_SYNTH_SYNTHOPS_H +#include "circt/Dialect/Synth/SynthAttributes.h" #include "circt/Dialect/Synth/SynthDialect.h" #include "circt/Dialect/Synth/SynthOpInterfaces.h" #include "circt/Support/LLVM.h" diff --git a/include/circt/Dialect/Synth/SynthOps.td b/include/circt/Dialect/Synth/SynthOps.td index b6b610770af9..0f79b562821a 100644 --- a/include/circt/Dialect/Synth/SynthOps.td +++ b/include/circt/Dialect/Synth/SynthOps.td @@ -305,6 +305,29 @@ def GambleOp : SymmetricThreeInputOp<"gamble", "evaluateGambleLogic"> { }]; } +def CutRewritePatternOp : SynthOp<"cut_rewrite_pattern", [ + IsolatedFromAbove, + SingleBlockImplicitTerminator<"YieldOp"> +]> { + let summary = "Declarative cut rewrite pattern"; + + let arguments = (ins + TypeAttrOf:$function_type, + MappingCostAttr:$cost + ); + + let regions = (region SizedRegion<1>:$body); + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +def YieldOp : SynthOp<"yield", + [Pure, Terminator]> { + let summary = "Yield synth operations"; + + let arguments = (ins Variadic:$operands); + let assemblyFormat = "$operands attr-dict `:` type($operands)"; +} #endif // CIRCT_DIALECT_SYNTH_SYNTHOPS_TD diff --git a/lib/Dialect/Synth/CMakeLists.txt b/lib/Dialect/Synth/CMakeLists.txt index a4f378823fb4..da3ee219c7fc 100644 --- a/lib/Dialect/Synth/CMakeLists.txt +++ b/lib/Dialect/Synth/CMakeLists.txt @@ -5,6 +5,7 @@ ##===----------------------------------------------------------------------===// add_circt_dialect_library(CIRCTSynth + SynthAttributes.cpp SynthDialect.cpp SynthOpInterfaces.cpp SynthOps.cpp diff --git a/lib/Dialect/Synth/SynthAttributes.cpp b/lib/Dialect/Synth/SynthAttributes.cpp new file mode 100644 index 000000000000..294804b829b2 --- /dev/null +++ b/lib/Dialect/Synth/SynthAttributes.cpp @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/Synth/SynthAttributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace circt; +using namespace circt::synth; +using namespace mlir; + +//===----------------------------------------------------------------------===// +// MappingCostAttr +//===----------------------------------------------------------------------===// + +LogicalResult +MappingCostAttr::verify(llvm::function_ref emitError, + FloatAttr area, ArrayAttr arcs, + DictionaryAttr inputCaps) { + if (arcs) + for (auto attr : arcs) + if (!isa(attr)) + return emitError() + << "expected arcs to contain synth.linear_timing_arc"; + + if (inputCaps) + for (auto entry : inputCaps) + if (!isa(entry.getValue())) + return emitError() + << "expected input_caps values to be floating-point attributes"; + + return success(); +} diff --git a/lib/Dialect/Synth/SynthOps.cpp b/lib/Dialect/Synth/SynthOps.cpp index 166018b6342d..2d6bbe0fe714 100644 --- a/lib/Dialect/Synth/SynthOps.cpp +++ b/lib/Dialect/Synth/SynthOps.cpp @@ -19,7 +19,10 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionImplementation.h" #include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/LogicalResult.h" @@ -626,3 +629,98 @@ void GambleOp::emitCNFWithoutInversion( // out = allSet | ~orSet circt::addOrClauses(outVar, {allSet, -orSet}, addClause); } + +//===----------------------------------------------------------------------===// +// CutRewritePatternOp +//===----------------------------------------------------------------------===// + +ParseResult CutRewritePatternOp::parse(OpAsmParser &parser, + OperationState &result) { + + SmallVector entryArgs; + SmallVector resultTypes; + SmallVector resultAttrs; + bool isVariadic = false; + + if (function_interface_impl::parseFunctionSignatureWithArguments( + parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, + resultAttrs)) + return failure(); + + auto inputTypes = llvm::map_to_vector( + entryArgs, [](auto &arg) -> Type { return arg.type; }); + auto functionType = + parser.getBuilder().getFunctionType(inputTypes, resultTypes); + + result.addAttribute(getFunctionTypeAttrName(result.name), + TypeAttr::get(functionType)); + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + return failure(); + + return parser.parseRegion(*result.addRegion(), entryArgs, + /*enableNameShadowing=*/false); +} + +void CutRewritePatternOp::print(OpAsmPrinter &p) { + auto functionType = getFunctionType(); + call_interface_impl::printFunctionSignature( + p, functionType.getInputs(), /*argAttrs=*/{}, /*isVariadic=*/false, + functionType.getResults(), /*resultAttrs=*/{}, &getBody(), + /*printEmptyResult=*/false); + + p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), + {getFunctionTypeAttrName()}); + + p << ' '; + p.printRegion(getBody(), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); +} + +LogicalResult CutRewritePatternOp::verify() { + auto functionType = getFunctionType(); + + if (functionType.getNumResults() != 1) + return emitError() << "requires exactly one result"; + + for (auto type : functionType.getInputs()) + if (!type.isInteger(1)) + return emitError() << "argument type must be i1, but got " << type; + + for (auto type : functionType.getResults()) + if (!type.isInteger(1)) + return emitError() << "result type must be i1, but got " << type; + + // Check outputs. + auto *terminator = this->getBody().front().getTerminator(); + if (terminator->getOperands().size() != functionType.getNumResults()) + return emitError() << "result type doesn't match with the terminator"; + + for (auto [lhs, rhs] : llvm::zip(terminator->getOperands().getTypes(), + functionType.getResults())) + if (rhs != lhs) + return emitError() << rhs << " is expected but got " << lhs; + + auto blockArgs = this->getBody().front().getArguments(); + if (blockArgs.size() != functionType.getNumInputs()) + return emitError() << "operand type doesn't match with the block arg"; + + for (auto [blockArg, inputType] : + llvm::zip(blockArgs, functionType.getInputs())) + if (blockArg.getType() != inputType) + return emitError() << inputType << " is expected but got " + << blockArg.getType(); + + auto cost = getCost(); + if (auto arcs = cost.getArcs()) + if (!arcs.empty()) + return emitError() + << "mapping cost arcs for cut rewrite patterns must not use " + "input/output names"; + + if (auto inputCaps = cost.getInputCaps()) + if (inputCaps.size() != functionType.getNumInputs()) + return emitError() + << "input_caps size must match the number of arguments"; + + return success(); +} diff --git a/lib/Dialect/Synth/Transforms/TechMapper.cpp b/lib/Dialect/Synth/Transforms/TechMapper.cpp index 405c3a4987d6..70dc536b0d5f 100644 --- a/lib/Dialect/Synth/Transforms/TechMapper.cpp +++ b/lib/Dialect/Synth/Transforms/TechMapper.cpp @@ -222,14 +222,15 @@ struct TechMapperPass : public impl::TechMapperBase { } llvm::DenseMap delayByInput; - for (auto attr : mappingCost.getArcs()) { - auto arc = cast(attr); - if (!arc) { - hwModule.emitError( - "expected synth.linear_timing_arc in synth.mapping_cost arcs"); - signalPassFailure(); - return; - } + auto arcs = mappingCost.getArcs(); + if (!arcs) { + hwModule.emitError( + "expected synth.linear_timing_arc in synth.mapping_cost arcs"); + signalPassFailure(); + return; + } + for (auto attr : arcs) { + auto arc = dyn_cast(attr); if (arc.getPin() != outputName) { hwModule.emitError("mapping cost arc output '") diff --git a/test/Dialect/Synth/errors.mlir b/test/Dialect/Synth/errors.mlir index 91c6e764d973..ea170b7046cb 100644 --- a/test/Dialect/Synth/errors.mlir +++ b/test/Dialect/Synth/errors.mlir @@ -5,3 +5,41 @@ hw.module @test(out result : i1) { %0 = synth.choice : i1 hw.output %0 : i1 } + +// ----- + +// expected-error @below {{argument type must be i1, but got 'i2'}} +synth.cut_rewrite_pattern (%a: i2) -> i1 attributes {cost = #synth.mapping_cost} { + %0 = comb.extract %a from 0 : (i2) -> i1 + synth.yield %0 : i1 +} + +// ----- + +// expected-error @below {{result type must be i1, but got 'i2'}} +synth.cut_rewrite_pattern (%a: i1) -> i2 attributes {cost = #synth.mapping_cost} { + %0 = hw.constant 0 : i2 + synth.yield %0 : i2 +} + +// ----- + +// expected-error @below {{requires exactly one result}} +synth.cut_rewrite_pattern (%a: i1) -> (i1, i1) attributes {cost = #synth.mapping_cost} { + synth.yield %a, %a : i1, i1 +} + +// ----- + +// expected-error @below {{result type doesn't match with the terminator}} +synth.cut_rewrite_pattern (%a: i1) -> i1 attributes {cost = #synth.mapping_cost} { + "synth.yield"() : () -> () +} + +// ----- + +// expected-error @below {{'i1' is expected but got 'i2'}} +synth.cut_rewrite_pattern (%a: i1) -> i1 attributes {cost = #synth.mapping_cost} { + %0 = hw.constant 0 : i2 + synth.yield %0 : i2 +} diff --git a/test/Dialect/Synth/round-trip.mlir b/test/Dialect/Synth/round-trip.mlir index 03d5f72d78cf..3de8b1697d07 100644 --- a/test/Dialect/Synth/round-trip.mlir +++ b/test/Dialect/Synth/round-trip.mlir @@ -54,3 +54,17 @@ hw.module @mux_inv(in %c: i4, in %a: i4, in %b: i4) { hw.module @gamble(in %x: i1, in %y: i1, in %z: i1) { %0 = synth.gamble %x, not %y, %z : i1 } + +// CHECK-LABEL: synth.cut_rewrite_pattern +// CHECK-SAME: (%{{.*}}: i1, %{{.*}}: i1, %{{.*}}: i1) -> i1 +synth.cut_rewrite_pattern (%a: i1, %b: i1, %c: i1) -> i1 attributes {cost = #synth.mapping_cost} { + %0 = synth.aig.and_inv %a, not %b, %c : i1 + synth.yield %0 : i1 +} + +// CHECK-LABEL: synth.cut_rewrite_pattern +// CHECK-SAME: (%{{.*}}: i1, %{{.*}}: i1) -> i1 attributes {cost = #synth.mapping_cost i1 attributes {cost = #synth.mapping_cost} { + %0 = synth.aig.and_inv %a, %b : i1 + synth.yield %0 : i1 +}