diff --git a/src/internal/concurrency/async-abort-controller.ts b/src/internal/concurrency/async-abort-controller.ts index d25ec2695..9a7d58a09 100644 --- a/src/internal/concurrency/async-abort-controller.ts +++ b/src/internal/concurrency/async-abort-controller.ts @@ -1,41 +1,100 @@ /** * This special AbortController is used to wait for all the abort handlers to finish before resolving the promise. */ +type AbortListener = EventListenerOrEventListenerObject + +type ListenerRecord = { + wrapped: EventListener + cleanup: () => void +} + export class AsyncAbortController extends AbortController { - protected promises: Promise[] = [] + protected runningPromises = new Set>() + protected abortListeners = new WeakMap>() protected _nextGroup?: AsyncAbortController constructor() { super() - const originalEventListener = this.signal.addEventListener + const originalAddEventListener = this.signal.addEventListener.bind(this.signal) + const originalRemoveEventListener = this.signal.removeEventListener.bind(this.signal) // Patch event addEventListener to keep track of listeners and their promises - this.signal.addEventListener = (type: string, listener: any, options: any) => { + this.signal.addEventListener = ( + type: string, + listener: EventListenerOrEventListenerObject | null, + options?: boolean | AddEventListenerOptions + ) => { + if (!listener) { + return + } + if (type !== 'abort') { - return originalEventListener.call(this.signal, type, listener, options) + return originalAddEventListener(type, listener, options) + } + + if (this.signal.aborted) { + return originalAddEventListener(type, listener, options) + } + + const capture = getCaptureOption(options) + const existingRecord = this.getAbortListenerRecord(listener, capture) + if (existingRecord) { + return originalAddEventListener(type, existingRecord.wrapped, options) + } + + const registrationSignal = getRegistrationSignal(options) + if (registrationSignal?.aborted) { + return + } + + let wrapped!: EventListener + const cleanupRegistrationSignal = this.watchListenerRemovalSignal( + registrationSignal, + listener, + capture + ) + + wrapped = (event: Event) => { + this.deleteAbortListenerRecord(listener, capture) + originalRemoveEventListener(type, wrapped, capture) + + const runningPromise = this.invokeAbortListener(listener, event) + this.runningPromises.add(runningPromise) + void runningPromise.finally(() => { + this.runningPromises.delete(runningPromise) + }) } - let resolving: undefined | (() => Promise) = undefined - const promise = new Promise((resolve, reject) => { - resolving = async (): Promise => { - return Promise.resolve() - .then(() => listener()) - .then(() => { - resolve() - }) - .catch((error) => { - reject(error) - }) - } + this.setAbortListenerRecord(listener, capture, { + wrapped, + cleanup: cleanupRegistrationSignal, }) - this.promises.push(promise) - if (!resolving) { - throw new Error('resolve is undefined') + return originalAddEventListener(type, wrapped, options) + } + + this.signal.removeEventListener = ( + type: string, + listener: EventListenerOrEventListenerObject | null, + options?: boolean | EventListenerOptions + ) => { + if (!listener) { + return + } + + if (type !== 'abort') { + return originalRemoveEventListener(type, listener, options) + } + + const capture = getCaptureOption(options) + const record = this.getAbortListenerRecord(listener, capture) + if (!record) { + return originalRemoveEventListener(type, listener, options) } - return originalEventListener.call(this.signal, type, resolving, options) + this.deleteAbortListenerRecord(listener, capture) + return originalRemoveEventListener(type, record.wrapped, options) } } @@ -50,8 +109,8 @@ export class AsyncAbortController extends AbortController { async abortAsync() { this.abort() - while (this.promises.length > 0) { - const promises = this.promises.splice(0, 100) + while (this.runningPromises.size > 0) { + const promises = Array.from(this.runningPromises) await Promise.allSettled(promises) } await this.abortNextGroup() @@ -62,4 +121,108 @@ export class AsyncAbortController extends AbortController { await this._nextGroup.abortAsync() } } + + protected invokeAbortListener(listener: AbortListener, event: Event): Promise { + try { + const result = + typeof listener === 'function' + ? listener.call(this.signal, event) + : listener.handleEvent(event) + + return Promise.resolve(result).then(() => undefined) + } catch (error) { + return Promise.reject(error) + } + } + + protected getAbortListenerRecord( + listener: AbortListener, + capture: boolean + ): ListenerRecord | undefined { + return this.abortListeners.get(listener)?.get(capture) + } + + protected setAbortListenerRecord( + listener: AbortListener, + capture: boolean, + record: ListenerRecord + ) { + const records = this.abortListeners.get(listener) ?? new Map() + records.set(capture, record) + this.abortListeners.set(listener, records) + } + + protected deleteAbortListenerRecord(listener: AbortListener, capture: boolean) { + const records = this.abortListeners.get(listener) + const record = records?.get(capture) + if (!records || !record) { + return + } + + record.cleanup() + records.delete(capture) + + if (records.size === 0) { + this.abortListeners.delete(listener) + } + } + + protected watchListenerRemovalSignal( + signal: AbortSignal | undefined, + listener: AbortListener, + capture: boolean + ): () => void { + if (!signal) { + return () => {} + } + + const onAbort = () => { + this.deleteAbortListenerRecord(listener, capture) + } + + addNativeEventListener(signal, 'abort', onAbort, { once: true }) + + return () => { + removeNativeEventListener(signal, 'abort', onAbort, { capture: false }) + } + } +} + +const nativeAddEventListener = EventTarget.prototype.addEventListener +const nativeRemoveEventListener = EventTarget.prototype.removeEventListener + +function addNativeEventListener( + target: EventTarget, + type: string, + listener: EventListenerOrEventListenerObject, + options?: boolean | AddEventListenerOptions +) { + nativeAddEventListener.call(target, type, listener, options) +} + +function removeNativeEventListener( + target: EventTarget, + type: string, + listener: EventListenerOrEventListenerObject, + options?: boolean | EventListenerOptions +) { + nativeRemoveEventListener.call(target, type, listener, options) +} + +function getCaptureOption(options?: boolean | EventListenerOptions): boolean { + if (typeof options === 'boolean') { + return options + } + + return options?.capture ?? false +} + +function getRegistrationSignal( + options?: boolean | AddEventListenerOptions +): AbortSignal | undefined { + if (typeof options === 'boolean') { + return undefined + } + + return options?.signal } diff --git a/src/test/async-abort-controller.test.ts b/src/test/async-abort-controller.test.ts index 335933f09..619747575 100644 --- a/src/test/async-abort-controller.test.ts +++ b/src/test/async-abort-controller.test.ts @@ -69,4 +69,112 @@ describe('AsyncAbortController', () => { expect(order).toEqual(['root:start', 'root:end', 'child', 'grandchild']) }) + + it('forwards the real abort event to function listeners with the signal as context', async () => { + const controller = new AsyncAbortController() + const seen: { + target: EventTarget | null + currentTarget: EventTarget | null + context: unknown + } = { + target: null, + currentTarget: null, + context: undefined, + } + + controller.signal.addEventListener('abort', function (event) { + seen.target = event.target + seen.currentTarget = event.currentTarget + seen.context = this + }) + + await controller.abortAsync() + + expect(seen.target).toBe(controller.signal) + expect(seen.currentTarget).toBe(controller.signal) + expect(seen.context).toBe(controller.signal) + }) + + it('waits for handleEvent listeners before aborting nested groups', async () => { + const controller = new AsyncAbortController() + const childGroup = controller.nextGroup + const order: string[] = [] + let releaseRootAbort!: () => void + const rootAbortDone = new Promise((resolve) => { + releaseRootAbort = resolve + }) + const listener = { + target: null as EventTarget | null, + async handleEvent(event: Event) { + this.target = event.target + order.push('root:start') + await rootAbortDone + order.push('root:end') + }, + } + + controller.signal.addEventListener('abort', listener) + childGroup.signal.addEventListener('abort', () => { + order.push('child') + }) + + const abortPromise = controller.abortAsync() + + await Promise.resolve() + expect(order).toEqual(['root:start']) + + releaseRootAbort() + await abortPromise + + expect(listener.target).toBe(controller.signal) + expect(order).toEqual(['root:start', 'root:end', 'child']) + }) + + it('ignores null abort listeners', async () => { + const controller = new AsyncAbortController() + const nullListener = null as unknown as EventListenerOrEventListenerObject + + expect(() => controller.signal.addEventListener('abort', nullListener)).not.toThrow() + await expect(controller.abortAsync()).resolves.toBeUndefined() + }) + + it('does not invoke or wait on explicitly removed abort listeners', async () => { + const controller = new AsyncAbortController() + const listener = jest.fn() + + controller.signal.addEventListener('abort', listener) + controller.signal.removeEventListener('abort', listener) + + await expect(controller.abortAsync()).resolves.toBeUndefined() + expect(listener).not.toHaveBeenCalled() + }) + + it('does not invoke or wait on abort listeners removed by a registration signal', async () => { + const controller = new AsyncAbortController() + const registration = new AbortController() + const listener = jest.fn() + + controller.signal.addEventListener('abort', listener, { + signal: registration.signal, + }) + + registration.abort() + + await expect(controller.abortAsync()).resolves.toBeUndefined() + expect(listener).not.toHaveBeenCalled() + }) + + it('ignores abort listeners registered with an already aborted signal', async () => { + const controller = new AsyncAbortController() + const registration = new AbortController() + const listener = jest.fn() + + registration.abort() + controller.signal.addEventListener('abort', listener, { + signal: registration.signal, + }) + + await expect(controller.abortAsync()).resolves.toBeUndefined() + expect(listener).not.toHaveBeenCalled() + }) })