From 63868a32d7138690419fa48fec7030b05ef1beed Mon Sep 17 00:00:00 2001 From: Fabian Schuiki Date: Tue, 2 Jun 2026 09:28:40 -0700 Subject: [PATCH] [Arc] Add coroutine ops Extend the Arc dialect with a set of `arc.coroutine.*` operations and types. These model suspendable coroutines similar to async/await fns in Rust. The compiler is expected to split the function at yield points, determine what local state needs to be persisted to the next re-entry, and then allow the coroutine to be polled using the persisted state and an indication of where to resume execution. The coroutine operations will later allow us to lower `llhd.process` and `llhd.coroutine` operations to `arc.coroutine.*`, and extend Arcilator to support concurrently executing these coroutines as a first step towards proper process support. This PR only adds operations, tests, and some rationale to the Arc dialect documentation. --- docs/Dialects/Arc.md | 62 ++++++ include/circt/Dialect/Arc/ArcOps.h | 4 + include/circt/Dialect/Arc/ArcOps.td | 290 ++++++++++++++++++++++++++ include/circt/Dialect/Arc/ArcTypes.td | 24 +++ lib/Dialect/Arc/ArcOps.cpp | 151 ++++++++++++++ test/Dialect/Arc/basic-errors.mlir | 196 +++++++++++++++++ test/Dialect/Arc/basic.mlir | 68 ++++++ 7 files changed, 795 insertions(+) diff --git a/docs/Dialects/Arc.md b/docs/Dialects/Arc.md index 705925db73b3..6626df5861dd 100644 --- a/docs/Dialects/Arc.md +++ b/docs/Dialects/Arc.md @@ -13,6 +13,68 @@ It transforms hardware descriptions from the HW, Seq, and Comb dialects into a f The Arc dialect is used by the *arcilator* simulation tool, which compiles Arc IR to a binary object via LLVM for fast simulation. +## Process and Coroutine Lowering + +LLHD distinguishes two suspendable constructs. +An `llhd.process` defines procedural behavior inline in an `hw.module`; it runs once and may suspend execution at `llhd.wait` ops or terminate at `llhd.halt`. +An `llhd.coroutine` is a separately-defined suspendable subroutine, invoked at `llhd.call_coroutine` sites from inside a process or another coroutine; it terminates with `llhd.return`. + +A process is, semantically, a coroutine defined inline in a module and invoked exactly once at its definition site. +Both bodies are SSACFG regions that are turned into a state machine driven by a *program counter* (PC), with values live across a suspension carried in *persistent state*. +Processes and coroutines therefore share a single lowering mechanism with only minor differences. + +### Outlined Form + +Both constructs are rewritten into a canonical outlined form: an `arc.coroutine.define` definition plus one or more call sites that re-enter it. +For a process, the call site is an `arc.coroutine.instance` placed in the enclosing `hw.module`. +For a coroutine, each `llhd.call_coroutine` becomes an `arc.coroutine.call` inside its parent coroutine's body. +After outlining, processes and coroutines are no longer distinguished. +Recursive coroutines are rejected during lowering. + +### Program Counter + +Every coroutine uses the same PC encoding: + +| Name | Value | Meaning | +|-----------|-----------|------------------------------------------------------| +| `START` | `0` | First entry; the body executes from its entry block. | +| resume | `1..N` | Resume at one of the body's suspension points. | +| `RETURN` | `MAX-1` | The body returned normally; results are valid. | +| `HALT` | `MAX` | The body halted; no further execution. | + +`START = 0` matches the zero-initialized layout of fresh persistent state and requires no special initialization at runtime. +Resume PCs are densely packed low integers, lowering to a single `switch` and keeping the per-coroutine PC width small. +`RETURN` and `HALT` are shared constants across all coroutines, so call sites dispatch on completion uniformly. + +### Persistent State + +The state carried across a suspension op corresponds to all the SSA values that are alive from that op into the resume block. +It is therefore not listed in the coroutine definition explicitly, but implied by its data and control flow structure. +The persistent state only becomes explicit when calling a coroutine, since the caller needs to decide how to re-enter a coroutine. + +When lowering to a concrete implementation, the persisted state is a union of structs, with one variant per resume block capturing all the live values. +Multiple suspension ops targeting the same resume block share a variant. + +Each resume block's first arguments must match the coroutine's function type. +These leading arguments are supplied fresh by the caller on each resumption and are therefore *not* part of the persistent state. +Any remaining block arguments hold the values passed as destination operands from the suspension op and *are* part of the persistent state. +The values captured into each variant are the SSA values that are live across the suspension ops into the resume block. + +When a coroutine contains an `arc.coroutine.call`, the callee's state and PC are SSA values returned from the call. +If the call site is itself suspended -- i.e. the callee did not complete in a single eval -- those values are live across the parent's "I am inside a call" suspension point and are captured into the parent's variant like any other block argument. +State allocation is therefore compositional: the size of a coroutine's persistent state is the size of its own union plus, transitively, the size of each callee's persistent state at each call site. +Lowering proceeds bottom-up over the call graph so that callee state sizes are known by the time a parent is lowered. + +### Instances and Wakeup + +`arc.coroutine.instance` exists only inside `hw.module` bodies and represents the once-per-module entry into a top-level coroutine. +It guards entry into the coroutine with `if (now >= my_wakeup && resume_pc != HALT)`. +The referenced coroutine must produce an `i64` wakeup time as its last result, which is not returned as a result from the instance op. +The model's `next_wakeup` slot is reset to `UINT64_MAX` by `LowerState` at the top of every eval body. +Each `arc.coroutine.instance`, regardless of whether it dispatched, contributes its current stored wakeup to a min-reduction into that slot. +The driver reads the slot after eval to decide when next to call the model. + + ## Types [include "Dialects/ArcTypes.md"] diff --git a/include/circt/Dialect/Arc/ArcOps.h b/include/circt/Dialect/Arc/ArcOps.h index e0a48dcd37be..74be4c2075ab 100644 --- a/include/circt/Dialect/Arc/ArcOps.h +++ b/include/circt/Dialect/Arc/ArcOps.h @@ -25,6 +25,10 @@ #include "circt/Dialect/Arc/ArcInterfaces.h.inc" +namespace circt::hw { +class HWModuleOp; +} // namespace circt::hw + #define GET_OP_CLASSES #include "circt/Dialect/Arc/Arc.h.inc" diff --git a/include/circt/Dialect/Arc/ArcOps.td b/include/circt/Dialect/Arc/ArcOps.td index 8a1b6327ec96..774d1c417b9c 100644 --- a/include/circt/Dialect/Arc/ArcOps.td +++ b/include/circt/Dialect/Arc/ArcOps.td @@ -676,6 +676,296 @@ def SetNextWakeupOp : ArcOp<"set_next_wakeup", [ }]; } +//===----------------------------------------------------------------------===// +// Coroutines +//===----------------------------------------------------------------------===// + +def CoroutineDefineOp : ArcOp<"coroutine.define", [ + FunctionOpInterface, + IsolatedFromAbove, + Symbol, +]> { + let summary = "Coroutine definition"; + let description = [{ + Define a coroutine. Coroutines are suspendable functions that are entered at + their entry block or re-entered at a resume block. Local state and values + are persisted by an opaque state held by the caller, alongside an opaque + program counter value indicating at which point the coroutine should be + resumed. + + Coroutines can be suspended by the `arc.coroutine.yield` terminator, which + returns control and local state back to the caller. The caller can then + re-enter the coroutine by passing that control and local state back into the + coroutine. Coroutines can be finished by the `arc.coroutine.return` and + `arc.coroutine.halt` terminators, which return control back to the caller + with a corresponding program counter indicating return or halt. + + Arguments are passed to the coroutine upon each entry. The entry block's + arguments hold the values supplied by the caller on the first call. Each + block targeted by a suspension op must have its first arguments match the + coroutine's function type. On resumption, those arguments are bound to the + values supplied by the caller and may differ from the values passed on prior + entries. + + Results are returned from the coroutine upon each suspension. Each of the + terminators must provide a set of values to be yielded back to the caller + upon suspension. + + The local state of a coroutine is represented by `!arc.coroutine_state<@A>`, + and the program counter for resuming by `!arc.coroutine_pc<@A>`. These types + are opaque and are expanded to a concrete union/struct of local state and a + concrete integer PC via a lowering. These types are only used on coroutine + calls; coroutine definitions define them implicitly by values carried across + resume points. + + To nest coroutines, a coroutine definition can call another coroutine and + carry the `!arc.coroutine_state` and `!arc.coroutine_pc` of that call as + local state across its own suspension points. + }]; + let arguments = (ins + SymbolNameAttr:$sym_name, + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); + let results = (outs); + let regions = (region MinSizedRegion<1>:$body); + let hasCustomAssemblyFormat = 1; + + let builders = [ + OpBuilder<(ins "mlir::StringAttr":$sym_name, + "mlir::TypeAttr":$function_type), [{ + build($_builder, $_state, sym_name, function_type, mlir::ArrayAttr(), + mlir::ArrayAttr()); + }]>, + OpBuilder<(ins "mlir::StringRef":$sym_name, + "mlir::FunctionType":$function_type), [{ + build($_builder, $_state, + $_builder.getStringAttr(sym_name), + mlir::TypeAttr::get(function_type), + mlir::ArrayAttr(), mlir::ArrayAttr()); + }]>, + ]; + + let extraClassDeclaration = [{ + /// Returns the argument types of this coroutine. + mlir::ArrayRef getArgumentTypes() { + return getFunctionType().getInputs(); + } + + /// Returns the result types of this coroutine. + mlir::ArrayRef getResultTypes() { + return getFunctionType().getResults(); + } + + mlir::Region *getCallableRegion() { return &getBody(); } + }]; +} + +// A coroutine state/PC type whose `coroutine` parameter matches the enclosing +// op's `callee` symbol attribute. +class CoroutineTypeOfCallee : ConfinedType($_self).getCoroutine() == " + "($_op).getAttrOfType<::mlir::FlatSymbolRefAttr>(\"callee\")"> +], baseType.summary # " bound to the op's callee symbol", baseType.cppType>; + +def CoroutineStateOfCallee : CoroutineTypeOfCallee; +def CoroutinePCOfCallee : CoroutineTypeOfCallee; + +def CoroutineCallOp : ArcOp<"coroutine.call", [ + CallOpInterface, + DeclareOpInterfaceMethods, +]> { + let summary = "Call a coroutine"; + let description = [{ + Call an `arc.coroutine.define`. The coroutine is resumed at the point + indicated by the `pc` operand, and local state is restored from the `state` + operand. The `args` are passed into the coroutine and may differ between + subsequent re-entries. When the coroutine suspends or finishes, control is + transferred back to the caller and the call op returns the coroutine's + resume program counter, resume state, and the values yielded back from the + coroutine as results. + + The caller is responsible for interpreting the program counter returned from + the coroutine. A `return` indicates that the coroutine is finished and + control shall continue in the parent. A `halt` indicates that the coroutine + suspends forever, and the parent should also return `halt` if it is a + coroutine itself. Any other value indicates that the callee has suspended + and expects to be re-entered at a later point, and the caller must suspend + itself and re-enter the callee if it is a coroutine itself. + }]; + let arguments = (ins + FlatSymbolRefAttr:$callee, + CoroutineStateOfCallee:$state, + CoroutinePCOfCallee:$pc, + Variadic:$args + ); + let results = (outs + CoroutineStateOfCallee:$resumeState, + CoroutinePCOfCallee:$resumePC, + Variadic:$results + ); + let assemblyFormat = [{ + $callee `(` $state `,` $pc (`,` $args^)? `)` attr-dict + `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + operand_range getArgOperands() { + return getArgs(); + } + MutableOperandRange getArgOperandsMutable() { + return getArgsMutable(); + } + + mlir::CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("callee"); + } + + void setCalleeFromCallable(mlir::CallInterfaceCallable callee) { + (*this)->setAttr(getCalleeAttrName(), + llvm::cast(callee)); + } + + /// CallOpInterface requires ArgAndResultAttrsOpInterface. Call sites + /// don't carry these attributes, so stub them out as no-ops. + mlir::ArrayAttr getArgAttrsAttr() { return nullptr; } + mlir::ArrayAttr getResAttrsAttr() { return nullptr; } + void setArgAttrsAttr(mlir::ArrayAttr args) {} + void setResAttrsAttr(mlir::ArrayAttr args) {} + mlir::Attribute removeArgAttrsAttr() { return nullptr; } + mlir::Attribute removeResAttrsAttr() { return nullptr; } + }]; +} + +def CoroutineInstanceOp : ArcOp<"coroutine.instance", [ + CallOpInterface, + DeclareOpInterfaceMethods, + HasParent<"hw::HWModuleOp">, +]> { + let summary = "Continuously run a coroutine in an hw.module"; + let description = [{ + Execute a coroutine concurrently in an `hw.module`. The program counter and + state of the coroutine are held implicitly by the instance and passed into + the coroutine when executed next. The values yielded by the coroutine are + produced as results of the instance. The callee must produce a wakeup time + as its last result value. This wakeup time is not exposed as a result of + the instance op and is instead used to schedule the next execution. + }]; + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$args + ); + let results = (outs Variadic:$results); + let assemblyFormat = [{ + $callee `(` $args `)` attr-dict `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + operand_range getArgOperands() { + return getArgs(); + } + MutableOperandRange getArgOperandsMutable() { + return getArgsMutable(); + } + + mlir::CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("callee"); + } + + void setCalleeFromCallable(mlir::CallInterfaceCallable callee) { + (*this)->setAttr(getCalleeAttrName(), + llvm::cast(callee)); + } + + mlir::ArrayAttr getArgAttrsAttr() { return nullptr; } + mlir::ArrayAttr getResAttrsAttr() { return nullptr; } + void setArgAttrsAttr(mlir::ArrayAttr args) {} + void setResAttrsAttr(mlir::ArrayAttr args) {} + mlir::Attribute removeArgAttrsAttr() { return nullptr; } + mlir::Attribute removeResAttrsAttr() { return nullptr; } + }]; +} + +def CoroutineYieldOp : ArcOp<"coroutine.yield", [ + AttrSizedOperandSegments, + HasParent<"CoroutineDefineOp">, + Terminator, +]> { + let summary = "Suspend a coroutine and request resumption at a block"; + let description = [{ + Suspend a coroutine. Control is transferred back to the caller, alongside + a program counter value indicating the destination block, and the local + state needed to restore the values live across the yield op. Additionally, + the yield operands are returned to the caller and must match the result + types of the coroutine. + + The destination block's first arguments must match the coroutine's function + type. On resumption, those arguments are bound to the values supplied by the + caller. The remaining arguments of the destination block are bound to the + destination operands of the yield op. + }]; + let arguments = (ins + Variadic:$yieldOperands, + Variadic:$destOperands + ); + let successors = (successor AnySuccessor:$dest); + let assemblyFormat = [{ + (` ` `(` $yieldOperands^ `:` type($yieldOperands) `)` `,`)? + $dest (`(` $destOperands^ `:` type($destOperands) `)`)? + attr-dict + }]; + let hasVerifier = 1; +} + +def CoroutineReturnOp : ArcOp<"coroutine.return", [ + HasParent<"CoroutineDefineOp">, + Terminator, +]> { + let summary = "Return from a coroutine"; + let description = [{ + Returns control from a coroutine to the caller, yielding back a special + sentinel program counter value indicating that the coroutine has run to + completion. Additionally, the yield operands are returned to the caller and + must match the result types of the coroutine. + }]; + let arguments = (ins Variadic:$yieldOperands); + let assemblyFormat = [{ + ($yieldOperands^ `:` type($yieldOperands))? attr-dict + }]; + let hasVerifier = 1; +} + +def CoroutineHaltOp : ArcOp<"coroutine.halt", [ + HasParent<"CoroutineDefineOp">, + Terminator, +]> { + let summary = "Halt a coroutine permanently"; + let description = [{ + Halts execution of a coroutine forever. Control effectively gets stuck + indefinitely at the halt operation, also preventing all callers from making + progress. Yields back a special sentinel program counter value to the caller + which the caller must translate into either halting itself if it is a + coroutine, or arranging for the coroutine to never be re-entered again. + }]; + let arguments = (ins Variadic:$yieldOperands); + let assemblyFormat = [{ + ($yieldOperands^ `:` type($yieldOperands))? attr-dict + }]; + let hasVerifier = 1; +} + +class CoroutinePCBase : ArcOp { + let summary = "Check whether a coroutine PC is a sentinel value"; + let arguments = (ins CoroutinePCType:$pc); + let results = (outs I1:$result); + let assemblyFormat = [{ + $pc `:` type($pc) attr-dict + }]; +} +def CoroutinePCIsReturnOp : CoroutinePCBase<"coroutine.pc_is_return">; +def CoroutinePCIsHaltOp : CoroutinePCBase<"coroutine.pc_is_halt">; + //===----------------------------------------------------------------------===// // Procedural Ops //===----------------------------------------------------------------------===// diff --git a/include/circt/Dialect/Arc/ArcTypes.td b/include/circt/Dialect/Arc/ArcTypes.td index f48b7bd9e455..0edd759a0a80 100644 --- a/include/circt/Dialect/Arc/ArcTypes.td +++ b/include/circt/Dialect/Arc/ArcTypes.td @@ -59,6 +59,30 @@ def SimModelInstance : ArcTypeDef<"SimModelInstance"> { let assemblyFormat = "`<` $model `>`"; } +def CoroutinePCType : ArcTypeDef<"CoroutinePC"> { + let summary = "Program counter of a coroutine"; + let description = [{ + An opaque program counter value indicating where to resume execution of a + coroutine. The concrete size is determined by lowering. + }]; + let mnemonic = "coroutine_pc"; + let parameters = (ins "mlir::FlatSymbolRefAttr":$coroutine); + let assemblyFormat = "`<` $coroutine `>`"; +} + +def CoroutineStateType : ArcTypeDef<"CoroutineState"> { + let summary = "State persisted between re-entries into a coroutine"; + let description = [{ + An opaque value representing the local state within a coroutine that + persists between re-entries into that coroutine. The concrete layout is + determined by lowering and is likely a union of structs, with each struct + corresponding to a unique resumption block in the coroutine. + }]; + let mnemonic = "coroutine_state"; + let parameters = (ins "mlir::FlatSymbolRefAttr":$coroutine); + let assemblyFormat = "`<` $coroutine `>`"; +} + def ArrayRefType : ArcTypeDef<"ArrayRef", [ DeclareTypeInterfaceMethods, DeclareTypeInterfaceMethods diff --git a/lib/Dialect/Arc/ArcOps.cpp b/lib/Dialect/Arc/ArcOps.cpp index 18cae04e4114..3163af7c9083 100644 --- a/lib/Dialect/Arc/ArcOps.cpp +++ b/lib/Dialect/Arc/ArcOps.cpp @@ -756,6 +756,157 @@ SimGetNextWakeupOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } +//===----------------------------------------------------------------------===// +// CoroutineDefineOp +//===----------------------------------------------------------------------===// + +/// Resolve the callee symbol to a `CoroutineDefineOp` and verify that the +/// given operand and result types match its function type. +static LogicalResult verifyCoroutineCallTypes(Operation *op, + FlatSymbolRefAttr callee, + TypeRange operands, + TypeRange results, + SymbolTableCollection &symTable) { + auto defineOp = + symTable.lookupNearestSymbolFrom(op, callee); + if (!defineOp) + return op->emitOpError() << "`" << callee.getValue() + << "` does not reference a valid " + "`arc.coroutine.define`"; + + auto fnType = defineOp.getFunctionType(); + if (failed(verifyTypeListEquivalence(op, fnType.getInputs(), operands, + "operand"))) + return failure(); + if (failed(verifyTypeListEquivalence(op, fnType.getResults(), results, + "result"))) + return failure(); + return success(); +} + +ParseResult CoroutineDefineOp::parse(OpAsmParser &parser, + OperationState &result) { + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void CoroutineDefineOp::print(OpAsmPrinter &p) { + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, "function_type", getArgAttrsAttrName(), + getResAttrsAttrName()); +} + +//===----------------------------------------------------------------------===// +// CoroutineCallOp +//===----------------------------------------------------------------------===// + +LogicalResult +CoroutineCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // The `state`/`pc` and `resumeState`/`resumePC` types are constrained to + // wrap the callee symbol by the `CoroutineCalleeWrappedType` traits on the + // op. All that remains is to resolve the callee and check that the trailing + // arg/result types match its signature. + auto callee = (*this)->getAttrOfType("callee"); + return verifyCoroutineCallTypes(*this, callee, getArgs().getTypes(), + getResults().getTypes(), symbolTable); +} + +//===----------------------------------------------------------------------===// +// CoroutineInstanceOp +//===----------------------------------------------------------------------===// + +// An instance hides the coroutine's trailing wakeup time. Verify that the +// callee declares a wakeup as its last result and that the instance's args and +// results match the callee's signature with that last result removed. +LogicalResult +CoroutineInstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + auto callee = (*this)->getAttrOfType("callee"); + auto defineOp = + symbolTable.lookupNearestSymbolFrom(*this, callee); + if (!defineOp) + return emitOpError() << "`" << callee.getValue() + << "` does not reference a valid " + "`arc.coroutine.define`"; + + auto fnType = defineOp.getFunctionType(); + auto fnResults = fnType.getResults(); + if (fnResults.empty() || !fnResults.back().isInteger(64)) + return emitOpError() << "referenced coroutine `" << callee.getValue() + << "` must produce an `i64` wakeup time as its " + "last result"; + + if (failed(verifyTypeListEquivalence(*this, fnType.getInputs(), + getArgs().getTypes(), "operand"))) + return failure(); + if (failed(verifyTypeListEquivalence(*this, fnResults.drop_back(), + getResults().getTypes(), "result"))) + return failure(); + return success(); +} + +//===----------------------------------------------------------------------===// +// Coroutine Terminators +//===----------------------------------------------------------------------===// + +// The three terminators all yield values back through the enclosing +// `arc.coroutine.define`'s result types. The helper below extracts the +// expected types from the parent and checks them against the given operand +// types. +static LogicalResult verifyCoroutineTerminator(Operation *op, + TypeRange yieldOperands) { + auto parent = op->getParentOfType(); + return verifyTypeListEquivalence(op, parent.getResultTypes(), yieldOperands, + "yielded value"); +} + +LogicalResult CoroutineYieldOp::verify() { + if (failed(verifyCoroutineTerminator(*this, getYieldOperands().getTypes()))) + return failure(); + + // Verify that the destination block's leading arguments match the + // coroutine's function type, and that the remaining arguments match the + // yield's destination operands. + auto parent = (*this)->getParentOfType(); + TypeRange coroutineArgTypes = parent.getArgumentTypes(); + TypeRange destArgTypes = getDest()->getArgumentTypes(); + TypeRange destOperandTypes = getDestOperands().getTypes(); + size_t expectedCount = coroutineArgTypes.size() + destOperandTypes.size(); + if (destArgTypes.size() != expectedCount) + return emitOpError("destination block has ") + << destArgTypes.size() << " arguments, but expected " + << expectedCount << " (" << coroutineArgTypes.size() + << " coroutine arguments followed " << destOperandTypes.size() + << " yield destination operands)"; + + if (failed(verifyTypeListEquivalence( + *this, coroutineArgTypes, + destArgTypes.take_front(coroutineArgTypes.size()), + "destination resume argument"))) + return failure(); + + if (failed(verifyTypeListEquivalence( + *this, destArgTypes.drop_front(coroutineArgTypes.size()), + destOperandTypes, "destination operand"))) + return failure(); + + return success(); +} + +LogicalResult CoroutineReturnOp::verify() { + return verifyCoroutineTerminator(*this, getYieldOperands().getTypes()); +} + +LogicalResult CoroutineHaltOp::verify() { + return verifyCoroutineTerminator(*this, getYieldOperands().getTypes()); +} + //===----------------------------------------------------------------------===// // ExecuteOp //===----------------------------------------------------------------------===// diff --git a/test/Dialect/Arc/basic-errors.mlir b/test/Dialect/Arc/basic-errors.mlir index 2375c3399275..b0dd34a0fcf5 100644 --- a/test/Dialect/Arc/basic-errors.mlir +++ b/test/Dialect/Arc/basic-errors.mlir @@ -624,3 +624,199 @@ arc.execute -> (i42) { // expected-note @below {{actual type: 'i19'}} arc.output %0 : i19 } + +// ----- + +func.func @Foo(%arg0: !arc.coroutine_state<@NotACoroutine>, %arg1: !arc.coroutine_pc<@NotACoroutine>) { + // expected-error @below {{`NotACoroutine` does not reference a valid `arc.coroutine.define`}} + arc.coroutine.call @NotACoroutine(%arg0, %arg1) : (!arc.coroutine_state<@NotACoroutine>, !arc.coroutine_pc<@NotACoroutine>) -> (!arc.coroutine_state<@NotACoroutine>, !arc.coroutine_pc<@NotACoroutine>) + return +} +func.func @NotACoroutine() { + return +} + +// ----- + +hw.module @Foo() { + // expected-error @below {{`NotACoroutine` does not reference a valid `arc.coroutine.define`}} + arc.coroutine.instance @NotACoroutine() : () -> () +} +func.func @NotACoroutine() { + return +} + +// ----- + +func.func @Foo(%arg0: !arc.coroutine_state<@NeedsI42>, %arg1: !arc.coroutine_pc<@NeedsI42>, %arg2: i9001) { + // expected-error @below {{operand type mismatch: operand #0}} + // expected-note @below {{expected type: 'i42'}} + // expected-note @below {{actual type: 'i9001'}} + arc.coroutine.call @NeedsI42(%arg0, %arg1, %arg2) : (!arc.coroutine_state<@NeedsI42>, !arc.coroutine_pc<@NeedsI42>, i9001) -> (!arc.coroutine_state<@NeedsI42>, !arc.coroutine_pc<@NeedsI42>) + return +} +arc.coroutine.define @NeedsI42(%arg0: i42) { + arc.coroutine.return +} + +// ----- + +hw.module @Foo(in %a: i9001) { + // expected-error @below {{operand type mismatch: operand #0}} + // expected-note @below {{expected type: 'i42'}} + // expected-note @below {{actual type: 'i9001'}} + arc.coroutine.instance @NeedsI42(%a) : (i9001) -> () +} +arc.coroutine.define @NeedsI42(%arg0: i42) -> i64 { + %c0_i64 = hw.constant 0 : i64 + arc.coroutine.return %c0_i64 : i64 +} + +// ----- + +func.func @Foo(%arg0: !arc.coroutine_state<@BarB>, %arg1: !arc.coroutine_pc<@BarB>) { + // expected-error @below {{bound to the op's callee symbol}} + arc.coroutine.call @BarA(%arg0, %arg1) : (!arc.coroutine_state<@BarB>, !arc.coroutine_pc<@BarB>) -> (!arc.coroutine_state<@BarB>, !arc.coroutine_pc<@BarB>) + return +} +arc.coroutine.define @BarA() { + arc.coroutine.return +} +arc.coroutine.define @BarB() { + arc.coroutine.return +} + +// ----- + +hw.module @Foo() { + // expected-error @below {{`DoesNotExist` does not reference a valid `arc.coroutine.define`}} + arc.coroutine.instance @DoesNotExist() : () -> () +} + +// ----- + +func.func @Foo(%arg0: !arc.coroutine_state<@DoesNotExist>, %arg1: !arc.coroutine_pc<@DoesNotExist>) { + // expected-error @below {{`DoesNotExist` does not reference a valid `arc.coroutine.define`}} + arc.coroutine.call @DoesNotExist(%arg0, %arg1) : (!arc.coroutine_state<@DoesNotExist>, !arc.coroutine_pc<@DoesNotExist>) -> (!arc.coroutine_state<@DoesNotExist>, !arc.coroutine_pc<@DoesNotExist>) + return +} + +// ----- + +hw.module @Foo() { + // expected-error @below {{referenced coroutine `Bar` must produce an `i64` wakeup time as its last result}} + arc.coroutine.instance @Bar() : () -> () +} +arc.coroutine.define @Bar() { + arc.coroutine.return +} + +// ----- + +hw.module @Foo() { + // expected-error @below {{referenced coroutine `Bar` must produce an `i64` wakeup time as its last result}} + arc.coroutine.instance @Bar() : () -> i42 +} +arc.coroutine.define @Bar() -> i42 { + %c0_i42 = hw.constant 0 : i42 + arc.coroutine.return %c0_i42 : i42 +} + +// ----- + +arc.coroutine.define @Foo() -> i42 { + // expected-error @below {{incorrect number of yielded values: expected 1, but got 0}} + arc.coroutine.return +} + +// ----- + +arc.coroutine.define @Foo() -> i42 { + %c0_i9001 = hw.constant 0 : i9001 + // expected-error @below {{yielded value type mismatch: yielded value #0}} + // expected-note @below {{expected type: 'i42'}} + // expected-note @below {{actual type: 'i9001'}} + arc.coroutine.return %c0_i9001 : i9001 +} + +// ----- + +arc.coroutine.define @Foo() -> i42 { + // expected-error @below {{incorrect number of yielded values: expected 1, but got 0}} + arc.coroutine.halt +} + +// ----- + +arc.coroutine.define @Foo() -> i42 { + %c0_i9001 = hw.constant 0 : i9001 + // expected-error @below {{yielded value type mismatch: yielded value #0}} + // expected-note @below {{expected type: 'i42'}} + // expected-note @below {{actual type: 'i9001'}} + arc.coroutine.halt %c0_i9001 : i9001 +} + +// ----- + +arc.coroutine.define @Foo() -> i42 { + // expected-error @below {{incorrect number of yielded values: expected 1, but got 0}} + arc.coroutine.yield ^bb1 +^bb1: + %c0_i42 = hw.constant 0 : i42 + arc.coroutine.halt %c0_i42 : i42 +} + +// ----- + +arc.coroutine.define @Foo() -> i42 { + %c0_i9001 = hw.constant 0 : i9001 + // expected-error @below {{yielded value type mismatch: yielded value #0}} + // expected-note @below {{expected type: 'i42'}} + // expected-note @below {{actual type: 'i9001'}} + arc.coroutine.yield (%c0_i9001 : i9001), ^bb1 +^bb1: + %c0_i42 = hw.constant 0 : i42 + arc.coroutine.halt %c0_i42 : i42 +} + +// ----- + +arc.coroutine.define @Foo(%arg0: i42) { + // expected-error @below {{destination block has 0 arguments, but expected 1}} + arc.coroutine.yield ^bb1 +^bb1: + arc.coroutine.halt +} + +// ----- + +arc.coroutine.define @Foo(%arg0: i42) { + %c0_i42 = hw.constant 0 : i42 + // expected-error @below {{destination block has 1 arguments, but expected 2}} + arc.coroutine.yield ^bb1(%c0_i42 : i42) +^bb1(%arg1: i42): + arc.coroutine.halt +} + +// ----- + +arc.coroutine.define @Foo(%arg0: i42) { + // expected-error @below {{destination resume argument type mismatch: destination resume argument #0}} + // expected-note @below {{expected type: 'i42'}} + // expected-note @below {{actual type: 'i9001'}} + arc.coroutine.yield ^bb1 +^bb1(%arg1: i9001): + arc.coroutine.halt +} + +// ----- + +arc.coroutine.define @Foo(%arg0: i42) { + %c0_i42 = hw.constant 0 : i42 + // expected-error @below {{destination operand type mismatch: destination operand #0}} + // expected-note @below {{expected type: 'i9001'}} + // expected-note @below {{actual type: 'i42'}} + arc.coroutine.yield ^bb1(%c0_i42 : i42) +^bb1(%arg1: i42, %arg2: i9001): + arc.coroutine.halt +} diff --git a/test/Dialect/Arc/basic.mlir b/test/Dialect/Arc/basic.mlir index c362fc0a510b..d0f75385bfee 100644 --- a/test/Dialect/Arc/basic.mlir +++ b/test/Dialect/Arc/basic.mlir @@ -437,3 +437,71 @@ func.func @SimGetSetTime() { return } hw.module @TimeTestModule() {} + + +// CHECK-LABEL: arc.coroutine.define @CoroutineEmpty +arc.coroutine.define @CoroutineEmpty() { + arc.coroutine.return +} + +// CHECK-LABEL: arc.coroutine.define @CoroutineNoResults +arc.coroutine.define @CoroutineNoResults(%arg0: i42) { + // CHECK: arc.coroutine.yield ^bb1 + arc.coroutine.yield ^bb1 +^bb1(%arg1: i42): + // CHECK: arc.coroutine.yield ^bb2(%arg0 : i42) + arc.coroutine.yield ^bb2(%arg0 : i42) +^bb2(%arg2: i42, %arg3: i42): + // CHECK: arc.coroutine.return + arc.coroutine.return +^bb3: + // CHECK: arc.coroutine.halt + arc.coroutine.halt +} + +// CHECK-LABEL: arc.coroutine.define @CoroutineWithResults +arc.coroutine.define @CoroutineWithResults(%arg0: i42, %arg1: i9001) -> (i42, i9001) { + // CHECK: arc.coroutine.yield (%arg0, %arg1 : i42, i9001), ^bb1 + arc.coroutine.yield (%arg0, %arg1 : i42, i9001), ^bb1 +^bb1(%arg2: i42, %arg3: i9001): + // CHECK: arc.coroutine.return %arg0, %arg1 : i42, i9001 + arc.coroutine.return %arg0, %arg1 : i42, i9001 +^bb3: + // CHECK: arc.coroutine.halt %arg0, %arg1 : i42, i9001 + arc.coroutine.halt %arg0, %arg1 : i42, i9001 +} + +// CHECK-LABEL: func.func @CoroutineCallEmpty +func.func @CoroutineCallEmpty(%arg0: !arc.coroutine_state<@CoroutineEmpty>, %arg1: !arc.coroutine_pc<@CoroutineEmpty>) { + // CHECK: arc.coroutine.call @CoroutineEmpty(%arg0, %arg1) + // CHECK-SAME: : (!arc.coroutine_state<@CoroutineEmpty>, !arc.coroutine_pc<@CoroutineEmpty>) + // CHECK-SAME: -> (!arc.coroutine_state<@CoroutineEmpty>, !arc.coroutine_pc<@CoroutineEmpty>) + %0, %1 = arc.coroutine.call @CoroutineEmpty(%arg0, %arg1) : (!arc.coroutine_state<@CoroutineEmpty>, !arc.coroutine_pc<@CoroutineEmpty>) -> (!arc.coroutine_state<@CoroutineEmpty>, !arc.coroutine_pc<@CoroutineEmpty>) + return +} + +// CHECK-LABEL: func.func @CoroutineCallWithResults +func.func @CoroutineCallWithResults( + %arg0: !arc.coroutine_state<@CoroutineWithResults>, + %arg1: !arc.coroutine_pc<@CoroutineWithResults>, + %arg2: i42, + %arg3: i9001 +) { + // CHECK: arc.coroutine.call @CoroutineWithResults(%arg0, %arg1, %arg2, %arg3) + // CHECK-SAME: : (!arc.coroutine_state<@CoroutineWithResults>, !arc.coroutine_pc<@CoroutineWithResults>, i42, i9001) + // CHECK-SAME: -> (!arc.coroutine_state<@CoroutineWithResults>, !arc.coroutine_pc<@CoroutineWithResults>, i42, i9001) + %0, %1, %2:2 = arc.coroutine.call @CoroutineWithResults(%arg0, %arg1, %arg2, %arg3) : (!arc.coroutine_state<@CoroutineWithResults>, !arc.coroutine_pc<@CoroutineWithResults>, i42, i9001) -> (!arc.coroutine_state<@CoroutineWithResults>, !arc.coroutine_pc<@CoroutineWithResults>, i42, i9001) + return +} + +// CHECK-LABEL: hw.module @CoroutineInstanceA +hw.module @CoroutineInstanceA(in %a: i42, out z: i9001) { + // CHECK: arc.coroutine.instance @CoroutineInstanceB(%a) : (i42) -> i9001 + %0 = arc.coroutine.instance @CoroutineInstanceB(%a) : (i42) -> i9001 + hw.output %0 : i9001 +} +arc.coroutine.define @CoroutineInstanceB(%arg0: i42) -> (i9001, i64) { + %c0_i9001 = hw.constant 0 : i9001 + %c0_i64 = hw.constant 0 : i64 + arc.coroutine.halt %c0_i9001, %c0_i64 : i9001, i64 +}