diff --git a/include/circt/Dialect/Sim/SimTransforms.h b/include/circt/Dialect/Sim/SimTransforms.h new file mode 100644 index 000000000000..df6ff2e04dcc --- /dev/null +++ b/include/circt/Dialect/Sim/SimTransforms.h @@ -0,0 +1,44 @@ +//===- SimTransforms.h - Sim transform helpers -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares reusable transformation helpers for the Sim dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_SIM_SIMTRANSFORMS_H +#define CIRCT_DIALECT_SIM_SIMTRANSFORMS_H + +#include "mlir/IR/Builders.h" +#include "llvm/ADT/ArrayRef.h" + +namespace circt { +namespace sim { + +struct PrintProceduralizationRequest { + mlir::Location loc; + mlir::Value input; + mlir::Value condition; + mlir::Value stream; + + /// Operation used for diagnostics, if any. + mlir::Operation *anchorOp = nullptr; + + /// Operation to erase once the request has been proceduralized, if any. + mlir::Operation *cleanupRoot = nullptr; +}; + +/// Lower a list of same-clock print requests into a shared `hw.triggered` +/// region containing `sim.proc.print` operations. +mlir::LogicalResult proceduralizePrintsForClock( + mlir::OpBuilder &builder, mlir::Value clock, + llvm::ArrayRef printRequests); + +} // namespace sim +} // namespace circt + +#endif // CIRCT_DIALECT_SIM_SIMTRANSFORMS_H diff --git a/lib/Dialect/Sim/Transforms/ProceduralizeSim.cpp b/lib/Dialect/Sim/Transforms/ProceduralizeSim.cpp index b108d6720245..9076d03f6139 100644 --- a/lib/Dialect/Sim/Transforms/ProceduralizeSim.cpp +++ b/lib/Dialect/Sim/Transforms/ProceduralizeSim.cpp @@ -14,11 +14,13 @@ #include "circt/Dialect/HW/HWOps.h" #include "circt/Dialect/Seq/SeqOps.h" #include "circt/Dialect/Sim/SimOps.h" +#include "circt/Dialect/Sim/SimTransforms.h" #include "circt/Dialect/Sim/SimTypes.h" #include "circt/Support/Debug.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Pass/Pass.h" -#include "llvm/ADT/IndexedMap.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" @@ -37,8 +39,16 @@ using namespace circt; using namespace sim; namespace { +static InFlightDiagnostic +emitProceduralizationError(const PrintProceduralizationRequest &request, + const Twine &message) { + if (auto *anchor = request.anchorOp) + return anchor->emitError(message); + return mlir::emitError(request.loc, message); +} + static LogicalResult collectFormatStringFragments( - Value formatString, PrintFormattedOp anchorOp, + Value formatString, const PrintProceduralizationRequest &request, SmallVectorImpl &fragmentList, SmallSetVector &allFStringFragments, SmallSetVector &arguments) { @@ -46,7 +56,8 @@ static LogicalResult collectFormatStringFragments( if (auto concatInput = formatString.getDefiningOp()) { auto isAcyclic = concatInput.getFlattenedInputs(flatString); if (failed(isAcyclic)) { - anchorOp.emitError("Cyclic format string cannot be proceduralized."); + emitProceduralizationError(request, "Cyclic format string cannot be " + "proceduralized."); return failure(); } } else { @@ -57,8 +68,9 @@ static LogicalResult collectFormatStringFragments( for (auto fragment : flatString) { auto *fmtOp = fragment.getDefiningOp(); if (!fmtOp) { - anchorOp.emitError("Proceduralization of format strings passed as block " - "argument is unsupported."); + emitProceduralizationError( + request, "Proceduralization of format strings passed as block " + "argument is unsupported."); return failure(); } fragmentList.push_back(fmtOp); @@ -111,87 +123,140 @@ static Block *getOrCreateConditionBlock(OpBuilder &builder, return prevConditionBlock; } +// Prune the DAGs of formatting fragments left outside of the newly created +// TriggeredOps. +static void cleanupDeadPrintArtifacts(ArrayRef cleanupList) { + SmallVector worklist(cleanupList.begin(), cleanupList.end()); + SmallVector deferred; + SmallDenseSet erasedOps; + + bool noChange = true; + while (!worklist.empty() || !deferred.empty()) { + if (worklist.empty()) { + if (noChange) + break; + worklist = std::move(deferred); + deferred.clear(); + noChange = true; + } + + auto *opToErase = worklist.pop_back_val(); + if (!opToErase || erasedOps.contains(opToErase)) + continue; + + if (opToErase->getUses().empty()) { + if (auto concat = dyn_cast(opToErase)) + for (auto operand : concat.getInputs()) + deferred.push_back(operand.getDefiningOp()); + opToErase->erase(); + erasedOps.insert(opToErase); + noChange = false; + } else { + deferred.push_back(opToErase); + } + } +} + struct ProceduralizeSimPass : impl::ProceduralizeSimBase { public: void runOnOperation() override; private: - LogicalResult proceduralizePrintOps(Value clock, - ArrayRef printOps); - void cleanup(); - - // Mapping Clock -> List of printf ops - SmallMapVector, 2> printfOpMap; - - // List of formatting ops to be pruned after proceduralization. - SmallVector cleanupList; + // Mapping Clock -> List of print requests. + SmallMapVector, 2> + printRequestMap; }; } // namespace -LogicalResult ProceduralizeSimPass::proceduralizePrintOps( - Value clock, ArrayRef printOps) { - +LogicalResult circt::sim::proceduralizePrintsForClock( + OpBuilder &builder, Value clock, + ArrayRef printRequests) { // List of uniqued values to become arguments of the TriggeredOp. SmallSetVector arguments; - // Map print ops -> flattened list of format-string fragments. - SmallDenseMap, 4> printFragmentMap; + // Map print requests -> flattened list of format-string fragments. + SmallDenseMap, + 4> + printFragmentMap; // Map get_file ops -> flattened list of filename format-string fragments. SmallDenseMap, 4> fileNameFragmentMap; // All non-concat format-string fragment ops needed in the triggered body. SmallSetVector allFStringFragments; // Keep get_file ops in first-use order. SmallSetVector getFileOps; - SmallVector locs; SmallDenseSet alwaysEnabledConditions; - SmallVector livePrintOps; + SmallVector livePrintRequests; + SmallVector locs; + SmallVector cleanupList; + SmallVector sourceOpsToErase; - locs.reserve(printOps.size()); - for (auto printOp : printOps) { - if (auto cstCond = printOp.getCondition().getDefiningOp()) { + locs.reserve(printRequests.size()); + for (const auto &request : printRequests) { + if (auto cstCond = request.condition.getDefiningOp()) { if (cstCond.getValue().isZero()) { - printOp.erase(); + if (auto *inputDef = request.input.getDefiningOp()) + cleanupList.push_back(inputDef); + if (auto stream = request.stream) { + if (auto getFileOp = stream.getDefiningOp()) { + cleanupList.push_back(getFileOp); + cleanupList.push_back(getFileOp.getFileName().getDefiningOp()); + } else if (auto *streamDef = stream.getDefiningOp()) { + cleanupList.push_back(streamDef); + } + } + if (request.cleanupRoot) + sourceOpsToErase.push_back(request.cleanupRoot); continue; } - if (cstCond.getValue().isAllOnes()) - alwaysEnabledConditions.insert(printOp.getCondition()); + if (cstCond.getValue().isAllOnes()) { + alwaysEnabledConditions.insert(request.condition); + } else { + arguments.insert(request.condition); + } } else { - arguments.insert(printOp.getCondition()); + arguments.insert(request.condition); } - livePrintOps.push_back(printOp); - locs.push_back(printOp.getLoc()); + livePrintRequests.push_back(&request); + locs.push_back(request.loc); - auto &printFragments = printFragmentMap[printOp]; - if (failed(::collectFormatStringFragments(printOp.getInput(), printOp, - printFragments, - allFStringFragments, arguments))) + auto &printFragments = printFragmentMap[&request]; + if (failed(collectFormatStringFragments(request.input, request, + printFragments, allFStringFragments, + arguments))) return failure(); - if (auto stream = printOp.getStream()) { - auto getFileOp = stream.getDefiningOp(); - if (!getFileOp) { - if (!stream.getDefiningOp()) - printOp.emitError("proceduralization requires stream to be produced " - "by sim.get_file, block arguments are unsupported"); - else - printOp.emitError("proceduralization requires stream to be produced " - "by sim.get_file"); + if (auto stream = request.stream) { + if (auto getFileOp = stream.getDefiningOp()) { + getFileOps.insert(getFileOp); + auto &fileNameFragments = fileNameFragmentMap[getFileOp]; + if (fileNameFragments.empty() && + failed(collectFormatStringFragments( + getFileOp.getFileName(), request, fileNameFragments, + allFStringFragments, arguments))) + return failure(); + } else { + if (!stream.getDefiningOp()) { + emitProceduralizationError( + request, "proceduralization requires stream to be produced by " + "sim.get_file, block arguments are unsupported"); + } else { + emitProceduralizationError(request, + "proceduralization requires stream to be " + "produced by sim.get_file"); + } return failure(); } - getFileOps.insert(getFileOp); - auto &fileNameFragments = fileNameFragmentMap[getFileOp]; - if (fileNameFragments.empty() && - failed(::collectFormatStringFragments( - getFileOp.getFileName(), printOp, fileNameFragments, - allFStringFragments, arguments))) - return failure(); } } - if (livePrintOps.empty()) + if (livePrintRequests.empty()) { + for (auto *op : sourceOpsToErase) + if (op) + op->erase(); + cleanupDeadPrintArtifacts(cleanupList); return success(); + } - OpBuilder builder(livePrintOps.back()); auto fusedLoc = builder.getFusedLoc(locs); SmallVector argVec = arguments.takeVector(); @@ -207,13 +272,6 @@ LogicalResult ProceduralizeSimPass::proceduralizePrintOps( mapping.map(arg, trigOp.getBodyBlock()->getArgument(idx)); builder.setInsertionPointToStart(trigOp.getBodyBlock()); - if (!alwaysEnabledConditions.empty()) { - auto cstTrue = builder.createOrFold( - fusedLoc, IntegerAttr::get(builder.getI1Type(), 1)); - for (auto cstCond : alwaysEnabledConditions) - mapping.map(cstCond, cstTrue); - } - for (auto *fragment : allFStringFragments) { auto original = fragment->getResult(0); if (mapping.lookupOrNull(original)) @@ -224,7 +282,7 @@ LogicalResult ProceduralizeSimPass::proceduralizePrintOps( for (auto getFileOp : getFileOps) { auto &fileNameFragments = fileNameFragmentMap[getFileOp]; - Value clonedFileName = ::rematerializeFormatStringFromFragments( + Value clonedFileName = rematerializeFormatStringFromFragments( fileNameFragments, builder, mapping, getFileOp.getLoc()); auto clonedGetFile = @@ -238,95 +296,77 @@ LogicalResult ProceduralizeSimPass::proceduralizePrintOps( // Materialize print inputs before creating any conditional blocks. // Whether to actually construct strings eagerly/lazily is left to lowering // backends. - SmallDenseMap procPrintInputMap; + SmallDenseMap + procPrintInputMap; // Insert after rematerialized fragments/get_file ops so operands dominate. builder.setInsertionPointToEnd(trigOp.getBodyBlock()); - for (auto printOp : livePrintOps) { - auto &printFragments = printFragmentMap[printOp]; - procPrintInputMap[printOp] = ::rematerializeFormatStringFromFragments( - printFragments, builder, mapping, printOp.getLoc()); + for (auto *request : livePrintRequests) { + auto &printFragments = printFragmentMap[request]; + procPrintInputMap[request] = rematerializeFormatStringFromFragments( + printFragments, builder, mapping, request->loc); } Value prevConditionValue; Block *prevConditionBlock = nullptr; - for (auto printOp : livePrintOps) { - auto condArg = mapping.lookup(printOp.getCondition()); - auto *condBlock = - ::getOrCreateConditionBlock(builder, trigOp, printOp.getLoc(), condArg, + for (auto *request : livePrintRequests) { + if (alwaysEnabledConditions.contains(request->condition)) { + prevConditionValue = Value(); + prevConditionBlock = nullptr; + builder.setInsertionPointToEnd(trigOp.getBodyBlock()); + } else { + auto condArg = mapping.lookup(request->condition); + auto *condBlock = + getOrCreateConditionBlock(builder, trigOp, request->loc, condArg, prevConditionValue, prevConditionBlock); - - builder.setInsertionPoint(condBlock->getTerminator()); - Value procPrintInput = procPrintInputMap[printOp]; + builder.setInsertionPoint(condBlock->getTerminator()); + } Value procPrintStream; - if (auto stream = printOp.getStream()) { + if (auto stream = request->stream) { procPrintStream = mapping.lookupOrNull(stream); if (!procPrintStream) { - printOp.emitError("proceduralization failed to rematerialize stream"); + emitProceduralizationError(*request, + "proceduralization failed to rematerialize " + "stream"); return failure(); } + if (auto *streamDef = stream.getDefiningOp()) + cleanupList.push_back(streamDef); } - PrintFormattedProcOp::create(builder, printOp.getLoc(), procPrintInput, - procPrintStream); - cleanupList.push_back(printOp.getInput().getDefiningOp()); - printOp.erase(); + PrintFormattedProcOp::create(builder, request->loc, + procPrintInputMap[request], procPrintStream); + if (auto *inputDef = request->input.getDefiningOp()) + cleanupList.push_back(inputDef); + if (request->cleanupRoot) + sourceOpsToErase.push_back(request->cleanupRoot); } - return success(); -} - -// Prune the DAGs of formatting fragments left outside of the newly created -// TriggeredOps. -void ProceduralizeSimPass::cleanup() { - SmallVector cleanupNextList; - SmallDenseSet erasedOps; - - bool noChange = true; - while (!cleanupList.empty() || !cleanupNextList.empty()) { - - if (cleanupList.empty()) { - if (noChange) - break; - cleanupList = std::move(cleanupNextList); - cleanupNextList = {}; - noChange = true; - } - auto *opToErase = cleanupList.pop_back_val(); - if (erasedOps.contains(opToErase)) - continue; - - if (opToErase->getUses().empty()) { - // Remove a dead op. If it is a concat remove its operands, too. - if (auto concat = dyn_cast(opToErase)) - for (auto operand : concat.getInputs()) - cleanupNextList.push_back(operand.getDefiningOp()); - opToErase->erase(); - erasedOps.insert(opToErase); - noChange = false; - } else { - // Op still has uses, revisit later. - cleanupNextList.push_back(opToErase); - } - } + for (auto *op : sourceOpsToErase) + if (op) + op->erase(); + cleanupDeadPrintArtifacts(cleanupList); + return success(); } void ProceduralizeSimPass::runOnOperation() { LLVM_DEBUG(debugPassHeader(this) << "\n"); - printfOpMap.clear(); - cleanupList.clear(); + printRequestMap.clear(); auto theModule = getOperation(); - // Collect printf operations grouped by their clock. - theModule.walk( - [&](PrintFormattedOp op) { printfOpMap[op.getClock()].push_back(op); }); - - // Create a hw::TriggeredOp for each clock - for (auto &[clock, printOps] : printfOpMap) - if (failed(proceduralizePrintOps(clock, printOps))) { + // Collect print operations grouped by their clock, preserving IR order. + theModule.walk([&](PrintFormattedOp op) { + printRequestMap[op.getClock()].push_back({op.getLoc(), op.getInput(), + op.getCondition(), op.getStream(), + op, op}); + }); + + // Create a hw::TriggeredOp for each clock. + for (auto &[clock, requests] : printRequestMap) { + OpBuilder builder(requests.back().anchorOp); + if (failed(proceduralizePrintsForClock(builder, clock, requests))) { signalPassFailure(); return; } - - cleanup(); + } } diff --git a/test/Dialect/Sim/proceduralize-sim.mlir b/test/Dialect/Sim/proceduralize-sim.mlir index 49caba248244..d168e133c656 100644 --- a/test/Dialect/Sim/proceduralize-sim.mlir +++ b/test/Dialect/Sim/proceduralize-sim.mlir @@ -363,3 +363,46 @@ hw.module @shared_fmt_between_print_and_get_file( sim.print %msg on %clk if %cond to %file } + +// CHECK-LABEL: @mixed_conditional_and_unconditional +// CHECK-NEXT: %[[TRG:.*]] = seq.from_clock %clk +// CHECK-NEXT: hw.triggered posedge %[[TRG]](%a, %b) : i1, i1 { +// CHECK-NEXT: ^bb0(%[[ARGA:.*]]: i1, %[[ARGB:.*]]: i1): +// CHECK-DAG: %[[L0:.*]] = sim.fmt.literal "cond-a-0" +// CHECK-DAG: %[[L1:.*]] = sim.fmt.literal "always-1" +// CHECK-DAG: %[[L2:.*]] = sim.fmt.literal "cond-a-2" +// CHECK-DAG: %[[L3:.*]] = sim.fmt.literal "cond-b-3" +// CHECK-DAG: %[[L4:.*]] = sim.fmt.literal "always-4" +// CHECK-DAG: %[[L5:.*]] = sim.fmt.literal "cond-b-5" +// CHECK: scf.if %[[ARGA]] { +// CHECK-NEXT: sim.proc.print %[[L0]] +// CHECK-NEXT: } +// CHECK-NEXT: sim.proc.print %[[L1]] +// CHECK-NEXT: scf.if %[[ARGA]] { +// CHECK-NEXT: sim.proc.print %[[L2]] +// CHECK-NEXT: } +// CHECK-NEXT: scf.if %[[ARGB]] { +// CHECK-NEXT: sim.proc.print %[[L3]] +// CHECK-NEXT: } +// CHECK-NEXT: sim.proc.print %[[L4]] +// CHECK-NEXT: scf.if %[[ARGB]] { +// CHECK-NEXT: sim.proc.print %[[L5]] +// CHECK-NEXT: } +// CHECK-NEXT: } +hw.module @mixed_conditional_and_unconditional( + in %clk: !seq.clock, in %a: i1, in %b: i1) { + %true = hw.constant true + + %l0 = sim.fmt.literal "cond-a-0" + sim.print %l0 on %clk if %a + %l1 = sim.fmt.literal "always-1" + sim.print %l1 on %clk if %true + %l2 = sim.fmt.literal "cond-a-2" + sim.print %l2 on %clk if %a + %l3 = sim.fmt.literal "cond-b-3" + sim.print %l3 on %clk if %b + %l4 = sim.fmt.literal "always-4" + sim.print %l4 on %clk if %true + %l5 = sim.fmt.literal "cond-b-5" + sim.print %l5 on %clk if %b +}