Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 5 additions & 14 deletions src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,7 @@
import { inspect } from 'node:util';
import * as path from 'node:path';
import { readdir, stat } from 'node:fs/promises';
import {
Connection,
generateApiName,
Lifecycle,
Logger,
Messages,
SfError,
SfProject,
} from '@salesforce/core';
import { Connection, generateApiName, Lifecycle, Logger, Messages, SfError, SfProject } from '@salesforce/core';
import { ComponentSetBuilder } from '@salesforce/source-deploy-retrieve';
import { Duration } from '@salesforce/kit';
import {
Expand Down Expand Up @@ -108,16 +100,15 @@ export class Agent {
public static async init(
options: ProductionAgentOptions | ScriptAgentOptions
): Promise<ScriptAgent | ProductionAgent> {
// Create ConnectionManager which handles JWT and standard connections
// ConnectionManager isolates JWT (for SFAP) and standard (for org) connections so
// the caller's connection is never mutated by JWT upgrades or auto-refresh.
const connectionManager = await ConnectionManager.create(options.connection);

// Type guard: check if it's ScriptAgentOptions by looking for 'aabName'
if ('aabName' in options) {
// TypeScript now knows this is ScriptAgentOptions
return new ScriptAgent({ ...options, connectionManager });
return new ScriptAgent(options, connectionManager);
} else {
// TypeScript now knows this is ProductionAgentOptions
const agent = new ProductionAgent({ ...options, connectionManager });
const agent = new ProductionAgent(options, connectionManager);
await agent.getBotMetadata();
return agent;
}
Expand Down
42 changes: 38 additions & 4 deletions src/agents/agentBase.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
import { readFile, readdir, cp, mkdir } from 'node:fs/promises';
import { join } from 'node:path';
import { SfError } from '@salesforce/core';
import { Connection, SfError } from '@salesforce/core';
import { AgentPreviewInterface, type AgentPreviewSendResponse, type PlannerResponse, PreviewMetadata } from '../types';
import { getHistoryDir, SessionHistoryBuffer, TranscriptEntry } from '../utils';
import { ConnectionManager } from '../connectionManager';
Expand All @@ -29,6 +29,20 @@ export abstract class AgentBase {
* The display name of the agent (user-friendly name, not API name)
*/
public name: string | undefined;
/**
* The standard org connection used for SOQL queries, tooling API calls, and metadata
* operations. This is distinct from the JWT-upgraded connection used for SFAP API
* calls (held internally by the connection manager).
*/
protected readonly connection: Connection;
/**
* Holds isolated JWT and standard connections. When a ConnectionManager is supplied
* by Agent.init() / Agent.create() / Agent.createSpec(), the agent gets fully
* isolated connections (the caller's connection is not mutated). When undefined
* (consumers instantiating an agent directly), the supplied connection is used as
* both the JWT and standard connection — preserving the pre-isolation behavior.
*/
protected readonly connectionManager: ConnectionManager | undefined;
protected sessionId: string | undefined;
protected historyDir: string | undefined;
protected historyBuffer: SessionHistoryBuffer | undefined;
Expand All @@ -37,14 +51,25 @@ export abstract class AgentBase {
protected planIds = new Set<string>();
public abstract preview: AgentPreviewInterface;

protected constructor(protected readonly connectionManager: ConnectionManager) {}
protected constructor(connection: Connection, connectionManager?: ConnectionManager) {
this.connectionManager = connectionManager;
this.connection = connectionManager ? connectionManager.getStandardConnection() : connection;
}

/**
* Restore the connection by refreshing the standard (non-JWT) connection.
* Refreshes the access token on the standard connection.
*
* Retained for backward compatibility. With the connection manager in place the
* caller's original Connection object is no longer mutated by agent operations,
* so this method only refreshes the internal standard connection.
*/
public async restoreConnection(): Promise<void> {
await this.connectionManager.refreshStandardConnection();
if (this.connectionManager) {
await this.connectionManager.refreshStandardConnection();
return;
}
delete this.connection.accessToken;
await this.connection.refreshAuth();
}

public setSessionId(sessionId: string): void {
Expand Down Expand Up @@ -80,6 +105,15 @@ export abstract class AgentBase {
}
}

/**
* Returns the connection to use for SFAP API calls (api.salesforce.com/einstein/ai-agent).
* When a ConnectionManager is in use this is the JWT-upgraded connection; otherwise it
* is the connection supplied at construction time.
*/
protected getJwtConnection(): Connection {
return this.connectionManager ? this.connectionManager.getJwtConnection() : this.connection;
}

/**
* Get all traces from the current session
* Reads traces from the session directory if available, otherwise fetches from API
Expand Down
83 changes: 32 additions & 51 deletions src/agents/productionAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import {
SessionHistoryBuffer,
} from '../utils';
import { createTraceFlag, findTraceFlag, getDebugLog } from '../apexUtils';
import { ConnectionManager } from '../connectionManager';
import { AgentBase } from './agentBase';
Messages.importMessagesDirectory(__dirname);
const messages = Messages.loadMessages('@salesforce/agents', 'agents');
Expand All @@ -52,16 +53,8 @@ export class ProductionAgent extends AgentBase {
private apiName: string | undefined;
private readonly apiBase: string;

public constructor(options: ProductionAgentOptions) {
// ConnectionManager should be provided by Agent.init(), but fallback to creating one if needed
const connectionManager = options.connectionManager;
if (!connectionManager) {
throw SfError.create({
name: 'MissingConnectionManager',
message: 'ConnectionManager is required. Use Agent.init() to create ProductionAgent instances.',
});
}
super(connectionManager);
public constructor(private options: ProductionAgentOptions, connectionManager?: ConnectionManager) {
super(options.connection, connectionManager);
this.apiBase = 'https://api.salesforce.com/einstein/ai-agent/v1';
if (!options.apiNameOrId) {
throw messages.createError('missingAgentNameOrId');
Expand Down Expand Up @@ -90,8 +83,7 @@ export class ProductionAgent extends AgentBase {
'Id, IsDeleted, DeveloperName, MasterLabel, CreatedDate, CreatedById, LastModifiedDate, LastModifiedById, SystemModstamp, BotUserId, Description, Type, AgentType, AgentTemplate';
const botVersionFields =
'Id, Status, IsDeleted, BotDefinitionId, DeveloperName, CreatedDate, CreatedById, LastModifiedDate, LastModifiedById, SystemModstamp, VersionNumber, CopilotPrimaryLanguage, ToneType, CopilotSecondaryLanguages';
const standardConn = this.connectionManager.getStandardConnection();
this.botMetadata = await standardConn.singleRecordQuery<BotMetadata>(
this.botMetadata = await this.connection.singleRecordQuery<BotMetadata>(
`SELECT ${botDefinitionFields}, (SELECT ${botVersionFields} FROM BotVersions WHERE IsDeleted = false ORDER BY VersionNumber) FROM BotDefinition WHERE ${whereClause} LIMIT 1`
);
this.id = this.botMetadata.Id;
Expand Down Expand Up @@ -217,10 +209,9 @@ export class ProductionAgent extends AgentBase {
protected async handleApexDebuggingSetup(): Promise<void> {
const botMetadata = await this.getBotMetadata();
if (botMetadata.BotUserId) {
const standardConn = this.connectionManager.getStandardConnection();
const traceFlag = await findTraceFlag(standardConn, botMetadata.BotUserId);
const traceFlag = await findTraceFlag(this.connection, botMetadata.BotUserId);
if (!traceFlag) {
await createTraceFlag(standardConn, botMetadata.BotUserId);
await createTraceFlag(this.connection, botMetadata.BotUserId);
}
}
}
Expand Down Expand Up @@ -274,17 +265,14 @@ export class ProductionAgent extends AgentBase {
};
await logTurnToHistory(userEntry, ++this.turnCounter, this.historyDir, this.historyBuffer);

const response = await requestWithEndpointFallback<AgentPreviewSendResponse>(
this.connectionManager.getJwtConnection(),
{
method: 'POST',
url,
body: JSON.stringify(body),
headers: {
'x-client-name': 'afdx',
},
}
);
const response = await requestWithEndpointFallback<AgentPreviewSendResponse>(this.getJwtConnection(), {
method: 'POST',
url,
body: JSON.stringify(body),
headers: {
'x-client-name': 'afdx',
},
});

const planId = response.messages.at(0)!.planId;
this.planIds.add(planId);
Expand All @@ -309,8 +297,7 @@ export class ProductionAgent extends AgentBase {
await this.historyBuffer.flush();

if (this.apexDebugging && this.canApexDebug()) {
const standardConn = this.connectionManager.getStandardConnection();
const apexLog = await getDebugLog(standardConn, start, Date.now());
const apexLog = await getDebugLog(this.connection, start, Date.now());
if (apexLog) {
response.apexDebugLog = apexLog;
}
Expand All @@ -336,7 +323,7 @@ export class ProductionAgent extends AgentBase {
}

const url = `/connect/bot-versions/${botVersionMetadata.Id}/activation`;
const maybeMock = new MaybeMock(this.connectionManager.getJwtConnection());
const maybeMock = new MaybeMock(this.getJwtConnection());
const response = await maybeMock.request<BotActivationResponse>('POST', url, { status: desiredState });
if (response.success) {
const versionToUpdate = this.botMetadata!.BotVersions.records.find(
Expand Down Expand Up @@ -365,7 +352,7 @@ export class ProductionAgent extends AgentBase {
const body = {
externalSessionKey: randomUUID(),
instanceConfig: {
endpoint: this.connectionManager.getStandardConnection().instanceUrl,
endpoint: this.options.connection.instanceUrl,
},
streamingCapabilities: {
chunkTypes: ['Text'],
Expand All @@ -374,17 +361,14 @@ export class ProductionAgent extends AgentBase {
};

try {
const response = await requestWithEndpointFallback<AgentPreviewStartResponse>(
this.connectionManager.getJwtConnection(),
{
method: 'POST',
url,
body: JSON.stringify(body),
headers: {
'x-client-name': 'afdx',
},
}
);
const response = await requestWithEndpointFallback<AgentPreviewStartResponse>(this.getJwtConnection(), {
method: 'POST',
url,
body: JSON.stringify(body),
headers: {
'x-client-name': 'afdx',
},
});
this.sessionId = response.sessionId;

const agentId = this.id!;
Expand Down Expand Up @@ -439,16 +423,13 @@ export class ProductionAgent extends AgentBase {
const url = `${this.apiBase}/sessions/${this.sessionId}`;
try {
// https://developer.salesforce.com/docs/einstein/genai/guide/agent-api-examples.html#end-session
const response = await requestWithEndpointFallback<AgentPreviewEndResponse>(
this.connectionManager.getJwtConnection(),
{
method: 'DELETE',
url,
headers: {
'x-session-end-reason': reason,
},
}
);
const response = await requestWithEndpointFallback<AgentPreviewEndResponse>(this.getJwtConnection(), {
method: 'DELETE',
url,
headers: {
'x-session-end-reason': reason,
},
});

// Write end entry and flush buffer
if (this.historyDir && this.historyBuffer) {
Expand Down
54 changes: 20 additions & 34 deletions src/agents/scriptAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import {
import { getDebugLog } from '../apexUtils';
import { generateAgentScript } from '../templates/agentScriptTemplate';
import { applyStringReplacementsToAgent } from '../stringReplacements';
import { ConnectionManager } from '../connectionManager';
import { ScriptAgentPublisher } from './scriptAgentPublisher';
import { AgentBase } from './agentBase';

Expand All @@ -63,16 +64,8 @@ export class ScriptAgent extends AgentBase {
private readonly aabDirectory: string;
private readonly metaContent: string;
private readonly agentFilePath: string;
public constructor(private options: ScriptAgentOptions) {
// ConnectionManager should be provided by Agent.init(), but fallback to creating one if needed
const connectionManager = options.connectionManager;
if (!connectionManager) {
throw SfError.create({
name: 'MissingConnectionManager',
message: 'ConnectionManager is required. Use Agent.init() to create ScriptAgent instances.',
});
}
super(connectionManager);
public constructor(private options: ScriptAgentOptions, connectionManager?: ConnectionManager) {
super(options.connection, connectionManager);
this.options = options;
this.apiBase = 'https://api.salesforce.com/einstein/ai-agent';

Expand Down Expand Up @@ -177,7 +170,7 @@ export class ScriptAgent extends AgentBase {
}

public async getTrace(planId: string): Promise<PlannerResponse> {
return requestWithEndpointFallback<PlannerResponse>(this.connectionManager.getJwtConnection(), {
return requestWithEndpointFallback<PlannerResponse>(this.getJwtConnection(), {
method: 'GET',
url: `${this.apiBase}/v1.1/preview/sessions/${this.sessionId!}/plans/${planId}`,
headers: {
Expand Down Expand Up @@ -213,7 +206,7 @@ export class ScriptAgent extends AgentBase {

try {
const response = await requestWithEndpointFallback<CompileAgentScriptResponse>(
this.connectionManager.getJwtConnection(),
this.getJwtConnection(),
{
method: 'POST',
url,
Expand Down Expand Up @@ -279,10 +272,11 @@ export class ScriptAgent extends AgentBase {
}

const publisher = new ScriptAgentPublisher(
this.connectionManager,
this.options.connection,
this.options.project,
this.agentJson!,
skipMetadataRetrieve
skipMetadataRetrieve,
this.connectionManager
);
return publisher.publishAgentJson();
}
Expand Down Expand Up @@ -403,17 +397,14 @@ export class ScriptAgent extends AgentBase {
this.historyBuffer
);

const response = await requestWithEndpointFallback<AgentPreviewSendResponse>(
this.connectionManager.getJwtConnection(),
{
method: 'POST',
url,
body: JSON.stringify(body),
headers: {
'x-client-name': 'afdx',
},
}
);
const response = await requestWithEndpointFallback<AgentPreviewSendResponse>(this.getJwtConnection(), {
method: 'POST',
url,
body: JSON.stringify(body),
headers: {
'x-client-name': 'afdx',
},
});

const planId = response.messages.at(0)!.planId;
this.planIds.add(planId);
Expand Down Expand Up @@ -443,9 +434,7 @@ export class ScriptAgent extends AgentBase {
await this.historyBuffer.flush();

if (this.apexDebugging && this.canApexDebug()) {
// Use standard connection for tooling API query to avoid JWT token clobbering
const standardConn = this.connectionManager.getStandardConnection();
const apexLog = await getDebugLog(standardConn, start, Date.now());
const apexLog = await getDebugLog(this.connection, start, Date.now());
if (apexLog) {
response.apexDebugLog = apexLog;
}
Expand All @@ -457,7 +446,6 @@ export class ScriptAgent extends AgentBase {
}
}


private setMockMode(mockMode: 'Mock' | 'Live Test'): void {
this.mockMode = mockMode;
}
Expand All @@ -484,11 +472,9 @@ export class ScriptAgent extends AgentBase {
}

// send bypassUser=false when the compiledAgent.globalConfiguration.defaultAgentUser is INVALID
// Use standard connection for SOQL query to avoid JWT token clobbering
const standardConn = this.connectionManager.getStandardConnection();
let bypassUser =
(
await standardConn.query<{ Id: string }>(
await this.connection.query<{ Id: string }>(
`SELECT Id FROM USER WHERE username='${this.agentJson.globalConfiguration.defaultAgentUser}'`
)
).totalSize === 1;
Expand Down Expand Up @@ -523,9 +509,9 @@ export class ScriptAgent extends AgentBase {
void Lifecycle.getInstance().emit('agents:simulation-starting', {});

let response: AgentPreviewStartResponse;
try {
try {
response = await requestWithEndpointFallback<AgentPreviewStartResponse>(
this.connectionManager.getJwtConnection(),
this.getJwtConnection(),
{
method: 'POST',
url: `${this.apiBase}/v1.1/preview/sessions`,
Expand Down
Loading
Loading