diff --git a/.changeset/tall-peaches-tease.md b/.changeset/tall-peaches-tease.md new file mode 100644 index 000000000000..62ab40a40c14 --- /dev/null +++ b/.changeset/tall-peaches-tease.md @@ -0,0 +1,10 @@ +--- +'@ai-sdk/google-vertex': patch +--- + +fix(provider/google-vertex): avoid recreating Node GoogleAuth clients for repeated requests + +Create Google auth token generators per provider instance instead of using a +module-level shared `GoogleAuth` cache. This avoids unnecessary `GoogleAuth` +recreation when `googleAuthOptions` are omitted or when multiple provider +instances use equivalent auth settings. diff --git a/packages/google-vertex/src/anthropic/google-vertex-anthropic-provider-node.test.ts b/packages/google-vertex/src/anthropic/google-vertex-anthropic-provider-node.test.ts index f6ae7358bee9..4df2e8c06ef6 100644 --- a/packages/google-vertex/src/anthropic/google-vertex-anthropic-provider-node.test.ts +++ b/packages/google-vertex/src/anthropic/google-vertex-anthropic-provider-node.test.ts @@ -1,12 +1,14 @@ import { resolve } from '@ai-sdk/provider-utils'; import { createVertexAnthropic as createVertexAnthropicOriginal } from './google-vertex-anthropic-provider'; import { createVertexAnthropic as createVertexAnthropicNode } from './google-vertex-anthropic-provider-node'; -import { generateAuthToken } from '../google-vertex-auth-google-auth-library'; +import { createAuthTokenGenerator } from '../google-vertex-auth-google-auth-library'; import { describe, beforeEach, expect, it, vi } from 'vitest'; // Mock the imported modules vi.mock('../google-vertex-auth-google-auth-library', () => ({ - generateAuthToken: vi.fn().mockResolvedValue('mock-auth-token'), + createAuthTokenGenerator: vi.fn(() => + vi.fn().mockResolvedValue('mock-auth-token'), + ), })); vi.mock('./google-vertex-anthropic-provider', () => ({ @@ -51,7 +53,7 @@ describe('google-vertex-anthropic-provider-node', () => { }); }); - it('passes googleAuthOptions to generateAuthToken', async () => { + it('passes googleAuthOptions to createAuthTokenGenerator', async () => { createVertexAnthropicNode({ googleAuthOptions: { scopes: ['https://www.googleapis.com/auth/cloud-platform'], @@ -63,11 +65,25 @@ describe('google-vertex-anthropic-provider-node', () => { const passedOptions = vi.mocked(createVertexAnthropicOriginal).mock .calls[0][0]; - await resolve(passedOptions?.headers); // call the headers function + await resolve(passedOptions?.headers); - expect(generateAuthToken).toHaveBeenCalledWith({ + expect(createAuthTokenGenerator).toHaveBeenCalledWith({ scopes: ['https://www.googleapis.com/auth/cloud-platform'], keyFile: 'path/to/key.json', }); }); + + it('creates the auth token generator once per provider instance', async () => { + createVertexAnthropicNode({ project: 'test-project' }); + + expect(createAuthTokenGenerator).toHaveBeenCalledTimes(1); + + const passedOptions = vi.mocked(createVertexAnthropicOriginal).mock + .calls[0][0]; + + await resolve(passedOptions?.headers); + await resolve(passedOptions?.headers); + + expect(createAuthTokenGenerator).toHaveBeenCalledTimes(1); + }); }); diff --git a/packages/google-vertex/src/anthropic/google-vertex-anthropic-provider-node.ts b/packages/google-vertex/src/anthropic/google-vertex-anthropic-provider-node.ts index b3d946b3e21e..757fbea9ee93 100644 --- a/packages/google-vertex/src/anthropic/google-vertex-anthropic-provider-node.ts +++ b/packages/google-vertex/src/anthropic/google-vertex-anthropic-provider-node.ts @@ -1,6 +1,6 @@ import { resolve } from '@ai-sdk/provider-utils'; import { GoogleAuthOptions } from 'google-auth-library'; -import { generateAuthToken } from '../google-vertex-auth-google-auth-library'; +import { createAuthTokenGenerator } from '../google-vertex-auth-google-auth-library'; import { createVertexAnthropic as createVertexAnthropicOriginal, GoogleVertexAnthropicProvider, @@ -22,12 +22,12 @@ export interface GoogleVertexAnthropicProviderSettings extends GoogleVertexAnthr export function createVertexAnthropic( options: GoogleVertexAnthropicProviderSettings = {}, ): GoogleVertexAnthropicProvider { + const generateAuthToken = createAuthTokenGenerator(options.googleAuthOptions); + return createVertexAnthropicOriginal({ ...options, headers: async () => ({ - Authorization: `Bearer ${await generateAuthToken( - options.googleAuthOptions, - )}`, + Authorization: `Bearer ${await generateAuthToken()}`, ...(await resolve(options.headers)), }), }); diff --git a/packages/google-vertex/src/google-vertex-auth-google-auth-library.test.ts b/packages/google-vertex/src/google-vertex-auth-google-auth-library.test.ts index 844618d48126..71d32d3e99af 100644 --- a/packages/google-vertex/src/google-vertex-auth-google-auth-library.test.ts +++ b/packages/google-vertex/src/google-vertex-auth-google-auth-library.test.ts @@ -1,58 +1,80 @@ import { describe, it, expect, vi, beforeEach } from 'vitest'; -import { - generateAuthToken, - _resetAuthInstance, -} from './google-vertex-auth-google-auth-library'; +import { createAuthTokenGenerator } from './google-vertex-auth-google-auth-library'; import { GoogleAuth } from 'google-auth-library'; +const getAccessToken = vi.fn().mockResolvedValue({ token: 'mocked-token' }); +const getClient = vi.fn().mockResolvedValue({ getAccessToken }); + vi.mock('google-auth-library', () => { return { GoogleAuth: vi.fn(function () { return { - getClient: vi.fn().mockResolvedValue({ - getAccessToken: vi.fn().mockResolvedValue({ token: 'mocked-token' }), - }), + getClient, }; }), }; }); -describe('generateAuthToken', () => { +describe('createAuthTokenGenerator', () => { beforeEach(() => { vi.clearAllMocks(); - _resetAuthInstance(); + getAccessToken.mockResolvedValue({ token: 'mocked-token' }); }); it('should generate a valid auth token', async () => { + const generateAuthToken = createAuthTokenGenerator(); + const token = await generateAuthToken(); + expect(token).toBe('mocked-token'); }); it('should return null if no token is received', async () => { - // Reset the mock completely - vi.mocked(GoogleAuth).mockReset(); - - // Create a new mock implementation - vi.mocked(GoogleAuth).mockImplementation(function () { - return { - getClient: vi.fn().mockResolvedValue({ - getAccessToken: vi.fn().mockResolvedValue({ token: null }), - }), - isGCE: vi.fn(), - } as unknown as GoogleAuth; - }); + getAccessToken.mockResolvedValueOnce({ token: null }); + const generateAuthToken = createAuthTokenGenerator(); const token = await generateAuthToken(); + expect(token).toBeNull(); }); - it('should create new auth instance with provided options', async () => { - const options = { keyFile: 'test-key.json' }; - await generateAuthToken(options); + it('should create a new auth instance with provided options', () => { + createAuthTokenGenerator({ keyFile: 'test-key.json' }); expect(GoogleAuth).toHaveBeenCalledWith({ scopes: ['https://www.googleapis.com/auth/cloud-platform'], keyFile: 'test-key.json', }); }); + + it('should create only one GoogleAuth instance for repeated calls', async () => { + const generateAuthToken = createAuthTokenGenerator(); + + await generateAuthToken(); + await generateAuthToken(); + + expect(GoogleAuth).toHaveBeenCalledTimes(1); + }); + + it('should create independent generators for separate option sets', async () => { + const generateFirstAuthToken = createAuthTokenGenerator({ + keyFile: 'first-key.json', + }); + const generateSecondAuthToken = createAuthTokenGenerator({ + keyFile: 'second-key.json', + }); + + await generateFirstAuthToken(); + await generateSecondAuthToken(); + + expect(GoogleAuth).toHaveBeenCalledTimes(2); + expect(GoogleAuth).toHaveBeenNthCalledWith(1, { + scopes: ['https://www.googleapis.com/auth/cloud-platform'], + keyFile: 'first-key.json', + }); + expect(GoogleAuth).toHaveBeenNthCalledWith(2, { + scopes: ['https://www.googleapis.com/auth/cloud-platform'], + keyFile: 'second-key.json', + }); + }); }); diff --git a/packages/google-vertex/src/google-vertex-auth-google-auth-library.ts b/packages/google-vertex/src/google-vertex-auth-google-auth-library.ts index 34832e25765d..2360fedad0a1 100644 --- a/packages/google-vertex/src/google-vertex-auth-google-auth-library.ts +++ b/packages/google-vertex/src/google-vertex-auth-google-auth-library.ts @@ -1,27 +1,14 @@ import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library'; -let authInstance: GoogleAuth | null = null; -let authOptions: GoogleAuthOptions | null = null; - -function getAuth(options: GoogleAuthOptions) { - if (!authInstance || options !== authOptions) { - authInstance = new GoogleAuth({ - scopes: ['https://www.googleapis.com/auth/cloud-platform'], - ...options, - }); - authOptions = options; - } - return authInstance; -} - -export async function generateAuthToken(options?: GoogleAuthOptions) { - const auth = getAuth(options || {}); - const client = await auth.getClient(); - const token = await client.getAccessToken(); - return token?.token || null; -} - -// For testing purposes only -export function _resetAuthInstance() { - authInstance = null; +export function createAuthTokenGenerator(options?: GoogleAuthOptions) { + const auth = new GoogleAuth({ + scopes: ['https://www.googleapis.com/auth/cloud-platform'], + ...options, + }); + + return async function generateAuthToken() { + const client = await auth.getClient(); + const token = await client.getAccessToken(); + return token?.token ?? null; + }; } diff --git a/packages/google-vertex/src/google-vertex-provider-node.test.ts b/packages/google-vertex/src/google-vertex-provider-node.test.ts index 399f0a60ce9e..3b0021bcad5b 100644 --- a/packages/google-vertex/src/google-vertex-provider-node.test.ts +++ b/packages/google-vertex/src/google-vertex-provider-node.test.ts @@ -1,12 +1,14 @@ import { resolve } from '@ai-sdk/provider-utils'; import { createVertex as createVertexOriginal } from './google-vertex-provider'; import { createVertex as createVertexNode } from './google-vertex-provider-node'; -import { generateAuthToken } from './google-vertex-auth-google-auth-library'; +import { createAuthTokenGenerator } from './google-vertex-auth-google-auth-library'; import { describe, beforeEach, afterEach, expect, it, vi } from 'vitest'; // Mock the imported modules vi.mock('./google-vertex-auth-google-auth-library', () => ({ - generateAuthToken: vi.fn().mockResolvedValue('mock-auth-token'), + createAuthTokenGenerator: vi.fn(() => + vi.fn().mockResolvedValue('mock-auth-token'), + ), })); vi.mock('./google-vertex-provider', () => ({ @@ -54,7 +56,7 @@ describe('google-vertex-provider-node', () => { }); }); - it('passes googleAuthOptions to generateAuthToken', async () => { + it('passes googleAuthOptions to createAuthTokenGenerator', async () => { createVertexNode({ googleAuthOptions: { scopes: ['https://www.googleapis.com/auth/cloud-platform'], @@ -65,9 +67,9 @@ describe('google-vertex-provider-node', () => { expect(createVertexOriginal).toHaveBeenCalledTimes(1); const passedOptions = vi.mocked(createVertexOriginal).mock.calls[0][0]; - await resolve(passedOptions?.headers); // call the headers function + await resolve(passedOptions?.headers); - expect(generateAuthToken).toHaveBeenCalledWith({ + expect(createAuthTokenGenerator).toHaveBeenCalledWith({ scopes: ['https://www.googleapis.com/auth/cloud-platform'], keyFile: 'path/to/key.json', }); @@ -83,6 +85,19 @@ describe('google-vertex-provider-node', () => { expect(passedOptions?.apiKey).toBe('test-api-key'); expect(passedOptions?.headers).toBeUndefined(); - expect(generateAuthToken).not.toHaveBeenCalled(); + expect(createAuthTokenGenerator).not.toHaveBeenCalled(); + }); + + it('creates the auth token generator once per provider instance', async () => { + createVertexNode({ project: 'test-project' }); + + expect(createAuthTokenGenerator).toHaveBeenCalledTimes(1); + + const passedOptions = vi.mocked(createVertexOriginal).mock.calls[0][0]; + + await resolve(passedOptions?.headers); + await resolve(passedOptions?.headers); + + expect(createAuthTokenGenerator).toHaveBeenCalledTimes(1); }); }); diff --git a/packages/google-vertex/src/google-vertex-provider-node.ts b/packages/google-vertex/src/google-vertex-provider-node.ts index a75f838a54d4..6cbff91768cf 100644 --- a/packages/google-vertex/src/google-vertex-provider-node.ts +++ b/packages/google-vertex/src/google-vertex-provider-node.ts @@ -1,6 +1,6 @@ import { loadOptionalSetting, resolve } from '@ai-sdk/provider-utils'; import { GoogleAuthOptions } from 'google-auth-library'; -import { generateAuthToken } from './google-vertex-auth-google-auth-library'; +import { createAuthTokenGenerator } from './google-vertex-auth-google-auth-library'; import { createVertex as createVertexOriginal, GoogleVertexProvider, @@ -31,12 +31,12 @@ export function createVertex( return createVertexOriginal(options); } + const generateAuthToken = createAuthTokenGenerator(options.googleAuthOptions); + return createVertexOriginal({ ...options, headers: async () => ({ - Authorization: `Bearer ${await generateAuthToken( - options.googleAuthOptions, - )}`, + Authorization: `Bearer ${await generateAuthToken()}`, ...(await resolve(options.headers)), }), }); diff --git a/packages/google-vertex/src/maas/google-vertex-maas-provider-node.test.ts b/packages/google-vertex/src/maas/google-vertex-maas-provider-node.test.ts index 08c46fd8756b..fa840efbe4e6 100644 --- a/packages/google-vertex/src/maas/google-vertex-maas-provider-node.test.ts +++ b/packages/google-vertex/src/maas/google-vertex-maas-provider-node.test.ts @@ -3,11 +3,15 @@ import { createVertexMaas, vertexMaas, } from './google-vertex-maas-provider-node'; -import { generateAuthToken } from '../google-vertex-auth-google-auth-library'; +import { createAuthTokenGenerator } from '../google-vertex-auth-google-auth-library'; + +const { generateAuthToken } = vi.hoisted(() => ({ + generateAuthToken: vi.fn(), +})); // Mock the imported modules vi.mock('../google-vertex-auth-google-auth-library', () => ({ - generateAuthToken: vi.fn(), + createAuthTokenGenerator: vi.fn(() => generateAuthToken), })); vi.mock('./google-vertex-maas-provider', () => ({ @@ -29,6 +33,7 @@ vi.mock('@ai-sdk/provider-utils', () => ({ describe('google-vertex-maas-provider-node', () => { beforeEach(() => { vi.clearAllMocks(); + generateAuthToken.mockReset(); }); it('should create provider with auth wrapper', async () => { @@ -66,7 +71,8 @@ describe('google-vertex-maas-provider-node', () => { headers: { 'Content-Type': 'application/json' }, }); - expect(generateAuthToken).toHaveBeenCalledWith(undefined); + expect(createAuthTokenGenerator).toHaveBeenCalledWith(undefined); + expect(generateAuthToken).toHaveBeenCalledWith(); expect(mockFetch).toHaveBeenCalledWith('https://example.com/test', { method: 'POST', @@ -77,7 +83,7 @@ describe('google-vertex-maas-provider-node', () => { }); }); - it('should pass googleAuthOptions to generateAuthToken', async () => { + it('should pass googleAuthOptions to createAuthTokenGenerator', async () => { vi.mocked(generateAuthToken).mockResolvedValue('mock-auth-token'); const mockFetch = vi.fn().mockResolvedValue(new Response('{}')); global.fetch = mockFetch; @@ -91,7 +97,8 @@ describe('google-vertex-maas-provider-node', () => { const customFetch = (provider as any).fetch; await customFetch('https://example.com/test', {}); - expect(generateAuthToken).toHaveBeenCalledWith(googleAuthOptions); + expect(createAuthTokenGenerator).toHaveBeenCalledWith(googleAuthOptions); + expect(generateAuthToken).toHaveBeenCalledWith(); }); it('should merge custom headers with auth header', async () => { @@ -200,6 +207,25 @@ describe('google-vertex-maas-provider-node', () => { expect(typeof vertexMaas).toBe('function'); }); + it('creates the auth token generator once per provider instance', async () => { + vi.mocked(generateAuthToken).mockResolvedValue('mock-auth-token'); + const mockFetch = vi.fn().mockResolvedValue(new Response('{}')); + global.fetch = mockFetch; + + const provider = createVertexMaas({ + project: 'test-project', + }); + + expect(createAuthTokenGenerator).toHaveBeenCalledTimes(1); + + const customFetch = (provider as any).fetch; + await customFetch('https://example.com/test-1', {}); + await customFetch('https://example.com/test-2', {}); + + expect(createAuthTokenGenerator).toHaveBeenCalledTimes(1); + expect(generateAuthToken).toHaveBeenCalledTimes(2); + }); + it('should set headers to undefined in base provider call', async () => { createVertexMaas({ project: 'test-project', diff --git a/packages/google-vertex/src/maas/google-vertex-maas-provider-node.ts b/packages/google-vertex/src/maas/google-vertex-maas-provider-node.ts index 65376dcdb1e5..eef31b114cc6 100644 --- a/packages/google-vertex/src/maas/google-vertex-maas-provider-node.ts +++ b/packages/google-vertex/src/maas/google-vertex-maas-provider-node.ts @@ -1,6 +1,6 @@ import { FetchFunction, resolve } from '@ai-sdk/provider-utils'; import { GoogleAuthOptions } from 'google-auth-library'; -import { generateAuthToken } from '../google-vertex-auth-google-auth-library'; +import { createAuthTokenGenerator } from '../google-vertex-auth-google-auth-library'; import { createVertexMaas as createVertexMaasOriginal, GoogleVertexMaasProvider, @@ -29,9 +29,11 @@ export interface GoogleVertexMaasProviderSettings extends GoogleVertexMaasProvid export function createVertexMaas( options: GoogleVertexMaasProviderSettings = {}, ): GoogleVertexMaasProvider { + const generateAuthToken = createAuthTokenGenerator(options.googleAuthOptions); + // Create a custom fetch wrapper that adds auth headers const customFetch: FetchFunction = async (url, init) => { - const token = await generateAuthToken(options.googleAuthOptions); + const token = await generateAuthToken(); const resolvedHeaders = await resolve(options.headers); const authHeaders = { ...resolvedHeaders,