diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/agent_endpoint.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/agent_endpoint.go new file mode 100644 index 00000000000..f6fceb434d1 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/agent_endpoint.go @@ -0,0 +1,186 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "fmt" + "net/url" + "regexp" + "strings" + + "azureaiagent/internal/exterrors" + "azureaiagent/internal/pkg/agents/agent_api" + "azureaiagent/internal/pkg/agents/agent_yaml" +) + +// agentEndpointHostSuffix is the required Foundry host suffix for endpoint URLs. +const agentEndpointHostSuffix = ".services.ai.azure.com" + +// agentEndpointHint is the suggestion appended to most --agent-endpoint validation errors. +// `azd ai agent show` persistently prints the agent endpoint URL, so it's the right +// thing to point users at any time after a deploy. +const agentEndpointHint = "run `azd ai agent show` to see the agent endpoint URL" + +// agentEndpointPathRegex matches the full Foundry agent-endpoint path. Captures: +// +// [1] project name (URL-escaped), +// [2] agent name (URL-escaped), +// [3] protocol tail ("invocations" or "openai/responses"). +var agentEndpointPathRegex = regexp.MustCompile( + `^/api/projects/([^/]+)/agents/([^/]+)/endpoint/protocols/(invocations|openai/responses)/?$`, +) + +// parsedAgentEndpoint describes a deployed agent invocation endpoint. +type parsedAgentEndpoint struct { + // ProjectEndpoint is the Foundry project root: https://.services.ai.azure.com/api/projects/. + ProjectEndpoint string + AgentName string + Protocol agent_api.AgentProtocol + // APIVersion is the api-version query parameter from the URL, or empty if absent. + APIVersion string +} + +// parseAgentEndpoint parses the full agent invocation URL printed by `azd ai agent show`. +// +// Accepted shapes: +// +// https://.services.ai.azure.com/api/projects//agents//endpoint/protocols/invocations[?api-version=…] +// https://.services.ai.azure.com/api/projects//agents//endpoint/protocols/openai/responses[?api-version=…] +// +// The host must be a `*.services.ai.azure.com` Foundry host. The path must include the +// protocol-specific suffix; the protocol is derived from the URL. +func parseAgentEndpoint(rawURL string) (*parsedAgentEndpoint, error) { + if strings.TrimSpace(rawURL) == "" { + return nil, exterrors.Validation( + exterrors.CodeInvalidParameter, + "--agent-endpoint requires a non-empty URL", + agentEndpointHint, + ) + } + + u, err := url.Parse(rawURL) + if err != nil { + return nil, exterrors.Validation( + exterrors.CodeInvalidParameter, + fmt.Sprintf("invalid --agent-endpoint URL: %v", err), + agentEndpointHint, + ) + } + + if !strings.EqualFold(u.Scheme, "https") { + return nil, exterrors.Validation( + exterrors.CodeInvalidParameter, + "--agent-endpoint must use https", + agentEndpointHint, + ) + } + + host := strings.ToLower(u.Hostname()) + if host == "" || !strings.HasSuffix(host, agentEndpointHostSuffix) { + return nil, exterrors.Validation( + exterrors.CodeInvalidParameter, + fmt.Sprintf("--agent-endpoint host %q is not a Foundry host (*%s)", u.Hostname(), agentEndpointHostSuffix), + agentEndpointHint, + ) + } + + // Reject explicit ports — Foundry endpoints always use the default HTTPS port, + // and silently dropping a non-default port would route requests to a different origin. + if u.Port() != "" { + return nil, exterrors.Validation( + exterrors.CodeInvalidParameter, + fmt.Sprintf("--agent-endpoint host %q must not include a port", u.Host), + agentEndpointHint+" (no explicit port)", + ) + } + + // Match the full path against the canonical Foundry agent-endpoint shape and pull + // the project name, agent name, and protocol tail out in one pass. + matches := agentEndpointPathRegex.FindStringSubmatch(u.EscapedPath()) + if matches == nil { + return nil, exterrors.Validation( + exterrors.CodeInvalidParameter, + "--agent-endpoint path must match /api/projects//agents//endpoint/protocols/", + agentEndpointHint, + ) + } + projectSegment, agentSegment, protocolTail := matches[1], matches[2], matches[3] + + projectName, err := url.PathUnescape(projectSegment) + if err != nil || projectName == "" || strings.ContainsAny(projectName, "/\\") { + return nil, exterrors.Validation( + exterrors.CodeInvalidParameter, + "--agent-endpoint project segment is invalid", + agentEndpointHint, + ) + } + + agentName, err := url.PathUnescape(agentSegment) + if err != nil || agent_yaml.ValidateAgentName(agentName) != nil { + return nil, exterrors.Validation( + exterrors.CodeInvalidAgentName, + fmt.Sprintf("--agent-endpoint agent name %q is invalid", agentSegment), + "agent names must start and end with an alphanumeric character, "+ + "may contain hyphens in the middle, and be 1-63 characters long", + ) + } + + var protocol agent_api.AgentProtocol + switch protocolTail { + case "invocations": + protocol = agent_api.AgentProtocolInvocations + case "openai/responses": + protocol = agent_api.AgentProtocolResponses + } + + // Reject an explicit but empty api-version query parameter; the default fallback would + // otherwise silently invoke a different version than the user pasted. + apiVersion := "" + query := u.Query() + if values, present := query["api-version"]; present { + if len(values) == 0 || values[0] == "" { + return nil, exterrors.Validation( + exterrors.CodeInvalidParameter, + "--agent-endpoint api-version query parameter is empty", + "include a non-empty api-version value or omit the parameter to use the default", + ) + } + apiVersion = values[0] + } + + projectEndpoint := fmt.Sprintf("https://%s/api/projects/%s", host, projectSegment) + + return &parsedAgentEndpoint{ + ProjectEndpoint: projectEndpoint, + AgentName: agentName, + Protocol: protocol, + APIVersion: apiVersion, + }, nil +} + +// buildResponsesURL builds the Foundry "openai/responses" protocol URL for an agent. +// apiVersion is URL-encoded so unusual characters cannot break out of the query value. +func buildResponsesURL(projectEndpoint, agentName, apiVersion string) string { + return fmt.Sprintf( + "%s/agents/%s/endpoint/protocols/openai/responses?api-version=%s", + projectEndpoint, agentName, url.QueryEscape(apiVersion), + ) +} + +// buildInvocationsURL builds the Foundry "invocations" protocol URL for an agent. +// When sid is non-empty, an agent_session_id query parameter is appended (URL-encoded). +func buildInvocationsURL(projectEndpoint, agentName, apiVersion, sid string) string { + invURL := fmt.Sprintf( + "%s/agents/%s/endpoint/protocols/invocations?api-version=%s", + projectEndpoint, agentName, url.QueryEscape(apiVersion), + ) + if sid != "" { + invURL += "&agent_session_id=" + url.QueryEscape(sid) + } + return invURL +} + +// (isValidAgentNameSegment was removed — agent name validation now delegates +// to agent_yaml.ValidateAgentName so --agent-endpoint enforces the same +// deployable-name format as the rest of the extension.) diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/agent_endpoint_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/agent_endpoint_test.go new file mode 100644 index 00000000000..7dd027ae3a7 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/agent_endpoint_test.go @@ -0,0 +1,314 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "strings" + "testing" + + "azureaiagent/internal/pkg/agents/agent_api" +) + +func TestParseAgentEndpoint(t *testing.T) { + t.Parallel() + tests := []struct { + name string + raw string + wantProj string + wantAgent string + wantProto agent_api.AgentProtocol + wantAPIVer string + wantErr bool + errContains string + }{ + { + name: "invocations with api-version", + raw: "https://acct.services.ai.azure.com/api/projects/proj/agents/hello/endpoint/protocols/invocations?api-version=2025-11-15-preview", + wantProj: "https://acct.services.ai.azure.com/api/projects/proj", + wantAgent: "hello", + wantProto: agent_api.AgentProtocolInvocations, + wantAPIVer: "2025-11-15-preview", + }, + { + name: "invocations without api-version", + raw: "https://acct.services.ai.azure.com/api/projects/proj/agents/hello/endpoint/protocols/invocations", + wantProj: "https://acct.services.ai.azure.com/api/projects/proj", + wantAgent: "hello", + wantProto: agent_api.AgentProtocolInvocations, + }, + { + name: "responses (openai/responses)", + raw: "https://acct.services.ai.azure.com/api/projects/proj/agents/echo/endpoint/protocols/openai/responses?api-version=2025-11-15-preview", + wantProj: "https://acct.services.ai.azure.com/api/projects/proj", + wantAgent: "echo", + wantProto: agent_api.AgentProtocolResponses, + wantAPIVer: "2025-11-15-preview", + }, + { + name: "trailing slash tolerated", + raw: "https://acct.services.ai.azure.com/api/projects/proj/agents/hello/endpoint/protocols/invocations/", + wantProj: "https://acct.services.ai.azure.com/api/projects/proj", + wantAgent: "hello", + wantProto: agent_api.AgentProtocolInvocations, + }, + { + name: "empty url", + raw: "", + wantErr: true, + errContains: "non-empty URL", + }, + { + name: "http scheme rejected", + raw: "http://acct.services.ai.azure.com/api/projects/proj/agents/hello/endpoint/protocols/invocations", + wantErr: true, + errContains: "https", + }, + { + name: "non-foundry host rejected", + raw: "https://evil.com/api/projects/proj/agents/hello/endpoint/protocols/invocations", + wantErr: true, + errContains: "Foundry host", + }, + { + name: "host suffix injection rejected", + raw: "https://services.ai.azure.com.evil.com/api/projects/proj/agents/hello/endpoint/protocols/invocations", + wantErr: true, + errContains: "Foundry host", + }, + { + name: "missing api/projects prefix", + raw: "https://acct.services.ai.azure.com/agents/hello/endpoint/protocols/invocations", + wantErr: true, + errContains: "path must match", + }, + { + name: "unknown protocol tail", + raw: "https://acct.services.ai.azure.com/api/projects/proj/agents/hello/endpoint/protocols/grpc", + wantErr: true, + errContains: "path must match", + }, + { + name: "missing protocol tail", + raw: "https://acct.services.ai.azure.com/api/projects/proj/agents/hello/endpoint/protocols", + wantErr: true, + errContains: "path must match", + }, + { + name: "invalid agent name (chars)", + raw: "https://acct.services.ai.azure.com/api/projects/proj/agents/hel%20lo/endpoint/protocols/invocations", + wantErr: true, + errContains: "agent name", + }, + { + name: "encoded slash in project segment rejected", + raw: "https://acct.services.ai.azure.com/api/projects/proj%2Fother/agents/hello/endpoint/protocols/invocations", + wantErr: true, + errContains: "project segment is invalid", + }, + { + name: "malformed url", + raw: "https://%zz/foo", + wantErr: true, + errContains: "invalid", + }, + { + name: "explicit port rejected", + raw: "https://acct.services.ai.azure.com:444/api/projects/proj/agents/hello/endpoint/protocols/invocations", + wantErr: true, + errContains: "must not include a port", + }, + { + name: "empty api-version rejected", + raw: "https://acct.services.ai.azure.com/api/projects/proj/agents/hello/endpoint/protocols/invocations?api-version=", + wantErr: true, + errContains: "api-version query parameter is empty", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseAgentEndpoint(tt.raw) + if tt.wantErr { + if err == nil { + t.Fatalf("expected error, got nil; result=%+v", got) + } + if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("error %q does not contain %q", err.Error(), tt.errContains) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.ProjectEndpoint != tt.wantProj { + t.Errorf("ProjectEndpoint = %q, want %q", got.ProjectEndpoint, tt.wantProj) + } + if got.AgentName != tt.wantAgent { + t.Errorf("AgentName = %q, want %q", got.AgentName, tt.wantAgent) + } + if got.Protocol != tt.wantProto { + t.Errorf("Protocol = %q, want %q", got.Protocol, tt.wantProto) + } + if got.APIVersion != tt.wantAPIVer { + t.Errorf("APIVersion = %q, want %q", got.APIVersion, tt.wantAPIVer) + } + }) + } +} + +// TestParseAgentEndpoint_RejectsInvalidAgentNames covers names that pass the +// regex's `[^/]+` capture but fail the canonical agent_yaml.ValidateAgentName +// check (which enforces the deployable-name format). Without this delegation +// these inputs would previously have been accepted locally and only failed +// later as 404s on the wire. +func TestParseAgentEndpoint_RejectsInvalidAgentNames(t *testing.T) { + t.Parallel() + cases := []string{ + // underscore — disallowed by the canonical validator + "agent_v2", + // 64 characters — exceeds the 63-char limit + strings.Repeat("a", 64), + // trailing hyphen — must end alphanumeric + "agent-", + // leading hyphen — must start alphanumeric + "-agent", + } + for _, name := range cases { + t.Run(name, func(t *testing.T) { + endpoint := "https://acct.services.ai.azure.com/api/projects/proj/agents/" + + name + "/endpoint/protocols/invocations?api-version=2025-11-15-preview" + _, err := parseAgentEndpoint(endpoint) + if err == nil { + t.Fatalf("parseAgentEndpoint(%q) = nil, want error", name) + } + }) + } +} + +// TestBuildResponsesURL verifies that the responses URL builder uses the parsed +// api-version (rather than the default fallback) and URL-encodes it. +func TestBuildResponsesURL(t *testing.T) { + t.Parallel() + parsed, err := parseAgentEndpoint( + "https://acct.services.ai.azure.com/api/projects/proj/agents/echo/endpoint/protocols/openai/responses?api-version=2025-11-15-preview", + ) + if err != nil { + t.Fatalf("parseAgentEndpoint: %v", err) + } + got := buildResponsesURL(parsed.ProjectEndpoint, parsed.AgentName, parsed.APIVersion) + want := "https://acct.services.ai.azure.com/api/projects/proj/agents/echo/endpoint/protocols/openai/responses?api-version=2025-11-15-preview" + if got != want { + t.Errorf("buildResponsesURL = %q, want %q", got, want) + } + + // api-version must be query-escaped so unusual characters cannot break out. + gotEscaped := buildResponsesURL("https://acct.services.ai.azure.com/api/projects/proj", "echo", "weird value&x=1") + if !strings.Contains(gotEscaped, "api-version=weird+value%26x%3D1") { + t.Errorf("buildResponsesURL did not escape api-version: %q", gotEscaped) + } +} + +// TestBuildInvocationsURL verifies that the invocations URL builder propagates +// the parsed api-version, URL-encodes it, and URL-encodes any session id. +func TestBuildInvocationsURL(t *testing.T) { + t.Parallel() + parsed, err := parseAgentEndpoint( + "https://acct.services.ai.azure.com/api/projects/proj/agents/hello/endpoint/protocols/invocations?api-version=2025-11-15-preview", + ) + if err != nil { + t.Fatalf("parseAgentEndpoint: %v", err) + } + + t.Run("no session id", func(t *testing.T) { + got := buildInvocationsURL(parsed.ProjectEndpoint, parsed.AgentName, parsed.APIVersion, "") + want := "https://acct.services.ai.azure.com/api/projects/proj/agents/hello/endpoint/protocols/invocations?api-version=2025-11-15-preview" + if got != want { + t.Errorf("buildInvocationsURL = %q, want %q", got, want) + } + }) + + t.Run("session id is escaped", func(t *testing.T) { + got := buildInvocationsURL(parsed.ProjectEndpoint, parsed.AgentName, parsed.APIVersion, "a b/c?d&e") + if !strings.Contains(got, "agent_session_id=a+b%2Fc%3Fd%26e") { + t.Errorf("buildInvocationsURL did not escape session id: %q", got) + } + }) + + t.Run("api-version is escaped", func(t *testing.T) { + got := buildInvocationsURL("https://acct.services.ai.azure.com/api/projects/proj", "hello", "weird value&x=1", "") + if !strings.Contains(got, "api-version=weird+value%26x%3D1") { + t.Errorf("buildInvocationsURL did not escape api-version: %q", got) + } + }) +} + +// TestResolveRemoteContext_EphemeralMode exercises the ephemeral branch of +// resolveRemoteContext (--agent-endpoint path). It pins the api-version +// fallback (default applied when the URL omits the parameter) and the +// override (parsed value used when present), plus verifies that name, +// projectEndpoint, and agentKey are populated from the parsed endpoint. +// +// The project-mode branch is intentionally not covered here: it depends on +// the azd gRPC client and is exercised end-to-end by the functional/live +// tests in this PR's verification. +func TestResolveRemoteContext_EphemeralMode(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + raw string + wantAPIVersion string + wantName string + wantProject string + }{ + { + name: "api-version omitted falls back to default", + raw: "https://acct.services.ai.azure.com/api/projects/proj/agents/" + + "hello/endpoint/protocols/openai/responses", + wantAPIVersion: DefaultAgentAPIVersion, + wantName: "hello", + wantProject: "https://acct.services.ai.azure.com/api/projects/proj", + }, + { + name: "explicit api-version overrides the default", + raw: "https://acct.services.ai.azure.com/api/projects/proj/agents/" + + "hello/endpoint/protocols/invocations?api-version=2025-09-01-preview", + wantAPIVersion: "2025-09-01-preview", + wantName: "hello", + wantProject: "https://acct.services.ai.azure.com/api/projects/proj", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + parsed, err := parseAgentEndpoint(tc.raw) + if err != nil { + t.Fatalf("parseAgentEndpoint: %v", err) + } + a := &InvokeAction{flags: &invokeFlags{}, endpoint: parsed} + + rc, err := a.resolveRemoteContext(t.Context()) + if err != nil { + t.Fatalf("resolveRemoteContext: %v", err) + } + if rc.azdClient != nil { + defer rc.azdClient.Close() + } + + if rc.apiVersion != tc.wantAPIVersion { + t.Errorf("apiVersion = %q, want %q", rc.apiVersion, tc.wantAPIVersion) + } + if rc.name != tc.wantName { + t.Errorf("name = %q, want %q", rc.name, tc.wantName) + } + if rc.projectEndpoint != tc.wantProject { + t.Errorf("projectEndpoint = %q, want %q", rc.projectEndpoint, tc.wantProject) + } + if rc.agentKey == "" { + t.Errorf("agentKey is empty; should be populated for ephemeral persistence") + } + }) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/invoke.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/invoke.go index 9f65bffe856..9a35ee535f9 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/invoke.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/invoke.go @@ -38,11 +38,13 @@ type invokeFlags struct { conversation string newConversation bool protocol string + agentEndpoint string } type InvokeAction struct { flags *invokeFlags noPrompt bool + endpoint *parsedAgentEndpoint } func newInvokeCommand(extCtx *azdext.ExtensionContext) *cobra.Command { @@ -87,7 +89,12 @@ session automatically. Pass --new-session to force a reset.`, azd ai agent invoke --local "Hello!" # Start a new session (discard conversation history) - azd ai agent invoke --new-session "Hello!"`, + azd ai agent invoke --new-session "Hello!" + + # Invoke a deployed agent from any directory using the endpoint URL shown by 'azd ai agent show' + azd ai agent invoke \ + --agent-endpoint https://.services.ai.azure.com/api/projects//agents//endpoint/protocols/openai/responses?api-version=2025-11-15-preview \ + "Hello!"`, Args: cobra.RangeArgs(0, 2), RunE: func(cmd *cobra.Command, args []string) error { ctx := azdext.WithAccessToken(cmd.Context()) @@ -107,6 +114,23 @@ session automatically. Pass --new-session to force a reset.`, // Only valid when -f is provided } + action := &InvokeAction{flags: flags, noPrompt: extCtx.NoPrompt} + + // Agent-endpoint structural conflicts are surfaced first so the user sees + // the precise reason their invocation cannot proceed. + if flags.agentEndpoint != "" { + if err := validateAgentEndpointFlags(cmd, flags); err != nil { + return err + } + parsed, err := parseAgentEndpoint(flags.agentEndpoint) + if err != nil { + return err + } + flags.protocol = string(parsed.Protocol) + flags.name = parsed.AgentName + action.endpoint = parsed + } + if flags.inputFile != "" && flags.message != "" { return exterrors.Validation( exterrors.CodeInvalidParameter, @@ -144,10 +168,6 @@ session automatically. Pass --new-session to force a reset.`, } } - action := &InvokeAction{ - flags: flags, - noPrompt: extCtx.NoPrompt, - } return action.Run(ctx) }, } @@ -161,10 +181,49 @@ session automatically. Pass --new-session to force a reset.`, cmd.Flags().BoolVar(&flags.newSession, "new-session", false, "Force a new session (discard saved one)") cmd.Flags().StringVar(&flags.conversation, "conversation-id", "", "Explicit conversation ID override") cmd.Flags().BoolVar(&flags.newConversation, "new-conversation", false, "Force a new conversation (discard saved one)") + cmd.Flags().StringVar( + &flags.agentEndpoint, + "agent-endpoint", + "", + "Full endpoint URL of a deployed agent (run 'azd ai agent show' to see it). "+ + "Invokes without requiring an azd project; protocol is derived from the URL.", + ) return cmd } +// validateAgentEndpointFlags rejects flags that have no effect (or conflict) when --agent-endpoint +// is used. Ephemeral mode has no project, no local persistence, and no localhost target. +func validateAgentEndpointFlags(cmd *cobra.Command, flags *invokeFlags) error { + // Disallowed companion flags for --agent-endpoint, in the order checked. + // `set` is true when the flag is meaningfully present on the command line. + checks := []struct { + name string + set bool + suggestion string + }{ + {"--local", flags.local, "omit --local to invoke the deployed agent at the given URL"}, + { + "a positional agent name", + flags.name != "", + "the agent name is read from the --agent-endpoint URL; remove the positional argument", + }, + {"--port", cmd.Flags().Changed("port"), "--port targets a local agent; omit it when using --agent-endpoint"}, + {"--protocol", cmd.Flags().Changed("protocol"), "the protocol is read from the --agent-endpoint URL; omit --protocol"}, + } + + for _, c := range checks { + if c.set { + return exterrors.Validation( + exterrors.CodeInvalidParameter, + fmt.Sprintf("--agent-endpoint cannot be combined with %s", c.name), + c.suggestion, + ) + } + } + return nil +} + func (a *InvokeAction) Run(ctx context.Context) error { protocol, err := a.resolveProtocol(ctx) if err != nil { @@ -342,43 +401,138 @@ func (a *InvokeAction) responsesLocal(ctx context.Context) error { return printAgentResponse(result, "local") } -func (a *InvokeAction) responsesRemote(ctx context.Context) error { +// remoteContext holds the resolved inputs for a remote (Foundry) invoke. +// In ephemeral mode (--agent-endpoint) the project endpoint / agent name / +// api-version come from the parsed URL. +// +// agentKey is the persistence key used by the global UserConfig store. It is +// non-empty whenever session/conversation IDs should be saved or resumed: +// - project mode: derived from AGENT_{SVC}_ENDPOINT +// - ephemeral mode: derived from the parsed --agent-endpoint URL +// (independent of api-version / trailing slash / fragment) +// +// In standalone mode (no parent azd daemon, e.g. running the extension binary +// directly outside an azd command) azdClient is nil and persistence helpers +// no-op. agentKey may still be non-empty in that case. +type remoteContext struct { + name string + agentKey string + projectEndpoint string + apiVersion string + azdClient *azdext.AzdClient + bearerToken string +} + +// resolveRemoteContext returns the inputs required to invoke a remote agent. +// In project mode it opens an azd client and reads the environment; in ephemeral +// mode (--agent-endpoint) it skips both. Auth token acquisition is intentionally +// deferred to acquireBearerToken so callers can validate the request body first +// and avoid unnecessary token round-trips on invalid input. Callers must close +// rc.azdClient when non-nil. +func (a *InvokeAction) resolveRemoteContext(ctx context.Context) (*remoteContext, error) { + rc := &remoteContext{apiVersion: DefaultAgentAPIVersion} + + if a.endpoint != nil { + rc.name = a.endpoint.AgentName + rc.projectEndpoint = a.endpoint.ProjectEndpoint + if a.endpoint.APIVersion != "" { + rc.apiVersion = a.endpoint.APIVersion + } + rc.agentKey = buildAgentKey(a.endpoint.ProjectEndpoint, a.endpoint.AgentName, "", false) + // Best-effort attach to the parent azd daemon so session/conversation IDs + // persist across invokes via global UserConfig. When running the extension + // binary directly (standalone), this fails and we proceed without persistence. + if azdClient, err := azdext.NewAzdClient(); err == nil { + rc.azdClient = azdClient + } + return rc, nil + } + azdClient, err := azdext.NewAzdClient() if err != nil { - return fmt.Errorf("failed to create azd client: %w", err) + return nil, fmt.Errorf("failed to create azd client: %w", err) } - defer azdClient.Close() - - name := a.flags.name - var agentEndpoint string + rc.azdClient = azdClient - // Auto-resolve agent name and version from azure.yaml - if info, err := resolveAgentServiceFromProject(ctx, azdClient, name, a.noPrompt); err == nil { - if name == "" && info.AgentName != "" { - name = info.AgentName + rc.name = a.flags.name + if info, err := resolveAgentServiceFromProject(ctx, azdClient, rc.name, a.noPrompt); err == nil { + if rc.name == "" && info.AgentName != "" { + rc.name = info.AgentName } - agentEndpoint = info.AgentEndpoint + if info.AgentEndpoint != "" { + rc.agentKey = buildRemoteAgentKeyFromEndpoint(info.AgentEndpoint) + } + } + if rc.name == "" { + azdClient.Close() + return nil, fmt.Errorf( + "agent name is required; provide as the first argument or " + + "define an azure.ai.agent service in azure.yaml", + ) } - if name == "" { - return fmt.Errorf("agent name is required; provide as the first argument or define an azure.ai.agent service in azure.yaml") + ep, err := resolveAgentEndpoint(ctx, "", "") + if err != nil { + azdClient.Close() + return nil, err } + rc.projectEndpoint = ep + return rc, nil +} - projectEndpoint, err := resolveAgentEndpoint(ctx, "", "") +// acquireBearerToken obtains a Foundry bearer token. Called after request body +// validation so that local errors (e.g., a missing --input-file) are surfaced +// before any auth round-trip is attempted. +func (a *InvokeAction) acquireBearerToken(ctx context.Context) (string, error) { + credential, err := newAgentCredential() + if err != nil { + return "", err + } + token, err := credential.GetToken(ctx, policy.TokenRequestOptions{ + Scopes: []string{"https://ai.azure.com/.default"}, + }) + if err != nil { + return "", ephemeralAuthError(a.endpoint != nil, err) + } + return token.Token, nil +} + +// ephemeralAuthError wraps a token-acquisition failure with a login suggestion when +// the user is invoking outside an azd project (where mis-configured credentials are common). +func ephemeralAuthError(ephemeral bool, err error) error { + if !ephemeral { + return fmt.Errorf("failed to get auth token: %w", err) + } + return exterrors.Auth( + exterrors.CodeAuthFailed, + fmt.Sprintf("failed to get auth token: %v", err), + "run `azd auth login` and try again", + ) +} + +func (a *InvokeAction) responsesRemote(ctx context.Context) error { + rc, err := a.resolveRemoteContext(ctx) if err != nil { return err } + if rc.azdClient != nil { + defer rc.azdClient.Close() + } - // Build the structured agent key for config store lookups. - // When the endpoint is unavailable (pre-deploy), skip session/conversation persistence. - var agentKey string - if agentEndpoint != "" { - agentKey = buildRemoteAgentKeyFromEndpoint(agentEndpoint) - } else { + body, bodyLabel, err := a.resolveBody() + if err != nil { + return err + } + + agentKey := rc.agentKey + if agentKey == "" && rc.azdClient != nil { log.Printf("warning: agent endpoint not available, session state will not be persisted") } - body, bodyLabel, err := a.resolveBody() + // Acquire the bearer token after body validation so a local input error + // (e.g., unreadable --input-file) does not pay an unnecessary auth round-trip + // and is surfaced before any auth failure. + rc.bearerToken, err = a.acquireBearerToken(ctx) if err != nil { return err } @@ -394,52 +548,49 @@ func (a *InvokeAction) responsesRemote(ctx context.Context) error { // Session ID — routes to the same microVM container instance. // When empty, let the server assign one. var sid string - if agentKey != "" { + if agentKey != "" && rc.azdClient != nil { sid, err = resolveStoredID( - ctx, azdClient, agentKey, a.flags.session, a.flags.newSession, "sessions", false, - legacyKeysForRemote(name)..., + ctx, rc.azdClient, agentKey, a.flags.session, a.flags.newSession, "sessions", false, + legacyKeysForRemote(rc.name)..., ) if err != nil { return err } - } else if a.flags.session != "" { + } else { sid = a.flags.session } if sid != "" { reqBody["session_id"] = sid } - // Acquire credential and token — used for both conversation creation and the invoke request - credential, err := newAgentCredential() - if err != nil { - return err - } - - token, err := credential.GetToken(ctx, policy.TokenRequestOptions{ - Scopes: []string{"https://ai.azure.com/.default"}, - }) - if err != nil { - return fmt.Errorf("failed to get auth token: %w", err) - } - - // Conversation ID — enables multi-turn memory via Foundry Conversations API - convID, err := resolveConversationID( - ctx, - azdClient, - agentKey, - a.flags.conversation, - a.flags.newConversation, - projectEndpoint, - token.Token, - name, - legacyKeysForRemote(name)..., - ) - if err != nil { - return err + // Conversation ID — enables multi-turn memory via Foundry Conversations API. + var convID string + if agentKey != "" && rc.azdClient != nil { + convID, err = resolveConversationID( + ctx, + rc.azdClient, + agentKey, + a.flags.conversation, + a.flags.newConversation, + rc.projectEndpoint, + rc.bearerToken, + rc.name, + legacyKeysForRemote(rc.name)..., + ) + if err != nil { + return err + } + } else if a.flags.conversation != "" { + convID = a.flags.conversation + } else { + convID, err = createConversation(ctx, rc.projectEndpoint, rc.name, rc.bearerToken) + if err != nil { + return err + } } reqBody["conversation"] = map[string]string{"id": convID} - fmt.Printf("Agent: %s (remote)\n", name) + fmt.Printf("Agent: %s (remote)\n", rc.name) fmt.Printf("Message: %s\n", bodyLabel) printSessionStatus("Session: ", sid) fmt.Printf("Conversation: %s\n", convID) @@ -450,21 +601,19 @@ func (a *InvokeAction) responsesRemote(ctx context.Context) error { return fmt.Errorf("failed to marshal request: %w", err) } - url := fmt.Sprintf( - "%s/agents/%s/endpoint/protocols/openai/responses?api-version=%s", - projectEndpoint, name, DefaultAgentAPIVersion, - ) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload)) + respURL := buildResponsesURL(rc.projectEndpoint, rc.name, rc.apiVersion) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, respURL, bytes.NewReader(payload)) if err != nil { return fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+token.Token) + req.Header.Set("Authorization", "Bearer "+rc.bearerToken) client := &http.Client{Timeout: a.httpTimeout()} - resp, err := client.Do(req) //nolint:gosec // G704: endpoint is resolved from azd environment configuration + //nolint:gosec // G704: URL is built from a validated Foundry endpoint (env or --agent-endpoint) + resp, err := client.Do(req) if err != nil { - return fmt.Errorf("POST %s failed: %w", url, err) + return fmt.Errorf("POST %s failed: %w", respURL, err) } defer resp.Body.Close() @@ -473,15 +622,22 @@ func (a *InvokeAction) responsesRemote(ctx context.Context) error { fmt.Printf("Trace ID: %s\n", requestID) } - captureResponseSession(ctx, azdClient, agentKey, sid, resp, "Session: ") + captureResponseSession(ctx, rc.azdClient, agentKey, sid, resp, "Session: ") if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(resp.Body) - return fmt.Errorf("POST %s failed with HTTP %d: %s\n%s", url, resp.StatusCode, resp.Status, string(respBody)) + return fmt.Errorf("POST %s failed with HTTP %d: %s\n%s", respURL, resp.StatusCode, resp.Status, string(respBody)) } // Parse SSE stream for agent output - return readSSEStream(resp.Body, name) + if err := readSSEStream(resp.Body, rc.name); err != nil { + return err + } + + if agentKey != "" && rc.azdClient != nil { + fmt.Println("\n(tip: pass --new-session or --new-conversation to reset; see `azd ai agent invoke --help`)") + } + return nil } func (a *InvokeAction) invocationsLocal(ctx context.Context) error { @@ -555,107 +711,102 @@ func (a *InvokeAction) invocationsLocal(ctx context.Context) error { // invocationsRemote sends the user's message to Foundry using // the invocations protocol (POST /agents/{name}/endpoint/protocols/invocations). func (a *InvokeAction) invocationsRemote(ctx context.Context) error { - azdClient, err := azdext.NewAzdClient() + rc, err := a.resolveRemoteContext(ctx) if err != nil { - return fmt.Errorf("failed to create azd client: %w", err) + return err + } + if rc.azdClient != nil { + defer rc.azdClient.Close() } - defer azdClient.Close() - - name := a.flags.name - var agentEndpoint string - // Auto-resolve agent name from azure.yaml / azd environment - if info, err := resolveAgentServiceFromProject(ctx, azdClient, name, a.noPrompt); err == nil { - if name == "" && info.AgentName != "" { - name = info.AgentName - } - agentEndpoint = info.AgentEndpoint + agentKey := rc.agentKey + if agentKey == "" && rc.azdClient != nil { + log.Printf("warning: agent endpoint not available, session state will not be persisted") } - if name == "" { - return fmt.Errorf( - "agent name is required; provide as the first argument or define an azure.ai.agent service in azure.yaml", - ) + if a.flags.newConversation { + fmt.Fprintln(os.Stderr, + "note: --new-conversation has no effect for the invocations protocol "+ + "(memory is bound to the session; use --new-session to reset).") } - endpoint, err := resolveAgentEndpoint(ctx, "", "") + body, bodyLabel, err := a.resolveBody() if err != nil { return err } - var agentKey string - if agentEndpoint != "" { - agentKey = buildRemoteAgentKeyFromEndpoint(agentEndpoint) - } else { - log.Printf("warning: agent endpoint not available, session state will not be persisted") - } - - body, bodyLabel, err := a.resolveBody() + // Acquire the bearer token after body validation so a local input error + // (e.g., unreadable --input-file) does not pay an unnecessary auth round-trip + // and is surfaced before any auth failure. + rc.bearerToken, err = a.acquireBearerToken(ctx) if err != nil { return err } - // Session ID — routes to the same container instance + // Session ID — routes to the same container instance. var sid string - if agentKey != "" { - sid, err = resolveStoredID(ctx, azdClient, agentKey, a.flags.session, a.flags.newSession, "sessions", false) + if agentKey != "" && rc.azdClient != nil { + sid, err = resolveStoredID( + ctx, rc.azdClient, agentKey, a.flags.session, a.flags.newSession, "sessions", false, + legacyKeysForRemote(rc.name)..., + ) if err != nil { return err } - } else if a.flags.session != "" { + } else { sid = a.flags.session } - // Acquire credential and token - credential, err := newAgentCredential() - if err != nil { - return err - } - - token, err := credential.GetToken(ctx, policy.TokenRequestOptions{ - Scopes: []string{"https://ai.azure.com/.default"}, - }) - if err != nil { - return fmt.Errorf("failed to get auth token: %w", err) - } - - fmt.Printf("Agent: %s (remote, invocations protocol)\n", name) + fmt.Printf("Agent: %s (remote, invocations protocol)\n", rc.name) fmt.Printf("Input: %s\n", bodyLabel) printSessionStatus("Session: ", sid) fmt.Println() - remoteBaseURL := fmt.Sprintf("%s/agents/%s/endpoint/protocols", endpoint, name) - - // Fetch and cache the agent's OpenAPI spec (skip if already cached). - fetchOpenAPISpec(ctx, azdClient, remoteBaseURL, name, "remote", token.Token, false) + remoteBaseURL := fmt.Sprintf("%s/agents/%s/endpoint/protocols", rc.projectEndpoint, rc.name) - invURL := fmt.Sprintf("%s/invocations?api-version=%s", remoteBaseURL, DefaultAgentAPIVersion) - if sid != "" { - invURL += "&agent_session_id=" + url.QueryEscape(sid) + // Fetch and cache the agent's OpenAPI spec only in project mode. In ephemeral + // mode (--agent-endpoint) we deliberately avoid the on-disk side effect since + // the user is one-off targeting a remote endpoint. + if rc.azdClient != nil && a.endpoint == nil { + fetchOpenAPISpec(ctx, rc.azdClient, remoteBaseURL, rc.name, "remote", rc.bearerToken, false) } + invURL := buildInvocationsURL(rc.projectEndpoint, rc.name, rc.apiVersion, sid) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, invURL, bytes.NewReader(body)) if err != nil { return fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", contentTypeForBody(body)) - req.Header.Set("Authorization", "Bearer "+token.Token) + req.Header.Set("Authorization", "Bearer "+rc.bearerToken) client := &http.Client{Timeout: a.httpTimeout()} - resp, err := client.Do(req) //nolint:gosec // G704: endpoint is resolved from azd environment configuration + //nolint:gosec // G704: URL is built from a validated Foundry endpoint (env or --agent-endpoint) + resp, err := client.Do(req) if err != nil { return fmt.Errorf("POST %s failed: %w", invURL, err) } defer resp.Body.Close() - // Print the invocation ID if the agent returned one. + // Print the invocation ID if the agent returned one. We do not persist it + // to the per-user config: the config store only supports the "sessions" + // and "conversations" maps (see validateStoreField), and invocation IDs + // are not used to drive any subsequent invoke — they are emitted purely + // for trace correlation. if invID := resp.Header.Get("x-agent-invocation-id"); invID != "" { - fmt.Printf("Invocation: %s\n", invID) + fmt.Printf("Invocation: %s\n", invID) } - captureResponseSession(ctx, azdClient, agentKey, sid, resp, "Session: ") + captureResponseSession(ctx, rc.azdClient, agentKey, sid, resp, "Session: ") + + if err := handleInvocationResponse(ctx, resp, rc.projectEndpoint, rc.bearerToken, rc.name, a.httpTimeout()); err != nil { + return err + } - return handleInvocationResponse(ctx, resp, endpoint, token.Token, name, a.httpTimeout()) + if agentKey != "" && rc.azdClient != nil { + fmt.Println("\n(tip: pass --new-session to reset; see `azd ai agent invoke --help`)") + } + return nil } // handleInvocationResponse dispatches the response from a POST /invocations call @@ -938,11 +1089,11 @@ func handleInvocationLRO( // createConversation creates a new Foundry conversation for multi-turn memory. func createConversation(ctx context.Context, projectEndpoint, agentName, bearerToken string) (string, error) { - url := fmt.Sprintf( + convURL := fmt.Sprintf( "%s/agents/%s/endpoint/protocols/openai/conversations?api-version=%s", projectEndpoint, agentName, ConversationsAPIVersion, ) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader([]byte("{}"))) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, convURL, bytes.NewReader([]byte("{}"))) if err != nil { return "", err } @@ -952,13 +1103,13 @@ func createConversation(ctx context.Context, projectEndpoint, agentName, bearerT client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) //nolint:gosec // G704: endpoint is resolved from azd environment configuration if err != nil { - return "", fmt.Errorf("POST %s failed: %w", url, err) + return "", fmt.Errorf("POST %s failed: %w", convURL, err) } defer resp.Body.Close() if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(resp.Body) - return "", fmt.Errorf("POST %s failed with HTTP %d: %s\n%s", url, resp.StatusCode, resp.Status, string(respBody)) + return "", fmt.Errorf("POST %s failed with HTTP %d: %s\n%s", convURL, resp.StatusCode, resp.Status, string(respBody)) } body, err := io.ReadAll(resp.Body) diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/invoke_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/invoke_test.go index b5bdda30971..27aa72d4ab7 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/invoke_test.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/invoke_test.go @@ -342,6 +342,81 @@ func TestProtocolFlagValidation(t *testing.T) { } } +// TestAgentEndpointFlagValidation covers the up-front validation rules for --agent-endpoint. +// These run before any network call, so they exercise the cobra RunE error path directly. +func TestAgentEndpointFlagValidation(t *testing.T) { + t.Parallel() + + const validURL = "https://acct.services.ai.azure.com/api/projects/proj/agents/hello/endpoint/protocols/invocations?api-version=2025-11-15-preview" + + tests := []struct { + name string + args []string + wantErr bool + errSub string + }{ + { + name: "rejects --local", + args: []string{"--agent-endpoint", validURL, "--local", "hi"}, + wantErr: true, + errSub: "cannot be combined with --local", + }, + { + name: "rejects positional name", + args: []string{"--agent-endpoint", validURL, "myagent", "hi"}, + wantErr: true, + errSub: "positional agent name", + }, + { + name: "rejects --port", + args: []string{"--agent-endpoint", validURL, "--port", "9999", "hi"}, + wantErr: true, + errSub: "cannot be combined with --port", + }, + { + name: "rejects explicit --port at default value", + args: []string{"--agent-endpoint", validURL, "--port", "8088", "hi"}, + wantErr: true, + errSub: "cannot be combined with --port", + }, + { + name: "rejects --protocol", + args: []string{"--agent-endpoint", validURL, "--protocol", "responses", "hi"}, + wantErr: true, + errSub: "cannot be combined with --protocol", + }, + { + name: "rejects --protocol even when matching", + args: []string{"--agent-endpoint", validURL, "--protocol", "invocations", "hi"}, + wantErr: true, + errSub: "cannot be combined with --protocol", + }, + { + name: "rejects malformed url", + args: []string{"--agent-endpoint", "https://evil.com/foo", "hi"}, + wantErr: true, + errSub: "Foundry host", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + cmd := newInvokeCommand(nil) + cmd.SetArgs(tt.args) + err := cmd.Execute() + if tt.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + if tt.errSub != "" && !strings.Contains(err.Error(), tt.errSub) { + t.Errorf("error %q should contain %q", err.Error(), tt.errSub) + } + } + }) + } +} + func TestHandleInvocationSync(t *testing.T) { t.Parallel() @@ -927,9 +1002,7 @@ func TestCreateConversation(t *testing.T) { })) defer srv.Close() - id, err := createConversation( - t.Context(), srv.URL, tt.agentName, "test-token", - ) + id, err := createConversation(t.Context(), srv.URL, tt.agentName, "test-token") if tt.wantErr { if err == nil {