diff --git a/lib/Dialect/LLHD/IR/LLHDOps.cpp b/lib/Dialect/LLHD/IR/LLHDOps.cpp index d56f5da60fc8..aeba2af218da 100644 --- a/lib/Dialect/LLHD/IR/LLHDOps.cpp +++ b/lib/Dialect/LLHD/IR/LLHDOps.cpp @@ -487,11 +487,84 @@ DrvOp::ensureOnlySafeAccesses(const MemorySlot &slot, //===----------------------------------------------------------------------===// LogicalResult ProcessOp::canonicalize(ProcessOp op, PatternRewriter &rewriter) { - if (op.getBody().hasOneBlock() && op.getNumResults() == 0) { - auto &block = op.getBody().front(); - if (block.getOperations().size() == 1 && isa(block.getTerminator())) - rewriter.eraseOp(op); + if (!op.getBody().hasOneBlock()) + return failure(); + + auto &block = op.getBody().front(); + auto haltOp = dyn_cast(block.getTerminator()); + if (!haltOp) + return failure(); + + if (op.getNumResults() == 0 && block.getOperations().size() == 1) { + rewriter.eraseOp(op); + return success(); + } + + // Only constants and halt terminator are expected in a single block. + if (!llvm::all_of(block.without_terminator(), [](auto &bodyOp) { + return bodyOp.template hasTrait(); + })) + return failure(); + + auto yieldOperands = haltOp.getYieldOperands(); + llvm::SmallDenseMap uniqueOperands; + llvm::SmallDenseMap origToNewPos; + llvm::BitVector operandsToErase(yieldOperands.size()); + + for (auto [operandNo, operand] : llvm::enumerate(yieldOperands)) { + auto *defOp = operand.getDefiningOp(); + if (defOp && defOp->hasTrait()) { + // If the constant is available outside the process, use it directly; + // otherwise move it outside. + if (!defOp->getParentRegion()->isProperAncestor(&op.getBody())) { + defOp->moveBefore(op); + } + rewriter.replaceAllUsesWith(op.getResult(operandNo), operand); + operandsToErase.set(operandNo); + continue; + } + + // Identify duplicate operands to merge and compute updated result + // positions for the process operation. + if (!uniqueOperands.contains(operand)) { + const auto newPos = uniqueOperands.size(); + uniqueOperands.insert(std::make_pair(operand, newPos)); + origToNewPos.insert(std::make_pair(operandNo, newPos)); + } else { + auto firstOccurrencePos = uniqueOperands.lookup(operand); + origToNewPos.insert(std::make_pair(operandNo, firstOccurrencePos)); + operandsToErase.set(operandNo); + } } + + const auto countOperandsToErase = operandsToErase.count(); + if (countOperandsToErase == 0) + return failure(); + + // Remove the process operation if all its results have been replaced with + // constants. + if (countOperandsToErase == op.getNumResults()) { + rewriter.eraseOp(op); + return success(); + } + + haltOp->eraseOperands(operandsToErase); + + SmallVector resultTypes = llvm::to_vector(haltOp->getOperandTypes()); + auto newProcessOp = ProcessOp::create(rewriter, op.getLoc(), resultTypes, + op->getOperands(), op->getAttrs()); + newProcessOp.getBody().takeBody(op.getBody()); + + // Update old results with new values, accounting for pruned halt operands. + for (auto oldResult : op.getResults()) { + auto newResultPos = origToNewPos.find(oldResult.getResultNumber()); + if (newResultPos == origToNewPos.end()) + continue; + auto newResult = newProcessOp.getResult(newResultPos->getSecond()); + rewriter.replaceAllUsesWith(oldResult, newResult); + } + + rewriter.eraseOp(op); return success(); } diff --git a/test/Dialect/LLHD/Canonicalization/processes.mlir b/test/Dialect/LLHD/Canonicalization/processes.mlir index ef1ce2ab61a7..96982bffe352 100644 --- a/test/Dialect/LLHD/Canonicalization/processes.mlir +++ b/test/Dialect/LLHD/Canonicalization/processes.mlir @@ -35,3 +35,60 @@ hw.module @InlineCombinational(in %a: i42, in %b: i42, in %c: i8917, out u: i42, } hw.output %0, %1 : i42, i9001 } + +// CHECK-LABEL: hw.module @IgnoreMultiBlockHalt +hw.module @IgnoreMultiBlockHalt(in %a : i1, in %b : i1, out v1 : i1, out v2 : i1) { + // CHECK: llhd.halt %a, %a + %0:2 = llhd.process -> i1, i1 { + ^bb0: + cf.br ^bb1 + ^bb1: + cf.cond_br %b, ^bb1, ^bb2 + ^bb2: + %true = hw.constant true + llhd.halt %a, %a : i1, i1 + } + hw.output %0#0, %0#1 : i1, i1 +} + +// CHECK-LABEL: hw.module @DeduplicateHaltOperands0 +hw.module @DeduplicateHaltOperands0(in %a : i1, in %b : i1, + out v1 : i1, out v2 : i1, out v3 : i1, out v4 : i1) { + // CHECK: %0:2 = llhd.process -> i1, i1 { + // CHECK-NEXT: llhd.halt %a, %b : i1, i1 + // CHECK-NEXT: } + // CHECK-NEXT: hw.output %0#0, %0#1, %0#0, %0#1 + %false = hw.constant false + %0:6 = llhd.process -> i1, i1, i1, i1, i1, i1 { + %true = hw.constant true + llhd.halt %false, %a, %b, %a, %true, %b : i1, i1, i1, i1, i1, i1 + } + hw.output %0#1, %0#2, %0#3, %0#5 : i1, i1, i1, i1 +} + +// CHECK-LABEL: hw.module @CanonProcessHalt0 +hw.module @CanonProcessHalt0(out v1 : i1, out v2 : i1) { + // CHECK-NOT: llhd.halt + // CHECK: hw.output %false, %true + %false = hw.constant false + %0:2 = llhd.process -> i1, i1 { + %true = hw.constant true + llhd.halt %false, %true : i1, i1 + } + hw.output %0#0, %0#1 : i1, i1 +} + +// CHECK-LABEL: hw.module @CanonProcessHalt1 +hw.module @CanonProcessHalt1(in %a : i1, in %b : i1, + out v1 : i1, out v2 : i1, out v3 : i1, out v4 : i1) { + // CHECK: %0:2 = llhd.process -> i1, i1 { + // CHECK-NEXT: llhd.halt %a, %b : i1, i1 + // CHECK-NEXT: } + // CHECK-NEXT: hw.output %0#1, %false, %0#0, %true + %0:4 = llhd.process -> i1, i1, i1, i1 { + %false = hw.constant false + %true = hw.constant true + llhd.halt %false, %a, %true, %b : i1, i1, i1, i1 + } + hw.output %0#3, %0#0, %0#1, %0#2 : i1, i1, i1, i1 +}