Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/circt/Dialect/Synth/SynthOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef CIRCT_DIALECT_SYNTH_SYNTHOPS_H
#define CIRCT_DIALECT_SYNTH_SYNTHOPS_H

#include "circt/Dialect/Synth/SynthAttributes.h"
#include "circt/Dialect/Synth/SynthDialect.h"
#include "circt/Dialect/Synth/SynthOpInterfaces.h"
#include "circt/Support/LLVM.h"
Expand Down
23 changes: 23 additions & 0 deletions include/circt/Dialect/Synth/SynthOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,29 @@ def GambleOp : SymmetricThreeInputOp<"gamble", "evaluateGambleLogic"> {
}];
}

def CutRewritePatternOp : SynthOp<"cut_rewrite_pattern", [
IsolatedFromAbove,
SingleBlockImplicitTerminator<"YieldOp">
]> {
let summary = "Declarative cut rewrite pattern";

let arguments = (ins
TypeAttrOf<FunctionType>:$function_type,
MappingCostAttr:$cost
);

let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}

def YieldOp : SynthOp<"yield",
[Pure, Terminator]> {
let summary = "Yield synth operations";

let arguments = (ins Variadic<AnyType>:$operands);
let assemblyFormat = "$operands attr-dict `:` type($operands)";
}


#endif // CIRCT_DIALECT_SYNTH_SYNTHOPS_TD
86 changes: 86 additions & 0 deletions lib/Dialect/Synth/SynthOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"
Expand Down Expand Up @@ -626,3 +629,86 @@ void GambleOp::emitCNFWithoutInversion(
// out = allSet | ~orSet
circt::addOrClauses(outVar, {allSet, -orSet}, addClause);
}

//===----------------------------------------------------------------------===//
// CutRewritePatternOp
//===----------------------------------------------------------------------===//

ParseResult CutRewritePatternOp::parse(OpAsmParser &parser,
OperationState &result) {

SmallVector<OpAsmParser::Argument> entryArgs;
SmallVector<Type> resultTypes;
SmallVector<DictionaryAttr> resultAttrs;
bool isVariadic = false;

if (function_interface_impl::parseFunctionSignatureWithArguments(
parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
resultAttrs))
return failure();

auto inputTypes = llvm::map_to_vector(
entryArgs, [](auto &arg) -> Type { return arg.type; });
auto functionType =
parser.getBuilder().getFunctionType(inputTypes, resultTypes);

result.addAttribute(getFunctionTypeAttrName(result.name),
TypeAttr::get(functionType));
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();

return parser.parseRegion(*result.addRegion(), entryArgs,
/*enableNameShadowing=*/false);
}

void CutRewritePatternOp::print(OpAsmPrinter &p) {
auto functionType = getFunctionType();
call_interface_impl::printFunctionSignature(
p, functionType.getInputs(), /*argAttrs=*/{}, /*isVariadic=*/false,
functionType.getResults(), /*resultAttrs=*/{}, &getBody(),
/*printEmptyResult=*/false);

p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
{getFunctionTypeAttrName()});

p << ' ';
p.printRegion(getBody(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
}

LogicalResult CutRewritePatternOp::verify() {
auto functionType = getFunctionType();

if (functionType.getNumResults() != 1)
return emitError() << "requires exactly one result";

for (auto type : functionType.getInputs())
if (!type.isInteger(1))
return emitError() << "argument type must be i1, but got " << type;

for (auto type : functionType.getResults())
if (!type.isInteger(1))
return emitError() << "result type must be i1, but got " << type;

// Check outputs.
auto *terminator = this->getBody().front().getTerminator();
if (terminator->getOperands().size() != functionType.getNumResults())
return emitError() << "result type doesn't match with the terminator";

for (auto [lhs, rhs] : llvm::zip(terminator->getOperands().getTypes(),
functionType.getResults()))
if (rhs != lhs)
return emitError() << rhs << " is expected but got " << lhs;

auto blockArgs = this->getBody().front().getArguments();
if (blockArgs.size() != functionType.getNumInputs())
return emitError() << "operand type doesn't match with the block arg";

for (auto [blockArg, inputType] :
llvm::zip(blockArgs, functionType.getInputs()))
if (blockArg.getType() != inputType)
return emitError() << inputType << " is expected but got "
<< blockArg.getType();

return success();
}
38 changes: 38 additions & 0 deletions test/Dialect/Synth/errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,41 @@ hw.module @test(out result : i1) {
%0 = synth.choice : i1
hw.output %0 : i1
}

// -----

// expected-error @below {{argument type must be i1, but got 'i2'}}
synth.cut_rewrite_pattern (%a: i2) -> i1 attributes {cost = #synth.mapping_cost<area = 1.0 : f64, arcs = [], input_caps = {}>} {
%0 = comb.extract %a from 0 : (i2) -> i1
synth.yield %0 : i1
}

// -----

// expected-error @below {{result type must be i1, but got 'i2'}}
synth.cut_rewrite_pattern (%a: i1) -> i2 attributes {cost = #synth.mapping_cost<area = 1.0 : f64, arcs = [], input_caps = {}>} {
%0 = hw.constant 0 : i2
synth.yield %0 : i2
}

// -----

// expected-error @below {{requires exactly one result}}
synth.cut_rewrite_pattern (%a: i1) -> (i1, i1) attributes {cost = #synth.mapping_cost<area = 1.0 : f64, arcs = [], input_caps = {}>} {
synth.yield %a, %a : i1, i1
}

// -----

// expected-error @below {{result type doesn't match with the terminator}}
synth.cut_rewrite_pattern (%a: i1) -> i1 attributes {cost = #synth.mapping_cost<area = 1.0 : f64, arcs = [], input_caps = {}>} {
"synth.yield"() : () -> ()
}

// -----

// expected-error @below {{'i1' is expected but got 'i2'}}
synth.cut_rewrite_pattern (%a: i1) -> i1 attributes {cost = #synth.mapping_cost<area = 1.0 : f64, arcs = [], input_caps = {}>} {
%0 = hw.constant 0 : i2
synth.yield %0 : i2
}
14 changes: 14 additions & 0 deletions test/Dialect/Synth/round-trip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,17 @@ hw.module @mux_inv(in %c: i4, in %a: i4, in %b: i4) {
hw.module @gamble(in %x: i1, in %y: i1, in %z: i1) {
%0 = synth.gamble %x, not %y, %z : i1
}

// CHECK-LABEL: synth.cut_rewrite_pattern
// CHECK-SAME: (%{{.*}}: i1, %{{.*}}: i1, %{{.*}}: i1) -> i1
synth.cut_rewrite_pattern (%a: i1, %b: i1, %c: i1) -> i1 attributes {cost = #synth.mapping_cost<area = 1.0 : f64, arcs = [], input_caps = {}>} {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also improve changing mapping cost's attribute representation:

  1. arcs currently store input/output names but this op doesn't have input/output names, so I think we should drop names in the attributes.
  2. This is MLIR textual representation for mapping_cost but input_caps should be omitted when empty.

Copy link
Copy Markdown
Contributor Author

@okekayode okekayode May 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, do we prefer to add a custom print/parser for MappingCostAttr to omit empty input_caps, and alter synth.linear_timing_arc to accept nameless/pos args for this op?

I do think also switching to OptionalAttr on the MappingCostAttr args may be cleaner, but this would require more verifications of null/absences in TechMapper.cpp.

WDYT?

%0 = synth.aig.and_inv %a, not %b, %c : i1
synth.yield %0 : i1
}

// CHECK-LABEL: synth.cut_rewrite_pattern
// CHECK-SAME: (%{{.*}}: i1, %{{.*}}: i1) -> i1 attributes {cost = #synth.mapping_cost<area =
synth.cut_rewrite_pattern (%a: i1, %b: i1) -> i1 attributes {cost = #synth.mapping_cost<area = 1.0 : f64, arcs = [], input_caps = {}>} {
%0 = synth.aig.and_inv %a, %b : i1
synth.yield %0 : i1
}
Loading