From da4f8a25ab209109cba468d1ee280ab5fcca8d4d Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 26 May 2026 18:26:19 +0000 Subject: [PATCH] [mlir][SCF] Allow early exit from regions via region-breaking terminators --- mlir/docs/LangRef.md | 28 +- mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 167 +++++----- mlir/include/mlir/IR/OpBase.td | 2 + mlir/include/mlir/IR/OpDefinition.h | 8 + .../SCFToControlFlow/SCFToControlFlow.cpp | 12 + mlir/lib/Dialect/SCF/IR/SCF.cpp | 286 +++++++++++++----- .../convert-to-cfg-unsupported.mlir | 31 ++ mlir/test/Dialect/SCF/invalid.mlir | 88 ++++-- mlir/test/Dialect/SCF/ops.mlir | 152 ++++++++-- 9 files changed, 573 insertions(+), 201 deletions(-) create mode 100644 mlir/test/Conversion/SCFToControlFlow/convert-to-cfg-unsupported.mlir diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md index 0e6fb006da48b..c0ac46efcaaa3 100644 --- a/mlir/docs/LangRef.md +++ b/mlir/docs/LangRef.md @@ -515,13 +515,17 @@ However, when control flow enters a region, it always begins in the first block of the region, called the *entry* block. Terminator operations ending each block represent control flow by explicitly specifying the successor blocks of the block. Control flow can only pass to one of the specified successor blocks as in -a `branch` operation, or back to the containing operation as in a `return` +a `branch` operation, or back to an enclosing operation as in a `return` operation. Terminator operations without successors can only pass control back -to the containing operation. Within these restrictions, the particular semantics -of terminator operations is determined by the specific dialect operations -involved. Blocks (other than the entry block) that are not listed as a successor -of a terminator operation are defined to be unreachable and can be removed -without affecting the semantics of the containing operation. +to an enclosing operation. By default, control returns to the *immediately* +containing operation, but a terminator may also pass control further out by +referring to an outer enclosing operation through a [token](#token-type) operand +("early exit"); see the [region-breaking +terminator](#region-breaking-terminators) section. Within these restrictions, +the particular semantics of terminator operations is determined by the specific +dialect operations involved. Blocks (other than the entry block) that are not +listed as a successor of a terminator operation are defined to be unreachable +and can be removed without affecting the semantics of the containing operation. Although control flow always enters a region through the entry block, control flow may exit a region through any block with an appropriate terminator. The @@ -558,6 +562,18 @@ func.func @accelerator_compute(i64, i1) -> i64 { // An SSACFG region } ``` +#### Region-Breaking Terminators + +A region-breaking terminator is a terminator that passes control back to an +enclosing operation other than its immediately containing one. It identifies +its destination through a [token](#token-type) operand: the token must be an +entry block argument of the enclosing operation that the terminator transfers +control to. + +Any operation on the path from a region-breaking terminator to the operation it +transfers control to (excluding the target operation itself) must carry the +`PropagateControlFlowBreak` trait. + #### Operations with Multiple Regions An operation containing multiple regions also completely determines the diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 57c07fa0a50fc..d2bbad83bfc8d 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -153,8 +153,10 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [ def LoopOp : SCF_Op<"loop", [ AutomaticAllocationScope, + PropagateControlFlowBreak, RecursiveMemoryEffects, - SingleBlock + SingleBlock, + TokenProducerTrait ]> { let summary = "Loop until a break operation"; let description = [{ @@ -165,37 +167,53 @@ def LoopOp : SCF_Op<"loop", [ by `initValues` and updated by each iteration of the loop, and (2) a region which represents the loop body. - The loop body must end with an explicit terminator, which must be one of: - - - `scf.continue`: re-enters the loop, supplying the next iteration's value - for each loop-carried variable. Terminator operand types and loop operand - types must match. If the loop has op results, its values are undefined. - - `scf.break`: terminates the loop, supplying the final values for the - `scf.loop` results. Terminator operand types and loop op result types - must match. - - Note: This operation will be extended in the future to support breaking and - continuing from nested regions. For now, `scf.break` and `scf.continue` - must be terminators of the loop body. In practice this means that an - `scf.loop` either runs forever (terminator is `scf.continue`) or executes - exactly one iteration (terminator is `scf.break`). + The loop body has a single block with a `token` block argument, which + identifies the loop, followed by one block argument per loop-carried value. + The loop body must terminate with an `scf.break` or `scf.continue` op. + + `scf.break` and `scf.continue` targeting this `scf.loop` op may appear as + terminators of the loop body or any block nested inside the loop body, as + long as every op on the path from the terminator up to this `scf.loop` op + carries the `PropagateControlFlowBreak` op trait. + + - `scf.continue` terminator that targets this loop: re-enters this loop, + supplying the next iteration's loop-carried variables. Terminator operand + types and loop-carried variable types of this loop op must match. + - `scf.break` terminator that targets this loop: terminates this loop, + supplying the final values (op results). Terminator operand types and the + op result types of this loop op must match. Examples: ```mlir - // Loop with iteration-carried values updated by `scf.continue`. - scf.loop iter_args(%i = %init) : i32 { + // Loop with iteration-carried values updated by `scf.continue`. This + // is an infinite loop. + scf.loop %t iter_args(%i = %init) : i32 { %v = "some.compute"(%i) : (i32) -> (i32) - scf.continue %v : i32 + scf.continue %t, %v : token, i32 } ``` ```mlir // Loop with both an iteration-carried value and a result. The iter_arg - // and result types may differ. - %r = scf.loop iter_args(%i = %init) : i32 -> i64 { + // and result types may differ. This is a loop with exactly one iteration. + %r = scf.loop %t iter_args(%i = %init) : i32 -> i64 { %v = "some.compute"(%i) : (i32) -> (i64) - scf.break %v : i64 + scf.break %t, %v : token, i64 + } + ``` + + ```mlir + // Early exit driven by a condition: when `%done` is true, the `scf.if`'s + // then-branch breaks out of the enclosing `scf.loop`; otherwise control + // falls through to the trailing `scf.continue`. + %r = scf.loop %t iter_args(%i = %init) : i32 -> i32 { + %next = "some.compute"(%i) : (i32) -> (i32) + %done = "some.predicate"(%next) : (i32) -> (i1) + scf.if %done { + scf.break %t, %next : token, i32 + } + scf.continue %t, %next : token, i32 } ``` }]; @@ -209,25 +227,33 @@ def LoopOp : SCF_Op<"loop", [ CArg<"::mlir::TypeRange", "{}">:$resultTypes, CArg<"::mlir::ValueRange", "{}">:$initValues, CArg<"::llvm::function_ref", "nullptr">:$bodyBuilder)> + "::mlir::Value, ::mlir::ValueRange)>", "nullptr">:$bodyBuilder)> ]; let extraClassDeclaration = [{ - /// Return the iteration values of the loop region. + /// Returns the loop body block. + Block *getBody() { return &getRegion().front(); } + + /// Returns the entry block argument that holds the loop's token. + ::mlir::BlockArgument getRegionToken() { + return getBody()->getArgument(0); + } + + /// Return the iteration values of the loop region (skipping the leading + /// token argument). Block::BlockArgListType getRegionIterValues() { - return getRegion().getArguments(); + return getBody()->getArguments().drop_front(); } /// Return the `index`-th region iteration value. - BlockArgument getRegionIterValue(unsigned index) { + ::mlir::BlockArgument getRegionIterValue(unsigned index) { return getRegionIterValues()[index]; } /// Returns the number of region arguments for loop-carried values. - unsigned getNumRegionIterValues() { return getRegion().getNumArguments(); } - - /// Returns the loop body block. - Block *getBody() { return &getRegion().front(); } + unsigned getNumRegionIterValues() { + return getBody()->getNumArguments() - 1; + } }]; let hasCustomAssemblyFormat = 1; @@ -239,30 +265,31 @@ def LoopOp : SCF_Op<"loop", [ //===----------------------------------------------------------------------===// def BreakOp : SCF_Op<"break", [ - Pure, ReturnLike, Terminator, ParentOneOf<["LoopOp"]> + Pure, Terminator, TokenConsumerTrait ]> { let summary = "Break from an `scf.loop`"; let description = [{ - The `scf.break` operation terminates the immediately enclosing `scf.loop`. - Its operands become the loop's result values; their types must match the - result types of the enclosing `scf.loop` (verified by the loop). - - Example: - - ```mlir - %r = scf.loop -> i32 { - ... - scf.break %v : i32 - } - ``` + The `scf.break` operation is a region-breaking terminator that terminates + the `scf.loop` identified by its `token` operand. The operands become + the target loop's result values; their types must match the result types + of the target `scf.loop`. + + `scf.break` may appear as the terminator of any block nested inside the + target `scf.loop`, as long as every operation on the path between this + `scf.break` and the target loop carries the `PropagateControlFlowBreak` + op trait. }]; - let arguments = (ins Variadic:$operands); - let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; + let arguments = (ins Token:$token, Variadic:$values); + let builders = [OpBuilder<(ins "::mlir::Value":$token), [{ + $_state.addOperands(token); + }]>]; let assemblyFormat = [{ - attr-dict ($operands^ `:` type($operands))? + $token (`,` $values^)? attr-dict `:` type($token) (`,` type($values)^)? }]; + + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -270,31 +297,31 @@ def BreakOp : SCF_Op<"break", [ //===----------------------------------------------------------------------===// def ContinueOp : SCF_Op<"continue", [ - Pure, Terminator, ParentOneOf<["LoopOp"]> + Pure, Terminator, TokenConsumerTrait ]> { let summary = "Continue to the next iteration of an `scf.loop`"; let description = [{ - The `scf.continue` operation re-enters the immediately enclosing `scf.loop` - for its next iteration. Its operands become the loop-carried values - (`iter_args`) for the next iteration; their types must match the loop's - iter_arg types (verified by the loop). - - Example: - - ```mlir - scf.loop iter_args(%i = %init) : i32 { - %next = arith.addi %i, %one : i32 - scf.continue %next : i32 - } - ``` + The `scf.continue` operation is a region-breaking terminator that re-enters + the `scf.loop` identified by its `token` operand for its next iteration. + The operands become the loop-carried values (`iter_args`) for the next + iteration; their types must match the target loop's iter_arg types. + + `scf.continue` may appear as the terminator of any block nested inside the + target `scf.loop`, as long as every operation on the path between this + `scf.continue` and the target loop carries the `PropagateControlFlowBreak` + op trait. }]; - let arguments = (ins Variadic:$operands); - let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; + let arguments = (ins Token:$token, Variadic:$values); + let builders = [OpBuilder<(ins "::mlir::Value":$token), [{ + $_state.addOperands(token); + }]>]; let assemblyFormat = [{ - attr-dict ($operands^ `:` type($operands))? + $token (`,` $values^)? attr-dict `:` type($token) (`,` type($values)^)? }]; + + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -856,7 +883,7 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - InferTypeOpAdaptor, SingleBlockImplicitTerminator<"scf::YieldOp">, + InferTypeOpAdaptor, PropagateControlFlowBreak, SingleBlock, RecursiveMemoryEffects, RecursivelySpeculatable, NoRegionArguments]> { let summary = "if-then-else operation"; let description = [{ @@ -893,9 +920,16 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods:$results); diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 0d0669e90c3f7..345cf4f040ed2 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -102,6 +102,8 @@ def Terminator : NativeOpTrait<"IsTerminator">; def TokenProducerTrait : NativeOpTrait<"TokenProducerTrait">; // Op consumes builtin token values. def TokenConsumerTrait : NativeOpTrait<"TokenConsumerTrait">; +// Op is transparent to region-breaking terminators. +def PropagateControlFlowBreak : NativeOpTrait<"PropagateControlFlowBreak">; // Op can be safely normalized in the presence of MemRefs with // non-identity maps. def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">; diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index c85aefc19eb81..4d33774e4e589 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -790,6 +790,14 @@ template class TokenConsumerTrait : public TraitBase { }; +/// This trait marks operations that are transparent to region-breaking +/// terminators: a region-breaking terminator (i.e., a terminator that passes +/// control to an enclosing operation) may appear as a terminator of any block +/// within this op. +template +class PropagateControlFlowBreak + : public TraitBase {}; + /// This class provides verification for ops that are known to have zero /// successors. template diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index 2972d79c4302f..0500ccb101ed1 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -402,6 +402,18 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const { auto loc = ifOp.getLoc(); + // Only `scf.if` ops whose regions terminate with `scf.yield` are supported. + // Region-breaking terminators (`scf.break` / `scf.continue`) are not yet + // handled by this lowering. + auto isYieldTerminated = [](Region ®ion) { + return region.empty() || isa(region.front().back()); + }; + if (!isYieldTerminated(ifOp.getThenRegion()) || + !isYieldTerminated(ifOp.getElseRegion())) + return rewriter.notifyMatchFailure( + ifOp, "lowering of 'scf.if' with a non-'scf.yield' terminator is " + "not implemented yet"); + // Start by splitting the block containing the 'scf.if' into two parts. // The part before will contain the condition, the part after will be the // continuation point. diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 60e5975f4ec48..50b2e01e894b6 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -289,15 +289,17 @@ ValueRange ExecuteRegionOp::getSuccessorInputs(RegionSuccessor successor) { void LoopOp::build( OpBuilder &builder, OperationState &result, TypeRange resultTypes, ValueRange initValues, - function_ref bodyBuilder) { + function_ref bodyBuilder) { result.addOperands(initValues); result.addTypes(resultTypes); - // Build the body region with a single entry block, one argument per init - // value. The caller-supplied `bodyBuilder` is responsible for terminating - // the block with either `scf.continue` or `scf.break`. + // Build the body region with a single entry block whose first argument is + // the loop's token, followed by one argument per init value. The + // caller-supplied `bodyBuilder` is responsible for terminating the block + // with either `scf.continue` or `scf.break`. Region *bodyRegion = result.addRegion(); Block *bodyBlock = builder.createBlock(bodyRegion); + bodyBlock->addArgument(TokenType::get(builder.getContext()), result.location); SmallVector argTypes(initValues.getTypes()); SmallVector argLocs(initValues.size(), result.location); bodyBlock->addArguments(argTypes, argLocs); @@ -305,7 +307,8 @@ void LoopOp::build( if (bodyBuilder) { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(bodyBlock); - bodyBuilder(builder, result.location, bodyBlock->getArguments()); + bodyBuilder(builder, result.location, bodyBlock->getArgument(0), + bodyBlock->getArguments().drop_front()); } } @@ -313,11 +316,14 @@ LogicalResult LoopOp::verifyRegions() { if (getRegion().empty()) return emitOpError("region cannot be empty"); Block &body = getRegion().front(); - if (body.getNumArguments() != getNumOperands()) + if (body.getNumArguments() != getNumOperands() + 1) return emitOpError( "mismatch in number of loop-carried values and defined values"); + // The first entry block argument must be a token. + if (!isa(body.getArgument(0).getType())) + return emitOpError("first region argument must be a token"); for (auto [index, regionArg, initOperand] : - llvm::enumerate(body.getArguments(), getOperands())) { + llvm::enumerate(body.getArguments().drop_front(), getOperands())) { if (regionArg.getType() != initOperand.getType()) return emitOpError() << "type mismatch between " << index << "th iter operand (" << initOperand.getType() @@ -325,40 +331,50 @@ LogicalResult LoopOp::verifyRegions() { << ")"; } - // The loop body must end with an explicit `scf.break` or `scf.continue`. + // The loop body must end with `scf.break` or `scf.continue`. Operation *terminator = body.getTerminator(); - if (auto breakOp = dyn_cast(terminator)) { - if (breakOp.getNumOperands() != getNumResults()) - return breakOp.emitOpError() - << "has " << breakOp.getNumOperands() - << " operands, but enclosing scf.loop returns " << getNumResults() - << " result(s)"; - for (auto [index, operandType, resultType] : - llvm::enumerate(breakOp.getOperandTypes(), getResultTypes())) { - if (operandType != resultType) - return breakOp.emitOpError() - << "type mismatch between " << index << "th operand (" - << operandType << ") and " << index - << "th result of enclosing scf.loop (" << resultType << ")"; - } - } else if (auto continueOp = dyn_cast(terminator)) { - if (continueOp.getNumOperands() != getNumRegionIterValues()) - return continueOp.emitOpError() - << "has " << continueOp.getNumOperands() - << " operands, but enclosing scf.loop has " - << getNumRegionIterValues() << " iter_args"; - for (auto [index, operandType, iterArgType] : llvm::enumerate( - continueOp.getOperandTypes(), body.getArgumentTypes())) { - if (operandType != iterArgType) - return continueOp.emitOpError() - << "type mismatch between " << index << "th operand (" - << operandType << ") and " << index - << "th iter_arg of enclosing scf.loop (" << iterArgType << ")"; - } - } else { + if (!isa(terminator)) return emitOpError("body must be terminated by 'scf.break' or " "'scf.continue', got '") << terminator->getName() << "'"; + + // Verify the operand counts/types of every `scf.break`/`scf.continue` that + // targets this loop. Token users that are not `scf.break` / `scf.continue` + // are ignored here. + Value token = getRegionToken(); + TypeRange resultTypes = getResultTypes(); + auto iterArgTypes = ValueRange(getRegionIterValues()).getTypes(); + for (Operation *user : token.getUsers()) { + if (auto breakOp = dyn_cast(user)) { + if (breakOp.getOperands().size() - 1 != resultTypes.size()) + return breakOp.emitOpError() + << "has " << (breakOp.getOperands().size() - 1) + << " value operand(s), but target scf.loop returns " + << resultTypes.size() << " result(s)"; + for (auto [index, operandType, resultType] : llvm::enumerate( + breakOp.getOperands().drop_front().getTypes(), resultTypes)) { + if (operandType != resultType) + return breakOp.emitOpError() + << "type mismatch between " << index << "th value operand (" + << operandType << ") and " << index + << "th result of target scf.loop (" << resultType << ")"; + } + } else if (auto continueOp = dyn_cast(user)) { + if (continueOp.getOperands().size() - 1 != iterArgTypes.size()) + return continueOp.emitOpError() + << "has " << (continueOp.getOperands().size() - 1) + << " value operand(s), but target scf.loop has " + << iterArgTypes.size() << " iter_arg(s)"; + for (auto [index, operandType, iterArgType] : + llvm::enumerate(continueOp.getOperands().drop_front().getTypes(), + iterArgTypes)) { + if (operandType != iterArgType) + return continueOp.emitOpError() + << "type mismatch between " << index << "th value operand (" + << operandType << ") and " << index + << "th iter_arg of target scf.loop (" << iterArgType << ")"; + } + } } return success(); } @@ -377,8 +393,10 @@ static void printFunctionalTypeList(OpAsmPrinter &p, TypeRange types) { void LoopOp::print(OpAsmPrinter &p) { p << " "; + // Print the token block argument first. + p.printRegionArgument(getRegionToken(), /*argAttrs=*/{}, /*omitType=*/true); if (!getInitValues().empty()) { - p << "iter_args("; + p << " iter_args("; llvm::interleaveComma( llvm::zip(getRegionIterValues(), getInitValues()), p, [&](auto it) { p.printRegionArgument(std::get<0>(it), /*argAttrs=*/{}, @@ -387,13 +405,12 @@ void LoopOp::print(OpAsmPrinter &p) { }); p << ") : "; printFunctionalTypeList(p, getInitValues().getTypes()); - p << " "; } if (!getResultTypes().empty()) { - p << "-> "; + p << " -> "; printFunctionalTypeList(p, getResultTypes()); - p << " "; } + p << " "; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); @@ -401,16 +418,20 @@ void LoopOp::print(OpAsmPrinter &p) { } ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) { + // The first region argument is always the token. + OpAsmParser::Argument tokenArg; + if (parser.parseArgument(tokenArg)) + return failure(); + tokenArg.type = TokenType::get(result.getContext()); + SmallVector regionArgs; + regionArgs.push_back(tokenArg); SmallVector iterOperands; SmallVector iterTypes; - if (failed(parser.parseOptionalKeyword("iter_args"))) { - // No iter_args, but may still have a result type list. - if (parser.parseOptionalArrowTypeList(result.types)) - return failure(); - } else { - if (parser.parseAssignmentList(regionArgs, iterOperands) || + if (succeeded(parser.parseOptionalKeyword("iter_args"))) { + SmallVector iterArgs; + if (parser.parseAssignmentList(iterArgs, iterOperands) || parser.parseColon()) return failure(); if (parser.parseOptionalLParen()) { @@ -423,14 +444,15 @@ ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseTypeList(iterTypes) || parser.parseRParen()) return failure(); } - if (regionArgs.size() != iterTypes.size()) + if (iterArgs.size() != iterTypes.size()) return parser.emitError(parser.getCurrentLocation(), "found different number of iter_args and types"); - if (parser.parseOptionalArrowTypeList(result.types)) - return failure(); - for (auto [regionArg, type] : llvm::zip_equal(regionArgs, iterTypes)) - regionArg.type = type; + for (auto [iterArg, type] : llvm::zip_equal(iterArgs, iterTypes)) + iterArg.type = type; + regionArgs.append(iterArgs.begin(), iterArgs.end()); } + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); Region *body = result.addRegion(); if (parser.parseRegion(*body, regionArgs)) @@ -444,6 +466,57 @@ ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } +//===----------------------------------------------------------------------===// +// BreakOp / ContinueOp +//===----------------------------------------------------------------------===// + +/// Common verifier for `scf.break` / `scf.continue`: ensure the `token` +/// operand is the token of an enclosing `scf.loop`, and that every operation +/// on the path between this terminator and the target loop allows +/// region-breaking terminators in its regions. +static LogicalResult verifyRegionBreakingTerminator(Operation *op, Value token, + StringRef opName) { + // The token must be the entry block token of some `scf.loop`. (The builtin + // token verifier already ensures the token is produced by a producer with + // `TokenProducerTrait`; here we additionally constrain that producer to be + // an `scf.loop`.) + auto blockArg = dyn_cast(token); + if (!blockArg) + return op->emitOpError() + << "expects the token operand to be the token of an enclosing " + "'scf.loop'"; + Operation *targetLoop = blockArg.getOwner()->getParentOp(); + if (!isa_and_nonnull(targetLoop) || + cast(targetLoop).getRegionToken() != blockArg) + return op->emitOpError() + << "expects the token operand to be the token of an enclosing " + "'scf.loop'"; + + // Every op on the path from this terminator up to the target loop must + // be transparent to region-breaking terminators, i.e., it must implement + // the `PropagateControlFlowBreak` trait. + Operation *cursor = op->getParentOp(); + while (cursor && cursor != targetLoop) { + if (!cursor->mightHaveTrait()) + return op->emitOpError() + << "cannot " << opName << " through '" << cursor->getName() + << "': op does not implement the 'PropagateControlFlowBreak' " + "trait"; + cursor = cursor->getParentOp(); + } + if (cursor != targetLoop) + return op->emitOpError("target 'scf.loop' is not an ancestor"); + return success(); +} + +LogicalResult BreakOp::verify() { + return verifyRegionBreakingTerminator(*this, getToken(), "break"); +} + +LogicalResult ContinueOp::verify() { + return verifyRegionBreakingTerminator(*this, getToken(), "continue"); +} + //===----------------------------------------------------------------------===// // ConditionOp //===----------------------------------------------------------------------===// @@ -2099,20 +2172,47 @@ IfOp::inferReturnTypes(MLIRContext *ctx, std::optional loc, SmallVectorImpl &inferredReturnTypes) { if (adaptor.getRegions().empty()) return failure(); - Region *r = &adaptor.getThenRegion(); - if (r->empty()) - return failure(); - Block &b = r->front(); - if (b.empty()) - return failure(); - auto yieldOp = llvm::dyn_cast(b.back()); - if (!yieldOp) - return failure(); - TypeRange types = yieldOp.getOperandTypes(); - llvm::append_range(inferredReturnTypes, types); + // Pick the first region whose fall-through terminator is `scf.yield` and + // use its operand types as the inferred result types. Regions ending with + // a region-breaking terminator (`scf.break` / `scf.continue`) transfer + // control out of this op and therefore do not contribute to the result + // types. + auto tryRegion = [&](Region &r) -> std::optional { + if (r.empty()) + return std::nullopt; + Block &b = r.front(); + if (b.empty()) + return std::nullopt; + if (auto yieldOp = llvm::dyn_cast(b.back())) + return TypeRange(yieldOp.getOperandTypes()); + return std::nullopt; + }; + if (auto types = tryRegion(adaptor.getThenRegion())) { + llvm::append_range(inferredReturnTypes, *types); + return success(); + } + if (auto types = tryRegion(adaptor.getElseRegion())) { + llvm::append_range(inferredReturnTypes, *types); + return success(); + } + // Neither region falls through with `scf.yield`. The op produces no + // values in this case. return success(); } +/// Ensure that `region`, which belongs to an `scf.if`, terminates with an +/// `scf.yield` if it does not already have a terminator. +template +static void ensureIfYieldTerminator(Region ®ion, BuilderTy &builder, + Location loc) { + ::mlir::impl::ensureRegionTerminator( + region, builder, loc, [](OpBuilder &b, Location l) -> Operation * { + OperationState state(l, YieldOp::getOperationName()); + YieldOp::build(b, state); + return Operation::create(state); + }); +} + void IfOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, Value cond) { return build(builder, result, resultTypes, cond, /*addThenBlock=*/false, @@ -2152,14 +2252,14 @@ void IfOp::build(OpBuilder &builder, OperationState &result, Region *thenRegion = result.addRegion(); builder.createBlock(thenRegion); if (resultTypes.empty()) - IfOp::ensureTerminator(*thenRegion, builder, result.location); + ensureIfYieldTerminator(*thenRegion, builder, result.location); // Build else region. Region *elseRegion = result.addRegion(); if (withElseRegion) { builder.createBlock(elseRegion); if (resultTypes.empty()) - IfOp::ensureTerminator(*elseRegion, builder, result.location); + ensureIfYieldTerminator(*elseRegion, builder, result.location); } } @@ -2196,6 +2296,33 @@ void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, LogicalResult IfOp::verify() { if (getNumResults() != 0 && getElseRegion().empty()) return emitOpError("must have an else block if defining values"); + + // The terminator of each region must be either `scf.yield` (the regular + // case) or a region-breaking terminator (`scf.break` / `scf.continue`) + // that transfers control to an enclosing `scf.loop`. The validity of + // region-breaking terminators is checked by their own verifiers; here we + // only enforce that nothing else terminates an `scf.if` region. + auto verifyRegionTerminator = [&](Region ®ion, + StringRef name) -> LogicalResult { + if (region.empty()) + return success(); + Operation &terminator = region.front().back(); + if (isa(terminator)) + return success(); + return emitOpError() << "expects '" << name + << "' region to be terminated by 'scf.yield', " + "'scf.break', or 'scf.continue', found '" + << terminator.getName() << "'"; + }; + if (failed(verifyRegionTerminator(getThenRegion(), "then"))) + return failure(); + if (failed(verifyRegionTerminator(getElseRegion(), "else"))) + return failure(); + + // Yield-operand count/type matching against the op's results is verified + // by the `RegionBranchOpInterface` verifier (which is aware of yield as a + // RegionBranchTerminatorOpInterface and skips region-breaking + // terminators). return success(); } @@ -2217,13 +2344,13 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { // Parse the 'then' region. if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); - IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); + ensureIfYieldTerminator(*thenRegion, parser.getBuilder(), result.location); // If we find an 'else' keyword then parse the 'else' region. if (!parser.parseOptionalKeyword("else")) { if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); - IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location); + ensureIfYieldTerminator(*elseRegion, parser.getBuilder(), result.location); } // Parse the optional attribute list. @@ -2233,18 +2360,28 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { } void IfOp::print(OpAsmPrinter &p) { - bool printBlockTerminators = false; + bool hasResults = !getResults().empty(); p << " " << getCondition(); - if (!getResults().empty()) { + if (hasResults) p << " -> (" << getResultTypes() << ")"; - // Print yield explicitly if the op defines values. - printBlockTerminators = true; - } + + // We must print the terminator whenever the op produces results, or when + // the terminator is not a plain `scf.yield` (so that region-breaking + // terminators round-trip). + auto needsExplicitTerminator = [&](Region ®ion) { + if (region.empty()) + return false; + if (hasResults) + return true; + return !isa(region.front().back()); + }; + p << ' '; p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/printBlockTerminators); + /*printBlockTerminators=*/ + needsExplicitTerminator(getThenRegion())); // Print the 'else' regions if it exists and has a block. auto &elseRegion = getElseRegion(); @@ -2252,7 +2389,8 @@ void IfOp::print(OpAsmPrinter &p) { p << " else "; p.printRegion(elseRegion, /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/printBlockTerminators); + /*printBlockTerminators=*/ + needsExplicitTerminator(elseRegion)); } p.printOptionalAttrDict((*this)->getAttrs()); diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg-unsupported.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg-unsupported.mlir new file mode 100644 index 0000000000000..fb2e1ed218102 --- /dev/null +++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg-unsupported.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt -convert-scf-to-cf -split-input-file -verify-diagnostics %s + +// Lowering of `scf.if` whose then/else region is terminated by an +// `scf.break` / `scf.continue` (region-breaking terminator) is not yet +// implemented. The pattern reports a match failure and the partial +// conversion driver leaves the op unconverted, which causes the pass to +// fail. `scf.loop` is not in the conversion target, so it stays legal. + +func.func @if_break_unsupported(%cond: i1, %v: i32) -> i32 { + %r = scf.loop %t -> i32 { + // expected-error@+1 {{failed to legalize operation 'scf.if' that was explicitly marked illegal}} + scf.if %cond { + scf.break %t, %v : token, i32 + } + scf.continue %t : token + } + return %r : i32 +} + +// ----- + +func.func @if_continue_unsupported(%cond: i1, %init: i32) { + scf.loop %t iter_args(%i = %init) : i32 { + // expected-error@+1 {{failed to legalize operation 'scf.if' that was explicitly marked illegal}} + scf.if %cond { + scf.continue %t, %i : token, i32 + } + scf.continue %t, %i : token, i32 + } + return +} diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir index 099c02631804f..53eec9f688dce 100644 --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -855,16 +855,21 @@ func.func @for_missing_induction_var(%arg0: index, %arg1: index) { // ----- -func.func @break_outside_loop(%v: i32) { - // expected-error@+1 {{'scf.break' op expects parent op 'scf.loop'}} - scf.break %v : i32 +// A token operand to scf.break must be the token of an enclosing scf.loop. +// Here we exercise the SCF verifier path using a token produced by a +// non-`scf.loop` token producer. +func.func @break_outside_loop() { + %t = test.token.produce + // expected-error@+1 {{'scf.break' op expects the token operand to be the token of an enclosing 'scf.loop'}} + scf.break %t : token } // ----- func.func @continue_outside_loop() { - // expected-error@+1 {{'scf.continue' op expects parent op 'scf.loop'}} - scf.continue + %t = test.token.produce + // expected-error@+1 {{'scf.continue' op expects the token operand to be the token of an enclosing 'scf.loop'}} + scf.continue %t : token } // ----- @@ -872,7 +877,7 @@ func.func @continue_outside_loop() { func.func @loop_bad_terminator() { // expected-error@+1 {{'scf.loop' op body must be terminated by 'scf.break' or 'scf.continue'}} "scf.loop"() ({ - ^bb0: + ^bb0(%t: token): "test.foo"() : () -> () "test.terminator"() : () -> () }) : () -> () @@ -881,11 +886,33 @@ func.func @loop_bad_terminator() { // ----- +func.func @loop_missing_token_arg() { + // expected-error@+1 {{'scf.loop' op mismatch in number of loop-carried values and defined values}} + "scf.loop"() ({ + ^bb0: + "test.terminator"() : () -> () + }) : () -> () + return +} + +// ----- + +func.func @loop_first_arg_not_token() { + // expected-error@+1 {{'scf.loop' op first region argument must be a token}} + "scf.loop"() ({ + ^bb0(%i: i32): + "test.terminator"() : () -> () + }) : () -> () + return +} + +// ----- + func.func @loop_init_arg_count_mismatch(%init: i32) { // expected-error@+1 {{'scf.loop' op mismatch in number of loop-carried values and defined values}} "scf.loop"(%init) ({ - ^bb0: - scf.continue + ^bb0(%t: token): + scf.continue %t : token }) : (i32) -> () return } @@ -895,8 +922,8 @@ func.func @loop_init_arg_count_mismatch(%init: i32) { func.func @loop_init_arg_type_mismatch(%init: i32) { // expected-error@+1 {{'scf.loop' op type mismatch between 0th iter operand ('i32') and region argument ('i64')}} "scf.loop"(%init) ({ - ^bb0(%i: i64): - scf.continue %i : i64 + ^bb0(%t: token, %i: i64): + scf.continue %t, %i : token, i64 }) : (i32) -> () return } @@ -904,9 +931,9 @@ func.func @loop_init_arg_type_mismatch(%init: i32) { // ----- func.func @loop_break_count_mismatch(%v: i32) -> (i32, i32) { - // expected-error@+2 {{'scf.break' op has 1 operands, but enclosing scf.loop returns 2 result(s)}} - %r:2 = scf.loop -> (i32, i32) { - scf.break %v : i32 + // expected-error@+2 {{'scf.break' op has 1 value operand(s), but target scf.loop returns 2 result(s)}} + %r:2 = scf.loop %t -> (i32, i32) { + scf.break %t, %v : token, i32 } return %r#0, %r#1 : i32, i32 } @@ -914,9 +941,9 @@ func.func @loop_break_count_mismatch(%v: i32) -> (i32, i32) { // ----- func.func @loop_break_type_mismatch(%v: i32) -> i64 { - // expected-error@+2 {{'scf.break' op type mismatch between 0th operand ('i32') and 0th result of enclosing scf.loop ('i64')}} - %r = scf.loop -> i64 { - scf.break %v : i32 + // expected-error@+2 {{'scf.break' op type mismatch between 0th value operand ('i32') and 0th result of target scf.loop ('i64')}} + %r = scf.loop %t -> i64 { + scf.break %t, %v : token, i32 } return %r : i64 } @@ -924,9 +951,9 @@ func.func @loop_break_type_mismatch(%v: i32) -> i64 { // ----- func.func @loop_continue_count_mismatch(%init: i32) { - // expected-error@+2 {{'scf.continue' op has 0 operands, but enclosing scf.loop has 1 iter_args}} - scf.loop iter_args(%i = %init) : i32 { - scf.continue + // expected-error@+2 {{'scf.continue' op has 0 value operand(s), but target scf.loop has 1 iter_arg(s)}} + scf.loop %t iter_args(%i = %init) : i32 { + scf.continue %t : token } return } @@ -934,9 +961,9 @@ func.func @loop_continue_count_mismatch(%init: i32) { // ----- func.func @loop_continue_type_mismatch(%init: i32, %v: i64) { - // expected-error@+2 {{'scf.continue' op type mismatch between 0th operand ('i64') and 0th iter_arg of enclosing scf.loop ('i32')}} - scf.loop iter_args(%i = %init) : i32 { - scf.continue %v : i64 + // expected-error@+2 {{'scf.continue' op type mismatch between 0th value operand ('i64') and 0th iter_arg of target scf.loop ('i32')}} + scf.loop %t iter_args(%i = %init) : i32 { + scf.continue %t, %v : token, i64 } return } @@ -946,10 +973,21 @@ func.func @loop_continue_type_mismatch(%init: i32, %v: i64) { func.func @loop_more_than_one_block(%v: i32) -> i32 { // expected-error@+1 {{'scf.loop' op expects region #0 to have 0 or 1 blocks}} %r = "scf.loop"() ({ - ^bb0: + ^bb0(%t: token): "test.unreachable"() [^bb1] : () -> () - ^bb1: - scf.break %v : i32 + ^bb1(%t2: token): + scf.break %t2 : token }) : () -> i32 return %r : i32 } + +// ----- + +// scf.if regions must terminate with scf.yield, scf.break, or scf.continue. +func.func @if_invalid_terminator(%cond: i1) { + // expected-error@+1 {{'scf.if' op expects 'then' region to be terminated by 'scf.yield', 'scf.break', or 'scf.continue', found 'func.return'}} + "scf.if"(%cond) ({ + func.return + }, {}) : (i1) -> () + return +} diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir index e8f5294b40a4d..ceeff4957fc56 100644 --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -444,62 +444,62 @@ func.func @switch(%arg0: index) -> i32 { // CHECK-LABEL: @loop_infinite func.func @loop_infinite() { - // CHECK: scf.loop { - scf.loop { + // CHECK: scf.loop %[[T:.*]] { + scf.loop %t { // CHECK-NEXT: "test.foo" "test.foo"() : () -> () - // CHECK-NEXT: scf.continue - scf.continue + // CHECK-NEXT: scf.continue %[[T]] : token + scf.continue %t : token } return } // CHECK-LABEL: @loop_break_no_operands func.func @loop_break_no_operands() { - // CHECK: scf.loop { - scf.loop { - // CHECK-NEXT: scf.break - scf.break + // CHECK: scf.loop %[[T:.*]] { + scf.loop %t { + // CHECK-NEXT: scf.break %[[T]] : token + scf.break %t : token } return } // CHECK-LABEL: @loop_break_single func.func @loop_break_single(%v: i32) -> i32 { - // CHECK: %{{.*}} = scf.loop -> i32 { - %r = scf.loop -> i32 { - // CHECK-NEXT: scf.break %{{.*}} : i32 - scf.break %v : i32 + // CHECK: %{{.*}} = scf.loop %[[T:.*]] -> i32 { + %r = scf.loop %t -> i32 { + // CHECK-NEXT: scf.break %[[T]], %{{.*}} : token, i32 + scf.break %t, %v : token, i32 } return %r : i32 } // CHECK-LABEL: @loop_break_multi func.func @loop_break_multi(%v: i32, %w: i64) -> (i32, i64) { - // CHECK: %{{.*}}:2 = scf.loop -> (i32, i64) { - %r:2 = scf.loop -> (i32, i64) { - // CHECK-NEXT: scf.break %{{.*}}, %{{.*}} : i32, i64 - scf.break %v, %w : i32, i64 + // CHECK: %{{.*}}:2 = scf.loop %[[T:.*]] -> (i32, i64) { + %r:2 = scf.loop %t -> (i32, i64) { + // CHECK-NEXT: scf.break %[[T]], %{{.*}}, %{{.*}} : token, i32, i64 + scf.break %t, %v, %w : token, i32, i64 } return %r#0, %r#1 : i32, i64 } // CHECK-LABEL: @loop_iter_single func.func @loop_iter_single(%init: i32) { - // CHECK: scf.loop iter_args(%{{.*}} = %{{.*}}) : i32 { - scf.loop iter_args(%i = %init) : i32 { - // CHECK: scf.continue %{{.*}} : i32 - scf.continue %i : i32 + // CHECK: scf.loop %[[T:.*]] iter_args(%[[I:.*]] = %{{.*}}) : i32 { + scf.loop %t iter_args(%i = %init) : i32 { + // CHECK: scf.continue %[[T]], %[[I]] : token, i32 + scf.continue %t, %i : token, i32 } return } // CHECK-LABEL: @loop_iter_multi func.func @loop_iter_multi(%init0: i32, %init1: i64) { - // CHECK: scf.loop iter_args(%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) : (i32, i64) { - scf.loop iter_args(%i = %init0, %j = %init1) : (i32, i64) { - // CHECK: scf.continue %{{.*}}, %{{.*}} : i32, i64 - scf.continue %i, %j : i32, i64 + // CHECK: scf.loop %[[T:.*]] iter_args(%[[I:.*]] = %{{.*}}, %[[J:.*]] = %{{.*}}) : (i32, i64) { + scf.loop %t iter_args(%i = %init0, %j = %init1) : (i32, i64) { + // CHECK: scf.continue %[[T]], %[[I]], %[[J]] : token, i32, i64 + scf.continue %t, %i, %j : token, i32, i64 } return } @@ -507,10 +507,106 @@ func.func @loop_iter_multi(%init0: i32, %init1: i64) { // Loop with iter_args of one type and a single result of another type. // CHECK-LABEL: @loop_iter_and_result func.func @loop_iter_and_result(%init: i32, %v: i64) -> i64 { - // CHECK: %{{.*}} = scf.loop iter_args(%{{.*}} = %{{.*}}) : i32 -> i64 { - %r = scf.loop iter_args(%i = %init) : i32 -> i64 { - // CHECK: scf.break %{{.*}} : i64 - scf.break %v : i64 + // CHECK: %{{.*}} = scf.loop %[[T:.*]] iter_args(%{{.*}} = %{{.*}}) : i32 -> i64 { + %r = scf.loop %t iter_args(%i = %init) : i32 -> i64 { + // CHECK: scf.break %[[T]], %{{.*}} : token, i64 + scf.break %t, %v : token, i64 } return %r : i64 } + +// Early exit: the inner loop body terminates by breaking the outer loop. This +// exercises the token-based region-breaking semantics. The outer loop's body +// still needs its own terminator (here `scf.continue %t_outer`) after the +// inner loop, even though it is unreachable in practice. +// CHECK-LABEL: @loop_nested_early_break +func.func @loop_nested_early_break(%v: i32) -> i32 { + // CHECK: %{{.*}} = scf.loop %[[T_OUT:.*]] -> i32 { + %r = scf.loop %t_outer -> i32 { + // CHECK: scf.loop %{{.*}} { + scf.loop %t_inner { + // CHECK: scf.break %[[T_OUT]], %{{.*}} : token, i32 + scf.break %t_outer, %v : token, i32 + } + // CHECK: scf.continue %[[T_OUT]] : token + scf.continue %t_outer : token + } + return %r : i32 +} + +// Early continue from the inner loop, targeting the outer loop. +// CHECK-LABEL: @loop_nested_early_continue +func.func @loop_nested_early_continue(%init: i32) { + // CHECK: scf.loop %[[T_OUT:.*]] iter_args(%[[I:.*]] = %{{.*}}) : i32 { + scf.loop %t_outer iter_args(%i = %init) : i32 { + // CHECK: scf.loop %{{.*}} { + scf.loop %t_inner { + // CHECK: scf.continue %[[T_OUT]], %[[I]] : token, i32 + scf.continue %t_outer, %i : token, i32 + } + // CHECK: scf.continue %[[T_OUT]], %[[I]] : token, i32 + scf.continue %t_outer, %i : token, i32 + } + return +} + +// scf.if carries the `PropagateControlFlowBreak` trait, so scf.break / +// scf.continue may appear as terminators inside its regions when they target +// an enclosing scf.loop. + +// CHECK-LABEL: @if_break_in_then +func.func @if_break_in_then(%cond: i1, %v: i32) -> i32 { + // CHECK: scf.loop %[[T:[A-Za-z0-9_]+]] -> i32 { + %r = scf.loop %t -> i32 { + // CHECK: scf.if %{{.*}} { + // CHECK-NEXT: scf.break %[[T]], %{{.*}} : token, i32 + // CHECK-NEXT: } + scf.if %cond { + scf.break %t, %v : token, i32 + } + // CHECK: scf.continue %[[T]] : token + scf.continue %t : token + } + return %r : i32 +} + +// CHECK-LABEL: @if_continue_in_both_branches +func.func @if_continue_in_both_branches(%cond: i1, %init: i32) { + // CHECK: scf.loop %[[T:[A-Za-z0-9_]+]] iter_args(%[[I:[A-Za-z0-9_]+]] = %{{.*}}) : i32 { + scf.loop %t iter_args(%i = %init) : i32 { + // CHECK: scf.if %{{.*}} { + // CHECK-NEXT: scf.continue %[[T]], %[[I]] : token, i32 + // CHECK-NEXT: } else { + // CHECK-NEXT: scf.continue %[[T]], %[[I]] : token, i32 + // CHECK-NEXT: } + scf.if %cond { + scf.continue %t, %i : token, i32 + } else { + scf.continue %t, %i : token, i32 + } + // The scf.loop body needs a terminator; this one is unreachable in + // practice because both `scf.if` regions transfer control out. + scf.continue %t, %i : token, i32 + } + return +} + +// scf.if break/continue can also propagate through nested scf.if ops, since +// scf.if carries the trait. +// CHECK-LABEL: @nested_if_break +func.func @nested_if_break(%c0: i1, %c1: i1, %v: i32) -> i32 { + // CHECK: scf.loop %[[T:[A-Za-z0-9_]+]] -> i32 { + %r = scf.loop %t -> i32 { + // CHECK: scf.if %{{.*}} { + scf.if %c0 { + // CHECK: scf.if %{{.*}} { + // CHECK-NEXT: scf.break %[[T]], %{{.*}} : token, i32 + // CHECK-NEXT: } + scf.if %c1 { + scf.break %t, %v : token, i32 + } + } + scf.continue %t : token + } + return %r : i32 +}