diff --git a/cli/azd/extensions/azure.ai.agents/cspell.yaml b/cli/azd/extensions/azure.ai.agents/cspell.yaml index 08483e3b4d5..6dbb468566a 100644 --- a/cli/azd/extensions/azure.ai.agents/cspell.yaml +++ b/cli/azd/extensions/azure.ai.agents/cspell.yaml @@ -1,5 +1,9 @@ import: ../../.vscode/cspell.yaml words: + # Connection commands + - tavily + - tvly + - conncmd # Azure region names - australiaeast - brazilsouth diff --git a/cli/azd/extensions/azure.ai.agents/go.sum b/cli/azd/extensions/azure.ai.agents/go.sum index 2d1a8679e86..e7b4fc1dace 100644 --- a/cli/azd/extensions/azure.ai.agents/go.sum +++ b/cli/azd/extensions/azure.ai.agents/go.sum @@ -17,6 +17,8 @@ github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthoriza github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2 v2.2.0/go.mod h1:/pz8dyNQe+Ey3yBp/XuYz7oqX8YDNWVpPB0hH3XWfbc= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v3 v3.0.0-beta.2 h1:qiir/pptnHqp6hV8QwV+IExYIf6cPsXBfUDUXQ27t2Y= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v3 v3.0.0-beta.2/go.mod h1:jVRrRDLCOuif95HDYC23ADTMlvahB7tMdl519m9Iyjc= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/cognitiveservices/armcognitiveservices v1.8.0 h1:ZMGAqCZov8+7iFUPWKVcTaLgNXUeTlz20sIuWkQWNfg= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/cognitiveservices/armcognitiveservices v1.8.0/go.mod h1:BElPQ/GZtrdQ2i5uDZw3OKLE1we75W0AEWyeBR1TWQA= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/cognitiveservices/armcognitiveservices/v2 v2.0.0 h1:pxphC/uRZKNHNPbZ0duDDgKkefju2F03OkG5xF6byHQ= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/cognitiveservices/armcognitiveservices/v2 v2.0.0/go.mod h1:twcwRey+l1znKBL5TEzYiZMtiVkWfM7Pq8a9vY04xYc= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerregistry/armcontainerregistry v1.3.0-beta.3 h1:4qfc7os3wRQcl+ImfeH9z0abWJzuV9IGcN1B9olmPTU= diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/connection_credentials.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/connection_credentials.go new file mode 100644 index 00000000000..7368e91e2d7 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/connection_credentials.go @@ -0,0 +1,209 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + "fmt" + "log" + "os" + "path/filepath" + "regexp" + "strings" + + "azureaiagent/internal/connections/pkg/connections" + "azureaiagent/internal/pkg/agents/agent_yaml" + + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "go.yaml.in/yaml/v3" +) + +// connectionRefPattern matches ${{connections..credentials.}} references +// in agent manifest environment variable values. +var connectionRefPattern = regexp.MustCompile( + `\$\{\{connections\.([^.]+)\.credentials\.([^}]+)\}\}`, +) + +// connRef represents a single connection credential reference found in an +// agent manifest's environment_variables section. +type connRef struct { + EnvName string // the env var name (e.g., TAVILY_API_KEY) + ConnName string // connection name (e.g., my-test-conn) + CredKey string // credential key (e.g., x-api-key) +} + +// extractConnectionRefs scans environment variable definitions for +// ${{connections..credentials.}} patterns and returns the parsed refs. +func extractConnectionRefs( + envVars []agent_yaml.EnvironmentVariable, +) []connRef { + var refs []connRef + for _, ev := range envVars { + matches := connectionRefPattern.FindStringSubmatch(ev.Value) + if matches != nil { + refs = append(refs, connRef{ + EnvName: ev.Name, + ConnName: matches[1], + CredKey: matches[2], + }) + } + } + return refs +} + +// lookupCredentialValue finds the value of a credential key on a connection. +// Returns the value and true if found, or empty string and false if not. +func lookupCredentialValue( + conn *connections.Connection, + credKey string, +) (string, bool) { + if conn == nil || conn.Credentials == nil { + return "", false + } + if credKey == "key" && conn.Credentials.Key != "" { + return conn.Credentials.Key, true + } + if v, ok := conn.Credentials.CustomKeys[credKey]; ok { + return v, true + } + return "", false +} + +// resolveConnectionCredentials reads the agent manifest from projectDir, +// scans environment_variables for ${{connections..credentials.}} patterns, +// fetches credential values from the Foundry data plane, and returns them as +// KEY=VALUE strings ready to inject into the agent process environment. +// +// This is additive to existing env var handling in run.go: +// - ${VAR} references are already resolved via loadAzdEnvironment +// - ${{connections...}} references are resolved here via data-plane API +// - Literal values pass through unchanged +// +// Returns nil (no error) if no manifest is found, no env vars are declared, +// or no connection references are present — the agent still starts normally. +func resolveConnectionCredentials( + ctx context.Context, + projectDir string, + endpoint string, +) ([]string, error) { + if endpoint == "" { + return nil, nil + } + + // Find and parse the agent manifest + manifestPath := findManifestInDir(projectDir) + if manifestPath == "" { + return nil, nil + } + + manifestBytes, err := os.ReadFile(manifestPath) //nolint:gosec // G304: path is from findManifestInDir which only checks known filenames in the project directory + if err != nil { + log.Printf("run: could not read manifest %s: %v", manifestPath, err) + return nil, nil + } + + // Try parsing as AgentManifest (agent.manifest.yaml — has "template:" wrapper) + var envVars []agent_yaml.EnvironmentVariable + + manifest, err := agent_yaml.LoadAndValidateAgentManifest(manifestBytes) + if err == nil { + if containerAgent, ok := manifest.Template.(agent_yaml.ContainerAgent); ok && + containerAgent.EnvironmentVariables != nil { + envVars = *containerAgent.EnvironmentVariables + } + } + + // Fall back to parsing as ContainerAgent directly (agent.yaml — no wrapper) + if len(envVars) == 0 { + var agentDef agent_yaml.ContainerAgent + if yamlErr := yaml.Unmarshal(manifestBytes, &agentDef); yamlErr == nil && + agentDef.EnvironmentVariables != nil { + envVars = *agentDef.EnvironmentVariables + } + } + + if len(envVars) == 0 { + return nil, nil + } + + // Scan for connection references + refs := extractConnectionRefs(envVars) + if len(refs) == 0 { + return nil, nil + } + + // Create data-plane credential and client + cred, err := azidentity.NewAzureDeveloperCLICredential( + &azidentity.AzureDeveloperCLICredentialOptions{}, + ) + if err != nil { + return nil, fmt.Errorf( + "failed to create credential for connection resolution: %w", err, + ) + } + + dpClient := connections.NewDataClient(endpoint, cred) + + // Resolve each reference, caching per connection name + connCache := map[string]*connections.Connection{} + var result []string + + for _, ref := range refs { + conn, cached := connCache[ref.ConnName] + if !cached { + conn, err = dpClient.GetConnectionWithCredentials(ctx, ref.ConnName) + if err != nil { + return nil, fmt.Errorf( + "failed to resolve credential for %s (connection %q): %w", + ref.EnvName, ref.ConnName, err, + ) + } + connCache[ref.ConnName] = conn + } + + credValue, found := lookupCredentialValue(conn, ref.CredKey) + if !found { + return nil, fmt.Errorf( + "credential key %q not found on connection %q (for env var %s)", + ref.CredKey, ref.ConnName, ref.EnvName, + ) + } + + result = append(result, fmt.Sprintf("%s=%s", ref.EnvName, credValue)) + // Log the key name only — NEVER log the value + log.Printf( + "run: resolved connection credential: %s (connection: %s, key: %s)", + ref.EnvName, ref.ConnName, ref.CredKey, + ) + } + + if len(result) > 0 { + fmt.Fprintf(os.Stderr, " %d connection credential(s) resolved\n", len(result)) + } + + return result, nil +} + +// findManifestInDir looks for an agent manifest or definition file in the given directory. +// Checks agent.yaml first (the definition the agent app uses), then agent.manifest.yaml. +// Returns the first file that exists and contains environment_variables with connection references. +func findManifestInDir(dir string) string { + // Check agent.yaml first — this is the file the agent app code references + candidates := []string{ + "agent.yaml", + "agent.manifest.yaml", + "agent.yml", + "agent.manifest.yml", + } + for _, name := range candidates { + path := filepath.Join(dir, name) + if _, err := os.Stat(path); err == nil { + data, err := os.ReadFile(path) //nolint:gosec // G304: path is constructed from known candidate filenames joined with the project directory + if err == nil && strings.Contains(string(data), "${{connections.") { + return path + } + } + } + return "" +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/connection_credentials_test.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/connection_credentials_test.go new file mode 100644 index 00000000000..3bdb362e229 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/connection_credentials_test.go @@ -0,0 +1,312 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "os" + "path/filepath" + "testing" + + "azureaiagent/internal/connections/pkg/connections" + "azureaiagent/internal/pkg/agents/agent_yaml" + + "github.com/stretchr/testify/require" +) + +func TestExtractConnectionRefs(t *testing.T) { + tests := []struct { + name string + envVars []agent_yaml.EnvironmentVariable + want []connRef + }{ + { + name: "single connection ref", + envVars: []agent_yaml.EnvironmentVariable{ + { + Name: "TAVILY_API_KEY", + Value: "${{connections.my-tavily.credentials.x-api-key}}", + }, + }, + want: []connRef{ + { + EnvName: "TAVILY_API_KEY", + ConnName: "my-tavily", + CredKey: "x-api-key", + }, + }, + }, + { + name: "multiple refs", + envVars: []agent_yaml.EnvironmentVariable{ + { + Name: "KEY1", + Value: "${{connections.conn-a.credentials.key}}", + }, + { + Name: "KEY2", + Value: "${{connections.conn-b.credentials.token}}", + }, + }, + want: []connRef{ + {EnvName: "KEY1", ConnName: "conn-a", CredKey: "key"}, + {EnvName: "KEY2", ConnName: "conn-b", CredKey: "token"}, + }, + }, + { + name: "no refs — literal values", + envVars: []agent_yaml.EnvironmentVariable{ + {Name: "PORT", Value: "8080"}, + {Name: "HOST", Value: "localhost"}, + }, + want: nil, + }, + { + name: "mixed — only refs extracted", + envVars: []agent_yaml.EnvironmentVariable{ + {Name: "PORT", Value: "8080"}, + { + Name: "SECRET", + Value: "${{connections.my-conn.credentials.api-key}}", + }, + {Name: "ENV_REF", Value: "${SOME_VAR}"}, + }, + want: []connRef{ + {EnvName: "SECRET", ConnName: "my-conn", CredKey: "api-key"}, + }, + }, + { + name: "empty env vars", + envVars: []agent_yaml.EnvironmentVariable{}, + want: nil, + }, + { + name: "nil env vars", + envVars: nil, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractConnectionRefs(tt.envVars) + require.Equal(t, tt.want, result) + }) + } +} + +func TestLookupCredentialValue(t *testing.T) { + tests := []struct { + name string + conn *connections.Connection + credKey string + wantValue string + wantFound bool + }{ + { + name: "api key lookup", + conn: &connections.Connection{ + Credentials: &connections.ConnectionCredentials{ + Key: "my-api-key-value", + }, + }, + credKey: "key", + wantValue: "my-api-key-value", + wantFound: true, + }, + { + name: "custom key lookup", + conn: &connections.Connection{ + Credentials: &connections.ConnectionCredentials{ + CustomKeys: map[string]string{ + "x-api-key": "tavily-secret", + "token": "bearer-token", + }, + }, + }, + credKey: "x-api-key", + wantValue: "tavily-secret", + wantFound: true, + }, + { + name: "key not found in custom keys", + conn: &connections.Connection{ + Credentials: &connections.ConnectionCredentials{ + CustomKeys: map[string]string{ + "other": "value", + }, + }, + }, + credKey: "missing-key", + wantValue: "", + wantFound: false, + }, + { + name: "nil credentials", + conn: &connections.Connection{ + Credentials: nil, + }, + credKey: "key", + wantValue: "", + wantFound: false, + }, + { + name: "nil connection", + conn: nil, + credKey: "key", + wantValue: "", + wantFound: false, + }, + { + name: "empty key field — falls through to custom keys", + conn: &connections.Connection{ + Credentials: &connections.ConnectionCredentials{ + Key: "", + CustomKeys: map[string]string{"key": "from-custom"}, + }, + }, + credKey: "key", + wantValue: "from-custom", + wantFound: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + value, found := lookupCredentialValue(tt.conn, tt.credKey) + require.Equal(t, tt.wantFound, found) + require.Equal(t, tt.wantValue, value) + }) + } +} + +func TestFindManifestInDir(t *testing.T) { + t.Run("finds agent.yaml with connection refs", func(t *testing.T) { + dir := t.TempDir() + content := `environment_variables: + - name: MY_KEY + value: "${{connections.test.credentials.api-key}}" +` + err := os.WriteFile( + filepath.Join(dir, "agent.yaml"), []byte(content), 0600, + ) + require.NoError(t, err) + + result := findManifestInDir(dir) + require.Equal(t, filepath.Join(dir, "agent.yaml"), result) + }) + + t.Run("finds agent.manifest.yaml with connection refs", func(t *testing.T) { + dir := t.TempDir() + content := `template: + environment_variables: + - name: SECRET + value: "${{connections.conn1.credentials.key}}" +` + err := os.WriteFile( + filepath.Join(dir, "agent.manifest.yaml"), + []byte(content), 0600, + ) + require.NoError(t, err) + + result := findManifestInDir(dir) + require.Equal(t, + filepath.Join(dir, "agent.manifest.yaml"), result) + }) + + t.Run("prefers agent.yaml over agent.manifest.yaml", func(t *testing.T) { + dir := t.TempDir() + agentYAML := `environment_variables: + - name: A + value: "${{connections.c.credentials.k}}" +` + manifestYAML := `template: + environment_variables: + - name: B + value: "${{connections.c.credentials.k}}" +` + require.NoError(t, os.WriteFile( + filepath.Join(dir, "agent.yaml"), + []byte(agentYAML), 0600, + )) + require.NoError(t, os.WriteFile( + filepath.Join(dir, "agent.manifest.yaml"), + []byte(manifestYAML), 0600, + )) + + result := findManifestInDir(dir) + require.Equal(t, filepath.Join(dir, "agent.yaml"), result) + }) + + t.Run("skips yaml without connection refs", func(t *testing.T) { + dir := t.TempDir() + content := `environment_variables: + - name: PORT + value: "8080" +` + require.NoError(t, os.WriteFile( + filepath.Join(dir, "agent.yaml"), + []byte(content), 0600, + )) + + result := findManifestInDir(dir) + require.Empty(t, result) + }) + + t.Run("returns empty for empty directory", func(t *testing.T) { + dir := t.TempDir() + result := findManifestInDir(dir) + require.Empty(t, result) + }) +} + +func TestConnectionRefPattern(t *testing.T) { + tests := []struct { + name string + input string + wantConn string + wantKey string + wantNil bool + }{ + { + name: "standard ref", + input: "${{connections.my-conn.credentials.x-api-key}}", + wantConn: "my-conn", + wantKey: "x-api-key", + }, + { + name: "simple key name", + input: "${{connections.conn1.credentials.key}}", + wantConn: "conn1", + wantKey: "key", + }, + { + name: "not a connection ref", + input: "${SOME_ENV_VAR}", + wantNil: true, + }, + { + name: "partial pattern", + input: "${{connections.only-name}}", + wantNil: true, + }, + { + name: "empty string", + input: "", + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matches := connectionRefPattern.FindStringSubmatch(tt.input) + if tt.wantNil { + require.Nil(t, matches) + return + } + require.Len(t, matches, 3) + require.Equal(t, tt.wantConn, matches[1]) + require.Equal(t, tt.wantKey, matches[2]) + }) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/root.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/root.go index d65d1c0b8e5..23c71da3720 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/root.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/root.go @@ -6,6 +6,8 @@ package cmd import ( "fmt" + conncmd "azureaiagent/internal/connections/cmd" + "github.com/azure/azure-dev/cli/azd/pkg/azdext" "github.com/fatih/color" "github.com/spf13/cobra" @@ -61,5 +63,10 @@ func NewRootCommand() *cobra.Command { rootCmd.AddCommand(newFilesCommand(extCtx)) rootCmd.AddCommand(newSessionCommand(extCtx)) + // Connection commands — in separate package for easy lift-and-shift later. + // When the azd core namespace change lands, move this AddCommand call + // to the new root and update the import path. + rootCmd.AddCommand(conncmd.NewConnectionRootCommand(extCtx)) + return rootCmd } diff --git a/cli/azd/extensions/azure.ai.agents/internal/cmd/run.go b/cli/azd/extensions/azure.ai.agents/internal/cmd/run.go index ee2ef0b530c..c684cbb8fce 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/cmd/run.go +++ b/cli/azd/extensions/azure.ai.agents/internal/cmd/run.go @@ -160,6 +160,18 @@ func runRun(ctx context.Context, flags *runFlags, noPrompt bool) error { env = appendFoundryEnvVars(env, azdEnvVars, runCtx.ServiceName) } + // Resolve ${{connections..credentials.}} references from the + // agent manifest's environment_variables section. These are fetched from + // the Foundry data plane at runtime and injected into the agent process. + // Uses the same endpoint resolution as other agent commands. + if endpoint, err := resolveAgentEndpoint(ctx, "", ""); err == nil { + if connEnv, err := resolveConnectionCredentials(ctx, projectDir, endpoint); err != nil { + fmt.Fprintf(os.Stderr, "Warning: connection credential resolution failed: %s\n", err) + } else { + env = append(env, connEnv...) + } + } + url := fmt.Sprintf("http://localhost:%d", flags.port) fmt.Println() fmt.Println("After startup, in another terminal, try:") diff --git a/cli/azd/extensions/azure.ai.agents/internal/connections/cmd/connection.go b/cli/azd/extensions/azure.ai.agents/internal/connections/cmd/connection.go new file mode 100644 index 00000000000..a6d8639ce26 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/connections/cmd/connection.go @@ -0,0 +1,801 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "log" + "maps" + "os" + "text/tabwriter" + + "azureaiagent/internal/connections/exterrors" + "azureaiagent/internal/connections/pkg/connections" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/cognitiveservices/armcognitiveservices/v2" + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/spf13/cobra" +) + +// --- LIST --- + +// connectionListFlags holds validated input for ConnectionListAction. +type connectionListFlags struct { + kind string + output string + projectEndpoint string +} + +// ConnectionListAction implements connection listing. +type ConnectionListAction struct { + flags *connectionListFlags +} + +// Run executes the list operation. +func (a *ConnectionListAction) Run(ctx context.Context) error { + normalizedKind := normalizeKind(a.flags.kind) + + connCtx, err := resolveConnectionContext(ctx, a.flags.projectEndpoint) + if err != nil { + return err + } + + pager := connCtx.armClient.NewListPager( + connCtx.rg, connCtx.account, connCtx.project, nil, + ) + + var results []connectionListItem + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return exterrors.ServiceFromAzure(err, exterrors.OpListConnections) + } + for _, conn := range page.Value { + props := conn.Properties.GetConnectionPropertiesV2() + if props == nil { + continue + } + if normalizedKind != "" && + (props.Category == nil || string(*props.Category) != normalizedKind) { + continue + } + results = append(results, connectionListItem{ + Name: deref(conn.Name), + Kind: categoryStr(props.Category), + AuthType: authTypeStr(props.AuthType), + Target: deref(props.Target), + }) + } + } + + return printList(results, a.flags.output) +} + +func newConnectionListCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + flags := &connectionListFlags{} + action := &ConnectionListAction{flags: flags} + + cmd := &cobra.Command{ + Use: "list", + Short: "List connections in the Foundry project.", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, _ []string) error { + flags.output = extCtx.OutputFormat + flags.projectEndpoint, _ = cmd.Flags().GetString("project-endpoint") + + ctx := azdext.WithAccessToken(cmd.Context()) + return action.Run(ctx) + }, + } + + cmd.Flags().StringVar(&flags.kind, "kind", "", + "Filter by connection kind (e.g., remote-tool)") + azdext.RegisterFlagOptions(cmd, azdext.FlagOptions{ + Name: "output", AllowedValues: []string{"json", "table"}, Default: "table", + }) + return cmd +} + +// --- SHOW --- + +// connectionShowFlags holds validated input for ConnectionShowAction. +type connectionShowFlags struct { + name string + showCredentials bool + output string + projectEndpoint string +} + +// ConnectionShowAction implements connection show. +type ConnectionShowAction struct { + flags *connectionShowFlags +} + +// Run executes the show operation. +func (a *ConnectionShowAction) Run(ctx context.Context) error { + connCtx, err := resolveConnectionContext(ctx, a.flags.projectEndpoint) + if err != nil { + return err + } + + armResp, err := connCtx.armClient.Get( + ctx, connCtx.rg, connCtx.account, connCtx.project, a.flags.name, nil, + ) + if err != nil { + return exterrors.ServiceFromAzure(err, exterrors.OpGetConnection) + } + + props := armResp.Properties.GetConnectionPropertiesV2() + if props == nil { + return fmt.Errorf("connection %q: unexpected response format", a.flags.name) + } + + result := connectionDetailResult{ + Name: deref(armResp.Name), + Kind: categoryStr(props.Category), + AuthType: authTypeStr(props.AuthType), + Target: deref(props.Target), + Metadata: props.Metadata, + } + + if a.flags.showCredentials { + dpConn, dpErr := connCtx.dpClient.GetConnectionWithCredentials( + ctx, a.flags.name, + ) + if dpErr != nil { + fmt.Fprintf(os.Stderr, + "Warning: could not fetch credentials: %s\n", dpErr) + } else if dpConn.Credentials != nil { + result.Credentials = dpConn.Credentials.RawFields + result.CredentialRefs = buildCredentialReferences( + a.flags.name, dpConn.Credentials, + ) + } + } + + return printDetail(result, a.flags.output) +} + +func newConnectionShowCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + flags := &connectionShowFlags{} + action := &ConnectionShowAction{flags: flags} + + cmd := &cobra.Command{ + Use: "show ", + Short: "Show connection details.", + Long: "Show connection details. Use --show-credentials to fetch secret values.", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + flags.name = args[0] + flags.output = extCtx.OutputFormat + flags.projectEndpoint, _ = cmd.Flags().GetString("project-endpoint") + + ctx := azdext.WithAccessToken(cmd.Context()) + return action.Run(ctx) + }, + } + + cmd.Flags().BoolVar(&flags.showCredentials, "show-credentials", false, + "Fetch credential values from the data plane") + azdext.RegisterFlagOptions(cmd, azdext.FlagOptions{ + Name: "output", AllowedValues: []string{"json", "table"}, Default: "table", + }) + return cmd +} + +// --- CREATE --- + +// connectionCreateFlags holds validated input for ConnectionCreateAction. +type connectionCreateFlags struct { + name string + kind string + target string + authType string + key string + customKeys []string + metadata []string + force bool + projectEndpoint string +} + +// ConnectionCreateAction implements connection creation. +type ConnectionCreateAction struct { + flags *connectionCreateFlags +} + +// Run executes the create operation. +func (a *ConnectionCreateAction) Run(ctx context.Context) error { + if a.flags.kind == "" { + return exterrors.Validation( + exterrors.CodeMissingConnectionField, + "Missing required flag --kind.", + "Specify the connection kind (e.g., --kind remote-tool).", + ) + } + if a.flags.target == "" { + return exterrors.Validation( + exterrors.CodeMissingConnectionField, + "Missing required flag --target.", + "Specify the target URL (e.g., --target https://example.com).", + ) + } + if a.flags.authType == "api-key" && a.flags.key == "" { + return exterrors.Validation( + exterrors.CodeMissingConnectionField, + "Missing required flag --key for api-key auth.", + "Specify the API key value.", + ) + } + if a.flags.authType == "custom-keys" && len(a.flags.customKeys) == 0 { + return exterrors.Validation( + exterrors.CodeMissingConnectionField, + "Missing required flag --custom-key for custom-keys auth.", + "Specify at least one custom key (e.g., --custom-key x-api-key=value).", + ) + } + + connCtx, err := resolveConnectionContext(ctx, a.flags.projectEndpoint) + if err != nil { + return err + } + + // Pre-check: fail if connection exists and --force not set + if !a.flags.force { + if _, err := connCtx.armClient.Get( + ctx, connCtx.rg, connCtx.account, connCtx.project, + a.flags.name, nil, + ); err == nil { + return exterrors.Validation( + exterrors.CodeConnectionAlreadyExists, + fmt.Sprintf("Connection %q already exists.", a.flags.name), + "Use --force to replace the existing connection.", + ) + } + } + + body, err := buildConnectionBody( + a.flags.kind, a.flags.target, a.flags.authType, + a.flags.key, a.flags.customKeys, a.flags.metadata, + ) + if err != nil { + return err + } + + _, err = connCtx.armClient.Create( + ctx, connCtx.rg, connCtx.account, connCtx.project, + a.flags.name, + &armcognitiveservices.ProjectConnectionsClientCreateOptions{ + Connection: body, + }, + ) + if err != nil { + return exterrors.ServiceFromAzure(err, exterrors.OpCreateConnection) + } + + fmt.Printf("Connection %q created in project %q.\n", + a.flags.name, connCtx.project) + return nil +} + +func newConnectionCreateCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + flags := &connectionCreateFlags{} + action := &ConnectionCreateAction{flags: flags} + + cmd := &cobra.Command{ + Use: "create ", + Short: "Create a new Foundry project connection.", + Example: ` azd ai connection create my-search \ + --kind cognitive-search --target https://my-search.search.windows.net/ \ + --auth-type api-key --key "abc123..." + + azd ai connection create my-tavily \ + --kind remote-tool --target https://mcp.tavily.com/mcp \ + --auth-type custom-keys --custom-key "x-api-key=tvly-abc123"`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + flags.name = args[0] + flags.projectEndpoint, _ = cmd.Flags().GetString("project-endpoint") + + ctx := azdext.WithAccessToken(cmd.Context()) + return action.Run(ctx) + }, + } + + cmd.Flags().StringVar(&flags.kind, "kind", "", + "Connection kind (e.g., remote-tool, cognitive-search)") + cmd.Flags().StringVar(&flags.target, "target", "", + "Target URL or ARM resource ID") + cmd.Flags().StringVar(&flags.authType, "auth-type", "none", + "Auth type: api-key, custom-keys, none") + cmd.Flags().StringVar(&flags.key, "key", "", + "API key (for api-key auth)") + cmd.Flags().StringArrayVar(&flags.customKeys, "custom-key", nil, + "Custom key=value (repeatable, for custom-keys auth)") + cmd.Flags().StringArrayVar(&flags.metadata, "metadata", nil, + "Metadata key=value (repeatable)") + cmd.Flags().BoolVar(&flags.force, "force", false, + "Replace existing connection (upsert)") + return cmd +} + +// --- UPDATE --- + +// connectionUpdateFlags holds validated input for ConnectionUpdateAction. +type connectionUpdateFlags struct { + name string + target string + key string + customKeys []string + targetChanged bool + keyChanged bool + customKeyChanged bool + projectEndpoint string +} + +// ConnectionUpdateAction implements connection update. +type ConnectionUpdateAction struct { + flags *connectionUpdateFlags +} + +// Run executes the update operation. +func (a *ConnectionUpdateAction) Run(ctx context.Context) error { + if !a.flags.targetChanged && !a.flags.keyChanged && + !a.flags.customKeyChanged { + return exterrors.Validation( + exterrors.CodeMissingConnectionField, + "No fields to update.", + "Specify --target, --key, or --custom-key.", + ) + } + + connCtx, err := resolveConnectionContext(ctx, a.flags.projectEndpoint) + if err != nil { + return err + } + + // GET current connection metadata from ARM + current, err := connCtx.armClient.Get( + ctx, connCtx.rg, connCtx.account, connCtx.project, + a.flags.name, nil, + ) + if err != nil { + return exterrors.ServiceFromAzure(err, exterrors.OpGetConnection) + } + + // Fetch current credentials from data-plane (ARM never returns credentials) + dpConn, err := connCtx.dpClient.GetConnectionWithCredentials( + ctx, a.flags.name, + ) + if err != nil { + return fmt.Errorf("failed to fetch current credentials: %w", err) + } + + props := current.Properties.GetConnectionPropertiesV2() + + // Apply target change + newTarget := deref(props.Target) + if a.flags.targetChanged { + newTarget = a.flags.target + } + + // Build merged credentials + newKey := "" + newCustomKeys := map[string]string{} + if dpConn.Credentials != nil { + newKey = dpConn.Credentials.Key + maps.Copy(newCustomKeys, dpConn.Credentials.CustomKeys) + } + if a.flags.keyChanged { + newKey = a.flags.key + } + if a.flags.customKeyChanged { + for _, kv := range a.flags.customKeys { + for i := range len(kv) { + if kv[i] == '=' { + newCustomKeys[kv[:i]] = kv[i+1:] + break + } + } + } + } + + // Rebuild the full connection body with credentials + normalizedAuth := normalizeAuthType(authTypeStr(props.AuthType)) + kindStr := categoryStr(props.Category) + metaPairs := []string{} + for k, v := range props.Metadata { + if v != nil { + metaPairs = append(metaPairs, k+"="+*v) + } + } + + var credKey string + var credCustomKeys []string + if newKey != "" { + credKey = newKey + } + for k, v := range newCustomKeys { + credCustomKeys = append(credCustomKeys, k+"="+v) + } + + body, err := buildConnectionBody( + kindStr, newTarget, normalizedAuth, + credKey, credCustomKeys, metaPairs, + ) + if err != nil { + return err + } + + _, err = connCtx.armClient.Create( + ctx, connCtx.rg, connCtx.account, connCtx.project, + a.flags.name, + &armcognitiveservices.ProjectConnectionsClientCreateOptions{ + Connection: body, + }, + ) + if err != nil { + return exterrors.ServiceFromAzure(err, exterrors.OpUpdateConnection) + } + + fmt.Printf("Connection %q updated.\n", a.flags.name) + return nil +} + +func newConnectionUpdateCommand( + extCtx *azdext.ExtensionContext, +) *cobra.Command { + flags := &connectionUpdateFlags{} + action := &ConnectionUpdateAction{flags: flags} + + cmd := &cobra.Command{ + Use: "update ", + Short: "Update a connection's target or credentials.", + Long: `Update a connection's target URL or credential values. + +Only the specified flags are changed; all other fields are preserved. +Does not accept --auth-type (delete and recreate to change auth type). +For metadata changes, use the 'metadata' subcommand.`, + Example: ` azd ai agent connection update prod-search --key "$NEW_SEARCH_KEY" + azd ai agent connection update my-conn --target https://new-endpoint.com + azd ai agent connection update my-mcp --custom-key "x-api-key=new-key"`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + flags.name = args[0] + flags.projectEndpoint, _ = cmd.Flags().GetString("project-endpoint") + flags.targetChanged = cmd.Flags().Changed("target") + flags.keyChanged = cmd.Flags().Changed("key") + flags.customKeyChanged = cmd.Flags().Changed("custom-key") + + ctx := azdext.WithAccessToken(cmd.Context()) + return action.Run(ctx) + }, + } + + cmd.Flags().StringVar(&flags.target, "target", "", + "New target URL or ARM resource ID") + cmd.Flags().StringVar(&flags.key, "key", "", + "New API key value (for api-key auth)") + cmd.Flags().StringArrayVar(&flags.customKeys, "custom-key", nil, + "Update custom key=value (repeatable, for custom-keys auth)") + return cmd +} + +// --- DELETE --- + +// connectionDeleteFlags holds validated input for ConnectionDeleteAction. +type connectionDeleteFlags struct { + name string + force bool + noPrompt bool + projectEndpoint string +} + +// ConnectionDeleteAction implements connection deletion. +type ConnectionDeleteAction struct { + flags *connectionDeleteFlags +} + +// Run executes the delete operation. +func (a *ConnectionDeleteAction) Run(ctx context.Context) error { + connCtx, err := resolveConnectionContext(ctx, a.flags.projectEndpoint) + if err != nil { + return err + } + + resp, err := connCtx.armClient.Get( + ctx, connCtx.rg, connCtx.account, connCtx.project, + a.flags.name, nil, + ) + if err != nil { + return exterrors.ServiceFromAzure(err, exterrors.OpGetConnection) + } + + props := resp.Properties.GetConnectionPropertiesV2() + fmt.Printf("Connection: %s (%s)\n", + a.flags.name, categoryStr(props.Category)) + fmt.Printf("Target: %s\n", deref(props.Target)) + + if !a.flags.force { + if a.flags.noPrompt { + return exterrors.Validation( + exterrors.CodeMissingForceFlag, + fmt.Sprintf( + "Deleting %q requires confirmation.", a.flags.name, + ), + "Use --force to skip confirmation in non-interactive mode.", + ) + } + azdClient, err := azdext.NewAzdClient() + if err != nil { + return fmt.Errorf("failed to create azd client: %w", err) + } + defer azdClient.Close() + + confirmResp, err := azdClient.Prompt().Confirm( + ctx, &azdext.ConfirmRequest{ + Options: &azdext.ConfirmOptions{ + Message: "Are you sure you want to delete this connection?", + DefaultValue: new(false), + }, + }, + ) + if err != nil { + return err + } + if !*confirmResp.Value { + fmt.Println("Cancelled.") + return nil + } + } + + _, err = connCtx.armClient.Delete( + ctx, connCtx.rg, connCtx.account, connCtx.project, + a.flags.name, nil, + ) + if err != nil { + return exterrors.ServiceFromAzure(err, exterrors.OpDeleteConnection) + } + + fmt.Printf("Connection %q deleted.\n", a.flags.name) + return nil +} + +func newConnectionDeleteCommand( + extCtx *azdext.ExtensionContext, +) *cobra.Command { + flags := &connectionDeleteFlags{} + action := &ConnectionDeleteAction{flags: flags} + + cmd := &cobra.Command{ + Use: "delete ", + Short: "Delete a connection.", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + flags.name = args[0] + flags.noPrompt = extCtx.NoPrompt + flags.projectEndpoint, _ = cmd.Flags().GetString("project-endpoint") + + ctx := azdext.WithAccessToken(cmd.Context()) + return action.Run(ctx) + }, + } + + cmd.Flags().BoolVar(&flags.force, "force", false, + "Skip confirmation prompt") + return cmd +} + +// --- Helpers --- + +type connectionListItem struct { + Name string `json:"name"` + Kind string `json:"kind"` + AuthType string `json:"authType"` + Target string `json:"target"` +} + +type connectionDetailResult struct { + Name string `json:"name"` + Kind string `json:"kind"` + AuthType string `json:"authType"` + Target string `json:"target"` + Metadata map[string]*string `json:"metadata,omitempty"` + Credentials map[string]string `json:"credentials,omitempty"` + CredentialRefs map[string]string `json:"credentialReferences,omitempty"` +} + +func buildCredentialReferences( + connName string, creds *connections.ConnectionCredentials, +) map[string]string { + if creds == nil { + return nil + } + refs := map[string]string{} + if creds.Key != "" { + refs["key"] = fmt.Sprintf("${{connections.%s.credentials.key}}", connName) + } + for k := range creds.CustomKeys { + refs[k] = fmt.Sprintf("${{connections.%s.credentials.%s}}", connName, k) + } + if len(refs) == 0 { + return nil + } + return refs +} + +func buildConnectionBody( + kind, target, authType, key string, + customKeys, metadata []string, +) (*armcognitiveservices.ConnectionPropertiesV2BasicResource, error) { + metaMap := parseKVPtrMap(metadata) + cat := armcognitiveservices.ConnectionCategory(normalizeKind(kind)) + + // Map CLI kebab-case auth types to ARM SDK values + switch authType { + case "api-key": + at := armcognitiveservices.ConnectionAuthTypeAPIKey + return &armcognitiveservices.ConnectionPropertiesV2BasicResource{ + Properties: &armcognitiveservices.APIKeyAuthConnectionProperties{ + AuthType: &at, + Category: &cat, + Target: &target, + Credentials: &armcognitiveservices.ConnectionAPIKey{Key: &key}, + Metadata: metaMap, + }, + }, nil + + case "custom-keys": + at := armcognitiveservices.ConnectionAuthTypeCustomKeys + keysMap := parseKVPtrMap(customKeys) + return &armcognitiveservices.ConnectionPropertiesV2BasicResource{ + Properties: &armcognitiveservices.CustomKeysConnectionProperties{ + AuthType: &at, + Category: &cat, + Target: &target, + Credentials: &armcognitiveservices.CustomKeys{Keys: keysMap}, + Metadata: metaMap, + }, + }, nil + + case "none", "": + at := armcognitiveservices.ConnectionAuthTypeNone + return &armcognitiveservices.ConnectionPropertiesV2BasicResource{ + Properties: &armcognitiveservices.NoneAuthTypeConnectionProperties{ + AuthType: &at, + Category: &cat, + Target: &target, + Metadata: metaMap, + }, + }, nil + + default: + return nil, exterrors.Validation( + exterrors.CodeInvalidAuthType, + fmt.Sprintf("Unsupported auth type %q.", authType), + "Supported: api-key, custom-keys, none", + ) + } +} + +func printList(items []connectionListItem, format string) error { + if format == "json" { + data, err := json.MarshalIndent(items, "", " ") + if err != nil { + return err + } + fmt.Println(string(data)) + return nil + } + w := tabwriter.NewWriter(os.Stdout, 0, 4, 2, ' ', 0) + fmt.Fprintln(w, "Name\tKind\tAuth Type\tTarget") + fmt.Fprintln(w, "----\t----\t---------\t------") + for _, item := range items { + fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", item.Name, item.Kind, item.AuthType, item.Target) + } + return w.Flush() +} + +func printDetail(result connectionDetailResult, format string) error { + if format == "json" { + data, err := json.MarshalIndent(result, "", " ") + if err != nil { + return err + } + fmt.Println(string(data)) + return nil + } + fmt.Printf("Name: %s\n", result.Name) + fmt.Printf("Kind: %s\n", result.Kind) + fmt.Printf("Auth Type: %s\n", result.AuthType) + fmt.Printf("Target: %s\n", result.Target) + if len(result.Credentials) > 0 { + fmt.Println("\nCredentials:") + for k, v := range result.Credentials { + fmt.Printf(" %s: %s\n", k, v) + } + } + if len(result.CredentialRefs) > 0 { + fmt.Println("\nCredential References (for agent.yaml):") + for k, v := range result.CredentialRefs { + fmt.Printf(" %s: %s\n", k, v) + } + } + return nil +} + +func parseKVPtrMap(pairs []string) map[string]*string { + if len(pairs) == 0 { + return nil + } + result := make(map[string]*string, len(pairs)) + for _, pair := range pairs { + found := false + for i := range len(pair) { + if pair[i] == '=' { + v := pair[i+1:] + result[pair[:i]] = &v + found = true + break + } + } + if !found { + log.Printf("warning: ignoring malformed key=value pair: %q", pair) + } + } + return result +} + +func deref(s *string) string { + if s == nil { + return "" + } + return *s +} + +func categoryStr(c *armcognitiveservices.ConnectionCategory) string { + if c == nil { + return "" + } + return string(*c) +} + +func authTypeStr(a *armcognitiveservices.ConnectionAuthType) string { + if a == nil { + return "" + } + return string(*a) +} + +func normalizeKind(cliKind string) string { + mapping := map[string]string{ + "remote-tool": "RemoteTool", + "cognitive-search": "CognitiveSearch", + "api-key": "ApiKey", + "app-insights": "AppInsights", + "grounding-with-bing-search": "GroundingWithBingSearch", + "ai-services": "AIServices", + "container-registry": "ContainerRegistry", + "custom-keys": "CustomKeys", + } + if mapped, ok := mapping[cliKind]; ok { + return mapped + } + return cliKind +} + +// normalizeAuthType converts ARM SDK auth type values to CLI kebab-case format. +func normalizeAuthType(armAuthType string) string { + switch armAuthType { + case "ApiKey": + return "api-key" + case "CustomKeys": + return "custom-keys" + case "None": + return "none" + default: + return armAuthType + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/connections/cmd/connection_test.go b/cli/azd/extensions/azure.ai.agents/internal/connections/cmd/connection_test.go new file mode 100644 index 00000000000..a840d164389 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/connections/cmd/connection_test.go @@ -0,0 +1,286 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "testing" + + "azureaiagent/internal/connections/pkg/connections" + + "github.com/stretchr/testify/require" +) + +func TestParseEndpointComponents(t *testing.T) { + tests := []struct { + name string + endpoint string + wantAccount string + wantProject string + wantErr bool + }{ + { + name: "standard endpoint", + endpoint: "https://myaccount.services.ai.azure.com/api/projects/myproject", + wantAccount: "myaccount", + wantProject: "myproject", + }, + { + name: "endpoint with trailing slash", + endpoint: "https://myaccount.services.ai.azure.com/api/projects/myproject/", + wantAccount: "myaccount", + wantProject: "myproject", + }, + { + name: "missing project segment", + endpoint: "https://myaccount.services.ai.azure.com/api/", + wantErr: true, + }, + { + name: "empty endpoint", + endpoint: "", + wantErr: true, + }, + { + name: "no host", + endpoint: "/api/projects/myproject", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account, project, err := parseEndpointComponents(tt.endpoint) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.wantAccount, account) + require.Equal(t, tt.wantProject, project) + }) + } +} + +func TestParseARMResourceID(t *testing.T) { + tests := []struct { + name string + resourceID string + wantSub string + wantRG string + wantAcct string + wantProj string + wantErr bool + }{ + { + name: "full resource ID", + resourceID: "/subscriptions/sub-123/resourceGroups/rg-test/" + + "providers/Microsoft.CognitiveServices/accounts/acct1/projects/proj1/" + + "connections/conn1", + wantSub: "sub-123", + wantRG: "rg-test", + wantAcct: "acct1", + wantProj: "proj1", + }, + { + name: "missing subscription", + resourceID: "/resourceGroups/rg/providers/Microsoft.CognitiveServices/accounts/a/projects/p", + wantErr: true, + }, + { + name: "empty string", + resourceID: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseARMResourceID(tt.resourceID) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.wantSub, result.SubscriptionID) + require.Equal(t, tt.wantRG, result.ResourceGroup) + require.Equal(t, tt.wantAcct, result.AccountName) + require.Equal(t, tt.wantProj, result.ProjectName) + }) + } +} + +func TestNormalizeKind(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"remote-tool", "RemoteTool"}, + {"cognitive-search", "CognitiveSearch"}, + {"api-key", "ApiKey"}, + {"app-insights", "AppInsights"}, + {"ai-services", "AIServices"}, + {"container-registry", "ContainerRegistry"}, + {"custom-keys", "CustomKeys"}, + // Already PascalCase — pass through + {"RemoteTool", "RemoteTool"}, + // Unknown kind — pass through + {"my-custom-kind", "my-custom-kind"}, + // Empty + {"", ""}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + require.Equal(t, tt.want, normalizeKind(tt.input)) + }) + } +} + +func TestNormalizeAuthType(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"ApiKey", "api-key"}, + {"CustomKeys", "custom-keys"}, + {"None", "none"}, + // Unknown — pass through + {"AAD", "AAD"}, + {"", ""}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + require.Equal(t, tt.want, normalizeAuthType(tt.input)) + }) + } +} + +func TestParseKVPtrMap(t *testing.T) { + tests := []struct { + name string + pairs []string + want map[string]string // compare dereferenced values + }{ + { + name: "single pair", + pairs: []string{"key1=value1"}, + want: map[string]string{"key1": "value1"}, + }, + { + name: "multiple pairs", + pairs: []string{"a=1", "b=2"}, + want: map[string]string{"a": "1", "b": "2"}, + }, + { + name: "value with equals sign", + pairs: []string{"key=val=ue"}, + want: map[string]string{"key": "val=ue"}, + }, + { + name: "empty value", + pairs: []string{"key="}, + want: map[string]string{"key": ""}, + }, + { + name: "malformed pair skipped", + pairs: []string{"noequals", "good=val"}, + want: map[string]string{"good": "val"}, + }, + { + name: "nil input", + pairs: nil, + want: nil, + }, + { + name: "empty slice", + pairs: []string{}, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseKVPtrMap(tt.pairs) + if tt.want == nil { + require.Nil(t, result) + return + } + require.Len(t, result, len(tt.want)) + for k, wantV := range tt.want { + require.NotNil(t, result[k], "missing key %q", k) + require.Equal(t, wantV, *result[k]) + } + }) + } +} + +func TestBuildCredentialReferences(t *testing.T) { + tests := []struct { + name string + connName string + creds *connections.ConnectionCredentials + want map[string]string + }{ + { + name: "api key only", + connName: "my-conn", + creds: &connections.ConnectionCredentials{ + Key: "secret", + }, + want: map[string]string{ + "key": "${{connections.my-conn.credentials.key}}", + }, + }, + { + name: "custom keys", + connName: "test-conn", + creds: &connections.ConnectionCredentials{ + CustomKeys: map[string]string{ + "x-api-key": "val1", + "token": "val2", + }, + }, + want: map[string]string{ + "x-api-key": "${{connections.test-conn.credentials.x-api-key}}", + "token": "${{connections.test-conn.credentials.token}}", + }, + }, + { + name: "both key and custom keys", + connName: "mixed", + creds: &connections.ConnectionCredentials{ + Key: "apikey", + CustomKeys: map[string]string{"extra": "v"}, + }, + want: map[string]string{ + "key": "${{connections.mixed.credentials.key}}", + "extra": "${{connections.mixed.credentials.extra}}", + }, + }, + { + name: "nil creds", + connName: "x", + creds: nil, + want: nil, + }, + { + name: "empty creds", + connName: "x", + creds: &connections.ConnectionCredentials{}, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := buildCredentialReferences(tt.connName, tt.creds) + if tt.want == nil { + require.Nil(t, result) + return + } + require.Equal(t, tt.want, result) + }) + } +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/connections/cmd/context.go b/cli/azd/extensions/azure.ai.agents/internal/connections/cmd/context.go new file mode 100644 index 00000000000..47da15e6321 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/connections/cmd/context.go @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + "fmt" + + "azureaiagent/internal/connections/exterrors" + "azureaiagent/internal/connections/pkg/connections" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/cognitiveservices/armcognitiveservices/v2" +) + +// dataClient is a type alias for the data-plane client (used in endpoint.go). +type dataClient = connections.DataClient + +// connectionContext holds the resolved clients and project info for connection operations. +type connectionContext struct { + armClient *armcognitiveservices.ProjectConnectionsClient + dpClient *connections.DataClient + rg string + account string + project string +} + +// resolveConnectionContext resolves the project endpoint, discovers ARM context, +// and creates both clients needed for connection operations. +func resolveConnectionContext( + ctx context.Context, + flagEndpoint string, +) (*connectionContext, error) { + endpoint, err := resolveProjectEndpoint(ctx, flagEndpoint) + if err != nil { + return nil, err + } + + account, project, err := parseEndpointComponents(endpoint) + if err != nil { + return nil, err + } + + cred, err := newCredential() + if err != nil { + return nil, err + } + + // Data-plane client (for list, get-with-credentials, and ARM discovery) + dpClient := connections.NewDataClient(endpoint, cred) + + // Discover subscription + resource group from data-plane response + armCtx, err := discoverARMContext(ctx, dpClient) + if err != nil { + return nil, err + } + + // ARM SDK client for CRUD + armClient, err := armcognitiveservices.NewProjectConnectionsClient( + armCtx.SubscriptionID, cred, nil, + ) + if err != nil { + return nil, fmt.Errorf("failed to create ARM connections client: %w", err) + } + + return &connectionContext{ + armClient: armClient, + dpClient: dpClient, + rg: armCtx.ResourceGroup, + account: account, + project: project, + }, nil +} + +// newCredential creates an Azure credential for API calls. +func newCredential() (azcore.TokenCredential, error) { + cred, err := azidentity.NewAzureDeveloperCLICredential( + &azidentity.AzureDeveloperCLICredentialOptions{}, + ) + if err != nil { + return nil, exterrors.Auth( + exterrors.CodeCredentialCreationFailed, + fmt.Sprintf("Failed to create Azure credential: %s", err), + "Run 'azd auth login' to authenticate.", + ) + } + + return cred, nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/connections/cmd/endpoint.go b/cli/azd/extensions/azure.ai.agents/internal/connections/cmd/endpoint.go new file mode 100644 index 00000000000..e79d56181c4 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/connections/cmd/endpoint.go @@ -0,0 +1,153 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "context" + "fmt" + "net/url" + "os" + "strings" + + "azureaiagent/internal/connections/exterrors" + + "github.com/azure/azure-dev/cli/azd/pkg/azdext" +) + +// TODO: Unify endpoint resolution with the project set/unset commands being added +// to avoid duplicating the resolution cascade logic. + +// resolveProjectEndpoint implements the 5-level resolution cascade from the spec. +// +// 1. -p / --project-endpoint flag (passed as flagEndpoint) +// 2. Active azd env → AZURE_AI_PROJECT_ENDPOINT +// 3. Global config → extensions.ai-agents.context.endpoint +// 4. FOUNDRY_PROJECT_ENDPOINT environment variable +// 5. Structured error +func resolveProjectEndpoint(ctx context.Context, flagEndpoint string) (string, error) { + // 1. Flag + if flagEndpoint != "" { + return flagEndpoint, nil + } + + // 2 & 3. Try azd host (env value + global config) — best-effort + azdClient, err := azdext.NewAzdClient() + if err == nil { + defer azdClient.Close() + + // 2. Active azd env → AZURE_AI_PROJECT_ENDPOINT + if envResp, err := azdClient.Environment().GetCurrent(ctx, &azdext.EmptyRequest{}); err == nil { + if valResp, err := azdClient.Environment().GetValue(ctx, &azdext.GetEnvRequest{ + EnvName: envResp.Environment.Name, + Key: "AZURE_AI_PROJECT_ENDPOINT", + }); err == nil && valResp.Value != "" { + return valResp.Value, nil + } + } + + // 3. Global config → extensions.ai-agents.context.endpoint + ch, cfgErr := azdext.NewConfigHelper(azdClient) + if cfgErr == nil { + var endpoint string + if found, err := ch.GetUserJSON( + ctx, "extensions.ai-agents.context.endpoint", &endpoint, + ); err == nil && found && endpoint != "" { + return endpoint, nil + } + } + } + + // 4. FOUNDRY_PROJECT_ENDPOINT environment variable + // TODO: Document FOUNDRY_PROJECT_ENDPOINT in cli/azd/docs/environment-variables.md + if ep := os.Getenv("FOUNDRY_PROJECT_ENDPOINT"); ep != "" { + return ep, nil + } + + // 5. Structured error + return "", exterrors.Dependency( + exterrors.CodeMissingProjectEndpoint, + "No Foundry project endpoint resolved.", + "Pass '--project-endpoint', set FOUNDRY_PROJECT_ENDPOINT env var, or run 'azd ai agent init' in an azd project.", + ) +} + +// parseEndpointComponents extracts account and project names from the endpoint URL. +// Expected format: https://{account}.services.ai.azure.com/api/projects/{project} +func parseEndpointComponents(endpoint string) (account, project string, err error) { + u, err := url.Parse(endpoint) + if err != nil { + return "", "", fmt.Errorf("invalid endpoint URL: %w", err) + } + + account, _, _ = strings.Cut(u.Hostname(), ".") + + parts := strings.Split(strings.Trim(u.Path, "/"), "/") + for i, p := range parts { + if p == "projects" && i+1 < len(parts) { + project = parts[i+1] + break + } + } + + if account == "" || project == "" { + return "", "", fmt.Errorf("could not parse account/project from endpoint %q", endpoint) + } + + return account, project, nil +} + +// armContext holds the ARM components needed for SDK calls. +type armContext struct { + SubscriptionID string + ResourceGroup string + AccountName string + ProjectName string +} + +// discoverARMContext makes a data-plane list call to discover subscription and +// resource group from the ARM resource IDs embedded in connection responses. +func discoverARMContext( + ctx context.Context, + dpClient *dataClient, +) (*armContext, error) { + conns, err := dpClient.ListConnections(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list connections for ARM discovery: %w", err) + } + + if len(conns) == 0 { + return nil, fmt.Errorf( + "no connections found in project; cannot discover ARM context. " + + "Create a connection via the Foundry portal first, or pass the project endpoint that already has connections", + ) + } + + return parseARMResourceID(conns[0].ID) +} + +// parseARMResourceID extracts ARM components from a full resource ID string. +func parseARMResourceID(resourceID string) (*armContext, error) { + parts := strings.Split(resourceID, "/") + result := &armContext{} + + for i, part := range parts { + switch { + case part == "subscriptions" && i+1 < len(parts): + result.SubscriptionID = parts[i+1] + case part == "resourceGroups" && i+1 < len(parts): + result.ResourceGroup = parts[i+1] + case part == "accounts" && i+1 < len(parts): + result.AccountName = parts[i+1] + case part == "projects" && i+1 < len(parts): + result.ProjectName = parts[i+1] + } + } + + if result.SubscriptionID == "" || result.ResourceGroup == "" || + result.AccountName == "" || result.ProjectName == "" { + return nil, fmt.Errorf("could not extract ARM context from resource ID: %s", resourceID) + } + + return result, nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/connections/cmd/root.go b/cli/azd/extensions/azure.ai.agents/internal/connections/cmd/root.go new file mode 100644 index 00000000000..e813c79e56c --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/connections/cmd/root.go @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package cmd + +import ( + "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "github.com/spf13/cobra" +) + +// NewConnectionRootCommand creates the "connection" subcommand group under "azd ai". +func NewConnectionRootCommand(extCtx *azdext.ExtensionContext) *cobra.Command { + cmd := &cobra.Command{ + Use: "connection [options]", + Short: "Manage Foundry project connections. (Preview)", + Long: `Manage connections (connected resources) in a Foundry project. + +Connections link a Foundry project to external services such as MCP servers, +AI Search, Bing, ACR, App Insights, AI Services, and custom APIs. + +Each connection has a kind, target URL, auth type, optional credentials, +and optional metadata.`, + } + + // Register -p / --project-endpoint as a persistent flag so all subcommands inherit it + cmd.PersistentFlags().StringP("project-endpoint", "p", "", + "Foundry project endpoint URL (overrides env var and config)") + + cmd.AddCommand(newConnectionListCommand(extCtx)) + cmd.AddCommand(newConnectionShowCommand(extCtx)) + cmd.AddCommand(newConnectionCreateCommand(extCtx)) + cmd.AddCommand(newConnectionUpdateCommand(extCtx)) + cmd.AddCommand(newConnectionDeleteCommand(extCtx)) + + return cmd +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/connections/exterrors/codes.go b/cli/azd/extensions/azure.ai.agents/internal/connections/exterrors/codes.go new file mode 100644 index 00000000000..0d2478cc1c9 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/connections/exterrors/codes.go @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package exterrors + +// Error codes for connection validation. +const ( + CodeConflictingArguments = "conflicting_arguments" + CodeMissingConnectionField = "missing_connection_field" + CodeInvalidConnectionKind = "invalid_connection_kind" + CodeInvalidAuthType = "invalid_auth_type" + CodeInvalidFromFile = "invalid_from_file" + CodeMissingForceFlag = "missing_force_flag" + CodeConnectionAlreadyExists = "connection_already_exists" +) + +// Error codes for endpoint resolution. +const ( + CodeMissingProjectEndpoint = "missing_project_endpoint" +) + +// Error codes for auth. +const ( + //nolint:gosec // error code identifier, not a credential + CodeCredentialCreationFailed = "credential_creation_failed" +) + +// Operation names for ServiceFromAzure errors. +const ( + OpCreateConnection = "create_connection" + OpUpdateConnection = "update_connection" + OpDeleteConnection = "delete_connection" + OpGetConnection = "get_connection" + OpGetConnectionCredentials = "get_connection_credentials" + OpListConnections = "list_connections" +) diff --git a/cli/azd/extensions/azure.ai.agents/internal/connections/exterrors/errors.go b/cli/azd/extensions/azure.ai.agents/internal/connections/exterrors/errors.go new file mode 100644 index 00000000000..c0effc647cc --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/connections/exterrors/errors.go @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package exterrors + +import ( + "errors" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/azure/azure-dev/cli/azd/pkg/azdext" +) + +// Validation returns a validation error for user-input or flag errors. +func Validation(code, message, suggestion string) error { + return &azdext.LocalError{ + Message: message, + Code: code, + Category: azdext.LocalErrorCategoryValidation, + Suggestion: suggestion, + } +} + +// Dependency returns a dependency error for missing resources or services. +func Dependency(code, message, suggestion string) error { + return &azdext.LocalError{ + Message: message, + Code: code, + Category: azdext.LocalErrorCategoryDependency, + Suggestion: suggestion, + } +} + +// Auth returns an auth error for authentication or authorization failures. +func Auth(code, message, suggestion string) error { + return &azdext.LocalError{ + Message: message, + Code: code, + Category: azdext.LocalErrorCategoryAuth, + Suggestion: suggestion, + } +} + +// ServiceFromAzure converts an Azure SDK error into a structured service error. +func ServiceFromAzure(err error, operation string) error { + if respErr, ok := errors.AsType[*azcore.ResponseError](err); ok { + return &azdext.ServiceError{ + Message: respErr.Error(), + ErrorCode: respErr.ErrorCode, + StatusCode: respErr.StatusCode, + ServiceName: operation, + } + } + return err +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/connections/pkg/connections/data_client.go b/cli/azd/extensions/azure.ai.agents/internal/connections/pkg/connections/data_client.go new file mode 100644 index 00000000000..5fa66eae3c3 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/connections/pkg/connections/data_client.go @@ -0,0 +1,196 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package connections + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/azure/azure-dev/cli/azd/pkg/azsdk" +) + +const dataPlaneAPIVersion = "2025-11-15-preview" + +// DataClient provides read operations via the Foundry data plane. +// Used for listing connections (including ARM ID discovery) and fetching credentials. +type DataClient struct { + endpoint string + pipeline runtime.Pipeline +} + +// NewDataClient creates a new data-plane client for connection operations. +func NewDataClient(endpoint string, cred azcore.TokenCredential) *DataClient { + clientOptions := &policy.ClientOptions{ + PerCallPolicies: []policy.Policy{ + runtime.NewBearerTokenPolicy( + cred, + []string{"https://ai.azure.com/.default"}, + nil, + ), + azsdk.NewMsCorrelationPolicy(), + azsdk.NewUserAgentPolicy("azd-ext-azure-ai-connection/0.1.0"), + }, + } + + pipeline := runtime.NewPipeline( + "azure-ai-connection-data", + "v1.0.0", + runtime.PipelineOptions{}, + clientOptions, + ) + + return &DataClient{endpoint: endpoint, pipeline: pipeline} +} + +// ListConnections retrieves all connections from the project via data-plane GET. +func (c *DataClient) ListConnections(ctx context.Context) ([]Connection, error) { + var allConnections []Connection + + paged, err := c.getPage( + ctx, + fmt.Sprintf("%s/connections?api-version=%s", c.endpoint, dataPlaneAPIVersion), + ) + if err != nil { + return nil, err + } + + allConnections = append(allConnections, paged.Value...) + nextLink := paged.NextLink + + for nextLink != nil && *nextLink != "" { + if err := c.validateNextLinkOrigin(*nextLink); err != nil { + return nil, fmt.Errorf("refusing to follow pagination link: %w", err) + } + + paged, err = c.getPage(ctx, *nextLink) + if err != nil { + return nil, err + } + + allConnections = append(allConnections, paged.Value...) + nextLink = paged.NextLink + } + + return allConnections, nil +} + +func (c *DataClient) validateNextLinkOrigin(nextLink string) error { + endpointURL, err := url.Parse(c.endpoint) + if err != nil { + return fmt.Errorf("invalid endpoint URL: %w", err) + } + + linkURL, err := url.Parse(nextLink) + if err != nil { + return fmt.Errorf("invalid nextLink URL: %w", err) + } + + if linkURL.Scheme == "" { + return fmt.Errorf("nextLink must have an explicit scheme, got %q", nextLink) + } + + if !strings.EqualFold(linkURL.Scheme, endpointURL.Scheme) || + !strings.EqualFold(linkURL.Host, endpointURL.Host) { + return fmt.Errorf( + "nextLink origin mismatch: expected %s://%s, got %s://%s", + endpointURL.Scheme, endpointURL.Host, linkURL.Scheme, linkURL.Host, + ) + } + + return nil +} + +func (c *DataClient) getPage(ctx context.Context, targetURL string) (*PagedConnection, error) { + req, err := runtime.NewRequest(ctx, http.MethodGet, targetURL) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := c.pipeline.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode(resp, http.StatusOK) { + return nil, runtime.NewResponseError(resp) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var paged PagedConnection + if err := json.Unmarshal(body, &paged); err != nil { + return nil, fmt.Errorf("failed to unmarshal connections: %w", err) + } + + return &paged, nil +} + +// GetConnectionWithCredentials retrieves a specific connection with its credentials +// via the data-plane POST endpoint. +func (c *DataClient) GetConnectionWithCredentials( + ctx context.Context, + name string, +) (*Connection, error) { + targetURL := fmt.Sprintf( + "%s/connections/%s/getConnectionWithCredentials?api-version=%s", + c.endpoint, url.PathEscape(name), dataPlaneAPIVersion, + ) + + req, err := runtime.NewRequest(ctx, http.MethodPost, targetURL) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := c.pipeline.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + if !runtime.HasStatusCode(resp, http.StatusOK) { + return nil, runtime.NewResponseError(resp) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var raw struct { + Name string `json:"name"` + ID string `json:"id"` + Type string `json:"type"` + Target string `json:"target"` + IsDefault bool `json:"isDefault"` + Credentials map[string]any `json:"credentials"` + Metadata map[string]string `json:"metadata"` + } + if err := json.Unmarshal(body, &raw); err != nil { + return nil, fmt.Errorf("failed to unmarshal connection: %w", err) + } + + conn := &Connection{ + Name: raw.Name, + ID: raw.ID, + Type: raw.Type, + Target: raw.Target, + IsDefault: raw.IsDefault, + Credentials: ParseCredentials(raw.Credentials), + Metadata: raw.Metadata, + } + + return conn, nil +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/connections/pkg/connections/models.go b/cli/azd/extensions/azure.ai.agents/internal/connections/pkg/connections/models.go new file mode 100644 index 00000000000..871372e42d6 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/connections/pkg/connections/models.go @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package connections + +// Connection represents a Foundry project connection from the data-plane API. +type Connection struct { + Name string `json:"name"` + ID string `json:"id"` + Type string `json:"type"` + Target string `json:"target"` + IsDefault bool `json:"isDefault"` + Credentials *ConnectionCredentials `json:"credentials,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// ConnectionCredentials holds credential values returned by the data-plane +// getConnectionWithCredentials endpoint. +// +// The API returns credentials as a flat JSON object where "type" identifies +// the auth type and all other fields are credential key-value pairs: +// +// ApiKey: {"type": "ApiKey", "key": "abc123"} +// CustomKeys: {"type": "CustomKeys", "my-secret": "val", "x-api-key": "val"} +// AAD/None: {"type": "AAD"} or {"type": "None"} — no secret fields +type ConnectionCredentials struct { + Type string `json:"-"` + Key string `json:"-"` + CustomKeys map[string]string `json:"-"` + // RawFields holds all fields from the JSON response for flexible access. + RawFields map[string]string `json:"-"` +} + +// ParseCredentials parses a raw credentials JSON object into a typed struct. +// The "type" field is extracted and remaining fields become either Key (for ApiKey) +// or CustomKeys entries. +func ParseCredentials(raw map[string]any) *ConnectionCredentials { + if raw == nil { + return nil + } + + creds := &ConnectionCredentials{ + CustomKeys: make(map[string]string), + RawFields: make(map[string]string), + } + + for k, v := range raw { + strVal, ok := v.(string) + if !ok { + continue + } + + switch k { + case "type": + creds.Type = strVal + case "key": + creds.Key = strVal + creds.RawFields[k] = strVal + default: + creds.CustomKeys[k] = strVal + creds.RawFields[k] = strVal + } + } + + return creds +} + +// PagedConnection represents a paged collection of connections. +type PagedConnection struct { + Value []Connection `json:"value"` + NextLink *string `json:"nextLink,omitempty"` +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/connections/pkg/connections/models_test.go b/cli/azd/extensions/azure.ai.agents/internal/connections/pkg/connections/models_test.go new file mode 100644 index 00000000000..fc1deefc4fd --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/connections/pkg/connections/models_test.go @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package connections + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseCredentials(t *testing.T) { + tests := []struct { + name string + raw map[string]any + wantType string + wantKey string + wantCustomKeys map[string]string + wantNil bool + }{ + { + name: "nil input", + raw: nil, + wantNil: true, + }, + { + name: "ApiKey credentials", + raw: map[string]any{ + "type": "ApiKey", + "key": "my-secret-key", + }, + wantType: "ApiKey", + wantKey: "my-secret-key", + wantCustomKeys: map[string]string{}, + }, + { + name: "CustomKeys credentials", + raw: map[string]any{ + "type": "CustomKeys", + "x-api-key": "tavily-key", + "token": "bearer-token", + }, + wantType: "CustomKeys", + wantKey: "", + wantCustomKeys: map[string]string{ + "x-api-key": "tavily-key", + "token": "bearer-token", + }, + }, + { + name: "AAD credentials (no secrets)", + raw: map[string]any{ + "type": "AAD", + }, + wantType: "AAD", + wantKey: "", + wantCustomKeys: map[string]string{}, + }, + { + name: "mixed key and custom keys", + raw: map[string]any{ + "type": "ApiKey", + "key": "primary", + "extra": "bonus", + }, + wantType: "ApiKey", + wantKey: "primary", + wantCustomKeys: map[string]string{ + "extra": "bonus", + }, + }, + { + name: "non-string values skipped", + raw: map[string]any{ + "type": "Custom", + "key": "valid", + "numeric": 42, + "nested": map[string]any{"a": "b"}, + }, + wantType: "Custom", + wantKey: "valid", + wantCustomKeys: map[string]string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ParseCredentials(tt.raw) + if tt.wantNil { + require.Nil(t, result) + return + } + require.NotNil(t, result) + require.Equal(t, tt.wantType, result.Type) + require.Equal(t, tt.wantKey, result.Key) + require.Equal(t, tt.wantCustomKeys, result.CustomKeys) + + // Verify RawFields contains non-type string fields + for k, v := range tt.raw { + if k == "type" { + continue + } + if strVal, ok := v.(string); ok { + require.Equal(t, strVal, result.RawFields[k], + "RawFields[%q] mismatch", k) + } + } + }) + } +}