Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import { inspect } from 'node:util';
import * as path from 'node:path';
import { readdir, stat } from 'node:fs/promises';
import {
AuthInfo,
Connection,
generateApiName,
Lifecycle,
Expand All @@ -43,9 +42,10 @@ import {
ScriptAgentOptions,
} from './types';
import { MaybeMock } from './maybe-mock';
import { decodeHtmlEntities, findLocalAgents, useNamedUserJwt } from './utils';
import { decodeHtmlEntities, findLocalAgents } from './utils';
import { ScriptAgent } from './agents/scriptAgent';
import { ProductionAgent } from './agents/productionAgent';
import { ConnectionManager } from './connectionManager';

/** Instance type returned from Agent.init(); has setSessionId, getHistoryDir, preview, etc. */
export type AgentInstance = ScriptAgent | ProductionAgent;
Expand Down Expand Up @@ -108,24 +108,16 @@ export class Agent {
public static async init(
options: ProductionAgentOptions | ScriptAgentOptions
): Promise<ScriptAgent | ProductionAgent> {
const username = options.connection.getUsername();

// Create a fresh connection instance for agent operations
// This ensures we don't modify the original connection passed in
// The original connection remains unchanged and can be used for other operations, mid agent-operation
const authInfo = await AuthInfo.create({ username });
const isolatedConnection = await Connection.create({ authInfo });

// Upgrade the isolated connection with JWT
const jwtConnection = await useNamedUserJwt(isolatedConnection);
// Create ConnectionManager which handles JWT and standard connections
const connectionManager = await ConnectionManager.create(options.connection);

// Type guard: check if it's ScriptAgentOptions by looking for 'aabName'
if ('aabName' in options) {
// TypeScript now knows this is ScriptAgentOptions
return new ScriptAgent({ ...options, connection: jwtConnection });
return new ScriptAgent({ ...options, connectionManager });
} else {
// TypeScript now knows this is ProductionAgentOptions
const agent = new ProductionAgent({ ...options, connection: jwtConnection });
const agent = new ProductionAgent({ ...options, connectionManager });
await agent.getBotMetadata();
return agent;
}
Expand Down Expand Up @@ -246,7 +238,11 @@ export class Agent {
config: AgentCreateConfig
): Promise<AgentCreateResponse> {
const url = '/connect/ai-assist/create-agent';
const maybeMock = new MaybeMock(connection);

// Create ConnectionManager to get JWT connection for SFAP API calls
const connectionManager = await ConnectionManager.create(connection);
const jwtConnection = connectionManager.getJwtConnection();
const maybeMock = new MaybeMock(jwtConnection);

// When previewing agent creation just return the response.
if (!config.saveAgent) {
Expand All @@ -272,19 +268,20 @@ export class Agent {
if (response.isSuccess) {
await Lifecycle.getInstance().emit(AgentCreateLifecycleStages.Retrieving, {});
const defaultPackagePath = project.getDefaultPackage().path ?? 'force-app';
const standardConnection = connectionManager.getStandardConnection();
try {
const cs = await ComponentSetBuilder.build({
metadata: {
metadataEntries: [`Agent:${config.agentSettings.agentApiName}`],
directoryPaths: [defaultPackagePath],
},
org: {
username: connection.getUsername() as string,
username: standardConnection.getUsername() as string,
exclude: [],
},
});
const retrieve = await cs.retrieve({
usernameOrConnection: connection,
usernameOrConnection: standardConnection,
merge: true,
format: 'source',
output: path.resolve(project.getPath(), defaultPackagePath),
Expand Down Expand Up @@ -324,7 +321,10 @@ export class Agent {
* @returns the agent job spec
*/
public static async createSpec(connection: Connection, config: AgentJobSpecCreateConfig): Promise<AgentJobSpec> {
const maybeMock = new MaybeMock(connection);
// Create ConnectionManager to get JWT connection for SFAP API calls
const connectionManager = await ConnectionManager.create(connection);
const jwtConnection = connectionManager.getJwtConnection();
const maybeMock = new MaybeMock(jwtConnection);
verifyAgentSpecConfig(config);

const url = '/connect/ai-assist/draft-agent-topics';
Expand Down
10 changes: 3 additions & 7 deletions src/agents/agentBase.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
*/
import { readFile, readdir, cp, mkdir } from 'node:fs/promises';
import { join } from 'node:path';
import { Connection, SfError } from '@salesforce/core';
import { SfError } from '@salesforce/core';
import { AgentPreviewInterface, type AgentPreviewSendResponse, type PlannerResponse, PreviewMetadata } from '../types';
import { getHistoryDir, SessionHistoryBuffer, TranscriptEntry } from '../utils';
import { ConnectionManager } from '../connectionManager';

/**
* Abstract base class for agent preview functionality.
Expand All @@ -36,12 +37,7 @@ export abstract class AgentBase {
protected planIds = new Set<string>();
public abstract preview: AgentPreviewInterface;

protected constructor(protected readonly connection: Connection) {}

public async restoreConnection(): Promise<void> {
Comment thread
shetzel marked this conversation as resolved.
delete this.connection.accessToken;
await this.connection.refreshAuth();
}
protected constructor(protected readonly connectionManager: ConnectionManager) {}

public setSessionId(sessionId: string): void {
this.sessionId = sessionId;
Expand Down
82 changes: 51 additions & 31 deletions src/agents/productionAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,16 @@ export class ProductionAgent extends AgentBase {
private apiName: string | undefined;
private readonly apiBase: string;

public constructor(private options: ProductionAgentOptions) {
super(options.connection);
public constructor(options: ProductionAgentOptions) {
// ConnectionManager should be provided by Agent.init(), but fallback to creating one if needed
const connectionManager = options.connectionManager;
if (!connectionManager) {
throw SfError.create({
name: 'MissingConnectionManager',
message: 'ConnectionManager is required. Use Agent.init() to create ProductionAgent instances.',
});
}
super(connectionManager);
this.apiBase = 'https://api.salesforce.com/einstein/ai-agent/v1';
if (!options.apiNameOrId) {
throw messages.createError('missingAgentNameOrId');
Expand Down Expand Up @@ -82,7 +90,8 @@ export class ProductionAgent extends AgentBase {
'Id, IsDeleted, DeveloperName, MasterLabel, CreatedDate, CreatedById, LastModifiedDate, LastModifiedById, SystemModstamp, BotUserId, Description, Type, AgentType, AgentTemplate';
const botVersionFields =
'Id, Status, IsDeleted, BotDefinitionId, DeveloperName, CreatedDate, CreatedById, LastModifiedDate, LastModifiedById, SystemModstamp, VersionNumber, CopilotPrimaryLanguage, ToneType, CopilotSecondaryLanguages';
this.botMetadata = await this.connection.singleRecordQuery<BotMetadata>(
const standardConn = this.connectionManager.getStandardConnection();
this.botMetadata = await standardConn.singleRecordQuery<BotMetadata>(
`SELECT ${botDefinitionFields}, (SELECT ${botVersionFields} FROM BotVersions WHERE IsDeleted = false ORDER BY VersionNumber) FROM BotDefinition WHERE ${whereClause} LIMIT 1`
);
this.id = this.botMetadata.Id;
Expand Down Expand Up @@ -208,9 +217,10 @@ export class ProductionAgent extends AgentBase {
protected async handleApexDebuggingSetup(): Promise<void> {
const botMetadata = await this.getBotMetadata();
if (botMetadata.BotUserId) {
const traceFlag = await findTraceFlag(this.connection, botMetadata.BotUserId);
const standardConn = this.connectionManager.getStandardConnection();
const traceFlag = await findTraceFlag(standardConn, botMetadata.BotUserId);
if (!traceFlag) {
await createTraceFlag(this.connection, botMetadata.BotUserId);
await createTraceFlag(standardConn, botMetadata.BotUserId);
}
}
}
Expand Down Expand Up @@ -264,14 +274,17 @@ export class ProductionAgent extends AgentBase {
};
await logTurnToHistory(userEntry, ++this.turnCounter, this.historyDir, this.historyBuffer);

const response = await requestWithEndpointFallback<AgentPreviewSendResponse>(this.connection, {
method: 'POST',
url,
body: JSON.stringify(body),
headers: {
'x-client-name': 'afdx',
},
});
const response = await requestWithEndpointFallback<AgentPreviewSendResponse>(
this.connectionManager.getJwtConnection(),
{
method: 'POST',
url,
body: JSON.stringify(body),
headers: {
'x-client-name': 'afdx',
},
}
);

const planId = response.messages.at(0)!.planId;
this.planIds.add(planId);
Expand All @@ -296,7 +309,8 @@ export class ProductionAgent extends AgentBase {
await this.historyBuffer.flush();

if (this.apexDebugging && this.canApexDebug()) {
const apexLog = await getDebugLog(this.connection, start, Date.now());
const standardConn = this.connectionManager.getStandardConnection();
const apexLog = await getDebugLog(standardConn, start, Date.now());
if (apexLog) {
response.apexDebugLog = apexLog;
}
Expand All @@ -322,7 +336,7 @@ export class ProductionAgent extends AgentBase {
}

const url = `/connect/bot-versions/${botVersionMetadata.Id}/activation`;
const maybeMock = new MaybeMock(this.connection);
const maybeMock = new MaybeMock(this.connectionManager.getJwtConnection());
const response = await maybeMock.request<BotActivationResponse>('POST', url, { status: desiredState });
if (response.success) {
const versionToUpdate = this.botMetadata!.BotVersions.records.find(
Expand Down Expand Up @@ -351,7 +365,7 @@ export class ProductionAgent extends AgentBase {
const body = {
externalSessionKey: randomUUID(),
instanceConfig: {
endpoint: this.options.connection.instanceUrl,
endpoint: this.connectionManager.getStandardConnection().instanceUrl,
},
streamingCapabilities: {
chunkTypes: ['Text'],
Expand All @@ -360,14 +374,17 @@ export class ProductionAgent extends AgentBase {
};

try {
const response = await requestWithEndpointFallback<AgentPreviewStartResponse>(this.connection, {
method: 'POST',
url,
body: JSON.stringify(body),
headers: {
'x-client-name': 'afdx',
},
});
const response = await requestWithEndpointFallback<AgentPreviewStartResponse>(
this.connectionManager.getJwtConnection(),
{
method: 'POST',
url,
body: JSON.stringify(body),
headers: {
'x-client-name': 'afdx',
},
}
);
this.sessionId = response.sessionId;

const agentId = this.id!;
Expand Down Expand Up @@ -422,13 +439,16 @@ export class ProductionAgent extends AgentBase {
const url = `${this.apiBase}/sessions/${this.sessionId}`;
try {
// https://developer.salesforce.com/docs/einstein/genai/guide/agent-api-examples.html#end-session
const response = await requestWithEndpointFallback<AgentPreviewEndResponse>(this.connection, {
method: 'DELETE',
url,
headers: {
'x-session-end-reason': reason,
},
});
const response = await requestWithEndpointFallback<AgentPreviewEndResponse>(
this.connectionManager.getJwtConnection(),
{
method: 'DELETE',
url,
headers: {
'x-session-end-reason': reason,
},
}
);

// Write end entry and flush buffer
if (this.historyDir && this.historyBuffer) {
Expand Down
48 changes: 32 additions & 16 deletions src/agents/scriptAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,15 @@ export class ScriptAgent extends AgentBase {
private readonly metaContent: string;
private readonly agentFilePath: string;
public constructor(private options: ScriptAgentOptions) {
super(options.connection);
// ConnectionManager should be provided by Agent.init(), but fallback to creating one if needed
const connectionManager = options.connectionManager;
if (!connectionManager) {
throw SfError.create({
name: 'MissingConnectionManager',
message: 'ConnectionManager is required. Use Agent.init() to create ScriptAgent instances.',
});
}
super(connectionManager);
this.options = options;
this.apiBase = 'https://api.salesforce.com/einstein/ai-agent';

Expand Down Expand Up @@ -169,7 +177,7 @@ export class ScriptAgent extends AgentBase {
}

public async getTrace(planId: string): Promise<PlannerResponse> {
return requestWithEndpointFallback<PlannerResponse>(this.connection, {
return requestWithEndpointFallback<PlannerResponse>(this.connectionManager.getJwtConnection(), {
method: 'GET',
url: `${this.apiBase}/v1.1/preview/sessions/${this.sessionId!}/plans/${planId}`,
headers: {
Expand Down Expand Up @@ -205,7 +213,7 @@ export class ScriptAgent extends AgentBase {

try {
const response = await requestWithEndpointFallback<CompileAgentScriptResponse>(
this.connection,
this.connectionManager.getJwtConnection(),
{
method: 'POST',
url,
Expand Down Expand Up @@ -271,7 +279,7 @@ export class ScriptAgent extends AgentBase {
}

const publisher = new ScriptAgentPublisher(
this.connection,
this.connectionManager,
this.options.project,
this.agentJson!,
skipMetadataRetrieve
Expand Down Expand Up @@ -395,14 +403,17 @@ export class ScriptAgent extends AgentBase {
this.historyBuffer
);

const response = await requestWithEndpointFallback<AgentPreviewSendResponse>(this.connection, {
method: 'POST',
url,
body: JSON.stringify(body),
headers: {
'x-client-name': 'afdx',
},
});
const response = await requestWithEndpointFallback<AgentPreviewSendResponse>(
this.connectionManager.getJwtConnection(),
{
method: 'POST',
url,
body: JSON.stringify(body),
headers: {
'x-client-name': 'afdx',
},
}
);

const planId = response.messages.at(0)!.planId;
this.planIds.add(planId);
Expand Down Expand Up @@ -432,7 +443,9 @@ export class ScriptAgent extends AgentBase {
await this.historyBuffer.flush();

if (this.apexDebugging && this.canApexDebug()) {
const apexLog = await getDebugLog(this.connection, start, Date.now());
// Use standard connection for tooling API query to avoid JWT token clobbering
const standardConn = this.connectionManager.getStandardConnection();
const apexLog = await getDebugLog(standardConn, start, Date.now());
if (apexLog) {
response.apexDebugLog = apexLog;
}
Expand All @@ -444,6 +457,7 @@ export class ScriptAgent extends AgentBase {
}
}


private setMockMode(mockMode: 'Mock' | 'Live Test'): void {
this.mockMode = mockMode;
}
Expand All @@ -470,9 +484,11 @@ export class ScriptAgent extends AgentBase {
}

// send bypassUser=false when the compiledAgent.globalConfiguration.defaultAgentUser is INVALID
// Use standard connection for SOQL query to avoid JWT token clobbering
const standardConn = this.connectionManager.getStandardConnection();
let bypassUser =
(
await this.connection.query(
await standardConn.query<{ Id: string }>(
`SELECT Id FROM USER WHERE username='${this.agentJson.globalConfiguration.defaultAgentUser}'`
)
).totalSize === 1;
Expand Down Expand Up @@ -507,9 +523,9 @@ export class ScriptAgent extends AgentBase {
void Lifecycle.getInstance().emit('agents:simulation-starting', {});

let response: AgentPreviewStartResponse;
try {
try {
response = await requestWithEndpointFallback<AgentPreviewStartResponse>(
this.connection,
this.connectionManager.getJwtConnection(),
{
method: 'POST',
url: `${this.apiBase}/v1.1/preview/sessions`,
Expand Down
Loading
Loading