Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions src/execution/AsyncWorkTracker.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import { isPromise } from '../jsutils/isPromise.js';
import type { PromiseOrValue } from '../jsutils/PromiseOrValue.js';

/** @internal */
export class AsyncWorkTracker {
pendingAsyncWork: Set<Promise<void>>;

constructor() {
this.pendingAsyncWork = new Set<Promise<void>>();
}

add(promise: Promise<unknown>): void {
const pendingAsyncWork = this.pendingAsyncWork;
const promiseToSettle = promise.then(
() => {
pendingAsyncWork.delete(promiseToSettle);
},
() => {
pendingAsyncWork.delete(promiseToSettle);
},
);
pendingAsyncWork.add(promiseToSettle);
}

addValues(values: ReadonlyArray<PromiseOrValue<unknown>>): void {
for (const value of values) {
if (isPromise(value)) {
this.add(value);
}
}
}

async wait(): Promise<void> {
await Promise.resolve();
return waitForPendingSet(this.pendingAsyncWork);
}

promiseAllTrackOnReject<T>(
values: ReadonlyArray<PromiseOrValue<T>>,
): Promise<Array<T>> {
const promise = Promise.all(values);
promise.then(undefined, () => {
this.addValues(values);
});
return promise;
}

promiseTrackPending<T, TResult>(
values: ReadonlyArray<PromiseOrValue<T>>,
combinator: (promises: ReadonlyArray<Promise<T>>) => Promise<TResult>,
): Promise<TResult> {
const promises = values.map((value) => Promise.resolve(value));
const settled = promises.map(() => false);

for (let index = 0; index < promises.length; index++) {
const promise = promises[index];
promise.then(
() => {
settled[index] = true;
},
() => {
settled[index] = true;
},
);
}

const trackPending = () => {
for (let index = 0; index < promises.length; index++) {
if (!settled[index]) {
this.add(promises[index]);
}
}
};

const promise = combinator(promises);
promise.then(trackPending, trackPending);
return promise;
}
}

async function waitForPendingSet(
pendingPromises: ReadonlySet<Promise<void>>,
): Promise<void> {
while (pendingPromises.size > 0) {
// eslint-disable-next-line no-await-in-loop
await Promise.allSettled(Array.from(pendingPromises));
}
}
64 changes: 38 additions & 26 deletions src/execution/Executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { addPath, pathToArray } from '../jsutils/Path.js';
import { promiseForObject } from '../jsutils/promiseForObject.js';
import type { PromiseOrValue } from '../jsutils/PromiseOrValue.js';
import { promiseReduce } from '../jsutils/promiseReduce.js';
import { toError } from '../jsutils/toError.js';

import { ensureGraphQLError } from '../error/ensureGraphQLError.js';
import type { GraphQLFormattedError } from '../error/GraphQLError.js';
Expand All @@ -32,6 +33,7 @@ import type {
GraphQLObjectType,
GraphQLOutputType,
GraphQLResolveInfo,
GraphQLResolveInfoHelpers,
GraphQLTypeResolver,
} from '../type/definition.js';
import {
Expand Down Expand Up @@ -231,6 +233,11 @@ export class Executor<
abortResultPromise: ((reason?: unknown) => void) | undefined;
resolverAbortController: AbortController | undefined;
getAbortSignal: () => AbortSignal | undefined;
getAsyncHelpers: () => GraphQLResolveInfoHelpers;
trackPromise: (cleanup: Promise<unknown>) => void;
promiseAll: <T>(
values: ReadonlyArray<PromiseOrValue<T>>,
) => Promise<Array<T>>;

constructor(
validatedExecutionArgs: ValidatedExecutionArgs,
Expand All @@ -249,8 +256,12 @@ export class Executor<
} else {
this.sharedExecutionContext = sharedExecutionContext;
}
const { getAbortSignal } = this.sharedExecutionContext;
const { getAbortSignal, getAsyncHelpers, getRegisterCleanup, promiseAll } =
this.sharedExecutionContext;
this.getAbortSignal = getAbortSignal;
this.getAsyncHelpers = getAsyncHelpers;
this.trackPromise = getRegisterCleanup();
this.promiseAll = promiseAll;
}

executeQueryOrMutationOrSubscriptionEvent(): PromiseOrValue<
Expand All @@ -261,10 +272,7 @@ export class Executor<
if (externalAbortSignal) {
externalAbortSignal.throwIfAborted();
const onExternalAbort = () => {
const aborted = this.abort(externalAbortSignal.reason);
if (isPromise(aborted)) {
aborted.catch(() => undefined);
}
this.abort(externalAbortSignal.reason);
};
removeExternalAbortListener = () =>
externalAbortSignal.removeEventListener('abort', onExternalAbort);
Expand Down Expand Up @@ -326,6 +334,7 @@ export class Executor<
return this.buildResponse(null);
},
);
this.sharedExecutionContext.asyncWorkTracker.add(promise);
const { promise: cancellablePromise, abort: abortResultPromise } =
withCancellation(promise);
this.abortResultPromise = abortResultPromise;
Expand All @@ -349,7 +358,7 @@ export class Executor<
}
}

abort(reason?: unknown): PromiseOrValue<void> {
abort(reason?: unknown): void {
if (this.aborted) {
return;
}
Expand Down Expand Up @@ -508,10 +517,12 @@ export class Executor<
}
} catch (error) {
if (containsPromise) {
// Ensure that any promises returned by other fields are handled, as they may also reject.
return promiseForObject(results).finally(() => {
throw error;
}) as never;
this.sharedExecutionContext.asyncWorkTracker.addValues(
Object.values(results),
);
if (this.aborted) {
return Promise.reject(toError(error));
}
}
throw error;
}
Expand All @@ -524,7 +535,7 @@ export class Executor<
// Otherwise, results is a map from field name to the result of resolving that
// field, which is possibly a promise. Return a promise that will return this
// same map, but with any promises replaced with the values they resolved to.
return promiseForObject(results);
return promiseForObject(results, this.promiseAll);
}

/**
Expand Down Expand Up @@ -561,6 +572,7 @@ export class Executor<
parentType,
path,
this.getAbortSignal,
this.getAsyncHelpers,
);

// Get the resolve function, regardless of if its result is normal or abrupt (error).
Expand Down Expand Up @@ -857,26 +869,26 @@ export class Executor<
index++;
}
} catch (error) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
returnIteratorCatchingErrors(asyncIterator);
this.trackPromise(returnIteratorCatchingErrors(asyncIterator));
if (containsPromise) {
return Promise.all(completedResults).finally(() => {
throw error;
});
this.sharedExecutionContext.asyncWorkTracker.addValues(
completedResults,
);
}
throw error;
}

// Throwing on completion outside of the loop may allow engines to better optimize
if (this.aborted) {
if (!iteration?.done) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
returnIteratorCatchingErrors(asyncIterator);
this.trackPromise(returnIteratorCatchingErrors(asyncIterator));
}
throw new Error('Aborted!');
}

return containsPromise ? Promise.all(completedResults) : completedResults;
return containsPromise
? this.promiseAll(completedResults)
: completedResults;
}

/* c8 ignore next 12 */
Expand Down Expand Up @@ -997,17 +1009,17 @@ export class Executor<
index++;
}
} catch (error) {
const maybePromises = containsPromise ? completedResults : [];
maybePromises.push(...collectIteratorPromises(iterator));
if (maybePromises.length) {
return Promise.all(maybePromises).finally(() => {
throw error;
});
const asyncWorkTracker = this.sharedExecutionContext.asyncWorkTracker;
if (containsPromise) {
asyncWorkTracker.addValues(completedResults);
}
asyncWorkTracker.addValues(collectIteratorPromises(iterator));
throw error;
}

return containsPromise ? Promise.all(completedResults) : completedResults;
return containsPromise
? this.promiseAll(completedResults)
: completedResults;
}

completeMaybePromisedListItemValue(
Expand Down
Loading
Loading