diff --git a/lib/core/symbols.js b/lib/core/symbols.js index ff37adc0448..ea34f0db5ae 100644 --- a/lib/core/symbols.js +++ b/lib/core/symbols.js @@ -42,6 +42,12 @@ module.exports = { kPendingIdx: Symbol('pending index'), kError: Symbol('error'), kClients: Symbol('clients'), + kHttp1OnlyClients: Symbol('http1-only clients'), + kGetDispatcherEntry: Symbol('get dispatcher entry'), + kSetDispatcherEntry: Symbol('set dispatcher entry'), + kDeleteDispatcherEntry: Symbol('delete dispatcher entry'), + kHasDispatcherForOrigin: Symbol('has dispatcher for origin'), + kForEachDispatcherEntry: Symbol('for each dispatcher entry'), kClient: Symbol('client'), kParser: Symbol('parser'), kOnDestroyed: Symbol('destroy callbacks'), diff --git a/lib/dispatcher/agent.js b/lib/dispatcher/agent.js index a1cc7fd6817..576da75388b 100644 --- a/lib/dispatcher/agent.js +++ b/lib/dispatcher/agent.js @@ -1,7 +1,20 @@ 'use strict' const { InvalidArgumentError, MaxOriginsReachedError } = require('../core/errors') -const { kClients, kRunning, kClose, kDestroy, kDispatch, kUrl } = require('../core/symbols') +const { + kClients, + kHttp1OnlyClients, + kRunning, + kClose, + kDestroy, + kDispatch, + kUrl, + kGetDispatcherEntry, + kSetDispatcherEntry, + kDeleteDispatcherEntry, + kHasDispatcherForOrigin, + kForEachDispatcherEntry +} = require('../core/symbols') const DispatcherBase = require('./dispatcher-base') const Pool = require('./pool') const Client = require('./client') @@ -21,6 +34,10 @@ function defaultFactory (origin, opts) { : new Pool(origin, opts) } +function shouldUseHttp1OnlyClients (allowH2) { + return allowH2 === false +} + class Agent extends DispatcherBase { constructor ({ factory = defaultFactory, maxOrigins = Infinity, connect, ...options } = {}) { if (typeof factory !== 'function') { @@ -44,6 +61,7 @@ class Agent extends DispatcherBase { this[kOptions] = { ...util.deepClone(options), maxOrigins, connect } this[kFactory] = factory this[kClients] = new Map() + this[kHttp1OnlyClients] = new Map() this[kOrigins] = new Set() this[kOnDrain] = (origin, targets) => { @@ -65,12 +83,45 @@ class Agent extends DispatcherBase { get [kRunning] () { let ret = 0 - for (const { dispatcher } of this[kClients].values()) { + + this[kForEachDispatcherEntry](({ dispatcher }) => { ret += dispatcher[kRunning] - } + }) + return ret } + [kGetDispatcherEntry] (origin, { allowH2 } = {}) { + return (shouldUseHttp1OnlyClients(allowH2) ? this[kHttp1OnlyClients] : this[kClients]).get(origin) + } + + [kSetDispatcherEntry] (origin, { allowH2 } = {}, entry) { + ;(shouldUseHttp1OnlyClients(allowH2) ? this[kHttp1OnlyClients] : this[kClients]).set(origin, entry) + this[kOrigins].add(origin) + } + + [kDeleteDispatcherEntry] (origin, { allowH2 } = {}) { + ;(shouldUseHttp1OnlyClients(allowH2) ? this[kHttp1OnlyClients] : this[kClients]).delete(origin) + + if (!this[kHasDispatcherForOrigin](origin)) { + this[kOrigins].delete(origin) + } + } + + [kHasDispatcherForOrigin] (origin) { + return this[kClients].has(origin) || this[kHttp1OnlyClients].has(origin) + } + + [kForEachDispatcherEntry] (callback) { + for (const [origin, entry] of this[kClients]) { + callback(entry, { origin }) + } + + for (const [origin, entry] of this[kHttp1OnlyClients]) { + callback(entry, { origin, allowH2: false }) + } + } + [kDispatch] (opts, handler) { let origin if (opts.origin && (typeof opts.origin === 'string' || opts.origin instanceof URL)) { @@ -80,45 +131,34 @@ class Agent extends DispatcherBase { } const allowH2 = opts.allowH2 ?? this[kOptions].allowH2 - const key = allowH2 === false ? `${origin}#http1-only` : origin + const registry = { allowH2 } if (this[kOrigins].size >= this[kOptions].maxOrigins && !this[kOrigins].has(origin)) { throw new MaxOriginsReachedError() } - const result = this[kClients].get(key) + const result = this[kGetDispatcherEntry](origin, registry) let dispatcher = result && result.dispatcher if (!dispatcher) { const closeClientIfUnused = (connected) => { - const result = this[kClients].get(key) + const result = this[kGetDispatcherEntry](origin, registry) if (result) { if (connected) result.count -= 1 if (result.count <= 0) { - this[kClients].delete(key) + this[kDeleteDispatcherEntry](origin, registry) if (!result.dispatcher.destroyed) { result.dispatcher.close() } } - - let hasOrigin = false - for (const entry of this[kClients].values()) { - if (entry.origin === origin) { - hasOrigin = true - break - } - } - - if (!hasOrigin) { - this[kOrigins].delete(origin) - } } } + dispatcher = this[kFactory](opts.origin, allowH2 === false ? { ...this[kOptions], allowH2: false } : this[kOptions]) .on('drain', this[kOnDrain]) .on('connect', (origin, targets) => { - const result = this[kClients].get(key) + const result = this[kGetDispatcherEntry](origin, registry) if (result) { result.count += 1 } @@ -133,8 +173,7 @@ class Agent extends DispatcherBase { this[kOnConnectionError](origin, targets, err) }) - this[kClients].set(key, { count: 0, dispatcher, origin }) - this[kOrigins].add(origin) + this[kSetDispatcherEntry](origin, registry, { count: 0, dispatcher, origin }) } return dispatcher.dispatch(opts, handler) @@ -142,31 +181,41 @@ class Agent extends DispatcherBase { [kClose] () { const closePromises = [] - for (const { dispatcher } of this[kClients].values()) { + + this[kForEachDispatcherEntry](({ dispatcher }) => { closePromises.push(dispatcher.close()) - } + }) + this[kClients].clear() + this[kHttp1OnlyClients].clear() + this[kOrigins].clear() return Promise.all(closePromises) } [kDestroy] (err) { const destroyPromises = [] - for (const { dispatcher } of this[kClients].values()) { + + this[kForEachDispatcherEntry](({ dispatcher }) => { destroyPromises.push(dispatcher.destroy(err)) - } + }) + this[kClients].clear() + this[kHttp1OnlyClients].clear() + this[kOrigins].clear() return Promise.all(destroyPromises) } get stats () { const allClientStats = {} - for (const { dispatcher } of this[kClients].values()) { + + this[kForEachDispatcherEntry](({ dispatcher }) => { if (dispatcher.stats) { allClientStats[dispatcher[kUrl].origin] = dispatcher.stats } - } + }) + return allClientStats } } diff --git a/lib/mock/mock-agent.js b/lib/mock/mock-agent.js index 61449e077ea..3561d2597b8 100644 --- a/lib/mock/mock-agent.js +++ b/lib/mock/mock-agent.js @@ -1,11 +1,18 @@ 'use strict' -const { kClients } = require('../core/symbols') +const { + kClients, + kGetDispatcherEntry, + kSetDispatcherEntry, + kDeleteDispatcherEntry, + kForEachDispatcherEntry +} = require('../core/symbols') const Agent = require('../dispatcher/agent') const { kAgent, kMockAgentSet, kMockAgentGet, + kMockAgentUnregisterDispatches, kDispatches, kIsMockActive, kNetConnect, @@ -23,11 +30,13 @@ const { const MockClient = require('./mock-client') const MockPool = require('./mock-pool') const { matchValue, normalizeSearchParams, buildAndValidateMockOptions, normalizeOrigin } = require('./mock-utils') -const { InvalidArgumentError, UndiciError } = require('../core/errors') +const { ClientDestroyedError, InvalidArgumentError, UndiciError } = require('../core/errors') const Dispatcher = require('../dispatcher/dispatcher') const PendingInterceptorsFormatter = require('./pending-interceptors-formatter') const { MockCallHistory } = require('./mock-call-history') +const kClosed = Symbol('closed') + class MockAgent extends Dispatcher { constructor (opts = {}) { super(opts) @@ -36,6 +45,7 @@ class MockAgent extends Dispatcher { this[kNetConnect] = true this[kIsMockActive] = true + this[kClosed] = false this[kMockAgentIsCallHistoryEnabled] = mockOptions.enableCallHistory ?? false this[kMockAgentAcceptsNonStandardSearchParameters] = mockOptions.acceptNonStandardSearchParameters ?? false this[kIgnoreTrailingSlash] = mockOptions.ignoreTrailingSlash ?? false @@ -47,7 +57,7 @@ class MockAgent extends Dispatcher { const agent = opts?.agent ? opts.agent : new Agent(opts) this[kAgent] = agent - this[kClients] = agent[kClients] + this[kClients] = new Map() this[kOptions] = mockOptions if (this[kMockAgentIsCallHistoryEnabled]) { @@ -62,14 +72,48 @@ class MockAgent extends Dispatcher { let dispatcher = this[kMockAgentGet](originKey) + if (!dispatcher && typeof originKey === 'string') { + dispatcher = this[kAgent][kGetDispatcherEntry](originKey)?.dispatcher + } + + if (!dispatcher && typeof originKey === 'string') { + for (const [keyMatcher, result] of Array.from(this[kClients])) { + if (result && typeof keyMatcher !== 'string' && matchValue(keyMatcher, originKey)) { + dispatcher = this[kFactory](originKey) + dispatcher[kDispatches] = result.dispatcher[kDispatches] + break + } + } + } + if (!dispatcher) { dispatcher = this[kFactory](originKey) this[kMockAgentSet](originKey, dispatcher) } + + if (typeof originKey === 'string') { + this[kAgent][kSetDispatcherEntry](originKey, {}, { count: 0, dispatcher, origin: originKey }) + + if (!this[kAgent][kGetDispatcherEntry](originKey, { allowH2: false })) { + const http1Dispatcher = this[kFactory](originKey, { ...this[kOptions], allowH2: false }) + http1Dispatcher[kDispatches] = dispatcher[kDispatches] + this[kAgent][kSetDispatcherEntry](originKey, { allowH2: false }, { count: 0, dispatcher: http1Dispatcher, origin: originKey }) + } + } + return dispatcher } dispatch (opts, handler) { + if (this[kClosed]) { + const err = new ClientDestroyedError() + if (typeof handler?.onResponseError === 'function') { + handler.onResponseError(null, err) + return false + } + throw err + } + opts.origin = normalizeOrigin(opts.origin) // Call MockAgent.get to perform additional setup before dispatching as normal @@ -91,7 +135,17 @@ class MockAgent extends Dispatcher { } async close () { + this[kClosed] = true this.clearCallHistory() + + const closePromises = [] + for (const [origin, result] of this[kClients]) { + if (typeof origin !== 'string') { + closePromises.push(result.dispatcher.close()) + } + } + + await Promise.all(closePromises) await this[kAgent].close() this[kClients].clear() } @@ -167,12 +221,12 @@ class MockAgent extends Dispatcher { } [kMockAgentSet] (origin, dispatcher) { - this[kClients].set(origin, { count: 0, dispatcher }) + this[kClients].set(origin, { dispatcher }) } - [kFactory] (origin) { - const mockOptions = Object.assign({ agent: this }, this[kOptions]) - return this[kOptions] && this[kOptions].connections === 1 + [kFactory] (origin, options = this[kOptions]) { + const mockOptions = Object.assign({ agent: this }, options) + return options && options.connections === 1 ? new MockClient(origin, mockOptions) : new MockPool(origin, mockOptions) } @@ -190,22 +244,40 @@ class MockAgent extends Dispatcher { this[kMockAgentSet](origin, dispatcher) return dispatcher } - - // If we match, create a pool and assign the same dispatches - for (const [keyMatcher, result] of Array.from(this[kClients])) { - if (result && typeof keyMatcher !== 'string' && matchValue(keyMatcher, origin)) { - const dispatcher = this[kFactory](origin) - this[kMockAgentSet](origin, dispatcher) - dispatcher[kDispatches] = result.dispatcher[kDispatches] - return dispatcher - } - } } [kGetNetConnect] () { return this[kNetConnect] } + async [kMockAgentUnregisterDispatches] (dispatches, closingDispatcher) { + for (const [origin, result] of Array.from(this[kClients])) { + if (result?.dispatcher?.[kDispatches] === dispatches) { + this[kClients].delete(origin) + } + } + + const dispatchersToClose = new Set() + const entriesToDelete = [] + + this[kAgent][kForEachDispatcherEntry]((entry, meta) => { + if (entry.dispatcher[kDispatches] === dispatches) { + entriesToDelete.push(meta) + dispatchersToClose.add(entry.dispatcher) + } + }) + + for (const { origin, allowH2 } of entriesToDelete) { + this[kAgent][kDeleteDispatcherEntry](origin, { allowH2 }) + } + + for (const dispatcher of dispatchersToClose) { + if (dispatcher !== closingDispatcher && !dispatcher.closed && !dispatcher.destroyed) { + await dispatcher.close() + } + } + } + pendingInterceptors () { const mockAgentClients = this[kClients] diff --git a/lib/mock/mock-client.js b/lib/mock/mock-client.js index b3be7ab3b91..57364ecfd79 100644 --- a/lib/mock/mock-client.js +++ b/lib/mock/mock-client.js @@ -6,6 +6,7 @@ const { buildMockDispatch } = require('./mock-utils') const { kDispatches, kMockAgent, + kMockAgentUnregisterDispatches, kClose, kOriginalClose, kOrigin, @@ -61,7 +62,7 @@ class MockClient extends Client { async [kClose] () { await promisify(this[kOriginalClose])() this[kConnected] = 0 - this[kMockAgent][Symbols.kClients].delete(this[kOrigin]) + await this[kMockAgent][kMockAgentUnregisterDispatches](this[kDispatches], this) } } diff --git a/lib/mock/mock-pool.js b/lib/mock/mock-pool.js index 2121e3c99a3..df0785f30c3 100644 --- a/lib/mock/mock-pool.js +++ b/lib/mock/mock-pool.js @@ -6,6 +6,7 @@ const { buildMockDispatch } = require('./mock-utils') const { kDispatches, kMockAgent, + kMockAgentUnregisterDispatches, kClose, kOriginalClose, kOrigin, @@ -61,7 +62,7 @@ class MockPool extends Pool { async [kClose] () { await promisify(this[kOriginalClose])() this[kConnected] = 0 - this[kMockAgent][Symbols.kClients].delete(this[kOrigin]) + await this[kMockAgent][kMockAgentUnregisterDispatches](this[kDispatches], this) } } diff --git a/lib/mock/mock-symbols.js b/lib/mock/mock-symbols.js index 9b23e8e3cf0..4410adbce9b 100644 --- a/lib/mock/mock-symbols.js +++ b/lib/mock/mock-symbols.js @@ -12,6 +12,7 @@ module.exports = { kMockAgent: Symbol('mock agent'), kMockAgentSet: Symbol('mock agent set'), kMockAgentGet: Symbol('mock agent get'), + kMockAgentUnregisterDispatches: Symbol('mock agent unregister dispatches'), kMockDispatch: Symbol('mock dispatch'), kClose: Symbol('close'), kOriginalClose: Symbol('original agent close'), diff --git a/test/mock-agent.js b/test/mock-agent.js index 500a8833116..3d4dfc95736 100644 --- a/test/mock-agent.js +++ b/test/mock-agent.js @@ -5,11 +5,11 @@ const { createServer } = require('node:http') const { once } = require('node:events') const { request, setGlobalDispatcher, MockAgent, Agent } = require('..') const { getResponse } = require('../lib/mock/mock-utils') -const { kClients, kConnected } = require('../lib/core/symbols') +const { kClients, kConnected, kGetDispatcherEntry, kHttp1OnlyClients } = require('../lib/core/symbols') const { InvalidArgumentError, ClientDestroyedError } = require('../lib/core/errors') const MockClient = require('../lib/mock/mock-client') const MockPool = require('../lib/mock/mock-pool') -const { kAgent, kMockAgentIsCallHistoryEnabled } = require('../lib/mock/mock-symbols') +const { kAgent, kDispatches, kMockAgentIsCallHistoryEnabled } = require('../lib/mock/mock-symbols') const Dispatcher = require('../lib/dispatcher/dispatcher') const { MockNotMatchedError } = require('../lib/mock/mock-errors') const { fetch } = require('..') @@ -171,6 +171,26 @@ describe('MockAgent - get', () => { const mockPool2 = mockAgent.get(baseUrl) t.assert.strictEqual(mockPool1, mockPool2) }) + + test('should register protocol-specific dispatchers without exposing agent pools', (t) => { + t.plan(6) + + const baseUrl = 'http://localhost:9999' + const agent = new Agent() + const mockAgent = new MockAgent({ agent }) + after(() => mockAgent.close()) + + const mockPool = mockAgent.get(baseUrl) + const defaultPool = agent[kGetDispatcherEntry](baseUrl).dispatcher + const http1OnlyPool = agent[kGetDispatcherEntry](baseUrl, { allowH2: false }).dispatcher + + t.assert.strictEqual(mockAgent[kClients].size, 1) + t.assert.ok(http1OnlyPool instanceof MockPool) + t.assert.notStrictEqual(http1OnlyPool, mockPool) + t.assert.strictEqual(defaultPool, mockPool) + t.assert.strictEqual(http1OnlyPool[kDispatches], mockPool[kDispatches]) + t.assert.strictEqual(agent[kHttp1OnlyClients].size, 1) + }) }) describe('MockAgent - dispatch', () => { @@ -267,8 +287,8 @@ test('MockAgent - .close should clean up registered clients', async (t) => { t.assert.strictEqual(mockAgent[kClients].size, 0) }) -test('MockAgent - [kClients] should match encapsulated agent', async (t) => { - t.plan(1) +test('MockAgent - [kClients] should not expose encapsulated agent pools', async (t) => { + t.plan(3) const server = createServer({ joinDuplicateHeaders: true }, (req, res) => { res.setHeader('content-type', 'text/plain') @@ -295,8 +315,10 @@ test('MockAgent - [kClients] should match encapsulated agent', async (t) => { method: 'GET' }).reply(200, 'hello') - // The MockAgent should encapsulate the input agent clients - t.assert.strictEqual(mockAgent[kClients].size, agent[kClients].size) + // The MockAgent should not expose the encapsulated agent pool registry + t.assert.notStrictEqual(mockAgent[kClients], agent[kClients]) + t.assert.strictEqual(mockAgent[kClients].size, 1) + t.assert.strictEqual(agent[kClients].size, 1) }) test('MockAgent - basic intercept with MockAgent.request', async (t) => { diff --git a/test/node-test/global-dispatcher-version.js b/test/node-test/global-dispatcher-version.js index e20a181a4dd..dd6db9e772b 100644 --- a/test/node-test/global-dispatcher-version.js +++ b/test/node-test/global-dispatcher-version.js @@ -97,6 +97,42 @@ test('setGlobalDispatcher mirrors a v1-compatible dispatcher that Node.js global assert.strictEqual(payload.mirroredV2, true) }) +test('setGlobalDispatcher mirrors a MockAgent that Node.js global fetch uses', () => { + const script = ` + const { MockAgent, setGlobalDispatcher } = require('./index.js') + const http = require('node:http') + const { once } = require('node:events') + + ;(async () => { + const server = http.createServer((req, res) => res.end('real')) + server.listen(0) + await once(server, 'listening') + + const origin = 'http://127.0.0.1:' + server.address().port + const agent = new MockAgent() + agent.disableNetConnect() + agent.get(origin).intercept({ + path: '/v1/test', + method: 'GET' + }).reply(200, 'mock') + + setGlobalDispatcher(agent) + + const res = await fetch(origin + '/v1/test') + process.stdout.write(await res.text()) + + server.close() + })().catch((err) => { + console.error(err?.cause?.stack || err?.stack || err) + process.exit(1) + }) + ` + + const result = runNode(script) + assert.strictEqual(result.status, 0, result.stderr) + assert.strictEqual(result.stdout, 'mock') +}) + test('Dispatcher1Wrapper bridges legacy handlers to a new Agent', () => { const script = ` const { Agent, Dispatcher1Wrapper } = require('./index.js')