Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions src/core/Microsoft.Dynamic/Ast/AsyncEnumerableExpression.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information.

#nullable enable

#if NET

using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

using Microsoft.Scripting.Utils;

namespace Microsoft.Scripting.Ast {
/// <summary>
/// Wraps an async-generator body (one that contains both <see cref="AwaitExpression"/> nodes and
/// <see cref="YieldExpression"/> nodes targeting <see cref="YieldLabel"/>) into an expression that
/// evaluates to <see cref="IAsyncEnumerable{T}"/> of <see cref="object"/>.
/// </summary>
/// <remarks>
/// Await points are rewritten to <c>yield AwaitPoint(task)</c> against the <em>same</em> label as the
/// language-level <c>yield</c>s, so a single <c>GeneratorRewriter</c>-produced
/// <see cref="IEnumerator{T}"/> carries both kinds of items.
/// <see cref="Microsoft.Scripting.Runtime.AsyncHelpers.DriveAsyncEnumerable"/> then awaits
/// <see cref="Microsoft.Scripting.Runtime.AwaitPoint"/> items internally and emits the rest to the
/// consumer. This marker is what lets <c>await</c> and <c>yield</c> coexist: a yielded Task is not an
/// AwaitPoint, so it is surfaced as a value rather than awaited.
/// </remarks>
public sealed class AsyncEnumerableExpression : Expression {
private Expression? _reduced;

internal AsyncEnumerableExpression(string? name, Expression body, LabelTarget yieldLabel,
Expression? cancellationToken = null,
Expression? cancellationException = null) {
Name = name;
Body = body;
YieldLabel = yieldLabel;
CancellationToken = cancellationToken ?? Expression.Default(typeof(CancellationToken));

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guess it doesn't matter either way (just noting the difference), but in AsyncExpression you defined DefaultCancellationToken and DefaultCancellationException.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I didn't notice it. I prefer the reusable statics, since it is slightly better performance at a minuscule startup cost and this is the style many IronPython expressions employ. To keep it DRY and not to incur more startup cost than necessary, I moved it to Utils. Also some more shared stuff, keeping all that internal.

CancellationException = cancellationException ?? Expression.Constant(null, typeof(StrongBox<Exception?>));
}

/// <summary>Optional diagnostic name (forwarded to the inner generator).</summary>
public string? Name { get; }

/// <summary>The generator body. May contain <see cref="AwaitExpression"/> and <see cref="YieldExpression"/> nodes.</summary>
public Expression Body { get; }

/// <summary>
/// The label both the language-level <c>yield</c>s and the rewritten <c>await</c>s target, so they
/// land in one generator. Supplied by the host (e.g. IronPython's shared generator label).
/// </summary>
public LabelTarget YieldLabel { get; }

/// <summary>Expression evaluating to the cancellation token (see <see cref="AsyncExpression"/>). Default <c>default(CancellationToken)</c>.</summary>
public Expression CancellationToken { get; }

/// <summary>Expression evaluating to a <c>StrongBox&lt;Exception?&gt;</c> exception override (or null). Default null.</summary>
public Expression CancellationException { get; }

public override bool CanReduce => true;

public override Type Type => typeof(IAsyncEnumerable<object?>);

public override ExpressionType NodeType => ExpressionType.Extension;

public override Expression Reduce() {
return _reduced ??= new AsyncEnumerableRewriter(this).Reduce();
}

protected override Expression VisitChildren(ExpressionVisitor visitor) {
Expression b = visitor.Visit(Body);
Expression ct = visitor.Visit(CancellationToken);
Expression ce = visitor.Visit(CancellationException);
if (b == Body && ct == CancellationToken && ce == CancellationException) return this;
return new AsyncEnumerableExpression(Name, b, YieldLabel, ct, ce);
}
}

public partial class Utils {
/// <summary>
/// Wraps an async-generator body in an <see cref="AsyncEnumerableExpression"/> producing <c>IAsyncEnumerable&lt;object&gt;</c>.
/// </summary>
/// <param name="yieldLabel">
/// It must be the same label the body's language-level <c>yield</c>s target.
/// </param>
public static AsyncEnumerableExpression AsyncEnumerable(string? name, Expression body, LabelTarget yieldLabel) {
ContractUtils.RequiresNotNull(body, nameof(body));
ContractUtils.RequiresNotNull(yieldLabel, nameof(yieldLabel));
return new AsyncEnumerableExpression(name, body, yieldLabel);
}

/// <summary>
/// Wraps an async-generator body in an <see cref="AsyncEnumerableExpression"/> producing <c>IAsyncEnumerable&lt;object&gt;</c>,
/// with a caller-provided <see cref="System.Threading.CancellationToken"/> and, optionally, an exception-override box.
/// </summary>
/// <remarks>
/// When cancellation fires and the box's <c>Value</c> is non-null, that exception is delivered to
/// the body instead of a fresh <see cref="System.OperationCanceledException"/>. This lets a host inject
/// an arbitrary exception (e.g. Python's <c>coro.throw(exc)</c>) by populating the box and then
/// cancelling the token. <paramref name="cancellationException"/> defaults to <c>null</c>
/// — the plain OCE-on-cancellation behavior.
/// </remarks>
/// <param name="yieldLabel">
/// It must be the same label the body's language-level <c>yield</c>s target.
/// </param>
public static AsyncEnumerableExpression AsyncEnumerable(string? name, Expression body, LabelTarget yieldLabel,
Expression cancellationToken,
Expression? cancellationException = null) {
ContractUtils.RequiresNotNull(body, nameof(body));
ContractUtils.RequiresNotNull(yieldLabel, nameof(yieldLabel));
RequireType(cancellationToken, typeof(CancellationToken), nameof(cancellationToken));
if (cancellationException is not null) {
RequireType(cancellationException, typeof(StrongBox<Exception?>), nameof(cancellationException));
}
return new AsyncEnumerableExpression(name, body, yieldLabel, cancellationToken, cancellationException);
}
}
}

#endif
134 changes: 134 additions & 0 deletions src/core/Microsoft.Dynamic/Ast/AsyncEnumerableRewriter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information.

#nullable enable

#if NET

using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Threading.Tasks;

namespace Microsoft.Scripting.Ast {
/// <summary>
/// Reduces an <see cref="AsyncEnumerableExpression"/> to an <c>IAsyncEnumerable&lt;object&gt;</c>-valued
/// expression tree that yields each <see cref="AwaitExpression"/>'s operand (wrapped in an
/// <see cref="Microsoft.Scripting.Runtime.AwaitPoint"/>) alongside language-level yields, and hands
/// the resulting state machine to <see cref="Microsoft.Scripting.Runtime.AsyncHelpers.DriveAsyncEnumerable"/>.
/// </summary>
internal sealed class AsyncEnumerableRewriter {
private static readonly MethodInfo s_driveMethod
= typeof(Microsoft.Scripting.Runtime.AsyncHelpers).GetMethod("DriveAsyncEnumerable")!;
private static readonly ConstructorInfo s_awaitPointCtor
= typeof(Microsoft.Scripting.Runtime.AwaitPoint).GetConstructor([typeof(Task)])!;
private static readonly FieldInfo s_valueSlotField
= typeof(StrongBox<object?>).GetField(nameof(StrongBox<object?>.Value))!;
private static readonly FieldInfo s_exceptionSlotField
= typeof(StrongBox<Exception?>).GetField(nameof(StrongBox<Exception?>.Value))!;
private static readonly ConstructorInfo s_valueSlotCtor
= typeof(StrongBox<object?>).GetConstructor(Type.EmptyTypes)!;
private static readonly ConstructorInfo s_exceptionSlotCtor
= typeof(StrongBox<Exception?>).GetConstructor(Type.EmptyTypes)!;

private readonly AsyncEnumerableExpression _node;

public AsyncEnumerableRewriter(AsyncEnumerableExpression node) {
_node = node;
}

public Expression Reduce() {
// valueSlot / exceptionSlot carry the per-await result / fault back into the body at each
// await's resume point (same role as in AsyncRewriter). The generator's final value is
// irrelevant — generators don't return a value — so there is no capture step here.
ParameterExpression valueSlot = Expression.Variable(typeof(StrongBox<object?>), "$asyncValue");
ParameterExpression exceptionSlot = Expression.Variable(typeof(StrongBox<Exception?>), "$awaitException");

var rewriter = new AwaitToAwaitPointRewriter(_node.YieldLabel, valueSlot, exceptionSlot);
Expression rewrittenBody = rewriter.Visit(_node.Body);

// Coerce to void for Utils.Generator (the generator body's value is discarded).
Expression generatorBody = rewrittenBody.Type == typeof(void)
? rewrittenBody
: Expression.Block(typeof(void), rewrittenBody);

Expression generator = Utils.Generator(
_node.Name ?? "$asyncgen",
_node.YieldLabel,
generatorBody,
typeof(IEnumerator<object>),
rewriteAssignments: false);

// Argument order matches DriveAsyncEnumerable: ..., cancellationToken, cancellationException
// (same as DriveAsync — cancellationToken is the last required parameter).
Expression drive = Expression.Call(
s_driveMethod,
generator,
valueSlot,
exceptionSlot,
_node.CancellationToken,
_node.CancellationException);

return Expression.Block(
typeof(IAsyncEnumerable<object?>),
[valueSlot, exceptionSlot],
Expression.Assign(valueSlot, Expression.New(s_valueSlotCtor)),
Expression.Assign(exceptionSlot, Expression.New(s_exceptionSlotCtor)),
drive);
}

/// <summary>
/// Rewrites <c>AwaitExpression(task)</c> → <c>{ yield AwaitPoint(task); rethrow-if-pending; valueSlot.Value }</c>,
/// targeting the shared yield label. Mirrors <c>AsyncRewriter.AwaitToYieldRewriter</c> but wraps the awaited
/// Task in an <see cref="Microsoft.Scripting.Runtime.AwaitPoint"/> so the driver distinguishes it from a
/// value yielded by a language-level <c>yield</c>.
/// </summary>
private sealed class AwaitToAwaitPointRewriter : ExpressionVisitor {
private static readonly MethodInfo s_captureMethod
= typeof(ExceptionDispatchInfo).GetMethod(nameof(ExceptionDispatchInfo.Capture))!;
private static readonly MethodInfo s_throwMethod
= typeof(ExceptionDispatchInfo).GetMethod(nameof(ExceptionDispatchInfo.Throw), Type.EmptyTypes)!;

private readonly LabelTarget _yieldLabel;
private readonly ParameterExpression _valueSlot;
private readonly ParameterExpression _exceptionSlot;

public AwaitToAwaitPointRewriter(LabelTarget yieldLabel, ParameterExpression valueSlot, ParameterExpression exceptionSlot) {
_yieldLabel = yieldLabel;
_valueSlot = valueSlot;
_exceptionSlot = exceptionSlot;
}

protected override Expression VisitExtension(Expression node) {
if (node is AwaitExpression aw) {
Expression operand = Visit(aw.Operand);
// Wrap the awaited Task in an AwaitPoint marker, then box to object for the yield.
Expression awaitPoint = Expression.New(s_awaitPointCtor, Expression.Convert(operand, typeof(Task)));
Expression yielded = Expression.Convert(awaitPoint, typeof(object));

Expression readException = Expression.Field(_exceptionSlot, s_exceptionSlotField);
Expression readSlot = Expression.Field(_valueSlot, s_valueSlotField);

Expression rethrow = Expression.IfThen(
Expression.ReferenceNotEqual(readException, Expression.Constant(null, typeof(Exception))),
Expression.Call(
Expression.Call(s_captureMethod, readException),
s_throwMethod));

return Expression.Block(
typeof(object),
Utils.YieldReturn(_yieldLabel, yielded),
rethrow,
readSlot);
}
return base.VisitExtension(node);
}
}
}
}

#endif
Loading
Loading