Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lib/core/symbols.js
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ module.exports = {
kPendingIdx: Symbol('pending index'),
kError: Symbol('error'),
kClients: Symbol('clients'),
kHttp1OnlyClients: Symbol('http1-only clients'),
kGetDispatcherEntry: Symbol('get dispatcher entry'),
kSetDispatcherEntry: Symbol('set dispatcher entry'),
kDeleteDispatcherEntry: Symbol('delete dispatcher entry'),
kHasDispatcherForOrigin: Symbol('has dispatcher for origin'),
kForEachDispatcherEntry: Symbol('for each dispatcher entry'),
kClient: Symbol('client'),
kParser: Symbol('parser'),
kOnDestroyed: Symbol('destroy callbacks'),
Expand Down
105 changes: 77 additions & 28 deletions lib/dispatcher/agent.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
'use strict'

const { InvalidArgumentError, MaxOriginsReachedError } = require('../core/errors')
const { kClients, kRunning, kClose, kDestroy, kDispatch, kUrl } = require('../core/symbols')
const {
kClients,
kHttp1OnlyClients,
kRunning,
kClose,
kDestroy,
kDispatch,
kUrl,
kGetDispatcherEntry,
kSetDispatcherEntry,
kDeleteDispatcherEntry,
kHasDispatcherForOrigin,
kForEachDispatcherEntry
} = require('../core/symbols')
const DispatcherBase = require('./dispatcher-base')
const Pool = require('./pool')
const Client = require('./client')
Expand All @@ -21,6 +34,10 @@ function defaultFactory (origin, opts) {
: new Pool(origin, opts)
}

function shouldUseHttp1OnlyClients (allowH2) {
return allowH2 === false
}

class Agent extends DispatcherBase {
constructor ({ factory = defaultFactory, maxOrigins = Infinity, connect, ...options } = {}) {
if (typeof factory !== 'function') {
Expand All @@ -44,6 +61,7 @@ class Agent extends DispatcherBase {
this[kOptions] = { ...util.deepClone(options), maxOrigins, connect }
this[kFactory] = factory
this[kClients] = new Map()
this[kHttp1OnlyClients] = new Map()
this[kOrigins] = new Set()

this[kOnDrain] = (origin, targets) => {
Expand All @@ -65,12 +83,45 @@ class Agent extends DispatcherBase {

get [kRunning] () {
let ret = 0
for (const { dispatcher } of this[kClients].values()) {

this[kForEachDispatcherEntry](({ dispatcher }) => {
ret += dispatcher[kRunning]
}
})

return ret
}

[kGetDispatcherEntry] (origin, { allowH2 } = {}) {
return (shouldUseHttp1OnlyClients(allowH2) ? this[kHttp1OnlyClients] : this[kClients]).get(origin)
}

[kSetDispatcherEntry] (origin, { allowH2 } = {}, entry) {
;(shouldUseHttp1OnlyClients(allowH2) ? this[kHttp1OnlyClients] : this[kClients]).set(origin, entry)
this[kOrigins].add(origin)
}

[kDeleteDispatcherEntry] (origin, { allowH2 } = {}) {
;(shouldUseHttp1OnlyClients(allowH2) ? this[kHttp1OnlyClients] : this[kClients]).delete(origin)

if (!this[kHasDispatcherForOrigin](origin)) {
this[kOrigins].delete(origin)
}
}

[kHasDispatcherForOrigin] (origin) {
return this[kClients].has(origin) || this[kHttp1OnlyClients].has(origin)
}

[kForEachDispatcherEntry] (callback) {
for (const [origin, entry] of this[kClients]) {
callback(entry, { origin })
}

for (const [origin, entry] of this[kHttp1OnlyClients]) {
callback(entry, { origin, allowH2: false })
}
}

[kDispatch] (opts, handler) {
let origin
if (opts.origin && (typeof opts.origin === 'string' || opts.origin instanceof URL)) {
Expand All @@ -80,45 +131,34 @@ class Agent extends DispatcherBase {
}

const allowH2 = opts.allowH2 ?? this[kOptions].allowH2
const key = allowH2 === false ? `${origin}#http1-only` : origin
const registry = { allowH2 }

if (this[kOrigins].size >= this[kOptions].maxOrigins && !this[kOrigins].has(origin)) {
throw new MaxOriginsReachedError()
}

const result = this[kClients].get(key)
const result = this[kGetDispatcherEntry](origin, registry)
let dispatcher = result && result.dispatcher
if (!dispatcher) {
const closeClientIfUnused = (connected) => {
const result = this[kClients].get(key)
const result = this[kGetDispatcherEntry](origin, registry)
if (result) {
if (connected) result.count -= 1
if (result.count <= 0) {
this[kClients].delete(key)
this[kDeleteDispatcherEntry](origin, registry)
if (!result.dispatcher.destroyed) {
result.dispatcher.close()
}
}

let hasOrigin = false
for (const entry of this[kClients].values()) {
if (entry.origin === origin) {
hasOrigin = true
break
}
}

if (!hasOrigin) {
this[kOrigins].delete(origin)
}
}
}

dispatcher = this[kFactory](opts.origin, allowH2 === false
? { ...this[kOptions], allowH2: false }
: this[kOptions])
.on('drain', this[kOnDrain])
.on('connect', (origin, targets) => {
const result = this[kClients].get(key)
const result = this[kGetDispatcherEntry](origin, registry)
if (result) {
result.count += 1
}
Expand All @@ -133,40 +173,49 @@ class Agent extends DispatcherBase {
this[kOnConnectionError](origin, targets, err)
})

this[kClients].set(key, { count: 0, dispatcher, origin })
this[kOrigins].add(origin)
this[kSetDispatcherEntry](origin, registry, { count: 0, dispatcher, origin })
}

return dispatcher.dispatch(opts, handler)
}

[kClose] () {
const closePromises = []
for (const { dispatcher } of this[kClients].values()) {

this[kForEachDispatcherEntry](({ dispatcher }) => {
closePromises.push(dispatcher.close())
}
})

this[kClients].clear()
this[kHttp1OnlyClients].clear()
this[kOrigins].clear()

return Promise.all(closePromises)
}

[kDestroy] (err) {
const destroyPromises = []
for (const { dispatcher } of this[kClients].values()) {

this[kForEachDispatcherEntry](({ dispatcher }) => {
destroyPromises.push(dispatcher.destroy(err))
}
})

this[kClients].clear()
this[kHttp1OnlyClients].clear()
this[kOrigins].clear()

return Promise.all(destroyPromises)
}

get stats () {
const allClientStats = {}
for (const { dispatcher } of this[kClients].values()) {

this[kForEachDispatcherEntry](({ dispatcher }) => {
if (dispatcher.stats) {
allClientStats[dispatcher[kUrl].origin] = dispatcher.stats
}
}
})

return allClientStats
}
}
Expand Down
106 changes: 89 additions & 17 deletions lib/mock/mock-agent.js
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
'use strict'

const { kClients } = require('../core/symbols')
const {
kClients,
kGetDispatcherEntry,
kSetDispatcherEntry,
kDeleteDispatcherEntry,
kForEachDispatcherEntry
} = require('../core/symbols')
const Agent = require('../dispatcher/agent')
const {
kAgent,
kMockAgentSet,
kMockAgentGet,
kMockAgentUnregisterDispatches,
kDispatches,
kIsMockActive,
kNetConnect,
Expand All @@ -23,11 +30,13 @@ const {
const MockClient = require('./mock-client')
const MockPool = require('./mock-pool')
const { matchValue, normalizeSearchParams, buildAndValidateMockOptions, normalizeOrigin } = require('./mock-utils')
const { InvalidArgumentError, UndiciError } = require('../core/errors')
const { ClientDestroyedError, InvalidArgumentError, UndiciError } = require('../core/errors')
const Dispatcher = require('../dispatcher/dispatcher')
const PendingInterceptorsFormatter = require('./pending-interceptors-formatter')
const { MockCallHistory } = require('./mock-call-history')

const kClosed = Symbol('closed')

class MockAgent extends Dispatcher {
constructor (opts = {}) {
super(opts)
Expand All @@ -36,6 +45,7 @@ class MockAgent extends Dispatcher {

this[kNetConnect] = true
this[kIsMockActive] = true
this[kClosed] = false
this[kMockAgentIsCallHistoryEnabled] = mockOptions.enableCallHistory ?? false
this[kMockAgentAcceptsNonStandardSearchParameters] = mockOptions.acceptNonStandardSearchParameters ?? false
this[kIgnoreTrailingSlash] = mockOptions.ignoreTrailingSlash ?? false
Expand All @@ -47,7 +57,7 @@ class MockAgent extends Dispatcher {
const agent = opts?.agent ? opts.agent : new Agent(opts)
this[kAgent] = agent

this[kClients] = agent[kClients]
this[kClients] = new Map()
this[kOptions] = mockOptions

if (this[kMockAgentIsCallHistoryEnabled]) {
Expand All @@ -62,14 +72,48 @@ class MockAgent extends Dispatcher {

let dispatcher = this[kMockAgentGet](originKey)

if (!dispatcher && typeof originKey === 'string') {
dispatcher = this[kAgent][kGetDispatcherEntry](originKey)?.dispatcher
}

if (!dispatcher && typeof originKey === 'string') {
for (const [keyMatcher, result] of Array.from(this[kClients])) {
if (result && typeof keyMatcher !== 'string' && matchValue(keyMatcher, originKey)) {
dispatcher = this[kFactory](originKey)
dispatcher[kDispatches] = result.dispatcher[kDispatches]
break
}
}
}

if (!dispatcher) {
dispatcher = this[kFactory](originKey)
this[kMockAgentSet](originKey, dispatcher)
}

if (typeof originKey === 'string') {
this[kAgent][kSetDispatcherEntry](originKey, {}, { count: 0, dispatcher, origin: originKey })

if (!this[kAgent][kGetDispatcherEntry](originKey, { allowH2: false })) {
const http1Dispatcher = this[kFactory](originKey, { ...this[kOptions], allowH2: false })
http1Dispatcher[kDispatches] = dispatcher[kDispatches]
this[kAgent][kSetDispatcherEntry](originKey, { allowH2: false }, { count: 0, dispatcher: http1Dispatcher, origin: originKey })
}
}

return dispatcher
}

dispatch (opts, handler) {
if (this[kClosed]) {
const err = new ClientDestroyedError()
if (typeof handler?.onResponseError === 'function') {
handler.onResponseError(null, err)
return false
}
throw err
}

opts.origin = normalizeOrigin(opts.origin)

// Call MockAgent.get to perform additional setup before dispatching as normal
Expand All @@ -91,7 +135,17 @@ class MockAgent extends Dispatcher {
}

async close () {
this[kClosed] = true
this.clearCallHistory()

const closePromises = []
for (const [origin, result] of this[kClients]) {
if (typeof origin !== 'string') {
closePromises.push(result.dispatcher.close())
}
}

await Promise.all(closePromises)
await this[kAgent].close()
this[kClients].clear()
}
Expand Down Expand Up @@ -167,12 +221,12 @@ class MockAgent extends Dispatcher {
}

[kMockAgentSet] (origin, dispatcher) {
this[kClients].set(origin, { count: 0, dispatcher })
this[kClients].set(origin, { dispatcher })
}

[kFactory] (origin) {
const mockOptions = Object.assign({ agent: this }, this[kOptions])
return this[kOptions] && this[kOptions].connections === 1
[kFactory] (origin, options = this[kOptions]) {
const mockOptions = Object.assign({ agent: this }, options)
return options && options.connections === 1
? new MockClient(origin, mockOptions)
: new MockPool(origin, mockOptions)
}
Expand All @@ -190,22 +244,40 @@ class MockAgent extends Dispatcher {
this[kMockAgentSet](origin, dispatcher)
return dispatcher
}

// If we match, create a pool and assign the same dispatches
for (const [keyMatcher, result] of Array.from(this[kClients])) {
if (result && typeof keyMatcher !== 'string' && matchValue(keyMatcher, origin)) {
const dispatcher = this[kFactory](origin)
this[kMockAgentSet](origin, dispatcher)
dispatcher[kDispatches] = result.dispatcher[kDispatches]
return dispatcher
}
}
}

[kGetNetConnect] () {
return this[kNetConnect]
}

async [kMockAgentUnregisterDispatches] (dispatches, closingDispatcher) {
for (const [origin, result] of Array.from(this[kClients])) {
if (result?.dispatcher?.[kDispatches] === dispatches) {
this[kClients].delete(origin)
}
}

const dispatchersToClose = new Set()
const entriesToDelete = []

this[kAgent][kForEachDispatcherEntry]((entry, meta) => {
if (entry.dispatcher[kDispatches] === dispatches) {
entriesToDelete.push(meta)
dispatchersToClose.add(entry.dispatcher)
}
})

for (const { origin, allowH2 } of entriesToDelete) {
this[kAgent][kDeleteDispatcherEntry](origin, { allowH2 })
}

for (const dispatcher of dispatchersToClose) {
if (dispatcher !== closingDispatcher && !dispatcher.closed && !dispatcher.destroyed) {
await dispatcher.close()
}
}
}

pendingInterceptors () {
const mockAgentClients = this[kClients]

Expand Down
Loading
Loading