diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/connection_manager.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/connection_manager.go new file mode 100644 index 00000000000..6493c2a668f --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/connection_manager.go @@ -0,0 +1,442 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azure + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + armcognitiveservices "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/cognitiveservices/armcognitiveservices/v2" +) + +// ConnectionManager provides CRUD operations on connected resources in a Foundry project. +// It combines the ARM SDK client (for create/get/update/delete/list) with the data-plane +// client (for reading credentials/secrets). ARM never returns credential values; the +// data-plane getConnectionWithCredentials endpoint is the only way to retrieve them. +type ConnectionManager struct { + armClient *armcognitiveservices.ProjectConnectionsClient + dpClient *FoundryProjectsClient // for reading credentials; nil if not needed + rg string + account string + project string +} + +// NewConnectionManager creates a ConnectionManager with both ARM and data-plane clients. +// The dpClient is used for reading credentials via getConnectionWithCredentials. +func NewConnectionManager( + subscriptionID, rg, account, project string, + cred azcore.TokenCredential, +) (*ConnectionManager, error) { + armClient, err := armcognitiveservices.NewProjectConnectionsClient( + subscriptionID, cred, NewArmClientOptions(), + ) + if err != nil { + return nil, fmt.Errorf("creating ARM connections client: %w", err) + } + + dpClient, err := NewFoundryProjectsClient(account, project, cred) + if err != nil { + return nil, fmt.Errorf("creating data-plane client: %w", err) + } + + return &ConnectionManager{ + armClient: armClient, + dpClient: dpClient, + rg: rg, + account: account, + project: project, + }, nil +} + +// ConnectionInfo holds metadata for a connection (no secrets). +type ConnectionInfo struct { + Name string + ID string + Category string + AuthType string + Target string + IsDefault bool + Metadata map[string]string +} + +// ConnectionDetail extends ConnectionInfo with credential key-value pairs +// retrieved from the data-plane getConnectionWithCredentials endpoint. +// The Credentials map contains only the actual secret fields (the "type" +// discriminator is excluded). +type ConnectionDetail struct { + ConnectionInfo + Credentials map[string]string +} + +// CreateConnectionParams holds the parameters for creating a new connection. +type CreateConnectionParams struct { + Category string // connection category (e.g., "RemoteTool", "ApiKey", "CustomKeys") + Target string // target URL or ARM resource ID + AuthType string // "ApiKey", "CustomKeys", or "None" + Key string // API key value (used when AuthType is "ApiKey") + Keys map[string]string // custom key-value pairs (used when AuthType is "CustomKeys") + Metadata map[string]string // optional metadata key-value pairs +} + +// UpdateConnectionParams holds the parameters for updating an existing connection. +// Nil pointer fields mean "don't change". Map fields are merged with existing values. +type UpdateConnectionParams struct { + Target *string // new target URL; nil = keep existing + Key *string // new API key; nil = keep existing (ApiKey auth only) + Keys map[string]string // custom keys to add/overwrite (CustomKeys auth only) + Metadata map[string]string // metadata to add/overwrite +} + +// List returns all connections in the project, optionally filtered by category. +// Pass an empty string for category to list all connections. +func (m *ConnectionManager) List(ctx context.Context, category string) ([]ConnectionInfo, error) { + opts := &armcognitiveservices.ProjectConnectionsClientListOptions{} + if category != "" { + opts.Category = &category + } + + pager := m.armClient.NewListPager(m.rg, m.account, m.project, opts) + + var results []ConnectionInfo + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, fmt.Errorf("listing connections: %w", err) + } + for _, r := range page.Value { + results = append(results, connectionInfoFromARM(r)) + } + } + + return results, nil +} + +// Get returns metadata for a single connection (no credentials). +func (m *ConnectionManager) Get(ctx context.Context, name string) (*ConnectionInfo, error) { + resp, err := m.armClient.Get(ctx, m.rg, m.account, m.project, name, nil) + if err != nil { + return nil, fmt.Errorf("getting connection %q: %w", name, err) + } + info := connectionInfoFromARM(&resp.ConnectionPropertiesV2BasicResource) + return &info, nil +} + +// GetWithCredentials returns metadata and credentials for a single connection. +// It calls both the ARM GET (metadata) and the data-plane getConnectionWithCredentials +// (secrets), then merges the results. +func (m *ConnectionManager) GetWithCredentials(ctx context.Context, name string) (*ConnectionDetail, error) { + info, err := m.Get(ctx, name) + if err != nil { + return nil, err + } + + creds, err := m.fetchCredentials(ctx, name) + if err != nil { + return nil, fmt.Errorf("fetching credentials for %q: %w", name, err) + } + + return &ConnectionDetail{ + ConnectionInfo: *info, + Credentials: creds, + }, nil +} + +// Create creates a new connection in the project. +// Supported auth types: ApiKey, CustomKeys, None. Returns an error for unsupported types. +func (m *ConnectionManager) Create( + ctx context.Context, name string, params CreateConnectionParams, +) (*ConnectionInfo, error) { + body, err := buildCreateBody(params) + if err != nil { + return nil, err + } + + resp, err := m.armClient.Create(ctx, m.rg, m.account, m.project, name, + &armcognitiveservices.ProjectConnectionsClientCreateOptions{ + Connection: body, + }, + ) + if err != nil { + return nil, fmt.Errorf("creating connection %q: %w", name, err) + } + + info := connectionInfoFromARM(&resp.ConnectionPropertiesV2BasicResource) + return &info, nil +} + +// Update updates an existing connection. It performs a GET-then-PUT because the ARM API +// does not support PATCH for connections. The update params are merged into the existing +// connection properties before the PUT. +func (m *ConnectionManager) Update( + ctx context.Context, name string, params UpdateConnectionParams, +) (*ConnectionInfo, error) { + // GET the current connection + current, err := m.armClient.Get(ctx, m.rg, m.account, m.project, name, nil) + if err != nil { + return nil, fmt.Errorf("getting connection %q for update: %w", name, err) + } + + updated, err := mergeUpdate(¤t.ConnectionPropertiesV2BasicResource, params) + if err != nil { + return nil, err + } + + resp, err := m.armClient.Create(ctx, m.rg, m.account, m.project, name, + &armcognitiveservices.ProjectConnectionsClientCreateOptions{ + Connection: updated, + }, + ) + if err != nil { + return nil, fmt.Errorf("updating connection %q: %w", name, err) + } + + info := connectionInfoFromARM(&resp.ConnectionPropertiesV2BasicResource) + return &info, nil +} + +// Delete removes a connection from the project. +func (m *ConnectionManager) Delete(ctx context.Context, name string) error { + _, err := m.armClient.Delete(ctx, m.rg, m.account, m.project, name, nil) + if err != nil { + return fmt.Errorf("deleting connection %q: %w", name, err) + } + return nil +} + +// fetchCredentials calls the data-plane getConnectionWithCredentials endpoint +// and returns the credential fields as a flat map (excluding the "type" discriminator). +func (m *ConnectionManager) fetchCredentials(ctx context.Context, name string) (map[string]string, error) { + conn, err := m.dpClient.GetConnectionWithCredentials(ctx, name) + if err != nil { + return nil, err + } + + // The existing Connection.Credentials struct only has Type and Key fields, + // but the raw response may contain arbitrary custom keys. We need to re-fetch + // the raw response to get all credential fields. + return m.fetchRawCredentials(ctx, name, conn) +} + +// fetchRawCredentials extracts all credential key-value pairs from the data-plane +// response. For ApiKey auth, this is just {"key": "..."}. For CustomKeys, this +// includes all named keys (e.g., {"x-api-key": "...", "secret": "..."}). +// The "type" discriminator is excluded from the result. +func (m *ConnectionManager) fetchRawCredentials( + ctx context.Context, name string, conn *Connection, +) (map[string]string, error) { + // For ApiKey connections, the existing struct captures the key + if conn.Credentials.Type == CredentialTypeApiKey && conn.Credentials.Key != "" { + return map[string]string{"key": conn.Credentials.Key}, nil + } + + // For CustomKeys and other types, we need the raw JSON to get all fields. + // Re-fetch using the raw response parser. + raw, err := m.dpClient.getConnectionWithCredentialsRaw(ctx, name) + if err != nil { + return nil, err + } + + // Parse the raw credentials JSON into a flat string map + var envelope struct { + Credentials map[string]json.RawMessage `json:"credentials"` + } + if err := json.Unmarshal(raw, &envelope); err != nil { + return nil, fmt.Errorf("parsing credentials JSON: %w", err) + } + + result := make(map[string]string, len(envelope.Credentials)) + for k, v := range envelope.Credentials { + if k == "type" { + continue // exclude the discriminator + } + var s string + if err := json.Unmarshal(v, &s); err != nil { + continue // skip non-string values + } + result[k] = s + } + + return result, nil +} + +// connectionInfoFromARM converts the ARM SDK response type to our domain type. +func connectionInfoFromARM(r *armcognitiveservices.ConnectionPropertiesV2BasicResource) ConnectionInfo { + info := ConnectionInfo{} + if r.ID != nil { + info.ID = *r.ID + info.Name = lastSegment(*r.ID) + } + + if r.Properties == nil { + return info + } + + props := r.Properties.GetConnectionPropertiesV2() + if props == nil { + return info + } + + if props.Category != nil { + info.Category = string(*props.Category) + } + if props.AuthType != nil { + info.AuthType = string(*props.AuthType) + } + if props.Target != nil { + info.Target = *props.Target + } + if props.IsSharedToAll != nil { + info.IsDefault = *props.IsSharedToAll + } + if props.Metadata != nil { + info.Metadata = make(map[string]string, len(props.Metadata)) + for k, v := range props.Metadata { + if v != nil { + info.Metadata[k] = *v + } + } + } + + return info +} + +// buildCreateBody constructs the ARM request body for creating a connection. +func buildCreateBody(params CreateConnectionParams) (*armcognitiveservices.ConnectionPropertiesV2BasicResource, error) { + category := armcognitiveservices.ConnectionCategory(params.Category) + metadata := toStringPtrMap(params.Metadata) + + switch params.AuthType { + case "ApiKey": + authType := armcognitiveservices.ConnectionAuthTypeAPIKey + return &armcognitiveservices.ConnectionPropertiesV2BasicResource{ + Properties: &armcognitiveservices.APIKeyAuthConnectionProperties{ + AuthType: &authType, + Category: &category, + Target: to.Ptr(params.Target), + Credentials: &armcognitiveservices.ConnectionAPIKey{ + Key: to.Ptr(params.Key), + }, + Metadata: metadata, + }, + }, nil + + case "CustomKeys": + authType := armcognitiveservices.ConnectionAuthTypeCustomKeys + return &armcognitiveservices.ConnectionPropertiesV2BasicResource{ + Properties: &armcognitiveservices.CustomKeysConnectionProperties{ + AuthType: &authType, + Category: &category, + Target: to.Ptr(params.Target), + Credentials: &armcognitiveservices.CustomKeys{ + Keys: toStringPtrMap(params.Keys), + }, + Metadata: metadata, + }, + }, nil + + case "None": + authType := armcognitiveservices.ConnectionAuthTypeNone + return &armcognitiveservices.ConnectionPropertiesV2BasicResource{ + Properties: &armcognitiveservices.NoneAuthTypeConnectionProperties{ + AuthType: &authType, + Category: &category, + Target: to.Ptr(params.Target), + Metadata: metadata, + }, + }, nil + + default: + return nil, fmt.Errorf( + "unsupported auth type %q; supported types: ApiKey, CustomKeys, None", params.AuthType, + ) + } +} + +// mergeUpdate applies UpdateConnectionParams onto an existing ARM resource for a PUT. +func mergeUpdate( + current *armcognitiveservices.ConnectionPropertiesV2BasicResource, + params UpdateConnectionParams, +) (*armcognitiveservices.ConnectionPropertiesV2BasicResource, error) { + if current.Properties == nil { + return nil, fmt.Errorf("connection has no properties to update") + } + + props := current.Properties.GetConnectionPropertiesV2() + if props == nil { + return nil, fmt.Errorf("connection has no base properties") + } + + // Update target + if params.Target != nil { + props.Target = params.Target + } + + // Merge metadata + if len(params.Metadata) > 0 { + if props.Metadata == nil { + props.Metadata = make(map[string]*string) + } + for k, v := range params.Metadata { + props.Metadata[k] = to.Ptr(v) + } + } + + // Update credentials based on auth type + if props.AuthType != nil { + switch *props.AuthType { + case armcognitiveservices.ConnectionAuthTypeAPIKey: + if params.Key != nil { + if apiKeyProps, ok := current.Properties.(*armcognitiveservices.APIKeyAuthConnectionProperties); ok { + if apiKeyProps.Credentials == nil { + apiKeyProps.Credentials = &armcognitiveservices.ConnectionAPIKey{} + } + apiKeyProps.Credentials.Key = params.Key + } + } + case armcognitiveservices.ConnectionAuthTypeCustomKeys: + if len(params.Keys) > 0 { + if customProps, ok := current.Properties.(*armcognitiveservices.CustomKeysConnectionProperties); ok { + if customProps.Credentials == nil { + customProps.Credentials = &armcognitiveservices.CustomKeys{ + Keys: make(map[string]*string), + } + } + if customProps.Credentials.Keys == nil { + customProps.Credentials.Keys = make(map[string]*string) + } + for k, v := range params.Keys { + customProps.Credentials.Keys[k] = to.Ptr(v) + } + } + } + } + } + + return current, nil +} + +// toStringPtrMap converts map[string]string to map[string]*string for the ARM SDK. +func toStringPtrMap(m map[string]string) map[string]*string { + if m == nil { + return nil + } + result := make(map[string]*string, len(m)) + for k, v := range m { + result[k] = to.Ptr(v) + } + return result +} + +// lastSegment returns the last "/" separated segment of a path. +func lastSegment(path string) string { + for i := len(path) - 1; i >= 0; i-- { + if path[i] == '/' { + return path[i+1:] + } + } + return path +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/connection_manager_test.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/connection_manager_test.go new file mode 100644 index 00000000000..50da8bb83f4 --- /dev/null +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/connection_manager_test.go @@ -0,0 +1,273 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package azure + +import ( + "encoding/json" + "testing" + + armcognitiveservices "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/cognitiveservices/armcognitiveservices/v2" + "github.com/stretchr/testify/require" +) + +func TestBuildCreateBody_ApiKey(t *testing.T) { + params := CreateConnectionParams{ + Category: "ApiKey", + Target: "https://httpbin.org/get", + AuthType: "ApiKey", + Key: "test-key-12345", + Metadata: map[string]string{"ApiType": "Azure"}, + } + + body, err := buildCreateBody(params) + require.NoError(t, err) + require.NotNil(t, body) + require.NotNil(t, body.Properties) + + props := body.Properties.GetConnectionPropertiesV2() + require.Equal(t, "ApiKey", string(*props.AuthType)) + require.Equal(t, "ApiKey", string(*props.Category)) + require.Equal(t, "https://httpbin.org/get", *props.Target) + require.Equal(t, "Azure", *props.Metadata["ApiType"]) +} + +func TestBuildCreateBody_CustomKeys(t *testing.T) { + params := CreateConnectionParams{ + Category: "RemoteTool", + Target: "https://mcp.tavily.com/mcp", + AuthType: "CustomKeys", + Keys: map[string]string{"x-api-key": "tvly-abc123"}, + Metadata: map[string]string{"type": "custom_MCP"}, + } + + body, err := buildCreateBody(params) + require.NoError(t, err) + require.NotNil(t, body) + + props := body.Properties.GetConnectionPropertiesV2() + require.Equal(t, "CustomKeys", string(*props.AuthType)) + require.Equal(t, "RemoteTool", string(*props.Category)) + require.Equal(t, "https://mcp.tavily.com/mcp", *props.Target) +} + +func TestBuildCreateBody_None(t *testing.T) { + params := CreateConnectionParams{ + Category: "RemoteTool", + Target: "https://learn.microsoft.com/api/mcp", + AuthType: "None", + } + + body, err := buildCreateBody(params) + require.NoError(t, err) + require.NotNil(t, body) + + props := body.Properties.GetConnectionPropertiesV2() + require.Equal(t, "None", string(*props.AuthType)) + require.Equal(t, "RemoteTool", string(*props.Category)) + require.Equal(t, "https://learn.microsoft.com/api/mcp", *props.Target) + require.Nil(t, props.Metadata) +} + +func TestBuildCreateBody_UnsupportedAuthType(t *testing.T) { + params := CreateConnectionParams{ + Category: "RemoteTool", + Target: "https://example.com", + AuthType: "OAuth2", + } + + _, err := buildCreateBody(params) + require.Error(t, err) + require.Contains(t, err.Error(), "unsupported auth type") + require.Contains(t, err.Error(), "OAuth2") +} + +func TestBuildCreateBody_NilMetadata(t *testing.T) { + params := CreateConnectionParams{ + Category: "ApiKey", + Target: "https://example.com", + AuthType: "ApiKey", + Key: "key", + } + + body, err := buildCreateBody(params) + require.NoError(t, err) + + props := body.Properties.GetConnectionPropertiesV2() + require.Nil(t, props.Metadata) +} + +func TestToStringPtrMap(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + result := toStringPtrMap(nil) + require.Nil(t, result) + }) + + t.Run("empty map", func(t *testing.T) { + result := toStringPtrMap(map[string]string{}) + require.NotNil(t, result) + require.Empty(t, result) + }) + + t.Run("populated map", func(t *testing.T) { + input := map[string]string{"a": "1", "b": "2"} + result := toStringPtrMap(input) + require.Len(t, result, 2) + require.Equal(t, "1", *result["a"]) + require.Equal(t, "2", *result["b"]) + }) +} + +func TestLastSegment(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"/subscriptions/sub/resourceGroups/rg/connections/my-conn", "my-conn"}, + {"my-conn", "my-conn"}, + {"/a/b/c", "c"}, + {"", ""}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + require.Equal(t, tt.want, lastSegment(tt.input)) + }) + } +} + +func TestParseRawCredentials(t *testing.T) { + t.Run("ApiKey credentials", func(t *testing.T) { + raw := `{ + "credentials": { + "key": "my-api-key", + "type": "ApiKey" + } + }` + + var envelope struct { + Credentials map[string]json.RawMessage `json:"credentials"` + } + err := json.Unmarshal([]byte(raw), &envelope) + require.NoError(t, err) + + result := make(map[string]string) + for k, v := range envelope.Credentials { + if k == "type" { + continue + } + var s string + if err := json.Unmarshal(v, &s); err != nil { + continue + } + result[k] = s + } + + require.Equal(t, map[string]string{"key": "my-api-key"}, result) + }) + + t.Run("CustomKeys credentials", func(t *testing.T) { + raw := `{ + "credentials": { + "x-api-key": "tvly-abc123", + "secret": "another-secret", + "type": "CustomKeys" + } + }` + + var envelope struct { + Credentials map[string]json.RawMessage `json:"credentials"` + } + err := json.Unmarshal([]byte(raw), &envelope) + require.NoError(t, err) + + result := make(map[string]string) + for k, v := range envelope.Credentials { + if k == "type" { + continue + } + var s string + if err := json.Unmarshal(v, &s); err != nil { + continue + } + result[k] = s + } + + require.Equal(t, "tvly-abc123", result["x-api-key"]) + require.Equal(t, "another-secret", result["secret"]) + require.Len(t, result, 2) + }) + + t.Run("CustomKeys with key named 'key'", func(t *testing.T) { + raw := `{ + "credentials": { + "key": "value-456", + "type": "CustomKeys" + } + }` + + var envelope struct { + Credentials map[string]json.RawMessage `json:"credentials"` + } + err := json.Unmarshal([]byte(raw), &envelope) + require.NoError(t, err) + + result := make(map[string]string) + for k, v := range envelope.Credentials { + if k == "type" { + continue + } + var s string + if err := json.Unmarshal(v, &s); err != nil { + continue + } + result[k] = s + } + + require.Equal(t, map[string]string{"key": "value-456"}, result) + }) + + t.Run("no credentials", func(t *testing.T) { + raw := `{ + "credentials": { + "type": "AAD" + } + }` + + var envelope struct { + Credentials map[string]json.RawMessage `json:"credentials"` + } + err := json.Unmarshal([]byte(raw), &envelope) + require.NoError(t, err) + + result := make(map[string]string) + for k, v := range envelope.Credentials { + if k == "type" { + continue + } + var s string + if err := json.Unmarshal(v, &s); err != nil { + continue + } + result[k] = s + } + + require.Empty(t, result) + }) +} + +func TestConnectionInfoFromARM_NilProperties(t *testing.T) { + r := &armcognitiveservices.ConnectionPropertiesV2BasicResource{} + info := connectionInfoFromARM(r) + require.Empty(t, info.Name) + require.Empty(t, info.Category) +} + +func TestConnectionInfoFromARM_WithID(t *testing.T) { + id := "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.CognitiveServices/accounts/acct/projects/proj/connections/my-conn" + r := &armcognitiveservices.ConnectionPropertiesV2BasicResource{ + ID: &id, + } + info := connectionInfoFromARM(r) + require.Equal(t, "my-conn", info.Name) + require.Equal(t, id, info.ID) +} diff --git a/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_projects_client.go b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_projects_client.go index 9bcbafc02e4..7146b20c11d 100644 --- a/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_projects_client.go +++ b/cli/azd/extensions/azure.ai.agents/internal/pkg/azure/foundry_projects_client.go @@ -163,6 +163,26 @@ func (c *FoundryProjectsClient) GetPagedConnections(ctx context.Context) (*Paged // GetConnectionWithCredentials retrieves a specific connection with its credentials func (c *FoundryProjectsClient) GetConnectionWithCredentials(ctx context.Context, name string) (*Connection, error) { + body, err := c.getConnectionWithCredentialsRaw(ctx, name) + if err != nil { + return nil, err + } + + var connection Connection + if err := json.Unmarshal(body, &connection); err != nil { + return nil, fmt.Errorf("failed to unmarshal connection response: %w", err) + } + + return &connection, nil +} + +// getConnectionWithCredentialsRaw returns the raw JSON response body from the +// data-plane getConnectionWithCredentials endpoint. This is used by ConnectionManager +// to parse arbitrary credential fields that the typed Connection struct doesn't capture +// (e.g., custom key names in CustomKeys auth type). +func (c *FoundryProjectsClient) getConnectionWithCredentialsRaw( + ctx context.Context, name string, +) ([]byte, error) { targetEndpoint := fmt.Sprintf( "%s/connections/%s/getConnectionWithCredentials?api-version=%s", c.baseEndpoint, url.PathEscape(name), c.apiVersion) @@ -187,12 +207,7 @@ func (c *FoundryProjectsClient) GetConnectionWithCredentials(ctx context.Context return nil, fmt.Errorf("failed to read response body: %w", err) } - var connection Connection - if err := json.Unmarshal(body, &connection); err != nil { - return nil, fmt.Errorf("failed to unmarshal connection response: %w", err) - } - - return &connection, nil + return body, nil } // GetAllConnections retrieves all connections from the project, handling pagination