From 1e9a455470295c521628793ef7a75eb318481324 Mon Sep 17 00:00:00 2001 From: Yaacov Rydzinski Date: Thu, 9 Apr 2026 17:14:30 +0300 Subject: [PATCH] feat(execution): track async work for rejection safety canonize async work tracking so promises started during execution are still observed after early errors or abort paths add `trackPromise` to resolve info async helpers so resolver-adjacent code can register fire-and-forget work for rejection safety in defaultTypeResolver, use `trackPromise(...)` for started `isTypeOf` promises in synchronous early-return and synchronous-throw paths, while still awaiting all results in the async path --- src/execution/AsyncWorkTracker.ts | 42 +++++++ src/execution/Executor.ts | 54 +++++---- .../__tests__/AsyncWorkTracker-test.ts | 105 ++++++++++++++++++ src/execution/__tests__/executor-test.ts | 10 ++ src/execution/createSharedExecutionContext.ts | 32 ++++++ src/execution/execute.ts | 23 ++-- .../incremental/IncrementalExecutor.ts | 64 ++++------- src/index.ts | 1 + src/jsutils/promiseForObject.ts | 21 ++-- src/type/definition.ts | 5 + src/type/index.ts | 1 + 11 files changed, 277 insertions(+), 81 deletions(-) create mode 100644 src/execution/AsyncWorkTracker.ts create mode 100644 src/execution/__tests__/AsyncWorkTracker-test.ts diff --git a/src/execution/AsyncWorkTracker.ts b/src/execution/AsyncWorkTracker.ts new file mode 100644 index 0000000000..c2c2e83e89 --- /dev/null +++ b/src/execution/AsyncWorkTracker.ts @@ -0,0 +1,42 @@ +import { isPromise } from '../jsutils/isPromise.js'; +import type { PromiseOrValue } from '../jsutils/PromiseOrValue.js'; + +/** @internal */ +export class AsyncWorkTracker { + pendingAsyncWork: Set>; + + constructor() { + this.pendingAsyncWork = new Set>(); + } + + add(promise: Promise): void { + const pendingAsyncWork = this.pendingAsyncWork; + const promiseToSettle = promise.then( + () => { + pendingAsyncWork.delete(promiseToSettle); + }, + () => { + pendingAsyncWork.delete(promiseToSettle); + }, + ); + pendingAsyncWork.add(promiseToSettle); + } + + addValues(values: ReadonlyArray>): void { + for (const value of values) { + if (isPromise(value)) { + this.add(value); + } + } + } + + promiseAllTrackOnReject( + values: ReadonlyArray>, + ): Promise> { + const promise = Promise.all(values); + promise.then(undefined, () => { + this.addValues(values); + }); + return promise; + } +} diff --git a/src/execution/Executor.ts b/src/execution/Executor.ts index 3be5027579..e0b8fda607 100644 --- a/src/execution/Executor.ts +++ b/src/execution/Executor.ts @@ -32,6 +32,7 @@ import type { GraphQLObjectType, GraphQLOutputType, GraphQLResolveInfo, + GraphQLResolveInfoHelpers, GraphQLTypeResolver, } from '../type/definition.js'; import { @@ -231,6 +232,11 @@ export class Executor< abortResultPromise: ((reason?: unknown) => void) | undefined; resolverAbortController: AbortController | undefined; getAbortSignal: () => AbortSignal | undefined; + getAsyncHelpers: () => GraphQLResolveInfoHelpers; + promiseAll: ( + values: ReadonlyArray>, + ) => Promise>; + trackPromise: (promise: Promise) => void; constructor( validatedExecutionArgs: ValidatedExecutionArgs, @@ -249,8 +255,12 @@ export class Executor< } else { this.sharedExecutionContext = sharedExecutionContext; } - const { getAbortSignal } = this.sharedExecutionContext; + const { getAbortSignal, getAsyncHelpers, promiseAll, trackPromise } = + this.sharedExecutionContext; this.getAbortSignal = getAbortSignal; + this.getAsyncHelpers = getAsyncHelpers; + this.promiseAll = promiseAll; + this.trackPromise = trackPromise; } executeQueryOrMutationOrSubscriptionEvent(): PromiseOrValue< @@ -261,10 +271,7 @@ export class Executor< if (externalAbortSignal) { externalAbortSignal.throwIfAborted(); const onExternalAbort = () => { - const aborted = this.abort(externalAbortSignal.reason); - if (isPromise(aborted)) { - aborted.catch(() => undefined); - } + this.abort(externalAbortSignal.reason); }; removeExternalAbortListener = () => externalAbortSignal.removeEventListener('abort', onExternalAbort); @@ -324,6 +331,7 @@ export class Executor< return this.buildResponse(null); }, ); + this.sharedExecutionContext.asyncWorkTracker.add(promise); const { promise: cancellablePromise, abort: abortResultPromise } = withCancellation(promise); this.abortResultPromise = abortResultPromise; @@ -347,7 +355,7 @@ export class Executor< } } - abort(reason?: unknown): PromiseOrValue { + abort(reason?: unknown): void { if (this.aborted) { return; } @@ -506,8 +514,9 @@ export class Executor< } } catch (error) { if (containsPromise) { - // Ensure that any promises returned by other fields are handled, as they may also reject. - promiseForObject(results).catch(() => undefined); + this.sharedExecutionContext.asyncWorkTracker.addValues( + Object.values(results), + ); } throw error; } @@ -520,7 +529,7 @@ export class Executor< // Otherwise, results is a map from field name to the result of resolving that // field, which is possibly a promise. Return a promise that will return this // same map, but with any promises replaced with the values they resolved to. - return promiseForObject(results); + return promiseForObject(results, this.promiseAll); } /** @@ -557,6 +566,7 @@ export class Executor< parentType, path, this.getAbortSignal, + this.getAsyncHelpers, ); // Get the resolve function, regardless of if its result is normal or abrupt (error). @@ -853,10 +863,11 @@ export class Executor< index++; } } catch (error) { - // eslint-disable-next-line @typescript-eslint/no-floating-promises - returnIteratorCatchingErrors(asyncIterator); + this.trackPromise(returnIteratorCatchingErrors(asyncIterator)); if (containsPromise) { - Promise.all(completedResults).catch(() => undefined); + this.sharedExecutionContext.asyncWorkTracker.addValues( + completedResults, + ); } throw error; } @@ -864,13 +875,14 @@ export class Executor< // Throwing on completion outside of the loop may allow engines to better optimize if (this.aborted) { if (!iteration?.done) { - // eslint-disable-next-line @typescript-eslint/no-floating-promises - returnIteratorCatchingErrors(asyncIterator); + this.trackPromise(returnIteratorCatchingErrors(asyncIterator)); } throw new Error('Aborted!'); } - return containsPromise ? Promise.all(completedResults) : completedResults; + return containsPromise + ? this.promiseAll(completedResults) + : completedResults; } /* c8 ignore next 12 */ @@ -991,15 +1003,17 @@ export class Executor< index++; } } catch (error) { - const maybePromises = containsPromise ? completedResults : []; - maybePromises.push(...collectIteratorPromises(iterator)); - if (maybePromises.length) { - Promise.all(maybePromises).catch(() => undefined); + const asyncWorkTracker = this.sharedExecutionContext.asyncWorkTracker; + if (containsPromise) { + asyncWorkTracker.addValues(completedResults); } + asyncWorkTracker.addValues(collectIteratorPromises(iterator)); throw error; } - return containsPromise ? Promise.all(completedResults) : completedResults; + return containsPromise + ? this.promiseAll(completedResults) + : completedResults; } completeMaybePromisedListItemValue( diff --git a/src/execution/__tests__/AsyncWorkTracker-test.ts b/src/execution/__tests__/AsyncWorkTracker-test.ts new file mode 100644 index 0000000000..05f408ac4d --- /dev/null +++ b/src/execution/__tests__/AsyncWorkTracker-test.ts @@ -0,0 +1,105 @@ +import { expect } from 'chai'; +import { describe, it } from 'mocha'; + +import { expectEqualPromisesOrValues } from '../../__testUtils__/expectEqualPromisesOrValues.js'; +import { resolveOnNextTick } from '../../__testUtils__/resolveOnNextTick.js'; + +import { promiseWithResolvers } from '../../jsutils/promiseWithResolvers.js'; + +import { AsyncWorkTracker } from '../AsyncWorkTracker.js'; + +describe('AsyncWorkTracker', () => { + it('works to track promises', async () => { + const tracker = new AsyncWorkTracker(); + const delayed = promiseWithResolvers(); + + tracker.add(delayed.promise); + expect(tracker.pendingAsyncWork.size).to.equal(1); + delayed.resolve(1); + await resolveOnNextTick(); + expect(tracker.pendingAsyncWork.size).to.equal(0); + }); +}); + +describe('promiseAllTrackOnReject', () => { + it('resolves like Promise.all', async () => { + const tracker = new AsyncWorkTracker(); + + const values = [Promise.resolve(1), Promise.resolve(2), Promise.resolve(3)]; + + await expectEqualPromisesOrValues([ + tracker.promiseAllTrackOnReject(values), + Promise.all(values), + ]); + }); + + it('resolves synchronous values without tracking', async () => { + const tracker = new AsyncWorkTracker(); + + const result = await tracker.promiseAllTrackOnReject([1, 2, 3]); + + expect(result).to.deep.equal([1, 2, 3]); + expect(tracker.pendingAsyncWork.size).to.equal(0); + }); + + it('does not add an extra microtask on fulfilled promiseAll results', async () => { + const tracker = new AsyncWorkTracker(); + let settled = false; + + const promise = Promise.resolve(1); + const trackedPromise = tracker.promiseAllTrackOnReject([promise]); + trackedPromise.then( + () => { + settled = true; + }, + () => undefined, + ); + await Promise.all([promise]); + expect(settled).to.equal(true); + }); + + it('tracks all promises only after rejection', async () => { + const delayed = promiseWithResolvers(); + const tracker = new AsyncWorkTracker(); + const result = tracker.promiseAllTrackOnReject([ + Promise.reject(new Error('bad')), + delayed.promise, + ] as const); + expect(tracker.pendingAsyncWork.size).to.equal(0); + + await result.catch(() => undefined); + expect(tracker.pendingAsyncWork.size).to.equal(1); + delayed.resolve(undefined); + + await resolveOnNextTick(); + expect(tracker.pendingAsyncWork.size).to.equal(0); + }); + + it('tracks promises until they settle and catches later rejections', async () => { + let unhandledRejection: unknown = null; + const unhandledRejectionListener = (reason: unknown) => { + unhandledRejection = reason; + }; + // eslint-disable-next-line no-undef + process.on('unhandledRejection', unhandledRejectionListener); + + const tracker = new AsyncWorkTracker(); + const delayed = promiseWithResolvers(); + const result = tracker.promiseAllTrackOnReject([ + Promise.reject(new Error('bad')), + delayed.promise, + ] as const); + + await result.catch(() => undefined); + expect(tracker.pendingAsyncWork.size).to.equal(1); + + delayed.reject(new Error('late bad')); + await new Promise((resolve) => setTimeout(resolve, 20)); + + // eslint-disable-next-line no-undef + process.removeListener('unhandledRejection', unhandledRejectionListener); + + expect(tracker.pendingAsyncWork.size).to.equal(0); + expect(unhandledRejection).to.equal(null); + }); +}); diff --git a/src/execution/__tests__/executor-test.ts b/src/execution/__tests__/executor-test.ts index c05b952132..1f540e3b61 100644 --- a/src/execution/__tests__/executor-test.ts +++ b/src/execution/__tests__/executor-test.ts @@ -258,7 +258,10 @@ describe('Execute: Handles basic execution tasks', () => { 'operation', 'variableValues', 'getAbortSignal', + 'getAsyncHelpers', ); + const asyncHelpers = resolvedInfo?.getAsyncHelpers(); + expect(asyncHelpers).to.have.all.keys('trackPromise'); const operation = document.definitions[0]; assert(operation.kind === Kind.OPERATION_DEFINITION); @@ -295,6 +298,13 @@ describe('Execute: Handles basic execution tasks', () => { expect(abortSignal).to.be.instanceOf(AbortSignal); expect(resolvedInfo?.getAbortSignal()).to.equal(abortSignal); + expect(resolvedInfo?.getAsyncHelpers()).to.equal(asyncHelpers); + + const trackPromise = asyncHelpers?.trackPromise; + expect(trackPromise).to.be.a('function'); + expect(resolvedInfo?.getAsyncHelpers().trackPromise).to.equal(trackPromise); + trackPromise?.(Promise.resolve()); + resolve(); await result; diff --git a/src/execution/createSharedExecutionContext.ts b/src/execution/createSharedExecutionContext.ts index d2eadbd57c..381fbd3917 100644 --- a/src/execution/createSharedExecutionContext.ts +++ b/src/execution/createSharedExecutionContext.ts @@ -1,12 +1,44 @@ +import type { PromiseOrValue } from '../jsutils/PromiseOrValue.js'; + +import type { GraphQLResolveInfoHelpers } from '../type/index.js'; + +import { AsyncWorkTracker } from './AsyncWorkTracker.js'; + /** @internal */ export interface SharedExecutionContext { + asyncWorkTracker: AsyncWorkTracker; getAbortSignal: () => AbortSignal | undefined; + getAsyncHelpers: () => GraphQLResolveInfoHelpers; + promiseAll: ( + values: ReadonlyArray>, + ) => Promise>; + trackPromise: (promise: Promise) => void; } export function createSharedExecutionContext( abortSignal: AbortSignal | undefined, ): SharedExecutionContext { + const asyncWorkTracker = new AsyncWorkTracker(); + let resolveInfoHelpers: GraphQLResolveInfoHelpers | undefined; + + const promiseAll = ( + values: ReadonlyArray>, + ): Promise> => asyncWorkTracker.promiseAllTrackOnReject(values); + + const trackPromise = (promise: Promise): void => { + asyncWorkTracker.add(promise); + }; + + const getAsyncHelpers = (): GraphQLResolveInfoHelpers => + (resolveInfoHelpers ??= { + trackPromise, + }); + return { + asyncWorkTracker, getAbortSignal: () => abortSignal, + getAsyncHelpers, + promiseAll, + trackPromise, }; } diff --git a/src/execution/execute.ts b/src/execution/execute.ts index 1272749c37..181be66433 100644 --- a/src/execution/execute.ts +++ b/src/execution/execute.ts @@ -26,6 +26,7 @@ import type { GraphQLFieldResolver, GraphQLObjectType, GraphQLResolveInfo, + GraphQLResolveInfoHelpers, GraphQLTypeResolver, } from '../type/index.js'; import { assertValidSchema } from '../type/index.js'; @@ -440,7 +441,7 @@ export const defaultTypeResolver: GraphQLTypeResolver = // Otherwise, test each possible type. const possibleTypes = info.schema.getPossibleTypes(abstractType); - const promisedIsTypeOfResults = []; + const promisedIsTypeOfResults: Array> = []; try { for (let i = 0; i < possibleTypes.length; i++) { @@ -453,12 +454,10 @@ export const defaultTypeResolver: GraphQLTypeResolver = promisedIsTypeOfResults[i] = isTypeOfResult; } else if (isTypeOfResult) { if (promisedIsTypeOfResults.length) { - // Explicitly ignore any promise rejections - Promise.allSettled(promisedIsTypeOfResults) - /* c8 ignore next 3 */ - .catch(() => { - // Do nothing - }); + const { trackPromise } = info.getAsyncHelpers(); + for (const promisedIsTypeOfResult of promisedIsTypeOfResults) { + trackPromise(promisedIsTypeOfResult); + } } return type.name; } @@ -466,9 +465,10 @@ export const defaultTypeResolver: GraphQLTypeResolver = } } catch (error) { if (promisedIsTypeOfResults.length) { - return Promise.allSettled(promisedIsTypeOfResults).then(() => { - throw error; - }); + const { trackPromise } = info.getAsyncHelpers(); + for (const promisedIsTypeOfResult of promisedIsTypeOfResults) { + trackPromise(promisedIsTypeOfResult); + } } throw error; } @@ -609,6 +609,7 @@ function executeSubscription( rootType, path, sharedExecutionContext.getAbortSignal, + sharedExecutionContext.getAsyncHelpers, ); try { @@ -683,6 +684,7 @@ export function buildResolveInfo( parentType: GraphQLObjectType, path: Path, getAbortSignal: () => AbortSignal | undefined, + getAsyncHelpers: () => GraphQLResolveInfoHelpers, ): GraphQLResolveInfo { const { schema, fragmentDefinitions, rootValue, operation, variableValues } = validatedExecutionArgs; @@ -700,6 +702,7 @@ export function buildResolveInfo( operation, variableValues, getAbortSignal, + getAsyncHelpers, }; } diff --git a/src/execution/incremental/IncrementalExecutor.ts b/src/execution/incremental/IncrementalExecutor.ts index c4f64baf9b..46b301d042 100644 --- a/src/execution/incremental/IncrementalExecutor.ts +++ b/src/execution/incremental/IncrementalExecutor.ts @@ -320,25 +320,15 @@ export class IncrementalExecutor< ); } - override abort(reason?: unknown): PromiseOrValue { - const abortPromises: Array> = []; - const superAborted = super.abort(reason); - // Executor.abort is currently synchronous - invariant(!isPromise(superAborted)); + override abort(reason?: unknown): void { + super.abort(reason); for (const task of this.tasks) { const aborted = task.computation.abort(reason); - if (isPromise(aborted)) { - abortPromises.push(aborted); - } + invariant(!isPromise(aborted)); } for (const stream of this.streams) { const aborted = stream.queue.abort(reason); - if (isPromise(aborted)) { - abortPromises.push(aborted); - } - } - if (abortPromises.length > 0) { - return Promise.allSettled(abortPromises).then(() => undefined); + invariant(!isPromise(aborted)); } } @@ -555,7 +545,7 @@ export class IncrementalExecutor< deliveryGroupMap, ); } catch (error) { - ignoreAbortCleanup(this.abort()); + this.abort(); throw error; } @@ -564,7 +554,7 @@ export class IncrementalExecutor< (resolved) => this.buildExecutionGroupResult(deliveryGroups, path, resolved), (error: unknown) => { - ignoreAbortCleanup(this.abort()); + this.abort(); throw error; }, ); @@ -603,7 +593,8 @@ export class IncrementalExecutor< const filteredTasks: Array = []; for (const task of tasks) { if (collectedErrors.hasNulledPosition(task.path)) { - ignoreAbortCleanup(task.computation.abort(cancellationReason)); + const aborted = task.computation.abort(cancellationReason); + invariant(!isPromise(aborted)); } else { filteredTasks.push(task); } @@ -612,7 +603,8 @@ export class IncrementalExecutor< const filteredStreams: Array = []; for (const stream of streams) { if (collectedErrors.hasNulledPosition(stream.path)) { - ignoreAbortCleanup(stream.queue.abort(cancellationReason)); + const aborted = stream.queue.abort(cancellationReason); + invariant(!isPromise(aborted)); } else { filteredStreams.push(stream); } @@ -732,35 +724,27 @@ export class IncrementalExecutor< const { enableEarlyExecution } = this.validatedExecutionArgs; const queue = new Queue( async ({ push, stop, onStop, started }) => { - const abortStreamItems = new Set< - (reason?: unknown) => PromiseOrValue - >(); + const abortStreamItems = new Set<(reason?: unknown) => void>(); let finishedNormally = false; let stopRequested = false; onStop((reason) => { stopRequested = true; if (!finishedNormally) { - const abortPromises: Array> = []; for (const abortStreamItem of abortStreamItems) { - const result = abortStreamItem(reason); - if (isPromise(result)) { - abortPromises.push(result); - } + abortStreamItem(reason); } if (isAsync) { - const returned = returnIteratorCatchingErrors( - iterator as AsyncIterator, + this.sharedExecutionContext.trackPromise( + returnIteratorCatchingErrors( + iterator as AsyncIterator, + ), ); - abortPromises.push(returned); } else { - abortPromises.push( - ...collectIteratorPromises(iterator as Iterator), + this.sharedExecutionContext.asyncWorkTracker.addValues( + collectIteratorPromises(iterator as Iterator), ); } - if (abortPromises.length > 0) { - return Promise.allSettled(abortPromises).then(() => undefined); - } } }); await (enableEarlyExecution ? Promise.resolve() : started); @@ -872,7 +856,7 @@ export class IncrementalExecutor< }, ) .then(undefined, (error: unknown) => { - ignoreAbortCleanup(this.abort()); + this.abort(); throw error; }); } @@ -893,7 +877,7 @@ export class IncrementalExecutor< return this.buildStreamItemResult(null); } } catch (error) { - ignoreAbortCleanup(this.abort()); + this.abort(); throw error; } @@ -912,7 +896,7 @@ export class IncrementalExecutor< }, ) .then(undefined, (error: unknown) => { - ignoreAbortCleanup(this.abort()); + this.abort(); throw error; }); } @@ -931,12 +915,6 @@ export class IncrementalExecutor< } } -function ignoreAbortCleanup(aborted: PromiseOrValue): void { - if (isPromise(aborted)) { - aborted.catch(() => undefined); - } -} - function toNodes(fieldDetailsList: FieldDetailsList): ReadonlyArray { return fieldDetailsList.map((fieldDetails) => fieldDetails.node); } diff --git a/src/index.ts b/src/index.ts index 3baacf2a79..2b6acaf537 100644 --- a/src/index.ts +++ b/src/index.ts @@ -213,6 +213,7 @@ export type { GraphQLObjectTypeConfig, GraphQLObjectTypeExtensions, GraphQLResolveInfo, + GraphQLResolveInfoHelpers, ResponsePath, GraphQLScalarTypeConfig, GraphQLScalarTypeExtensions, diff --git a/src/jsutils/promiseForObject.ts b/src/jsutils/promiseForObject.ts index ff48d9f218..f8d9499666 100644 --- a/src/jsutils/promiseForObject.ts +++ b/src/jsutils/promiseForObject.ts @@ -1,4 +1,5 @@ import type { ObjMap } from './ObjMap.js'; +import type { PromiseOrValue } from './PromiseOrValue.js'; /** * This function transforms a JS object `ObjMap>` into @@ -7,16 +8,20 @@ import type { ObjMap } from './ObjMap.js'; * This is akin to bluebird's `Promise.props`, but implemented only using * `Promise.all` so it will work with any implementation of ES6 promises. */ -export async function promiseForObject( - object: ObjMap>, +export function promiseForObject( + object: Readonly>>, + promiseAll: ( + values: ReadonlyArray>, + ) => Promise>, ): Promise> { const keys = Object.keys(object); const values = Object.values(object); - const resolvedValues = await Promise.all(values); - const resolvedObject = Object.create(null); - for (let i = 0; i < keys.length; ++i) { - resolvedObject[keys[i]] = resolvedValues[i]; - } - return resolvedObject; + return promiseAll(values).then((resolvedValues) => { + const resolvedObject = Object.create(null); + for (let i = 0; i < keys.length; ++i) { + resolvedObject[keys[i]] = resolvedValues[i]; + } + return resolvedObject; + }); } diff --git a/src/type/definition.ts b/src/type/definition.ts index 30926a7fb7..206f97b3e4 100644 --- a/src/type/definition.ts +++ b/src/type/definition.ts @@ -1045,6 +1045,10 @@ export type GraphQLFieldResolver< info: GraphQLResolveInfo, ) => TResult; +export interface GraphQLResolveInfoHelpers { + readonly trackPromise: (promise: Promise) => void; +} + export interface GraphQLResolveInfo { readonly fieldName: string; readonly fieldNodes: ReadonlyArray; @@ -1057,6 +1061,7 @@ export interface GraphQLResolveInfo { readonly operation: OperationDefinitionNode; readonly variableValues: VariableValues; readonly getAbortSignal: () => AbortSignal | undefined; + readonly getAsyncHelpers: () => GraphQLResolveInfoHelpers; } /** diff --git a/src/type/index.ts b/src/type/index.ts index dd9d103868..7c50f97f34 100644 --- a/src/type/index.ts +++ b/src/type/index.ts @@ -121,6 +121,7 @@ export type { GraphQLObjectTypeConfig, GraphQLObjectTypeExtensions, GraphQLResolveInfo, + GraphQLResolveInfoHelpers, GraphQLScalarTypeConfig, GraphQLScalarTypeExtensions, GraphQLTypeResolver,