diff --git a/include/circt/Dialect/Synth/SynthOps.h b/include/circt/Dialect/Synth/SynthOps.h index aae1b0bee7f1..51f7b24c8134 100644 --- a/include/circt/Dialect/Synth/SynthOps.h +++ b/include/circt/Dialect/Synth/SynthOps.h @@ -13,6 +13,7 @@ #ifndef CIRCT_DIALECT_SYNTH_SYNTHOPS_H #define CIRCT_DIALECT_SYNTH_SYNTHOPS_H +#include "circt/Dialect/Synth/SynthAttributes.h" #include "circt/Dialect/Synth/SynthDialect.h" #include "circt/Dialect/Synth/SynthOpInterfaces.h" #include "circt/Support/LLVM.h" diff --git a/include/circt/Dialect/Synth/SynthOps.td b/include/circt/Dialect/Synth/SynthOps.td index b6b610770af9..0f79b562821a 100644 --- a/include/circt/Dialect/Synth/SynthOps.td +++ b/include/circt/Dialect/Synth/SynthOps.td @@ -305,6 +305,29 @@ def GambleOp : SymmetricThreeInputOp<"gamble", "evaluateGambleLogic"> { }]; } +def CutRewritePatternOp : SynthOp<"cut_rewrite_pattern", [ + IsolatedFromAbove, + SingleBlockImplicitTerminator<"YieldOp"> +]> { + let summary = "Declarative cut rewrite pattern"; + + let arguments = (ins + TypeAttrOf:$function_type, + MappingCostAttr:$cost + ); + + let regions = (region SizedRegion<1>:$body); + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +def YieldOp : SynthOp<"yield", + [Pure, Terminator]> { + let summary = "Yield synth operations"; + + let arguments = (ins Variadic:$operands); + let assemblyFormat = "$operands attr-dict `:` type($operands)"; +} #endif // CIRCT_DIALECT_SYNTH_SYNTHOPS_TD diff --git a/lib/Dialect/Synth/SynthOps.cpp b/lib/Dialect/Synth/SynthOps.cpp index 166018b6342d..c6279c29a1d4 100644 --- a/lib/Dialect/Synth/SynthOps.cpp +++ b/lib/Dialect/Synth/SynthOps.cpp @@ -19,7 +19,10 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionImplementation.h" #include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/LogicalResult.h" @@ -626,3 +629,86 @@ void GambleOp::emitCNFWithoutInversion( // out = allSet | ~orSet circt::addOrClauses(outVar, {allSet, -orSet}, addClause); } + +//===----------------------------------------------------------------------===// +// CutRewritePatternOp +//===----------------------------------------------------------------------===// + +ParseResult CutRewritePatternOp::parse(OpAsmParser &parser, + OperationState &result) { + + SmallVector entryArgs; + SmallVector resultTypes; + SmallVector resultAttrs; + bool isVariadic = false; + + if (function_interface_impl::parseFunctionSignatureWithArguments( + parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, + resultAttrs)) + return failure(); + + auto inputTypes = llvm::map_to_vector( + entryArgs, [](auto &arg) -> Type { return arg.type; }); + auto functionType = + parser.getBuilder().getFunctionType(inputTypes, resultTypes); + + result.addAttribute(getFunctionTypeAttrName(result.name), + TypeAttr::get(functionType)); + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + return failure(); + + return parser.parseRegion(*result.addRegion(), entryArgs, + /*enableNameShadowing=*/false); +} + +void CutRewritePatternOp::print(OpAsmPrinter &p) { + auto functionType = getFunctionType(); + call_interface_impl::printFunctionSignature( + p, functionType.getInputs(), /*argAttrs=*/{}, /*isVariadic=*/false, + functionType.getResults(), /*resultAttrs=*/{}, &getBody(), + /*printEmptyResult=*/false); + + p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), + {getFunctionTypeAttrName()}); + + p << ' '; + p.printRegion(getBody(), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); +} + +LogicalResult CutRewritePatternOp::verify() { + auto functionType = getFunctionType(); + + if (functionType.getNumResults() != 1) + return emitError() << "requires exactly one result"; + + for (auto type : functionType.getInputs()) + if (!type.isInteger(1)) + return emitError() << "argument type must be i1, but got " << type; + + for (auto type : functionType.getResults()) + if (!type.isInteger(1)) + return emitError() << "result type must be i1, but got " << type; + + // Check outputs. + auto *terminator = this->getBody().front().getTerminator(); + if (terminator->getOperands().size() != functionType.getNumResults()) + return emitError() << "result type doesn't match with the terminator"; + + for (auto [lhs, rhs] : llvm::zip(terminator->getOperands().getTypes(), + functionType.getResults())) + if (rhs != lhs) + return emitError() << rhs << " is expected but got " << lhs; + + auto blockArgs = this->getBody().front().getArguments(); + if (blockArgs.size() != functionType.getNumInputs()) + return emitError() << "operand type doesn't match with the block arg"; + + for (auto [blockArg, inputType] : + llvm::zip(blockArgs, functionType.getInputs())) + if (blockArg.getType() != inputType) + return emitError() << inputType << " is expected but got " + << blockArg.getType(); + + return success(); +} diff --git a/test/Dialect/Synth/errors.mlir b/test/Dialect/Synth/errors.mlir index 91c6e764d973..e98c9aaddfe8 100644 --- a/test/Dialect/Synth/errors.mlir +++ b/test/Dialect/Synth/errors.mlir @@ -5,3 +5,41 @@ hw.module @test(out result : i1) { %0 = synth.choice : i1 hw.output %0 : i1 } + +// ----- + +// expected-error @below {{argument type must be i1, but got 'i2'}} +synth.cut_rewrite_pattern (%a: i2) -> i1 attributes {cost = #synth.mapping_cost} { + %0 = comb.extract %a from 0 : (i2) -> i1 + synth.yield %0 : i1 +} + +// ----- + +// expected-error @below {{result type must be i1, but got 'i2'}} +synth.cut_rewrite_pattern (%a: i1) -> i2 attributes {cost = #synth.mapping_cost} { + %0 = hw.constant 0 : i2 + synth.yield %0 : i2 +} + +// ----- + +// expected-error @below {{requires exactly one result}} +synth.cut_rewrite_pattern (%a: i1) -> (i1, i1) attributes {cost = #synth.mapping_cost} { + synth.yield %a, %a : i1, i1 +} + +// ----- + +// expected-error @below {{result type doesn't match with the terminator}} +synth.cut_rewrite_pattern (%a: i1) -> i1 attributes {cost = #synth.mapping_cost} { + "synth.yield"() : () -> () +} + +// ----- + +// expected-error @below {{'i1' is expected but got 'i2'}} +synth.cut_rewrite_pattern (%a: i1) -> i1 attributes {cost = #synth.mapping_cost} { + %0 = hw.constant 0 : i2 + synth.yield %0 : i2 +} diff --git a/test/Dialect/Synth/round-trip.mlir b/test/Dialect/Synth/round-trip.mlir index 03d5f72d78cf..08ed4de457b9 100644 --- a/test/Dialect/Synth/round-trip.mlir +++ b/test/Dialect/Synth/round-trip.mlir @@ -54,3 +54,17 @@ hw.module @mux_inv(in %c: i4, in %a: i4, in %b: i4) { hw.module @gamble(in %x: i1, in %y: i1, in %z: i1) { %0 = synth.gamble %x, not %y, %z : i1 } + +// CHECK-LABEL: synth.cut_rewrite_pattern +// CHECK-SAME: (%{{.*}}: i1, %{{.*}}: i1, %{{.*}}: i1) -> i1 +synth.cut_rewrite_pattern (%a: i1, %b: i1, %c: i1) -> i1 attributes {cost = #synth.mapping_cost} { + %0 = synth.aig.and_inv %a, not %b, %c : i1 + synth.yield %0 : i1 +} + +// CHECK-LABEL: synth.cut_rewrite_pattern +// CHECK-SAME: (%{{.*}}: i1, %{{.*}}: i1) -> i1 attributes {cost = #synth.mapping_cost i1 attributes {cost = #synth.mapping_cost} { + %0 = synth.aig.and_inv %a, %b : i1 + synth.yield %0 : i1 +}