Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .changeset/tall-peaches-tease.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
'@ai-sdk/google-vertex': patch
---

fix(provider/google-vertex): avoid recreating Node GoogleAuth clients for repeated requests

Create Google auth token generators per provider instance instead of using a
module-level shared `GoogleAuth` cache. This avoids unnecessary `GoogleAuth`
recreation when `googleAuthOptions` are omitted or when multiple provider
instances use equivalent auth settings.
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import { resolve } from '@ai-sdk/provider-utils';
import { createVertexAnthropic as createVertexAnthropicOriginal } from './google-vertex-anthropic-provider';
import { createVertexAnthropic as createVertexAnthropicNode } from './google-vertex-anthropic-provider-node';
import { generateAuthToken } from '../google-vertex-auth-google-auth-library';
import { createAuthTokenGenerator } from '../google-vertex-auth-google-auth-library';
import { describe, beforeEach, expect, it, vi } from 'vitest';

// Mock the imported modules
vi.mock('../google-vertex-auth-google-auth-library', () => ({
generateAuthToken: vi.fn().mockResolvedValue('mock-auth-token'),
createAuthTokenGenerator: vi.fn(() =>
vi.fn().mockResolvedValue('mock-auth-token'),
),
}));

vi.mock('./google-vertex-anthropic-provider', () => ({
Expand Down Expand Up @@ -51,7 +53,7 @@ describe('google-vertex-anthropic-provider-node', () => {
});
});

it('passes googleAuthOptions to generateAuthToken', async () => {
it('passes googleAuthOptions to createAuthTokenGenerator', async () => {
createVertexAnthropicNode({
googleAuthOptions: {
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
Expand All @@ -63,11 +65,25 @@ describe('google-vertex-anthropic-provider-node', () => {
const passedOptions = vi.mocked(createVertexAnthropicOriginal).mock
.calls[0][0];

await resolve(passedOptions?.headers); // call the headers function
await resolve(passedOptions?.headers);

expect(generateAuthToken).toHaveBeenCalledWith({
expect(createAuthTokenGenerator).toHaveBeenCalledWith({
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
keyFile: 'path/to/key.json',
});
});

it('creates the auth token generator once per provider instance', async () => {
createVertexAnthropicNode({ project: 'test-project' });

expect(createAuthTokenGenerator).toHaveBeenCalledTimes(1);

const passedOptions = vi.mocked(createVertexAnthropicOriginal).mock
.calls[0][0];

await resolve(passedOptions?.headers);
await resolve(passedOptions?.headers);

expect(createAuthTokenGenerator).toHaveBeenCalledTimes(1);
});
});
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { resolve } from '@ai-sdk/provider-utils';
import { GoogleAuthOptions } from 'google-auth-library';
import { generateAuthToken } from '../google-vertex-auth-google-auth-library';
import { createAuthTokenGenerator } from '../google-vertex-auth-google-auth-library';
import {
createVertexAnthropic as createVertexAnthropicOriginal,
GoogleVertexAnthropicProvider,
Expand All @@ -22,12 +22,12 @@ export interface GoogleVertexAnthropicProviderSettings extends GoogleVertexAnthr
export function createVertexAnthropic(
options: GoogleVertexAnthropicProviderSettings = {},
): GoogleVertexAnthropicProvider {
const generateAuthToken = createAuthTokenGenerator(options.googleAuthOptions);

return createVertexAnthropicOriginal({
...options,
headers: async () => ({
Authorization: `Bearer ${await generateAuthToken(
options.googleAuthOptions,
)}`,
Authorization: `Bearer ${await generateAuthToken()}`,
...(await resolve(options.headers)),
}),
});
Expand Down
Original file line number Diff line number Diff line change
@@ -1,58 +1,80 @@
import { describe, it, expect, vi, beforeEach } from 'vitest';
import {
generateAuthToken,
_resetAuthInstance,
} from './google-vertex-auth-google-auth-library';
import { createAuthTokenGenerator } from './google-vertex-auth-google-auth-library';
import { GoogleAuth } from 'google-auth-library';

const getAccessToken = vi.fn().mockResolvedValue({ token: 'mocked-token' });
const getClient = vi.fn().mockResolvedValue({ getAccessToken });

vi.mock('google-auth-library', () => {
return {
GoogleAuth: vi.fn(function () {
return {
getClient: vi.fn().mockResolvedValue({
getAccessToken: vi.fn().mockResolvedValue({ token: 'mocked-token' }),
}),
getClient,
};
}),
};
});

describe('generateAuthToken', () => {
describe('createAuthTokenGenerator', () => {
beforeEach(() => {
vi.clearAllMocks();
_resetAuthInstance();
getAccessToken.mockResolvedValue({ token: 'mocked-token' });
});

it('should generate a valid auth token', async () => {
const generateAuthToken = createAuthTokenGenerator();

const token = await generateAuthToken();

expect(token).toBe('mocked-token');
});

it('should return null if no token is received', async () => {
// Reset the mock completely
vi.mocked(GoogleAuth).mockReset();

// Create a new mock implementation
vi.mocked(GoogleAuth).mockImplementation(function () {
return {
getClient: vi.fn().mockResolvedValue({
getAccessToken: vi.fn().mockResolvedValue({ token: null }),
}),
isGCE: vi.fn(),
} as unknown as GoogleAuth;
});
getAccessToken.mockResolvedValueOnce({ token: null });

const generateAuthToken = createAuthTokenGenerator();
const token = await generateAuthToken();

expect(token).toBeNull();
});

it('should create new auth instance with provided options', async () => {
const options = { keyFile: 'test-key.json' };
await generateAuthToken(options);
it('should create a new auth instance with provided options', () => {
createAuthTokenGenerator({ keyFile: 'test-key.json' });

expect(GoogleAuth).toHaveBeenCalledWith({
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
keyFile: 'test-key.json',
});
});

it('should create only one GoogleAuth instance for repeated calls', async () => {
const generateAuthToken = createAuthTokenGenerator();

await generateAuthToken();
await generateAuthToken();

expect(GoogleAuth).toHaveBeenCalledTimes(1);
});

it('should create independent generators for separate option sets', async () => {
const generateFirstAuthToken = createAuthTokenGenerator({
keyFile: 'first-key.json',
});
const generateSecondAuthToken = createAuthTokenGenerator({
keyFile: 'second-key.json',
});

await generateFirstAuthToken();
await generateSecondAuthToken();

expect(GoogleAuth).toHaveBeenCalledTimes(2);
expect(GoogleAuth).toHaveBeenNthCalledWith(1, {
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
keyFile: 'first-key.json',
});
expect(GoogleAuth).toHaveBeenNthCalledWith(2, {
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
keyFile: 'second-key.json',
});
});
});
Original file line number Diff line number Diff line change
@@ -1,27 +1,14 @@
import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library';

let authInstance: GoogleAuth | null = null;
let authOptions: GoogleAuthOptions | null = null;

function getAuth(options: GoogleAuthOptions) {
if (!authInstance || options !== authOptions) {
authInstance = new GoogleAuth({
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
...options,
});
authOptions = options;
}
return authInstance;
}

export async function generateAuthToken(options?: GoogleAuthOptions) {
const auth = getAuth(options || {});
const client = await auth.getClient();
const token = await client.getAccessToken();
return token?.token || null;
}

// For testing purposes only
export function _resetAuthInstance() {
authInstance = null;
export function createAuthTokenGenerator(options?: GoogleAuthOptions) {
const auth = new GoogleAuth({
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
...options,
});

return async function generateAuthToken() {
const client = await auth.getClient();
const token = await client.getAccessToken();
return token?.token ?? null;
};
}
27 changes: 21 additions & 6 deletions packages/google-vertex/src/google-vertex-provider-node.test.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import { resolve } from '@ai-sdk/provider-utils';
import { createVertex as createVertexOriginal } from './google-vertex-provider';
import { createVertex as createVertexNode } from './google-vertex-provider-node';
import { generateAuthToken } from './google-vertex-auth-google-auth-library';
import { createAuthTokenGenerator } from './google-vertex-auth-google-auth-library';
import { describe, beforeEach, afterEach, expect, it, vi } from 'vitest';

// Mock the imported modules
vi.mock('./google-vertex-auth-google-auth-library', () => ({
generateAuthToken: vi.fn().mockResolvedValue('mock-auth-token'),
createAuthTokenGenerator: vi.fn(() =>
vi.fn().mockResolvedValue('mock-auth-token'),
),
}));

vi.mock('./google-vertex-provider', () => ({
Expand Down Expand Up @@ -54,7 +56,7 @@ describe('google-vertex-provider-node', () => {
});
});

it('passes googleAuthOptions to generateAuthToken', async () => {
it('passes googleAuthOptions to createAuthTokenGenerator', async () => {
createVertexNode({
googleAuthOptions: {
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
Expand All @@ -65,9 +67,9 @@ describe('google-vertex-provider-node', () => {
expect(createVertexOriginal).toHaveBeenCalledTimes(1);
const passedOptions = vi.mocked(createVertexOriginal).mock.calls[0][0];

await resolve(passedOptions?.headers); // call the headers function
await resolve(passedOptions?.headers);

expect(generateAuthToken).toHaveBeenCalledWith({
expect(createAuthTokenGenerator).toHaveBeenCalledWith({
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
keyFile: 'path/to/key.json',
});
Expand All @@ -83,6 +85,19 @@ describe('google-vertex-provider-node', () => {

expect(passedOptions?.apiKey).toBe('test-api-key');
expect(passedOptions?.headers).toBeUndefined();
expect(generateAuthToken).not.toHaveBeenCalled();
expect(createAuthTokenGenerator).not.toHaveBeenCalled();
});

it('creates the auth token generator once per provider instance', async () => {
createVertexNode({ project: 'test-project' });

expect(createAuthTokenGenerator).toHaveBeenCalledTimes(1);

const passedOptions = vi.mocked(createVertexOriginal).mock.calls[0][0];

await resolve(passedOptions?.headers);
await resolve(passedOptions?.headers);

expect(createAuthTokenGenerator).toHaveBeenCalledTimes(1);
});
});
8 changes: 4 additions & 4 deletions packages/google-vertex/src/google-vertex-provider-node.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { loadOptionalSetting, resolve } from '@ai-sdk/provider-utils';
import { GoogleAuthOptions } from 'google-auth-library';
import { generateAuthToken } from './google-vertex-auth-google-auth-library';
import { createAuthTokenGenerator } from './google-vertex-auth-google-auth-library';
import {
createVertex as createVertexOriginal,
GoogleVertexProvider,
Expand Down Expand Up @@ -31,12 +31,12 @@ export function createVertex(
return createVertexOriginal(options);
}

const generateAuthToken = createAuthTokenGenerator(options.googleAuthOptions);

return createVertexOriginal({
...options,
headers: async () => ({
Authorization: `Bearer ${await generateAuthToken(
options.googleAuthOptions,
)}`,
Authorization: `Bearer ${await generateAuthToken()}`,
...(await resolve(options.headers)),
}),
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@ import {
createVertexMaas,
vertexMaas,
} from './google-vertex-maas-provider-node';
import { generateAuthToken } from '../google-vertex-auth-google-auth-library';
import { createAuthTokenGenerator } from '../google-vertex-auth-google-auth-library';

const { generateAuthToken } = vi.hoisted(() => ({
generateAuthToken: vi.fn(),
}));

// Mock the imported modules
vi.mock('../google-vertex-auth-google-auth-library', () => ({
generateAuthToken: vi.fn(),
createAuthTokenGenerator: vi.fn(() => generateAuthToken),
}));

vi.mock('./google-vertex-maas-provider', () => ({
Expand All @@ -29,6 +33,7 @@ vi.mock('@ai-sdk/provider-utils', () => ({
describe('google-vertex-maas-provider-node', () => {
beforeEach(() => {
vi.clearAllMocks();
generateAuthToken.mockReset();
});

it('should create provider with auth wrapper', async () => {
Expand Down Expand Up @@ -66,7 +71,8 @@ describe('google-vertex-maas-provider-node', () => {
headers: { 'Content-Type': 'application/json' },
});

expect(generateAuthToken).toHaveBeenCalledWith(undefined);
expect(createAuthTokenGenerator).toHaveBeenCalledWith(undefined);
expect(generateAuthToken).toHaveBeenCalledWith();

expect(mockFetch).toHaveBeenCalledWith('https://example.com/test', {
method: 'POST',
Expand All @@ -77,7 +83,7 @@ describe('google-vertex-maas-provider-node', () => {
});
});

it('should pass googleAuthOptions to generateAuthToken', async () => {
it('should pass googleAuthOptions to createAuthTokenGenerator', async () => {
vi.mocked(generateAuthToken).mockResolvedValue('mock-auth-token');
const mockFetch = vi.fn().mockResolvedValue(new Response('{}'));
global.fetch = mockFetch;
Expand All @@ -91,7 +97,8 @@ describe('google-vertex-maas-provider-node', () => {
const customFetch = (provider as any).fetch;
await customFetch('https://example.com/test', {});

expect(generateAuthToken).toHaveBeenCalledWith(googleAuthOptions);
expect(createAuthTokenGenerator).toHaveBeenCalledWith(googleAuthOptions);
expect(generateAuthToken).toHaveBeenCalledWith();
});

it('should merge custom headers with auth header', async () => {
Expand Down Expand Up @@ -200,6 +207,25 @@ describe('google-vertex-maas-provider-node', () => {
expect(typeof vertexMaas).toBe('function');
});

it('creates the auth token generator once per provider instance', async () => {
vi.mocked(generateAuthToken).mockResolvedValue('mock-auth-token');
const mockFetch = vi.fn().mockResolvedValue(new Response('{}'));
global.fetch = mockFetch;

const provider = createVertexMaas({
project: 'test-project',
});

expect(createAuthTokenGenerator).toHaveBeenCalledTimes(1);

const customFetch = (provider as any).fetch;
await customFetch('https://example.com/test-1', {});
await customFetch('https://example.com/test-2', {});

expect(createAuthTokenGenerator).toHaveBeenCalledTimes(1);
expect(generateAuthToken).toHaveBeenCalledTimes(2);
});

it('should set headers to undefined in base provider call', async () => {
createVertexMaas({
project: 'test-project',
Expand Down
Loading
Loading