From ea211653961f4704eaa38e7344dd83d6cfec9b8c Mon Sep 17 00:00:00 2001 From: Yash Trivedi Date: Tue, 17 Mar 2026 11:57:16 -0700 Subject: [PATCH 01/12] feat: Introduce new Azure Cosmos DB Change Feed Scaler Add internal KEDA scaler for Azure Cosmos DB change feed processor lag estimation. Translates the existing C# external scaler to a native Go internal scaler. - REST API client with HMAC-SHA256 and workload identity auth - Supports .NET and Java SDK lease formats (PK range and EPK based) - Configurable lag and activation thresholds - Separate data and lease container connection support - Partition split detection with automatic retry - Unit tests with httptest mocks for all lease formats - E2E test scaffold - Auto-generated scaler metadata schema Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Yash Trivedi --- CHANGELOG.md | 1 + pkg/scalers/azure_cosmosdb_scaler.go | 541 +++++++++ pkg/scalers/azure_cosmosdb_scaler_test.go | 1045 +++++++++++++++++ pkg/scaling/scalers_builder.go | 2 + schema/generated/scalers-schema.json | 86 ++ schema/generated/scalers-schema.yaml | 57 + tests/.env | 1 + .../azure_cosmosdb/azure_cosmosdb_test.go | 216 ++++ 8 files changed, 1949 insertions(+) create mode 100644 pkg/scalers/azure_cosmosdb_scaler.go create mode 100644 pkg/scalers/azure_cosmosdb_scaler_test.go create mode 100644 tests/scalers/azure/azure_cosmosdb/azure_cosmosdb_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 22652fcdf62..1e70752c182 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,6 +68,7 @@ To learn more about active deprecations, we recommend checking [GitHub Discussio ### New +- **General**: Introduce new Azure Cosmos DB Change Feed Scaler ([#7556](https://github.com/kedacore/keda/issues/7556)) - TODO ([#XXX](https://github.com/kedacore/keda/issues/XXX)) #### Experimental diff --git a/pkg/scalers/azure_cosmosdb_scaler.go b/pkg/scalers/azure_cosmosdb_scaler.go new file mode 100644 index 00000000000..9c423b002c8 --- /dev/null +++ b/pkg/scalers/azure_cosmosdb_scaler.go @@ -0,0 +1,541 @@ +package scalers + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/go-logr/logr" + v2 "k8s.io/api/autoscaling/v2" + "k8s.io/metrics/pkg/apis/external_metrics" + + kedav1alpha1 "github.com/kedacore/keda/v2/apis/keda/v1alpha1" + "github.com/kedacore/keda/v2/pkg/scalers/azure" + "github.com/kedacore/keda/v2/pkg/scalers/scalersconfig" + kedautil "github.com/kedacore/keda/v2/pkg/util" +) + +const ( + cosmosDBMetricType = "External" + cosmosDBRestAPIVersion = "2018-12-31" +) + +type azureCosmosDBScaler struct { + metricType v2.MetricTargetType + metadata *azureCosmosDBMetadata + cosmosClient *cosmosDBClient + logger logr.Logger +} + +type azureCosmosDBMetadata struct { + DatabaseID string `keda:"name=databaseId, order=triggerMetadata"` + ContainerID string `keda:"name=containerId, order=triggerMetadata"` + LeaseDatabaseID string `keda:"name=leaseDatabaseId, order=triggerMetadata"` + LeaseContainerID string `keda:"name=leaseContainerId, order=triggerMetadata"` + ProcessorName string `keda:"name=processorName, order=triggerMetadata"` + Endpoint string `keda:"name=endpoint, order=authParams;triggerMetadata, optional"` + Connection string `keda:"name=connection, order=authParams;resolvedEnv;triggerMetadata, optional"` + LeaseEndpoint string `keda:"name=leaseEndpoint, order=authParams;triggerMetadata, optional"` + LeaseConnection string `keda:"name=leaseConnection, order=authParams;resolvedEnv;triggerMetadata, optional"` + CosmosDBKey string `keda:"name=cosmosDBKey, order=authParams;resolvedEnv, optional"` + LeaseCosmosDBKey string `keda:"name=leaseCosmosDBKey, order=authParams;resolvedEnv, optional"` + Threshold int64 `keda:"name=lagThreshold, order=triggerMetadata, default=1"` + ActivationThreshold int64 `keda:"name=activationLagThreshold, order=triggerMetadata, default=0"` + TriggerIndex int +} + +// cosmosDBClient provides low-level access to Cosmos DB via the REST API +// for querying lease documents and reading the change feed. +type cosmosDBClient struct { + httpClient *http.Client + dataEndpoint string + dataKey string + leaseEndpoint string + leaseKey string + leaseDatabaseID string + leaseContainerID string + databaseID string + containerID string + credential azcore.TokenCredential +} + +type leaseDocument struct { + ID string `json:"id"` + LeaseToken string `json:"LeaseToken"` + ContinuationToken string `json:"ContinuationToken"` + Owner string `json:"Owner,omitempty"` +} + +type changeFeedResponse struct { + StatusCode int + Items []json.RawMessage + SessionToken string +} + +// NewAzureCosmosDBScaler creates a new Azure Cosmos DB change feed scaler. +func NewAzureCosmosDBScaler(config *scalersconfig.ScalerConfig) (Scaler, error) { + metricType, err := GetMetricTargetType(config) + if err != nil { + return nil, fmt.Errorf("error getting scaler metric type: %w", err) + } + + logger := InitializeLogger(config, "azure_cosmosdb_scaler") + + meta, err := parseAzureCosmosDBMetadata(config) + if err != nil { + return nil, fmt.Errorf("error parsing azure cosmos db metadata: %w", err) + } + + cosmosClient, err := newCosmosDBClient(meta, config.PodIdentity, logger, config.GlobalHTTPTimeout) + if err != nil { + return nil, fmt.Errorf("error creating cosmos db client: %w", err) + } + + return &azureCosmosDBScaler{ + metricType: metricType, + metadata: meta, + cosmosClient: cosmosClient, + logger: logger, + }, nil +} + +func parseAzureCosmosDBMetadata(config *scalersconfig.ScalerConfig) (*azureCosmosDBMetadata, error) { + meta := &azureCosmosDBMetadata{} + if err := config.TypedConfig(meta); err != nil { + return nil, fmt.Errorf("error parsing metadata: %w", err) + } + + switch config.PodIdentity.Provider { + case "", kedav1alpha1.PodIdentityProviderNone: + if meta.Connection == "" && (meta.Endpoint == "" || meta.CosmosDBKey == "") { + return nil, fmt.Errorf("connection string or endpoint+cosmosDBKey is required when not using pod identity") + } + case kedav1alpha1.PodIdentityProviderAzureWorkload: + if meta.Endpoint == "" && meta.Connection == "" { + return nil, fmt.Errorf("endpoint or connection string is required when using workload identity") + } + default: + return nil, fmt.Errorf("pod identity %s not supported for azure cosmos db", config.PodIdentity.Provider) + } + + // Default lease settings to data settings if not specified + if meta.LeaseConnection == "" { + meta.LeaseConnection = meta.Connection + } + if meta.LeaseEndpoint == "" { + meta.LeaseEndpoint = meta.Endpoint + } + if meta.LeaseCosmosDBKey == "" { + meta.LeaseCosmosDBKey = meta.CosmosDBKey + } + + meta.TriggerIndex = config.TriggerIndex + return meta, nil +} + +func newCosmosDBClient(meta *azureCosmosDBMetadata, podIdentity kedav1alpha1.AuthPodIdentity, logger logr.Logger, httpTimeout time.Duration) (*cosmosDBClient, error) { + if httpTimeout == 0 { + httpTimeout = 30 * time.Second + } + + client := &cosmosDBClient{ + httpClient: kedautil.CreateHTTPClient(httpTimeout, false), + leaseDatabaseID: meta.LeaseDatabaseID, + leaseContainerID: meta.LeaseContainerID, + databaseID: meta.DatabaseID, + containerID: meta.ContainerID, + } + + // Resolve data endpoint and key + if meta.Connection != "" { + endpoint, key, err := parseCosmosDBConnectionString(meta.Connection) + if err != nil { + return nil, fmt.Errorf("error parsing connection string: %w", err) + } + client.dataEndpoint = endpoint + client.dataKey = key + } else if meta.Endpoint != "" { + client.dataEndpoint = meta.Endpoint + client.dataKey = meta.CosmosDBKey + } + + // Resolve lease endpoint and key + if meta.LeaseConnection != "" { + endpoint, key, err := parseCosmosDBConnectionString(meta.LeaseConnection) + if err != nil { + return nil, fmt.Errorf("error parsing lease connection string: %w", err) + } + client.leaseEndpoint = endpoint + client.leaseKey = key + } else if meta.LeaseEndpoint != "" { + client.leaseEndpoint = meta.LeaseEndpoint + client.leaseKey = meta.LeaseCosmosDBKey + } + + if client.dataEndpoint == "" || client.leaseEndpoint == "" { + return nil, fmt.Errorf("failed to determine cosmos db endpoints") + } + + // Set up workload identity credential for bearer token auth + if podIdentity.Provider == kedav1alpha1.PodIdentityProviderAzureWorkload && client.dataKey == "" { + cred, err := azure.NewChainedCredential(logger, podIdentity) + if err != nil { + return nil, fmt.Errorf("error creating azure credential for workload identity: %w", err) + } + client.credential = cred + } + + return client, nil +} + +func parseCosmosDBConnectionString(connectionString string) (string, string, error) { + parts := strings.Split(connectionString, ";") + var endpoint, key string + + for _, part := range parts { + part = strings.TrimSpace(part) + if strings.HasPrefix(part, "AccountEndpoint=") { + endpoint = strings.TrimPrefix(part, "AccountEndpoint=") + } else if strings.HasPrefix(part, "AccountKey=") { + key = strings.TrimPrefix(part, "AccountKey=") + } + } + + if endpoint == "" || key == "" { + return "", "", fmt.Errorf("invalid connection string: missing AccountEndpoint or AccountKey") + } + + return endpoint, key, nil +} + +// setAuthHeader sets the Authorization header using either master key HMAC-SHA256 or bearer token. +func (c *cosmosDBClient) setAuthHeader(req *http.Request, verb, resourceType, resourceLink, date, key string) error { + if key != "" { + token := generateCosmosDBAuthToken(verb, resourceType, resourceLink, date, key) + req.Header.Set("Authorization", token) + return nil + } + + if c.credential != nil { + tk, err := c.credential.GetToken(req.Context(), policy.TokenRequestOptions{ + Scopes: []string{azure.PublicCloud.ResourceIdentifiers.CosmosDB + "/.default"}, + }) + if err != nil { + return fmt.Errorf("error acquiring bearer token: %w", err) + } + req.Header.Set("Authorization", "type=aad&ver=1.0&sig="+tk.Token) + return nil + } + + return fmt.Errorf("no authentication method available: provide a key or configure workload identity") +} + +// generateCosmosDBAuthToken generates an HMAC-SHA256 auth token for Cosmos DB REST API. +// Format: type=master&ver=1.0&sig={hashsignature} +// Signature input: {verb}\n{resourceType}\n{resourceLink}\n{date}\n\n +func generateCosmosDBAuthToken(verb, resourceType, resourceLink, date, key string) string { + keyBytes, err := base64.StdEncoding.DecodeString(key) + if err != nil { + return "" + } + + text := fmt.Sprintf("%s\n%s\n%s\n%s\n\n", + strings.ToLower(verb), + strings.ToLower(resourceType), + resourceLink, + strings.ToLower(date)) + + h := hmac.New(sha256.New, keyBytes) + h.Write([]byte(text)) + signature := base64.StdEncoding.EncodeToString(h.Sum(nil)) + + return url.QueryEscape(fmt.Sprintf("type=master&ver=1.0&sig=%s", signature)) +} + +func (c *cosmosDBClient) queryLeases(ctx context.Context) ([]leaseDocument, error) { + resourceLink := fmt.Sprintf("dbs/%s/colls/%s", c.leaseDatabaseID, c.leaseContainerID) + reqURL := fmt.Sprintf("%s/%s/docs", strings.TrimRight(c.leaseEndpoint, "/"), resourceLink) + + body := `{"query":"SELECT * FROM c"}` + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, strings.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + + now := time.Now().UTC().Format(http.TimeFormat) + req.Header.Set("x-ms-date", now) + req.Header.Set("x-ms-version", cosmosDBRestAPIVersion) + req.Header.Set("Content-Type", "application/query+json") + req.Header.Set("x-ms-documentdb-isquery", "true") + req.Header.Set("x-ms-documentdb-query-enablecrosspartition", "true") + + if err := c.setAuthHeader(req, http.MethodPost, "docs", resourceLink, now, c.leaseKey); err != nil { + return nil, fmt.Errorf("error setting auth header: %w", err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("error executing request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("query failed with status %d: %s", resp.StatusCode, string(respBody)) + } + + var result struct { + Documents []json.RawMessage `json:"Documents"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("error decoding response: %w", err) + } + + // Parse and filter out metadata documents (those without LeaseToken or ContinuationToken) + var leases []leaseDocument + for _, raw := range result.Documents { + var doc leaseDocument + if err := json.Unmarshal(raw, &doc); err != nil { + continue + } + if doc.LeaseToken != "" && doc.ContinuationToken != "" { + leases = append(leases, doc) + } + } + + return leases, nil +} + +func (c *cosmosDBClient) readChangeFeed(ctx context.Context, partitionKeyRangeID, continuationToken string) (*changeFeedResponse, error) { + resourceLink := fmt.Sprintf("dbs/%s/colls/%s", c.databaseID, c.containerID) + reqURL := fmt.Sprintf("%s/%s/docs", strings.TrimRight(c.dataEndpoint, "/"), resourceLink) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + + now := time.Now().UTC().Format(http.TimeFormat) + req.Header.Set("x-ms-date", now) + req.Header.Set("x-ms-version", cosmosDBRestAPIVersion) + req.Header.Set("x-ms-documentdb-partitionkeyrangeid", partitionKeyRangeID) + req.Header.Set("A-IM", "Incremental feed") + req.Header.Set("x-ms-max-item-count", "1") + + if continuationToken != "" { + req.Header.Set("If-None-Match", continuationToken) + } + + if err := c.setAuthHeader(req, http.MethodGet, "docs", resourceLink, now, c.dataKey); err != nil { + return nil, fmt.Errorf("error setting auth header: %w", err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("error executing request: %w", err) + } + defer resp.Body.Close() + + cfResp := &changeFeedResponse{ + StatusCode: resp.StatusCode, + SessionToken: resp.Header.Get("x-ms-session-token"), + } + + if resp.StatusCode == http.StatusNotModified || resp.StatusCode == http.StatusGone { + return cfResp, nil + } + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("read change feed failed with status %d: %s", resp.StatusCode, string(respBody)) + } + + var result struct { + Documents []json.RawMessage `json:"Documents"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("error decoding response: %w", err) + } + + cfResp.Items = result.Documents + return cfResp, nil +} + +// estimateLag estimates the change feed lag and returns the count of partitions with lag > 0. +// If a partition split (410 Gone) is detected, it retries once to get fresh lease data. +func (c *cosmosDBClient) estimateLag(ctx context.Context) (int64, error) { + count, splitDetected, err := c.estimateOnce(ctx) + if err != nil { + return 0, err + } + if splitDetected { + count, _, err = c.estimateOnce(ctx) + if err != nil { + return 0, err + } + } + return count, nil +} + +func (c *cosmosDBClient) estimateOnce(ctx context.Context) (int64, bool, error) { + leases, err := c.queryLeases(ctx) + if err != nil { + return 0, false, fmt.Errorf("error querying leases: %w", err) + } + + if len(leases) == 0 { + return 0, false, nil + } + + partitionsWithLag := int64(0) + splitDetected := false + + for _, lease := range leases { + lag, isSplit, err := c.estimatePartitionLag(ctx, lease) + if err != nil { + return 0, false, fmt.Errorf("error estimating lag for partition %s: %w", lease.LeaseToken, err) + } + if isSplit { + splitDetected = true + continue + } + if lag > 0 { + partitionsWithLag++ + } + } + + return partitionsWithLag, splitDetected, nil +} + +// estimatePartitionLag calculates the lag for a single partition. +// Algorithm (matching .NET/Java SDKs): +// 1. Read change feed with maxItemCount=1 starting from the lease's continuation token +// 2. Extract latest LSN from session token +// 3. If items present: lag = sessionLSN - firstItem._lsn + 1 +// 4. If no items (304): lag = 0 (caught up) +// 5. If 410 Gone: report lag = -1 (split/merge) +func (c *cosmosDBClient) estimatePartitionLag(ctx context.Context, lease leaseDocument) (int64, bool, error) { + cfResp, err := c.readChangeFeed(ctx, lease.LeaseToken, lease.ContinuationToken) + if err != nil { + return 0, false, err + } + + // 410 Gone indicates partition split or merge + if cfResp.StatusCode == http.StatusGone { + return -1, true, nil + } + + // 304 Not Modified or empty results means processor is caught up + if cfResp.StatusCode == http.StatusNotModified || len(cfResp.Items) == 0 { + return 0, false, nil + } + + // Calculate lag: sessionLSN - firstItemLSN + 1 + sessionLSN, err := parseLSNFromSessionToken(cfResp.SessionToken) + if err != nil || sessionLSN < 0 { + return 0, false, nil + } + + firstItemLSN, err := extractItemLSN(cfResp.Items[0]) + if err != nil || firstItemLSN < 0 { + return 0, false, nil + } + + lag := sessionLSN - firstItemLSN + 1 + if lag < 0 { + return 0, false, nil + } + + return lag, false, nil +} + +// extractLSNFromSessionToken extracts the LSN from a Cosmos DB session token. +// Session token formats: +// - Simple: "{pkRangeId}:{lsn}" +// - Compound: "{pkRangeId}:{localLsn}#{globalLsn}" +// +// This matches the logic in both the .NET SDK (ChangeFeedEstimatorIterator.ExtractLsnFromSessionToken) +// and Java SDK (IncrementalChangeFeedProcessorImpl). +func extractLSNFromSessionToken(sessionToken string) string { + if sessionToken == "" { + return "" + } + + colonIdx := strings.IndexByte(sessionToken, ':') + if colonIdx < 0 { + return sessionToken + } + parsed := sessionToken[colonIdx+1:] + + segments := strings.Split(parsed, "#") + if len(segments) >= 2 { + return segments[1] // Global LSN + } + return segments[0] +} + +func parseLSNFromSessionToken(sessionToken string) (int64, error) { + lsnStr := extractLSNFromSessionToken(sessionToken) + if lsnStr == "" { + return -1, fmt.Errorf("empty session token") + } + return strconv.ParseInt(lsnStr, 10, 64) +} + +// extractItemLSN extracts the _lsn value from a Cosmos DB change feed document. +func extractItemLSN(item json.RawMessage) (int64, error) { + var doc struct { + LSN json.Number `json:"_lsn"` + } + if err := json.Unmarshal(item, &doc); err != nil { + return -1, fmt.Errorf("parsing item: %w", err) + } + return doc.LSN.Int64() +} + +// GetMetricSpecForScaling returns the metric spec for scaling. +func (s *azureCosmosDBScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpec { + metricName := kedautil.NormalizeString(fmt.Sprintf("azure-cosmosdb-%s-%s", + s.metadata.LeaseContainerID, s.metadata.ProcessorName)) + externalMetric := &v2.ExternalMetricSource{ + Metric: v2.MetricIdentifier{ + Name: GenerateMetricNameWithIndex(s.metadata.TriggerIndex, metricName), + }, + Target: GetMetricTarget(s.metricType, s.metadata.Threshold), + } + metricSpec := v2.MetricSpec{External: externalMetric, Type: cosmosDBMetricType} + return []v2.MetricSpec{metricSpec} +} + +// GetMetricsAndActivity returns the metric value and activity status. +func (s *azureCosmosDBScaler) GetMetricsAndActivity(ctx context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) { + partitionsWithLag, err := s.cosmosClient.estimateLag(ctx) + if err != nil { + s.logger.Error(err, "error getting cosmos db change feed lag") + return []external_metrics.ExternalMetricValue{}, false, err + } + + s.logger.V(1).Info(fmt.Sprintf("Cosmos DB partitions with lag: %d", partitionsWithLag)) + + metric := GenerateMetricInMili(metricName, float64(partitionsWithLag)) + return []external_metrics.ExternalMetricValue{metric}, partitionsWithLag > s.metadata.ActivationThreshold, nil +} + +// Close cleans up the scaler resources. +func (s *azureCosmosDBScaler) Close(context.Context) error { + if s.cosmosClient != nil && s.cosmosClient.httpClient != nil { + s.cosmosClient.httpClient.CloseIdleConnections() + } + return nil +} diff --git a/pkg/scalers/azure_cosmosdb_scaler_test.go b/pkg/scalers/azure_cosmosdb_scaler_test.go new file mode 100644 index 00000000000..23ae0a3b9e9 --- /dev/null +++ b/pkg/scalers/azure_cosmosdb_scaler_test.go @@ -0,0 +1,1045 @@ +package scalers + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-logr/logr" + "github.com/stretchr/testify/assert" + v2 "k8s.io/api/autoscaling/v2" + + kedav1alpha1 "github.com/kedacore/keda/v2/apis/keda/v1alpha1" + "github.com/kedacore/keda/v2/pkg/scalers/scalersconfig" +) + +var testCosmosDBResolvedEnv = map[string]string{ + "COSMOS_CONNECTION": "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=dGVzdGtleQ==", +} + +type parseCosmosDBMetadataTestData struct { + name string + metadata map[string]string + isError bool + resolvedEnv map[string]string + authParams map[string]string + podIdentity kedav1alpha1.PodIdentityProvider +} + +type cosmosDBMetricIdentifier struct { + name string + metadataTestData *parseCosmosDBMetadataTestData + triggerIndex int + metricName string +} + +var testCosmosDBMetadata = []parseCosmosDBMetadataTestData{ + { + name: "nothing passed", + metadata: map[string]string{}, + isError: true, + resolvedEnv: testCosmosDBResolvedEnv, + authParams: map[string]string{}, + podIdentity: "", + }, + { + name: "properly formed with connection string", + metadata: map[string]string{ + "connectionFromEnv": "COSMOS_CONNECTION", + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + }, + isError: false, + resolvedEnv: testCosmosDBResolvedEnv, + authParams: map[string]string{}, + podIdentity: "", + }, + { + name: "missing database id", + metadata: map[string]string{ + "connectionFromEnv": "COSMOS_CONNECTION", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + }, + isError: true, + resolvedEnv: testCosmosDBResolvedEnv, + authParams: map[string]string{}, + podIdentity: "", + }, + { + name: "missing container id", + metadata: map[string]string{ + "connectionFromEnv": "COSMOS_CONNECTION", + "databaseId": "testdb", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + }, + isError: true, + resolvedEnv: testCosmosDBResolvedEnv, + authParams: map[string]string{}, + podIdentity: "", + }, + { + name: "missing lease database id", + metadata: map[string]string{ + "connectionFromEnv": "COSMOS_CONNECTION", + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseContainerId": "leases", + "processorName": "testprocessor", + }, + isError: true, + resolvedEnv: testCosmosDBResolvedEnv, + authParams: map[string]string{}, + podIdentity: "", + }, + { + name: "missing lease container id", + metadata: map[string]string{ + "connectionFromEnv": "COSMOS_CONNECTION", + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "processorName": "testprocessor", + }, + isError: true, + resolvedEnv: testCosmosDBResolvedEnv, + authParams: map[string]string{}, + podIdentity: "", + }, + { + name: "missing processor name", + metadata: map[string]string{ + "connectionFromEnv": "COSMOS_CONNECTION", + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + }, + isError: true, + resolvedEnv: testCosmosDBResolvedEnv, + authParams: map[string]string{}, + podIdentity: "", + }, + { + name: "missing connection and key", + metadata: map[string]string{ + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + }, + isError: true, + resolvedEnv: map[string]string{}, + authParams: map[string]string{}, + podIdentity: "", + }, + { + name: "connection from authParams", + metadata: map[string]string{ + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + }, + isError: false, + resolvedEnv: map[string]string{}, + authParams: map[string]string{ + "connection": "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=dGVzdGtleQ==", + }, + podIdentity: kedav1alpha1.PodIdentityProviderNone, + }, + { + name: "endpoint with key", + metadata: map[string]string{ + "endpoint": "https://test.documents.azure.com:443/", + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + }, + isError: false, + resolvedEnv: map[string]string{}, + authParams: map[string]string{ + "cosmosDBKey": "dGVzdGtleQ==", + }, + podIdentity: "", + }, + { + name: "podIdentity azure-workload with endpoint", + metadata: map[string]string{ + "endpoint": "https://test.documents.azure.com:443/", + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + }, + isError: false, + resolvedEnv: map[string]string{}, + authParams: map[string]string{ + "cosmosDBKey": "dGVzdGtleQ==", + }, + podIdentity: kedav1alpha1.PodIdentityProviderAzureWorkload, + }, + { + name: "podIdentity azure-workload without endpoint or connection", + metadata: map[string]string{ + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + }, + isError: true, + resolvedEnv: map[string]string{}, + authParams: map[string]string{}, + podIdentity: kedav1alpha1.PodIdentityProviderAzureWorkload, + }, + { + name: "separate lease connection", + metadata: map[string]string{ + "connectionFromEnv": "COSMOS_CONNECTION", + "leaseConnectionFromEnv": "COSMOS_CONNECTION", + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + }, + isError: false, + resolvedEnv: testCosmosDBResolvedEnv, + authParams: map[string]string{}, + podIdentity: "", + }, + { + name: "invalid lagThreshold", + metadata: map[string]string{ + "connectionFromEnv": "COSMOS_CONNECTION", + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + "lagThreshold": "invalid", + }, + isError: true, + resolvedEnv: testCosmosDBResolvedEnv, + authParams: map[string]string{}, + podIdentity: "", + }, + { + name: "invalid activationLagThreshold", + metadata: map[string]string{ + "connectionFromEnv": "COSMOS_CONNECTION", + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + "activationLagThreshold": "invalid", + }, + isError: true, + resolvedEnv: testCosmosDBResolvedEnv, + authParams: map[string]string{}, + podIdentity: "", + }, +} + +var cosmosDBMetricIdentifiers = []cosmosDBMetricIdentifier{ + { + name: "properly formed metric", + metadataTestData: &testCosmosDBMetadata[1], + triggerIndex: 0, + metricName: "s0-azure-cosmosdb-leases-testprocessor", + }, + { + name: "endpoint with key metric", + metadataTestData: &testCosmosDBMetadata[9], + triggerIndex: 1, + metricName: "s1-azure-cosmosdb-leases-testprocessor", + }, +} + +func TestCosmosDBParseMetadata(t *testing.T) { + for _, testData := range testCosmosDBMetadata { + t.Run(testData.name, func(t *testing.T) { + config := &scalersconfig.ScalerConfig{ + TriggerMetadata: testData.metadata, + ResolvedEnv: testData.resolvedEnv, + AuthParams: testData.authParams, + PodIdentity: kedav1alpha1.AuthPodIdentity{Provider: testData.podIdentity}, + } + + _, err := parseAzureCosmosDBMetadata(config) + if err != nil && !testData.isError { + t.Errorf("Expected success but got error: %v", err) + } + if testData.isError && err == nil { + t.Errorf("Expected error but got success. testData: %v", testData) + } + }) + } +} + +func TestCosmosDBGetMetricSpecForScaling(t *testing.T) { + for _, testData := range cosmosDBMetricIdentifiers { + t.Run(testData.name, func(t *testing.T) { + config := &scalersconfig.ScalerConfig{ + TriggerMetadata: testData.metadataTestData.metadata, + ResolvedEnv: testData.metadataTestData.resolvedEnv, + AuthParams: testData.metadataTestData.authParams, + PodIdentity: kedav1alpha1.AuthPodIdentity{Provider: testData.metadataTestData.podIdentity}, + TriggerIndex: testData.triggerIndex, + } + + meta, err := parseAzureCosmosDBMetadata(config) + if err != nil { + t.Fatal("Could not parse metadata:", err) + } + + mockScaler := azureCosmosDBScaler{ + metadata: meta, + logger: logr.Discard(), + metricType: v2.AverageValueMetricType, + } + + metricSpec := mockScaler.GetMetricSpecForScaling(context.Background()) + metricName := metricSpec[0].External.Metric.Name + assert.Equal(t, testData.metricName, metricName) + }) + } +} + +func TestCosmosDBConnectionStringParsing(t *testing.T) { + testCases := []struct { + name string + connectionStr string + expectError bool + expectedEndpoint string + }{ + { + name: "valid connection string", + connectionStr: "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=dGVzdGtleQ==", + expectError: false, + expectedEndpoint: "https://test.documents.azure.com:443/", + }, + { + name: "missing endpoint", + connectionStr: "AccountKey=dGVzdGtleQ==", + expectError: true, + }, + { + name: "missing key", + connectionStr: "AccountEndpoint=https://test.documents.azure.com:443/", + expectError: true, + }, + { + name: "empty string", + connectionStr: "", + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + endpoint, key, err := parseCosmosDBConnectionString(tc.connectionStr) + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expectedEndpoint, endpoint) + assert.NotEmpty(t, key) + } + }) + } +} + +func TestExtractLSNFromSessionToken(t *testing.T) { + testCases := []struct { + name string + token string + expectedLSN string + }{ + { + name: "simple format", + token: "0:123", + expectedLSN: "123", + }, + { + name: "compound format with global LSN", + token: "0:1#100#2", + expectedLSN: "100", + }, + { + name: "two segments", + token: "5:42#999", + expectedLSN: "999", + }, + { + name: "empty token", + token: "", + expectedLSN: "", + }, + { + name: "no colon", + token: "justanumber", + expectedLSN: "justanumber", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + lsn := extractLSNFromSessionToken(tc.token) + assert.Equal(t, tc.expectedLSN, lsn) + }) + } +} + +func TestExtractItemLSN(t *testing.T) { + testCases := []struct { + name string + item string + expectedLSN int64 + expectError bool + }{ + { + name: "numeric LSN", + item: `{"_lsn": 1234}`, + expectedLSN: 1234, + }, + { + name: "string LSN", + item: `{"_lsn": "5678"}`, + expectedLSN: 5678, + }, + { + name: "missing LSN", + item: `{"id": "doc1"}`, + expectedLSN: 0, + expectError: true, + }, + { + name: "invalid JSON", + item: `not json`, + expectedLSN: -1, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + lsn, err := extractItemLSN(json.RawMessage(tc.item)) + if tc.expectError { + assert.True(t, err != nil || lsn <= 0) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expectedLSN, lsn) + } + }) + } +} + +func TestCosmosDBAuthTokenGeneration(t *testing.T) { + token := generateCosmosDBAuthToken("get", "docs", "dbs/testdb/colls/testcol", "thu, 01 jan 2024 00:00:00 gmt", "dGVzdGtleQ==") + assert.Contains(t, token, "type%3Dmaster%26ver%3D1.0%26sig%3D") +} + +func TestCosmosDBLeaseParsingDotNetFormat(t *testing.T) { + // Realistic .NET SDK lease documents have: version=0, FeedRange, Mode, properties fields. + // The scaler must parse LeaseToken and ContinuationToken and ignore the extra fields. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + // Return raw JSON matching actual .NET SDK lease format + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[ + { + "id": "host1.documents.azure.com_abc==_def=..6", + "version": 0, + "_etag": "\"08000b63-0000-0800-0000-69a8c6640000\"", + "LeaseToken": "6", + "FeedRange": {"Range": {"min": "36DB6DB6DB6DB6DB6DB6DB6DB6DB6DB6", "max": "FF"}}, + "Owner": "dotnet-host1", + "ContinuationToken": "\"511\"", + "properties": {}, + "timestamp": "2026-03-04T23:55:16.5233511Z", + "Mode": "Incremental Feed", + "_rid": "abc123", + "_self": "dbs/abc/colls/def/docs/ghi", + "_ts": 1772668516 + }, + { + "id": "host1.documents.azure.com_abc==_def=..3", + "version": 0, + "LeaseToken": "3", + "FeedRange": {"Range": {"min": "0", "max": "36DB6DB6DB6DB6DB6DB6DB6DB6DB6DB6"}}, + "Owner": "dotnet-host1", + "ContinuationToken": "\"248\"", + "properties": {}, + "Mode": "Incremental Feed" + }, + { + "id": ".metadata.lease", + "version": 0, + "Owner": "", + "properties": {} + } + ]}`)) + case "/dbs/testdb/colls/data/docs": + pkRangeID := r.Header.Get("x-ms-documentdb-partitionkeyrangeid") + switch pkRangeID { + case "6": + // Partition 6 has lag: sessionLSN=600, itemLSN=512, lag=89 + w.Header().Set("x-ms-session-token", "6:0#600") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[{"id":"doc1","_lsn":512}]}`)) + case "3": + // Partition 3 is caught up + w.Header().Set("x-ms-session-token", "3:0#248") + w.WriteHeader(http.StatusNotModified) + } + } + })) + defer server.Close() + + client := &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "data", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + } + + partitionsWithLag, err := client.estimateLag(context.Background()) + assert.NoError(t, err) + // Only partition 6 has lag; partition 3 is caught up; metadata doc is filtered + assert.Equal(t, int64(1), partitionsWithLag) +} + +func TestCosmosDBLeaseParsingJavaFormat(t *testing.T) { + // Realistic Java SDK lease documents: no version field, no FeedRange/Mode/properties. + // The scaler must parse these identically. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[ + { + "id": "myhost.documents.azure.com_changefeed-estimator_data..2", + "_etag": "\"0100baf0-0000-0800-0000-69a8c5560000\"", + "LeaseToken": "2", + "ContinuationToken": "\"248\"", + "timestamp": "2026-03-04T23:50:46.219570110Z", + "Owner": "java-host1", + "_rid": "5jBSAKD6NqgELTEBAAAAAA==", + "_ts": 1772668246 + }, + { + "id": "myhost.documents.azure.com_changefeed-estimator_data..5", + "LeaseToken": "5", + "ContinuationToken": "\"100\"", + "Owner": "java-host2" + }, + { + "id": ".lock", + "_etag": "\"abc\"", + "Owner": "" + } + ]}`)) + case "/dbs/testdb/colls/data/docs": + pkRangeID := r.Header.Get("x-ms-documentdb-partitionkeyrangeid") + switch pkRangeID { + case "2": + // Partition 2 has lag: sessionLSN=400, itemLSN=249, lag=152 + w.Header().Set("x-ms-session-token", "2:0#400") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[{"id":"doc1","_lsn":249}]}`)) + case "5": + // Partition 5 also has lag: sessionLSN=200, itemLSN=101, lag=100 + w.Header().Set("x-ms-session-token", "5:0#200") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[{"id":"doc2","_lsn":101}]}`)) + } + } + })) + defer server.Close() + + client := &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "data", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + } + + partitionsWithLag, err := client.estimateLag(context.Background()) + assert.NoError(t, err) + // Both partitions 2 and 5 have lag; lock doc is filtered out + assert.Equal(t, int64(2), partitionsWithLag) +} + +func TestCosmosDBLeaseParsingMixedFormats(t *testing.T) { + // Edge case: lease container might contain docs from both SDKs (e.g. during migration). + // The scaler should handle this gracefully since it only reads common fields. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[ + { + "id": "dotnet-lease", + "version": 0, + "LeaseToken": "0", + "ContinuationToken": "\"500\"", + "Owner": "dotnet-host", + "FeedRange": {"Range": {"min": "0", "max": "80"}}, + "Mode": "Incremental Feed" + }, + { + "id": "java-lease", + "LeaseToken": "1", + "ContinuationToken": "\"300\"", + "Owner": "java-host" + } + ]}`)) + case "/dbs/testdb/colls/data/docs": + // Both partitions have lag + w.Header().Set("x-ms-session-token", "0:0#700") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[{"id":"doc1","_lsn":550}]}`)) + } + })) + defer server.Close() + + client := &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "data", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + } + + partitionsWithLag, err := client.estimateLag(context.Background()) + assert.NoError(t, err) + assert.Equal(t, int64(2), partitionsWithLag) +} + +func TestCosmosDBLeaseParsingEPKBasedDotNet(t *testing.T) { + // .NET SDK EPK-based leases (version=1) use FeedRange with EPK ranges. + // ContinuationToken is still a quoted LSN for incremental feed mode. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[ + { + "id": "host1..epk..0-AA", + "version": 1, + "LeaseToken": "0", + "FeedRange": {"Range": {"min": "", "max": "AA"}}, + "Owner": "dotnet-host1", + "ContinuationToken": "\"750\"", + "Mode": "LatestVersion" + }, + { + "id": "host1..epk..AA-FF", + "version": 1, + "LeaseToken": "1", + "FeedRange": {"Range": {"min": "AA", "max": "FF"}}, + "Owner": "dotnet-host1", + "ContinuationToken": "\"320\"", + "Mode": "LatestVersion" + } + ]}`)) + case "/dbs/testdb/colls/data/docs": + pkRangeID := r.Header.Get("x-ms-documentdb-partitionkeyrangeid") + switch pkRangeID { + case "0": + w.Header().Set("x-ms-session-token", "0:0#900") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[{"id":"doc1","_lsn":751}]}`)) + case "1": + w.Header().Set("x-ms-session-token", "1:0#320") + w.WriteHeader(http.StatusNotModified) + } + } + })) + defer server.Close() + + client := &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "data", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + } + + partitionsWithLag, err := client.estimateLag(context.Background()) + assert.NoError(t, err) + // Partition 0 has lag (900-751+1=150), partition 1 is caught up + assert.Equal(t, int64(1), partitionsWithLag) +} + +func TestCosmosDBLeaseParsingEPKBasedJava(t *testing.T) { + // Java SDK EPK-based leases (version=1) may use Base64-encoded ContinuationTokens. + // The scaler passes ContinuationToken as-is to If-None-Match, and Cosmos DB + // recognizes its own tokens regardless of encoding. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[ + { + "id": "java-epk-lease-0", + "version": 1, + "LeaseToken": "0", + "ContinuationToken": "eyJWIjoiMiIsIlJpZCI6ImFiYz0iLCJDb250aW51YXRpb24iOlt7InRva2VuIjoiXCI1MDBcIiIsInJhbmdlIjp7Im1pbiI6IiIsIm1heCI6IkZGIn19XX0=", + "Owner": "java-host1", + "feedRange": {"min": "", "max": "FF"} + }, + { + "id": "java-epk-lease-1", + "version": 1, + "LeaseToken": "1", + "ContinuationToken": "eyJWIjoiMiIsIlJpZCI6ImRlZj0iLCJDb250aW51YXRpb24iOlt7InRva2VuIjoiXCIyMDBcIiIsInJhbmdlIjp7Im1pbiI6IkZGIiwibWF4IjoiRkZGRiJ9fV19", + "Owner": "java-host2", + "feedRange": {"min": "FF", "max": "FFFF"} + } + ]}`)) + case "/dbs/testdb/colls/data/docs": + // Simulate Cosmos DB accepting Base64 continuation tokens and returning results + w.Header().Set("x-ms-session-token", "0:0#600") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[{"id":"doc1","_lsn":501}]}`)) + } + })) + defer server.Close() + + client := &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "data", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + } + + partitionsWithLag, err := client.estimateLag(context.Background()) + assert.NoError(t, err) + // Both partitions have lag; Base64 tokens are passed through to the server + assert.Equal(t, int64(2), partitionsWithLag) +} + +func TestCosmosDBLagEstimation(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + response := map[string]interface{}{ + "Documents": []map[string]interface{}{ + { + "id": "lease1", + "LeaseToken": "0", + "ContinuationToken": `"1000"`, + "Owner": "testowner", + }, + { + "id": "lease2", + "LeaseToken": "1", + "ContinuationToken": `"2000"`, + "Owner": "testowner", + }, + { + // Metadata doc - should be filtered out + "id": "metadata", + "Owner": "metadata", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + case "/dbs/testdb/colls/testcontainer/docs": + pkRangeID := r.Header.Get("x-ms-documentdb-partitionkeyrangeid") + + switch pkRangeID { + case "0": + // Partition with lag: sessionLSN=1100, itemLSN=1050, lag=51 + w.Header().Set("x-ms-session-token", "0:0#1100") + w.Header().Set("Content-Type", "application/json") + response := map[string]interface{}{ + "Documents": []map[string]interface{}{ + {"id": "item1", "_lsn": 1050}, + }, + } + _ = json.NewEncoder(w).Encode(response) + default: + // Partition without lag (304 Not Modified) + w.Header().Set("x-ms-session-token", "1:0#2000") + w.WriteHeader(http.StatusNotModified) + } + } + })) + defer server.Close() + + client := &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "testcontainer", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + } + + partitionsWithLag, err := client.estimateLag(context.Background()) + assert.NoError(t, err) + assert.Equal(t, int64(1), partitionsWithLag) +} + +func TestCosmosDBLagEstimationEmptyLeases(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := map[string]interface{}{ + "Documents": []map[string]interface{}{}, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client := &cosmosDBClient{ + httpClient: &http.Client{}, + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + } + + partitionsWithLag, err := client.estimateLag(context.Background()) + assert.NoError(t, err) + assert.Equal(t, int64(0), partitionsWithLag) +} + +func TestCosmosDBLagEstimationAllPartitionsLagging(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + response := map[string]interface{}{ + "Documents": []map[string]interface{}{ + {"id": "lease1", "LeaseToken": "0", "ContinuationToken": `"100"`, "Owner": "owner1"}, + {"id": "lease2", "LeaseToken": "1", "ContinuationToken": `"200"`, "Owner": "owner2"}, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + case "/dbs/testdb/colls/testcontainer/docs": + // Both partitions have lag + w.Header().Set("x-ms-session-token", "0:0#500") + w.Header().Set("Content-Type", "application/json") + response := map[string]interface{}{ + "Documents": []map[string]interface{}{ + {"id": "item1", "_lsn": 400}, + }, + } + _ = json.NewEncoder(w).Encode(response) + } + })) + defer server.Close() + + client := &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "testcontainer", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + } + + partitionsWithLag, err := client.estimateLag(context.Background()) + assert.NoError(t, err) + assert.Equal(t, int64(2), partitionsWithLag) +} + +func TestCosmosDBLagEstimationPartitionSplit(t *testing.T) { + changeFeedCallCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + response := map[string]interface{}{ + "Documents": []map[string]interface{}{ + {"id": "lease1", "LeaseToken": "0", "ContinuationToken": `"100"`, "Owner": "owner1"}, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + case "/dbs/testdb/colls/testcontainer/docs": + changeFeedCallCount++ + if changeFeedCallCount <= 1 { + // First call returns 410 Gone (partition split) + w.WriteHeader(http.StatusGone) + } else { + // Retry returns caught up + w.Header().Set("x-ms-session-token", "0:0#100") + w.WriteHeader(http.StatusNotModified) + } + } + })) + defer server.Close() + + client := &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "testcontainer", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + } + + partitionsWithLag, err := client.estimateLag(context.Background()) + assert.NoError(t, err) + assert.Equal(t, int64(0), partitionsWithLag) + // Should have retried: lease query + change feed (410) + lease query (retry) + change feed (304) + assert.GreaterOrEqual(t, changeFeedCallCount, 2) +} + +func TestCosmosDBGetMetricsAndActivity(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + response := map[string]interface{}{ + "Documents": []map[string]interface{}{ + {"id": "lease1", "LeaseToken": "0", "ContinuationToken": `"100"`, "Owner": "owner1"}, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + case "/dbs/testdb/colls/testcontainer/docs": + w.Header().Set("x-ms-session-token", "0:0#200") + w.Header().Set("Content-Type", "application/json") + response := map[string]interface{}{ + "Documents": []map[string]interface{}{ + {"id": "item1", "_lsn": 150}, + }, + } + _ = json.NewEncoder(w).Encode(response) + } + })) + defer server.Close() + + scaler := &azureCosmosDBScaler{ + metricType: v2.AverageValueMetricType, + metadata: &azureCosmosDBMetadata{ + DatabaseID: "testdb", + ContainerID: "testcontainer", + LeaseDatabaseID: "testdb", + LeaseContainerID: "leases", + ProcessorName: "testprocessor", + Threshold: 1, + ActivationThreshold: 0, + }, + cosmosClient: &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "testcontainer", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + }, + logger: logr.Discard(), + } + + metrics, isActive, err := scaler.GetMetricsAndActivity(context.Background(), "test-metric") + assert.NoError(t, err) + assert.True(t, isActive) + assert.Len(t, metrics, 1) +} + +func TestCosmosDBGetMetricsAndActivityNotActive(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + response := map[string]interface{}{ + "Documents": []map[string]interface{}{ + {"id": "lease1", "LeaseToken": "0", "ContinuationToken": `"100"`, "Owner": "owner1"}, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + case "/dbs/testdb/colls/testcontainer/docs": + // Caught up + w.Header().Set("x-ms-session-token", "0:0#100") + w.WriteHeader(http.StatusNotModified) + } + })) + defer server.Close() + + scaler := &azureCosmosDBScaler{ + metricType: v2.AverageValueMetricType, + metadata: &azureCosmosDBMetadata{ + DatabaseID: "testdb", + ContainerID: "testcontainer", + LeaseDatabaseID: "testdb", + LeaseContainerID: "leases", + ProcessorName: "testprocessor", + Threshold: 1, + ActivationThreshold: 0, + }, + cosmosClient: &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "testcontainer", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + }, + logger: logr.Discard(), + } + + metrics, isActive, err := scaler.GetMetricsAndActivity(context.Background(), "test-metric") + assert.NoError(t, err) + assert.False(t, isActive) + assert.Len(t, metrics, 1) +} diff --git a/pkg/scaling/scalers_builder.go b/pkg/scaling/scalers_builder.go index e1146caa3d6..232de267cf8 100644 --- a/pkg/scaling/scalers_builder.go +++ b/pkg/scaling/scalers_builder.go @@ -145,6 +145,8 @@ func buildScaler(ctx context.Context, client client.Client, triggerType string, return scalers.NewAzureAppInsightsScaler(config) case "azure-blob": return scalers.NewAzureBlobScaler(config) + case "azure-cosmosdb": + return scalers.NewAzureCosmosDBScaler(config) case "azure-data-explorer": return scalers.NewAzureDataExplorerScaler(config) case "azure-eventhub": diff --git a/schema/generated/scalers-schema.json b/schema/generated/scalers-schema.json index dcfccdcd4a2..ba10008f2d8 100644 --- a/schema/generated/scalers-schema.json +++ b/schema/generated/scalers-schema.json @@ -756,6 +756,92 @@ } ] }, + { + "type": "azure-cosmosdb", + "parameters": [ + { + "name": "databaseId", + "type": "string", + "metadataVariableReadable": true + }, + { + "name": "containerId", + "type": "string", + "metadataVariableReadable": true + }, + { + "name": "leaseDatabaseId", + "type": "string", + "metadataVariableReadable": true + }, + { + "name": "leaseContainerId", + "type": "string", + "metadataVariableReadable": true + }, + { + "name": "processorName", + "type": "string", + "metadataVariableReadable": true + }, + { + "name": "endpoint", + "type": "string", + "optional": true, + "metadataVariableReadable": true, + "triggerAuthenticationVariableReadable": true + }, + { + "name": "connection", + "type": "string", + "optional": true, + "metadataVariableReadable": true, + "envVariableReadable": true, + "triggerAuthenticationVariableReadable": true + }, + { + "name": "leaseEndpoint", + "type": "string", + "optional": true, + "metadataVariableReadable": true, + "triggerAuthenticationVariableReadable": true + }, + { + "name": "leaseConnection", + "type": "string", + "optional": true, + "metadataVariableReadable": true, + "envVariableReadable": true, + "triggerAuthenticationVariableReadable": true + }, + { + "name": "cosmosDBKey", + "type": "string", + "optional": true, + "envVariableReadable": true, + "triggerAuthenticationVariableReadable": true + }, + { + "name": "leaseCosmosDBKey", + "type": "string", + "optional": true, + "envVariableReadable": true, + "triggerAuthenticationVariableReadable": true + }, + { + "name": "lagThreshold", + "type": "string", + "default": "1", + "metadataVariableReadable": true + }, + { + "name": "activationLagThreshold", + "type": "string", + "default": "0", + "metadataVariableReadable": true + } + ] + }, { "type": "azure-eventhub", "parameters": [ diff --git a/schema/generated/scalers-schema.yaml b/schema/generated/scalers-schema.yaml index 79ffb5f0e17..31eb41b1beb 100644 --- a/schema/generated/scalers-schema.yaml +++ b/schema/generated/scalers-schema.yaml @@ -494,6 +494,63 @@ scalers: type: string optional: true metadataVariableReadable: true + - type: azure-cosmosdb + parameters: + - name: databaseId + type: string + metadataVariableReadable: true + - name: containerId + type: string + metadataVariableReadable: true + - name: leaseDatabaseId + type: string + metadataVariableReadable: true + - name: leaseContainerId + type: string + metadataVariableReadable: true + - name: processorName + type: string + metadataVariableReadable: true + - name: endpoint + type: string + optional: true + metadataVariableReadable: true + triggerAuthenticationVariableReadable: true + - name: connection + type: string + optional: true + metadataVariableReadable: true + envVariableReadable: true + triggerAuthenticationVariableReadable: true + - name: leaseEndpoint + type: string + optional: true + metadataVariableReadable: true + triggerAuthenticationVariableReadable: true + - name: leaseConnection + type: string + optional: true + metadataVariableReadable: true + envVariableReadable: true + triggerAuthenticationVariableReadable: true + - name: cosmosDBKey + type: string + optional: true + envVariableReadable: true + triggerAuthenticationVariableReadable: true + - name: leaseCosmosDBKey + type: string + optional: true + envVariableReadable: true + triggerAuthenticationVariableReadable: true + - name: lagThreshold + type: string + default: "1" + metadataVariableReadable: true + - name: activationLagThreshold + type: string + default: "0" + metadataVariableReadable: true - type: azure-eventhub parameters: - name: unprocessedEventThreshold diff --git a/tests/.env b/tests/.env index efab23b86c5..d6f117770d7 100644 --- a/tests/.env +++ b/tests/.env @@ -9,6 +9,7 @@ TF_AZURE_APP_INSIGHTS_INSTRUMENTATION_KEY= TF_AZURE_APP_INSIGHTS_NAME= TF_AZURE_DATA_EXPLORER_DB= TF_AZURE_DATA_EXPLORER_ENDPOINT= +TF_AZURE_COSMOSDB_CONNECTION_STRING= AZURE_DEVOPS_BUILD_DEFINITION_ID= AZURE_DEVOPS_DEMAND_PARENT_BUILD_DEFINITION_ID= AZURE_DEVOPS_ORGANIZATION_URL= diff --git a/tests/scalers/azure/azure_cosmosdb/azure_cosmosdb_test.go b/tests/scalers/azure/azure_cosmosdb/azure_cosmosdb_test.go new file mode 100644 index 00000000000..fc41bc8f802 --- /dev/null +++ b/tests/scalers/azure/azure_cosmosdb/azure_cosmosdb_test.go @@ -0,0 +1,216 @@ +//go:build e2e +// +build e2e + +package azure_cosmosdb_test + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "os" + "testing" + + "github.com/joho/godotenv" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/client-go/kubernetes" + + . "github.com/kedacore/keda/v2/tests/helper" +) + +// Load environment variables from .env file +var _ = godotenv.Load("../../../.env") + +const ( + testName = "azure-cosmosdb-test" +) + +var ( + connectionString = os.Getenv("TF_AZURE_COSMOSDB_CONNECTION_STRING") + testNamespace = fmt.Sprintf("%s-ns", testName) + secretName = fmt.Sprintf("%s-secret", testName) + deploymentName = fmt.Sprintf("%s-deployment", testName) + scaledObjectName = fmt.Sprintf("%s-so", testName) + databaseID = "keda-test-db" + containerID = "keda-test-container" + leaseDatabaseID = "keda-test-db" + leaseContainerID = "keda-test-leases" + processorName = "keda-test-processor" +) + +type templateData struct { + TestNamespace string + SecretName string + Connection string + DeploymentName string + ScaledObjectName string + DatabaseID string + ContainerID string + LeaseDatabaseID string + LeaseContainerID string + ProcessorName string +} + +const ( + secretTemplate = ` +apiVersion: v1 +kind: Secret +metadata: + name: {{.SecretName}} + namespace: {{.TestNamespace}} +data: + connection: {{.Connection}} +` + + triggerAuthTemplate = ` +apiVersion: keda.sh/v1alpha1 +kind: TriggerAuthentication +metadata: + name: {{.SecretName}}-trigger-auth + namespace: {{.TestNamespace}} +spec: + secretTargetRef: + - parameter: connection + name: {{.SecretName}} + key: connection +` + + deploymentTemplate = ` +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{.DeploymentName}} + namespace: {{.TestNamespace}} + labels: + app: {{.DeploymentName}} +spec: + replicas: 0 + selector: + matchLabels: + app: {{.DeploymentName}} + template: + metadata: + labels: + app: {{.DeploymentName}} + spec: + containers: + - name: {{.DeploymentName}} + image: ghcr.io/kedacore/tests-azure-cosmosdb + env: + - name: COSMOS_CONNECTION + valueFrom: + secretKeyRef: + name: {{.SecretName}} + key: connection +` + + scaledObjectTemplate = ` +apiVersion: keda.sh/v1alpha1 +kind: ScaledObject +metadata: + name: {{.ScaledObjectName}} + namespace: {{.TestNamespace}} +spec: + scaleTargetRef: + name: {{.DeploymentName}} + pollingInterval: 5 + minReplicaCount: 0 + maxReplicaCount: 1 + cooldownPeriod: 10 + triggers: + - type: azure-cosmosdb + metadata: + databaseId: {{.DatabaseID}} + containerId: {{.ContainerID}} + leaseDatabaseId: {{.LeaseDatabaseID}} + leaseContainerId: {{.LeaseContainerID}} + processorName: {{.ProcessorName}} + connectionFromEnv: COSMOS_CONNECTION + activationLagThreshold: "0" + authenticationRef: + name: {{.SecretName}}-trigger-auth +` +) + +func TestScaler(t *testing.T) { + // setup + ctx := context.Background() + t.Log("--- setting up ---") + require.NotEmpty(t, connectionString, "TF_AZURE_COSMOSDB_CONNECTION_STRING env variable is required for azure cosmosdb test") + + // Create kubernetes resources + kc := GetKubernetesClient(t) + data, templates := getTemplateData() + + CreateKubernetesResources(t, kc, testNamespace, data, templates) + + assert.True(t, WaitForDeploymentReplicaReadyCount(t, kc, deploymentName, testNamespace, 0, 60, 1), + "replica count should be 0 after 1 minute") + + // test scaling + testActivation(t, kc) + testScaleOut(ctx, t, kc) + testScaleIn(t, kc) + + // cleanup + DeleteKubernetesResources(t, testNamespace, data, templates) +} + +func getTemplateData() (templateData, []Template) { + base64ConnectionString := base64.StdEncoding.EncodeToString([]byte(connectionString)) + + return templateData{ + TestNamespace: testNamespace, + SecretName: secretName, + Connection: base64ConnectionString, + DeploymentName: deploymentName, + ScaledObjectName: scaledObjectName, + DatabaseID: databaseID, + ContainerID: containerID, + LeaseDatabaseID: leaseDatabaseID, + LeaseContainerID: leaseContainerID, + ProcessorName: processorName, + }, []Template{ + {Name: "secretTemplate", Config: secretTemplate}, + {Name: "triggerAuthTemplate", Config: triggerAuthTemplate}, + {Name: "deploymentTemplate", Config: deploymentTemplate}, + {Name: "scaledObjectTemplate", Config: scaledObjectTemplate}, + } +} + +func testActivation(t *testing.T, kc *kubernetes.Clientset) { + t.Log("--- testing activation ---") + // With no documents being processed, the change feed lag should be 0 + // and the deployment should not scale + AssertReplicaCountNotChangeDuringTimePeriod(t, kc, deploymentName, testNamespace, 0, 60) +} + +func testScaleOut(ctx context.Context, t *testing.T, kc *kubernetes.Clientset) { + t.Log("--- testing scale out ---") + // Insert documents to create change feed lag + addDocuments(ctx, t, 10) + + assert.True(t, WaitForDeploymentReplicaReadyCount(t, kc, deploymentName, testNamespace, 1, 60, 1), + "replica count should be 1 after 1 minute") +} + +func testScaleIn(t *testing.T, kc *kubernetes.Clientset) { + t.Log("--- testing scale in ---") + // After processing completes, lag returns to 0 and deployment scales down + assert.True(t, WaitForDeploymentReplicaReadyCount(t, kc, deploymentName, testNamespace, 0, 60, 1), + "replica count should be 0 after 1 minute") +} + +func addDocuments(_ context.Context, t *testing.T, count int) { + t.Helper() + for i := 0; i < count; i++ { + doc := map[string]interface{}{ + "id": fmt.Sprintf("test-doc-%d-%d", GetRandomNumber(), i), + "message": fmt.Sprintf("Test document %d", i), + } + docBytes, err := json.Marshal(doc) + assert.NoErrorf(t, err, "cannot marshal document - %s", err) + t.Logf("Document prepared: %s", string(docBytes)) + } +} From c79b9e738ed4fcc236e0a72499841cb0091ef158 Mon Sep 17 00:00:00 2001 From: Yash Trivedi Date: Wed, 18 Mar 2026 09:46:04 -0700 Subject: [PATCH 02/12] fix: complete E2E test with actual Cosmos DB document insertion Implement addDocuments using Cosmos DB REST API with HMAC-SHA256 auth. Fix changelog format, golangci-lint errcheck/staticcheck/unparam issues. Signed-off-by: Yash Trivedi Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Yash Trivedi --- .../azure_cosmosdb/azure_cosmosdb_test.go | 84 ++++++++++++++++--- 1 file changed, 72 insertions(+), 12 deletions(-) diff --git a/tests/scalers/azure/azure_cosmosdb/azure_cosmosdb_test.go b/tests/scalers/azure/azure_cosmosdb/azure_cosmosdb_test.go index fc41bc8f802..105cf113520 100644 --- a/tests/scalers/azure/azure_cosmosdb/azure_cosmosdb_test.go +++ b/tests/scalers/azure/azure_cosmosdb/azure_cosmosdb_test.go @@ -5,11 +5,17 @@ package azure_cosmosdb_test import ( "context" + "crypto/hmac" + "crypto/sha256" "encoding/base64" - "encoding/json" "fmt" + "io" + "net/http" + "net/url" "os" + "strings" "testing" + "time" "github.com/joho/godotenv" "github.com/stretchr/testify/assert" @@ -181,14 +187,11 @@ func getTemplateData() (templateData, []Template) { func testActivation(t *testing.T, kc *kubernetes.Clientset) { t.Log("--- testing activation ---") - // With no documents being processed, the change feed lag should be 0 - // and the deployment should not scale AssertReplicaCountNotChangeDuringTimePeriod(t, kc, deploymentName, testNamespace, 0, 60) } func testScaleOut(ctx context.Context, t *testing.T, kc *kubernetes.Clientset) { t.Log("--- testing scale out ---") - // Insert documents to create change feed lag addDocuments(ctx, t, 10) assert.True(t, WaitForDeploymentReplicaReadyCount(t, kc, deploymentName, testNamespace, 1, 60, 1), @@ -197,20 +200,77 @@ func testScaleOut(ctx context.Context, t *testing.T, kc *kubernetes.Clientset) { func testScaleIn(t *testing.T, kc *kubernetes.Clientset) { t.Log("--- testing scale in ---") - // After processing completes, lag returns to 0 and deployment scales down assert.True(t, WaitForDeploymentReplicaReadyCount(t, kc, deploymentName, testNamespace, 0, 60, 1), "replica count should be 0 after 1 minute") } -func addDocuments(_ context.Context, t *testing.T, count int) { +// addDocuments inserts documents into the Cosmos DB data container via the REST API +// to generate change feed lag for the scaler to detect. +func addDocuments(ctx context.Context, t *testing.T, count int) { t.Helper() + + endpoint, key, err := parseConnString(connectionString) + require.NoErrorf(t, err, "cannot parse connection string - %s", err) + for i := 0; i < count; i++ { - doc := map[string]interface{}{ - "id": fmt.Sprintf("test-doc-%d-%d", GetRandomNumber(), i), - "message": fmt.Sprintf("Test document %d", i), + docID := fmt.Sprintf("test-doc-%d-%d", GetRandomNumber(), i) + body := fmt.Sprintf(`{"id":"%s","partitionKey":"%s","message":"Test document %d"}`, docID, docID, i) + + resourceLink := fmt.Sprintf("dbs/%s/colls/%s", databaseID, containerID) + reqURL := fmt.Sprintf("%s/%s/docs", strings.TrimRight(endpoint, "/"), resourceLink) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, strings.NewReader(body)) + require.NoErrorf(t, err, "cannot create request - %s", err) + + now := time.Now().UTC().Format(http.TimeFormat) + req.Header.Set("Authorization", cosmosAuthToken("post", "docs", resourceLink, now, key)) + req.Header.Set("x-ms-date", now) + req.Header.Set("x-ms-version", "2018-12-31") + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-ms-documentdb-partitionkey", fmt.Sprintf(`["%s"]`, docID)) + + resp, err := http.DefaultClient.Do(req) + require.NoErrorf(t, err, "cannot send request - %s", err) + + respBody, _ := io.ReadAll(resp.Body) + resp.Body.Close() + + require.Truef(t, resp.StatusCode == http.StatusCreated || resp.StatusCode == http.StatusOK, + "unexpected status %d creating document: %s", resp.StatusCode, string(respBody)) + + t.Logf("Document created: %s", docID) + } +} + +func parseConnString(conn string) (string, string, error) { + var endpoint, key string + for _, part := range strings.Split(conn, ";") { + part = strings.TrimSpace(part) + switch { + case strings.HasPrefix(part, "AccountEndpoint="): + endpoint = strings.TrimPrefix(part, "AccountEndpoint=") + case strings.HasPrefix(part, "AccountKey="): + key = strings.TrimPrefix(part, "AccountKey=") } - docBytes, err := json.Marshal(doc) - assert.NoErrorf(t, err, "cannot marshal document - %s", err) - t.Logf("Document prepared: %s", string(docBytes)) } + if endpoint == "" || key == "" { + return "", "", fmt.Errorf("invalid connection string: missing AccountEndpoint or AccountKey") + } + return endpoint, key, nil +} + +func cosmosAuthToken(verb, resourceType, resourceLink, date, key string) string { + keyBytes, err := base64.StdEncoding.DecodeString(key) + if err != nil { + return "" + } + text := fmt.Sprintf("%s\n%s\n%s\n%s\n\n", + strings.ToLower(verb), + strings.ToLower(resourceType), + resourceLink, + strings.ToLower(date)) + h := hmac.New(sha256.New, keyBytes) + h.Write([]byte(text)) + sig := base64.StdEncoding.EncodeToString(h.Sum(nil)) + return url.QueryEscape(fmt.Sprintf("type=master&ver=1.0&sig=%s", sig)) } From 5f5cae8f10cc36844086f551ea4d0086c034d089 Mon Sep 17 00:00:00 2001 From: Yash Trivedi Date: Thu, 19 Mar 2026 10:02:09 -0700 Subject: [PATCH 03/12] refactor: use total lag metric instead of partition count Change the scaler metric from counting partitions-with-lag to summing total estimated lag across all partitions. This matches the EventHub scaler's approach and provides better scaling behavior: - Small lag across many partitions no longer over-provisions replicas - HPA formula: replicas = ceil(totalLag / changeFeedLagThreshold) - Capped at partition count to prevent over-scaling - Renamed metadata: lagThreshold -> changeFeedLagThreshold (default: 100) - Added getChangeFeedTotalLagRelatedToPartitionAmount partition cap Signed-off-by: Yash Trivedi Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Yash Trivedi --- pkg/scalers/azure_cosmosdb_scaler.go | 60 ++++++++++------ pkg/scalers/azure_cosmosdb_scaler_test.go | 68 +++++++++---------- schema/generated/scalers-schema.json | 6 +- schema/generated/scalers-schema.yaml | 6 +- .../azure_cosmosdb/azure_cosmosdb_test.go | 2 +- 5 files changed, 81 insertions(+), 61 deletions(-) diff --git a/pkg/scalers/azure_cosmosdb_scaler.go b/pkg/scalers/azure_cosmosdb_scaler.go index 9c423b002c8..c7426464654 100644 --- a/pkg/scalers/azure_cosmosdb_scaler.go +++ b/pkg/scalers/azure_cosmosdb_scaler.go @@ -8,6 +8,7 @@ import ( "encoding/json" "fmt" "io" + "math" "net/http" "net/url" "strconv" @@ -50,8 +51,8 @@ type azureCosmosDBMetadata struct { LeaseConnection string `keda:"name=leaseConnection, order=authParams;resolvedEnv;triggerMetadata, optional"` CosmosDBKey string `keda:"name=cosmosDBKey, order=authParams;resolvedEnv, optional"` LeaseCosmosDBKey string `keda:"name=leaseCosmosDBKey, order=authParams;resolvedEnv, optional"` - Threshold int64 `keda:"name=lagThreshold, order=triggerMetadata, default=1"` - ActivationThreshold int64 `keda:"name=activationLagThreshold, order=triggerMetadata, default=0"` + Threshold int64 `keda:"name=changeFeedLagThreshold, order=triggerMetadata, default=100"` + ActivationThreshold int64 `keda:"name=activationChangeFeedLagThreshold, order=triggerMetadata, default=0"` TriggerIndex int } @@ -372,50 +373,56 @@ func (c *cosmosDBClient) readChangeFeed(ctx context.Context, partitionKeyRangeID return cfResp, nil } -// estimateLag estimates the change feed lag and returns the count of partitions with lag > 0. +// estimateLag estimates the total change feed lag across all partitions and +// returns both the total lag and partition count. // If a partition split (410 Gone) is detected, it retries once to get fresh lease data. -func (c *cosmosDBClient) estimateLag(ctx context.Context) (int64, error) { - count, splitDetected, err := c.estimateOnce(ctx) +func (c *cosmosDBClient) estimateLag(ctx context.Context) (totalLag int64, partitionCount int64, err error) { + totalLag, partitionCount, splitDetected, err := c.estimateOnce(ctx) if err != nil { - return 0, err + return 0, 0, err } if splitDetected { - count, _, err = c.estimateOnce(ctx) + totalLag, partitionCount, _, err = c.estimateOnce(ctx) if err != nil { - return 0, err + return 0, 0, err } } - return count, nil + return totalLag, partitionCount, nil } -func (c *cosmosDBClient) estimateOnce(ctx context.Context) (int64, bool, error) { +func (c *cosmosDBClient) estimateOnce(ctx context.Context) (int64, int64, bool, error) { leases, err := c.queryLeases(ctx) if err != nil { - return 0, false, fmt.Errorf("error querying leases: %w", err) + return 0, 0, false, fmt.Errorf("error querying leases: %w", err) } if len(leases) == 0 { - return 0, false, nil + return 0, 0, false, nil } - partitionsWithLag := int64(0) + totalLag := int64(0) splitDetected := false for _, lease := range leases { lag, isSplit, err := c.estimatePartitionLag(ctx, lease) if err != nil { - return 0, false, fmt.Errorf("error estimating lag for partition %s: %w", lease.LeaseToken, err) + return 0, 0, false, fmt.Errorf("error estimating lag for partition %s: %w", lease.LeaseToken, err) } if isSplit { splitDetected = true continue } if lag > 0 { - partitionsWithLag++ + totalLag += lag } } - return partitionsWithLag, splitDetected, nil + // Cap to prevent int64 overflow from summing across many partitions + if totalLag < 0 { + totalLag = math.MaxInt64 + } + + return totalLag, int64(len(leases)), splitDetected, nil } // estimatePartitionLag calculates the lag for a single partition. @@ -518,18 +525,31 @@ func (s *azureCosmosDBScaler) GetMetricSpecForScaling(context.Context) []v2.Metr return []v2.MetricSpec{metricSpec} } +// getChangeFeedTotalLagRelatedToPartitionAmount caps the total lag to prevent scaling beyond +// the number of partitions. This matches the EventHub scaler's approach. +func getChangeFeedTotalLagRelatedToPartitionAmount(totalLag int64, partitionCount int64, threshold int64) int64 { + if threshold > 0 && (totalLag/threshold) > partitionCount { + return partitionCount * threshold + } + return totalLag +} + // GetMetricsAndActivity returns the metric value and activity status. func (s *azureCosmosDBScaler) GetMetricsAndActivity(ctx context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) { - partitionsWithLag, err := s.cosmosClient.estimateLag(ctx) + totalLag, partitionCount, err := s.cosmosClient.estimateLag(ctx) if err != nil { s.logger.Error(err, "error getting cosmos db change feed lag") return []external_metrics.ExternalMetricValue{}, false, err } - s.logger.V(1).Info(fmt.Sprintf("Cosmos DB partitions with lag: %d", partitionsWithLag)) + // Don't scale out beyond the number of partitions + lagRelatedToPartitionCount := getChangeFeedTotalLagRelatedToPartitionAmount(totalLag, partitionCount, s.metadata.Threshold) + + s.logger.V(1).Info(fmt.Sprintf("Cosmos DB change feed total lag: %d, scaling for a lag of %d related to %d partitions", + totalLag, lagRelatedToPartitionCount, partitionCount)) - metric := GenerateMetricInMili(metricName, float64(partitionsWithLag)) - return []external_metrics.ExternalMetricValue{metric}, partitionsWithLag > s.metadata.ActivationThreshold, nil + metric := GenerateMetricInMili(metricName, float64(lagRelatedToPartitionCount)) + return []external_metrics.ExternalMetricValue{metric}, totalLag > s.metadata.ActivationThreshold, nil } // Close cleans up the scaler resources. diff --git a/pkg/scalers/azure_cosmosdb_scaler_test.go b/pkg/scalers/azure_cosmosdb_scaler_test.go index 23ae0a3b9e9..8e0c25a95be 100644 --- a/pkg/scalers/azure_cosmosdb_scaler_test.go +++ b/pkg/scalers/azure_cosmosdb_scaler_test.go @@ -224,15 +224,15 @@ var testCosmosDBMetadata = []parseCosmosDBMetadataTestData{ podIdentity: "", }, { - name: "invalid lagThreshold", + name: "invalid changeFeedLagThreshold", metadata: map[string]string{ - "connectionFromEnv": "COSMOS_CONNECTION", - "databaseId": "testdb", - "containerId": "testcontainer", - "leaseDatabaseId": "testdb", - "leaseContainerId": "leases", - "processorName": "testprocessor", - "lagThreshold": "invalid", + "connectionFromEnv": "COSMOS_CONNECTION", + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + "changeFeedLagThreshold": "invalid", }, isError: true, resolvedEnv: testCosmosDBResolvedEnv, @@ -240,15 +240,15 @@ var testCosmosDBMetadata = []parseCosmosDBMetadataTestData{ podIdentity: "", }, { - name: "invalid activationLagThreshold", + name: "invalid activationChangeFeedLagThreshold", metadata: map[string]string{ - "connectionFromEnv": "COSMOS_CONNECTION", - "databaseId": "testdb", - "containerId": "testcontainer", - "leaseDatabaseId": "testdb", - "leaseContainerId": "leases", - "processorName": "testprocessor", - "activationLagThreshold": "invalid", + "connectionFromEnv": "COSMOS_CONNECTION", + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + "activationChangeFeedLagThreshold": "invalid", }, isError: true, resolvedEnv: testCosmosDBResolvedEnv, @@ -526,10 +526,10 @@ func TestCosmosDBLeaseParsingDotNetFormat(t *testing.T) { leaseContainerID: "leases", } - partitionsWithLag, err := client.estimateLag(context.Background()) + totalLag, _, err := client.estimateLag(context.Background()) assert.NoError(t, err) // Only partition 6 has lag; partition 3 is caught up; metadata doc is filtered - assert.Equal(t, int64(1), partitionsWithLag) + assert.Equal(t, int64(89), totalLag) } func TestCosmosDBLeaseParsingJavaFormat(t *testing.T) { @@ -592,10 +592,10 @@ func TestCosmosDBLeaseParsingJavaFormat(t *testing.T) { leaseContainerID: "leases", } - partitionsWithLag, err := client.estimateLag(context.Background()) + totalLag, _, err := client.estimateLag(context.Background()) assert.NoError(t, err) // Both partitions 2 and 5 have lag; lock doc is filtered out - assert.Equal(t, int64(2), partitionsWithLag) + assert.Equal(t, int64(252), totalLag) } func TestCosmosDBLeaseParsingMixedFormats(t *testing.T) { @@ -643,9 +643,9 @@ func TestCosmosDBLeaseParsingMixedFormats(t *testing.T) { leaseContainerID: "leases", } - partitionsWithLag, err := client.estimateLag(context.Background()) + totalLag, _, err := client.estimateLag(context.Background()) assert.NoError(t, err) - assert.Equal(t, int64(2), partitionsWithLag) + assert.Equal(t, int64(302), totalLag) } func TestCosmosDBLeaseParsingEPKBasedDotNet(t *testing.T) { @@ -702,10 +702,10 @@ func TestCosmosDBLeaseParsingEPKBasedDotNet(t *testing.T) { leaseContainerID: "leases", } - partitionsWithLag, err := client.estimateLag(context.Background()) + totalLag, _, err := client.estimateLag(context.Background()) assert.NoError(t, err) // Partition 0 has lag (900-751+1=150), partition 1 is caught up - assert.Equal(t, int64(1), partitionsWithLag) + assert.Equal(t, int64(150), totalLag) } func TestCosmosDBLeaseParsingEPKBasedJava(t *testing.T) { @@ -755,10 +755,10 @@ func TestCosmosDBLeaseParsingEPKBasedJava(t *testing.T) { leaseContainerID: "leases", } - partitionsWithLag, err := client.estimateLag(context.Background()) + totalLag, _, err := client.estimateLag(context.Background()) assert.NoError(t, err) // Both partitions have lag; Base64 tokens are passed through to the server - assert.Equal(t, int64(2), partitionsWithLag) + assert.Equal(t, int64(200), totalLag) } func TestCosmosDBLagEstimation(t *testing.T) { @@ -823,9 +823,9 @@ func TestCosmosDBLagEstimation(t *testing.T) { leaseContainerID: "leases", } - partitionsWithLag, err := client.estimateLag(context.Background()) + totalLag, _, err := client.estimateLag(context.Background()) assert.NoError(t, err) - assert.Equal(t, int64(1), partitionsWithLag) + assert.Equal(t, int64(51), totalLag) } func TestCosmosDBLagEstimationEmptyLeases(t *testing.T) { @@ -846,9 +846,9 @@ func TestCosmosDBLagEstimationEmptyLeases(t *testing.T) { leaseContainerID: "leases", } - partitionsWithLag, err := client.estimateLag(context.Background()) + totalLag, _, err := client.estimateLag(context.Background()) assert.NoError(t, err) - assert.Equal(t, int64(0), partitionsWithLag) + assert.Equal(t, int64(0), totalLag) } func TestCosmosDBLagEstimationAllPartitionsLagging(t *testing.T) { @@ -889,9 +889,9 @@ func TestCosmosDBLagEstimationAllPartitionsLagging(t *testing.T) { leaseContainerID: "leases", } - partitionsWithLag, err := client.estimateLag(context.Background()) + totalLag, _, err := client.estimateLag(context.Background()) assert.NoError(t, err) - assert.Equal(t, int64(2), partitionsWithLag) + assert.Equal(t, int64(202), totalLag) } func TestCosmosDBLagEstimationPartitionSplit(t *testing.T) { @@ -932,9 +932,9 @@ func TestCosmosDBLagEstimationPartitionSplit(t *testing.T) { leaseContainerID: "leases", } - partitionsWithLag, err := client.estimateLag(context.Background()) + totalLag, _, err := client.estimateLag(context.Background()) assert.NoError(t, err) - assert.Equal(t, int64(0), partitionsWithLag) + assert.Equal(t, int64(0), totalLag) // Should have retried: lease query + change feed (410) + lease query (retry) + change feed (304) assert.GreaterOrEqual(t, changeFeedCallCount, 2) } diff --git a/schema/generated/scalers-schema.json b/schema/generated/scalers-schema.json index ba10008f2d8..d43ff78c10c 100644 --- a/schema/generated/scalers-schema.json +++ b/schema/generated/scalers-schema.json @@ -829,13 +829,13 @@ "triggerAuthenticationVariableReadable": true }, { - "name": "lagThreshold", + "name": "changeFeedLagThreshold", "type": "string", - "default": "1", + "default": "100", "metadataVariableReadable": true }, { - "name": "activationLagThreshold", + "name": "activationChangeFeedLagThreshold", "type": "string", "default": "0", "metadataVariableReadable": true diff --git a/schema/generated/scalers-schema.yaml b/schema/generated/scalers-schema.yaml index 31eb41b1beb..2002074f0f7 100644 --- a/schema/generated/scalers-schema.yaml +++ b/schema/generated/scalers-schema.yaml @@ -543,11 +543,11 @@ scalers: optional: true envVariableReadable: true triggerAuthenticationVariableReadable: true - - name: lagThreshold + - name: changeFeedLagThreshold type: string - default: "1" + default: "100" metadataVariableReadable: true - - name: activationLagThreshold + - name: activationChangeFeedLagThreshold type: string default: "0" metadataVariableReadable: true diff --git a/tests/scalers/azure/azure_cosmosdb/azure_cosmosdb_test.go b/tests/scalers/azure/azure_cosmosdb/azure_cosmosdb_test.go index 105cf113520..3925f1d875a 100644 --- a/tests/scalers/azure/azure_cosmosdb/azure_cosmosdb_test.go +++ b/tests/scalers/azure/azure_cosmosdb/azure_cosmosdb_test.go @@ -133,7 +133,7 @@ spec: leaseContainerId: {{.LeaseContainerID}} processorName: {{.ProcessorName}} connectionFromEnv: COSMOS_CONNECTION - activationLagThreshold: "0" + activationChangeFeedLagThreshold: "0" authenticationRef: name: {{.SecretName}}-trigger-auth ` From f49e6d65f7b1d48b14cb8df00a59b4249142e206 Mon Sep 17 00:00:00 2001 From: Yash Trivedi Date: Fri, 20 Mar 2026 11:30:14 -0700 Subject: [PATCH 04/12] feat: scale to max on error instead of returning error On failure to read lease documents or change feed, return partitions * threshold as the metric (scale to max replicas) instead of propagating the error. This ensures the system is not under-provisioned during transient failures. Caches last known partition count for use during errors. Falls back to threshold value if no partition count is cached. Signed-off-by: Yash Trivedi Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Yash Trivedi --- pkg/scalers/azure_cosmosdb_scaler.go | 27 ++++++-- pkg/scalers/azure_cosmosdb_scaler_test.go | 80 +++++++++++++++++++++++ 2 files changed, 101 insertions(+), 6 deletions(-) diff --git a/pkg/scalers/azure_cosmosdb_scaler.go b/pkg/scalers/azure_cosmosdb_scaler.go index c7426464654..d67fc97866c 100644 --- a/pkg/scalers/azure_cosmosdb_scaler.go +++ b/pkg/scalers/azure_cosmosdb_scaler.go @@ -33,10 +33,11 @@ const ( ) type azureCosmosDBScaler struct { - metricType v2.MetricTargetType - metadata *azureCosmosDBMetadata - cosmosClient *cosmosDBClient - logger logr.Logger + metricType v2.MetricTargetType + metadata *azureCosmosDBMetadata + cosmosClient *cosmosDBClient + logger logr.Logger + lastPartitionCount int64 } type azureCosmosDBMetadata struct { @@ -535,11 +536,25 @@ func getChangeFeedTotalLagRelatedToPartitionAmount(totalLag int64, partitionCoun } // GetMetricsAndActivity returns the metric value and activity status. +// On error, returns the maximum possible metric (partitions * threshold) to scale +// to max replicas, ensuring the system is not under-provisioned during failures. func (s *azureCosmosDBScaler) GetMetricsAndActivity(ctx context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) { totalLag, partitionCount, err := s.cosmosClient.estimateLag(ctx) if err != nil { - s.logger.Error(err, "error getting cosmos db change feed lag") - return []external_metrics.ExternalMetricValue{}, false, err + s.logger.Error(err, "error getting cosmos db change feed lag, scaling to max") + if s.lastPartitionCount > 0 { + maxLag := s.lastPartitionCount * s.metadata.Threshold + metric := GenerateMetricInMili(metricName, float64(maxLag)) + return []external_metrics.ExternalMetricValue{metric}, true, nil + } + // No cached partition count — fall back to threshold as a safe default + metric := GenerateMetricInMili(metricName, float64(s.metadata.Threshold)) + return []external_metrics.ExternalMetricValue{metric}, true, nil + } + + // Cache partition count for error fallback + if partitionCount > 0 { + s.lastPartitionCount = partitionCount } // Don't scale out beyond the number of partitions diff --git a/pkg/scalers/azure_cosmosdb_scaler_test.go b/pkg/scalers/azure_cosmosdb_scaler_test.go index 8e0c25a95be..c76d2d8af03 100644 --- a/pkg/scalers/azure_cosmosdb_scaler_test.go +++ b/pkg/scalers/azure_cosmosdb_scaler_test.go @@ -1043,3 +1043,83 @@ func TestCosmosDBGetMetricsAndActivityNotActive(t *testing.T) { assert.False(t, isActive) assert.Len(t, metrics, 1) } + +func TestCosmosDBGetMetricsAndActivityOnError(t *testing.T) { + // Server that returns 500 for all requests + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":"internal server error"}`)) + })) + defer server.Close() + + scaler := &azureCosmosDBScaler{ + metricType: v2.AverageValueMetricType, + metadata: &azureCosmosDBMetadata{ + DatabaseID: "testdb", + ContainerID: "testcontainer", + LeaseDatabaseID: "testdb", + LeaseContainerID: "leases", + ProcessorName: "testprocessor", + Threshold: 100, + ActivationThreshold: 0, + }, + cosmosClient: &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "testcontainer", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + }, + logger: logr.Discard(), + lastPartitionCount: 4, // Simulate cached partition count from previous successful poll + } + + // On error, should return max lag (4 * 100 = 400) and active=true + metrics, isActive, err := scaler.GetMetricsAndActivity(context.Background(), "test-metric") + assert.NoError(t, err) // No error returned — we handle it internally + assert.True(t, isActive) + assert.Len(t, metrics, 1) +} + +func TestCosmosDBGetMetricsAndActivityOnErrorNoCachedPartitions(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + scaler := &azureCosmosDBScaler{ + metricType: v2.AverageValueMetricType, + metadata: &azureCosmosDBMetadata{ + DatabaseID: "testdb", + ContainerID: "testcontainer", + LeaseDatabaseID: "testdb", + LeaseContainerID: "leases", + ProcessorName: "testprocessor", + Threshold: 100, + ActivationThreshold: 0, + }, + cosmosClient: &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "testcontainer", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + }, + logger: logr.Discard(), + lastPartitionCount: 0, // No cached partition count + } + + // On error with no cached partitions, should return threshold as fallback + metrics, isActive, err := scaler.GetMetricsAndActivity(context.Background(), "test-metric") + assert.NoError(t, err) + assert.True(t, isActive) + assert.Len(t, metrics, 1) +} From 09e5db416d981421d1cadcd7de2a8af5ff86a654 Mon Sep 17 00:00:00 2001 From: Yash Trivedi Date: Mon, 23 Mar 2026 16:14:18 -0700 Subject: [PATCH 05/12] fix: use larger fallback value when no cached partition count When error occurs with no prior successful poll (e.g. fresh operator restart with bad credentials), return 100*threshold instead of just threshold to ensure HPA scales to maxReplicaCount. Signed-off-by: Yash Trivedi Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Yash Trivedi --- pkg/scalers/azure_cosmosdb_scaler.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pkg/scalers/azure_cosmosdb_scaler.go b/pkg/scalers/azure_cosmosdb_scaler.go index d67fc97866c..e0229edad18 100644 --- a/pkg/scalers/azure_cosmosdb_scaler.go +++ b/pkg/scalers/azure_cosmosdb_scaler.go @@ -547,8 +547,9 @@ func (s *azureCosmosDBScaler) GetMetricsAndActivity(ctx context.Context, metricN metric := GenerateMetricInMili(metricName, float64(maxLag)) return []external_metrics.ExternalMetricValue{metric}, true, nil } - // No cached partition count — fall back to threshold as a safe default - metric := GenerateMetricInMili(metricName, float64(s.metadata.Threshold)) + // No cached partition count — return a large value to trigger max scaling. + // Use 100 * threshold to ensure HPA scales well beyond 1 replica. + metric := GenerateMetricInMili(metricName, float64(100*s.metadata.Threshold)) return []external_metrics.ExternalMetricValue{metric}, true, nil } From a029eba40c79c254a98a264aadb0acee68cd0ec7 Mon Sep 17 00:00:00 2001 From: Yash Trivedi Date: Tue, 24 Mar 2026 12:54:19 -0700 Subject: [PATCH 06/12] feat: add comprehensive logging throughout CosmosDB scaler - Log partition count and per-partition lag at debug level - Log partition split detection as warning - Log which error fallback path is taken (cached vs uncached) - Log fallback lag value used during errors - Log unparseable session tokens and LSN values - Log empty lease container at debug level Signed-off-by: Yash Trivedi Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Yash Trivedi --- pkg/scalers/azure_cosmosdb_scaler.go | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/pkg/scalers/azure_cosmosdb_scaler.go b/pkg/scalers/azure_cosmosdb_scaler.go index e0229edad18..c7207232da9 100644 --- a/pkg/scalers/azure_cosmosdb_scaler.go +++ b/pkg/scalers/azure_cosmosdb_scaler.go @@ -70,6 +70,7 @@ type cosmosDBClient struct { databaseID string containerID string credential azcore.TokenCredential + logger logr.Logger } type leaseDocument struct { @@ -157,6 +158,7 @@ func newCosmosDBClient(meta *azureCosmosDBMetadata, podIdentity kedav1alpha1.Aut leaseContainerID: meta.LeaseContainerID, databaseID: meta.DatabaseID, containerID: meta.ContainerID, + logger: logger, } // Resolve data endpoint and key @@ -383,6 +385,7 @@ func (c *cosmosDBClient) estimateLag(ctx context.Context) (totalLag int64, parti return 0, 0, err } if splitDetected { + c.logger.Info("Warning: partition split detected, re-reading leases") totalLag, partitionCount, _, err = c.estimateOnce(ctx) if err != nil { return 0, 0, err @@ -398,9 +401,12 @@ func (c *cosmosDBClient) estimateOnce(ctx context.Context) (int64, int64, bool, } if len(leases) == 0 { + c.logger.V(1).Info("no lease documents found in lease container") return 0, 0, false, nil } + c.logger.V(1).Info(fmt.Sprintf("found %d lease documents", len(leases))) + totalLag := int64(0) splitDetected := false @@ -410,9 +416,11 @@ func (c *cosmosDBClient) estimateOnce(ctx context.Context) (int64, int64, bool, return 0, 0, false, fmt.Errorf("error estimating lag for partition %s: %w", lease.LeaseToken, err) } if isSplit { + c.logger.Info(fmt.Sprintf("Warning: partition %s returned 410 Gone (split/merge detected)", lease.LeaseToken)) splitDetected = true continue } + c.logger.V(1).Info(fmt.Sprintf("partition %s: estimated lag = %d, owner = %s", lease.LeaseToken, lag, lease.Owner)) if lag > 0 { totalLag += lag } @@ -436,7 +444,7 @@ func (c *cosmosDBClient) estimateOnce(ctx context.Context) (int64, int64, bool, func (c *cosmosDBClient) estimatePartitionLag(ctx context.Context, lease leaseDocument) (int64, bool, error) { cfResp, err := c.readChangeFeed(ctx, lease.LeaseToken, lease.ContinuationToken) if err != nil { - return 0, false, err + return 0, false, fmt.Errorf("error reading change feed for partition %s: %w", lease.LeaseToken, err) } // 410 Gone indicates partition split or merge @@ -452,16 +460,19 @@ func (c *cosmosDBClient) estimatePartitionLag(ctx context.Context, lease leaseDo // Calculate lag: sessionLSN - firstItemLSN + 1 sessionLSN, err := parseLSNFromSessionToken(cfResp.SessionToken) if err != nil || sessionLSN < 0 { + c.logger.V(1).Info(fmt.Sprintf("partition %s: could not parse session token LSN (token: %s), assuming no lag", lease.LeaseToken, cfResp.SessionToken)) return 0, false, nil } firstItemLSN, err := extractItemLSN(cfResp.Items[0]) if err != nil || firstItemLSN < 0 { + c.logger.V(1).Info(fmt.Sprintf("partition %s: could not extract _lsn from first item, assuming no lag", lease.LeaseToken)) return 0, false, nil } lag := sessionLSN - firstItemLSN + 1 if lag < 0 { + c.logger.V(1).Info(fmt.Sprintf("partition %s: negative lag (sessionLSN=%d, firstItemLSN=%d), assuming no lag", lease.LeaseToken, sessionLSN, firstItemLSN)) return 0, false, nil } @@ -544,12 +555,16 @@ func (s *azureCosmosDBScaler) GetMetricsAndActivity(ctx context.Context, metricN s.logger.Error(err, "error getting cosmos db change feed lag, scaling to max") if s.lastPartitionCount > 0 { maxLag := s.lastPartitionCount * s.metadata.Threshold + s.logger.Info(fmt.Sprintf("Warning: using cached partition count (%d) for error fallback, reporting lag=%d", + s.lastPartitionCount, maxLag)) metric := GenerateMetricInMili(metricName, float64(maxLag)) return []external_metrics.ExternalMetricValue{metric}, true, nil } // No cached partition count — return a large value to trigger max scaling. // Use 100 * threshold to ensure HPA scales well beyond 1 replica. - metric := GenerateMetricInMili(metricName, float64(100*s.metadata.Threshold)) + fallbackLag := 100 * s.metadata.Threshold + s.logger.Info(fmt.Sprintf("Warning: no cached partition count available, using large fallback lag=%d", fallbackLag)) + metric := GenerateMetricInMili(metricName, float64(fallbackLag)) return []external_metrics.ExternalMetricValue{metric}, true, nil } From a598a8b9b5414f8065a7980e16de7e1caaaa2260 Mon Sep 17 00:00:00 2001 From: Yash Trivedi Date: Tue, 24 Mar 2026 15:10:48 -0700 Subject: [PATCH 07/12] fix: return error when no cached partition count on failure Instead of returning a large fallback value when there is no cached partition count, propagate the error to KEDA. This lets KEDA's standard error handling (keep current replicas) and optional fallback config handle the situation. Cached partition fallback still returns max lag when available. Signed-off-by: Yash Trivedi Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Yash Trivedi --- pkg/scalers/azure_cosmosdb_scaler.go | 22 +++++++++------------- pkg/scalers/azure_cosmosdb_scaler_test.go | 8 +++----- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/pkg/scalers/azure_cosmosdb_scaler.go b/pkg/scalers/azure_cosmosdb_scaler.go index c7207232da9..8e1e240e0e6 100644 --- a/pkg/scalers/azure_cosmosdb_scaler.go +++ b/pkg/scalers/azure_cosmosdb_scaler.go @@ -33,11 +33,11 @@ const ( ) type azureCosmosDBScaler struct { - metricType v2.MetricTargetType - metadata *azureCosmosDBMetadata - cosmosClient *cosmosDBClient - logger logr.Logger - lastPartitionCount int64 + metricType v2.MetricTargetType + metadata *azureCosmosDBMetadata + cosmosClient *cosmosDBClient + logger logr.Logger + lastPartitionCount int64 } type azureCosmosDBMetadata struct { @@ -552,20 +552,16 @@ func getChangeFeedTotalLagRelatedToPartitionAmount(totalLag int64, partitionCoun func (s *azureCosmosDBScaler) GetMetricsAndActivity(ctx context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) { totalLag, partitionCount, err := s.cosmosClient.estimateLag(ctx) if err != nil { - s.logger.Error(err, "error getting cosmos db change feed lag, scaling to max") if s.lastPartitionCount > 0 { maxLag := s.lastPartitionCount * s.metadata.Threshold - s.logger.Info(fmt.Sprintf("Warning: using cached partition count (%d) for error fallback, reporting lag=%d", + s.logger.Error(err, fmt.Sprintf("error getting cosmos db change feed lag, using cached partition count (%d) for fallback, reporting lag=%d", s.lastPartitionCount, maxLag)) metric := GenerateMetricInMili(metricName, float64(maxLag)) return []external_metrics.ExternalMetricValue{metric}, true, nil } - // No cached partition count — return a large value to trigger max scaling. - // Use 100 * threshold to ensure HPA scales well beyond 1 replica. - fallbackLag := 100 * s.metadata.Threshold - s.logger.Info(fmt.Sprintf("Warning: no cached partition count available, using large fallback lag=%d", fallbackLag)) - metric := GenerateMetricInMili(metricName, float64(fallbackLag)) - return []external_metrics.ExternalMetricValue{metric}, true, nil + // No cached partition count — propagate error to KEDA (standard behavior) + s.logger.Error(err, "error getting cosmos db change feed lag, no cached partition count available") + return []external_metrics.ExternalMetricValue{}, false, err } // Cache partition count for error fallback diff --git a/pkg/scalers/azure_cosmosdb_scaler_test.go b/pkg/scalers/azure_cosmosdb_scaler_test.go index c76d2d8af03..d57b6845892 100644 --- a/pkg/scalers/azure_cosmosdb_scaler_test.go +++ b/pkg/scalers/azure_cosmosdb_scaler_test.go @@ -1117,9 +1117,7 @@ func TestCosmosDBGetMetricsAndActivityOnErrorNoCachedPartitions(t *testing.T) { lastPartitionCount: 0, // No cached partition count } - // On error with no cached partitions, should return threshold as fallback - metrics, isActive, err := scaler.GetMetricsAndActivity(context.Background(), "test-metric") - assert.NoError(t, err) - assert.True(t, isActive) - assert.Len(t, metrics, 1) + // On error with no cached partitions, should return error + _, _, err := scaler.GetMetricsAndActivity(context.Background(), "test-metric") + assert.Error(t, err) } From e7d4a73084a316f8eec51ea0644a1366d0a829fe Mon Sep 17 00:00:00 2001 From: Yash Trivedi Date: Thu, 26 Mar 2026 12:25:09 -0700 Subject: [PATCH 08/12] fix: create Cosmos DB database and containers in E2E test setup Add setupCosmosDB function that creates the database, data container, and lease container via REST API before running tests. Handles 409 Conflict (already exists) gracefully. Signed-off-by: Yash Trivedi Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Yash Trivedi --- .../azure_cosmosdb/azure_cosmosdb_test.go | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/tests/scalers/azure/azure_cosmosdb/azure_cosmosdb_test.go b/tests/scalers/azure/azure_cosmosdb/azure_cosmosdb_test.go index 3925f1d875a..5621f30942a 100644 --- a/tests/scalers/azure/azure_cosmosdb/azure_cosmosdb_test.go +++ b/tests/scalers/azure/azure_cosmosdb/azure_cosmosdb_test.go @@ -145,6 +145,9 @@ func TestScaler(t *testing.T) { t.Log("--- setting up ---") require.NotEmpty(t, connectionString, "TF_AZURE_COSMOSDB_CONNECTION_STRING env variable is required for azure cosmosdb test") + // Create Cosmos DB resources (database + containers) + setupCosmosDB(ctx, t) + // Create kubernetes resources kc := GetKubernetesClient(t) data, templates := getTemplateData() @@ -274,3 +277,56 @@ func cosmosAuthToken(verb, resourceType, resourceLink, date, key string) string sig := base64.StdEncoding.EncodeToString(h.Sum(nil)) return url.QueryEscape(fmt.Sprintf("type=master&ver=1.0&sig=%s", sig)) } + +// setupCosmosDB creates the database, data container, and lease container if they don't exist. +func setupCosmosDB(ctx context.Context, t *testing.T) { + t.Helper() + + endpoint, key, err := parseConnString(connectionString) + require.NoErrorf(t, err, "cannot parse connection string - %s", err) + + // Create database + cosmosCreateResource(ctx, t, endpoint, key, "", "dbs", fmt.Sprintf(`{"id":"%s"}`, databaseID)) + + // Create data container with /id as partition key + dbLink := fmt.Sprintf("dbs/%s", databaseID) + cosmosCreateResource(ctx, t, endpoint, key, dbLink, "colls", + fmt.Sprintf(`{"id":"%s","partitionKey":{"paths":["/id"],"kind":"Hash"}}`, containerID)) + + // Create lease container with /id as partition key + cosmosCreateResource(ctx, t, endpoint, key, dbLink, "colls", + fmt.Sprintf(`{"id":"%s","partitionKey":{"paths":["/id"],"kind":"Hash"}}`, leaseContainerID)) +} + +// cosmosCreateResource creates a Cosmos DB resource via REST API, ignoring 409 Conflict (already exists). +func cosmosCreateResource(ctx context.Context, t *testing.T, endpoint, key, parentLink, resourceType, body string) { + t.Helper() + + reqURL := fmt.Sprintf("%s/%s/%s", strings.TrimRight(endpoint, "/"), parentLink, resourceType) + if parentLink == "" { + reqURL = fmt.Sprintf("%s/%s", strings.TrimRight(endpoint, "/"), resourceType) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, strings.NewReader(body)) + require.NoErrorf(t, err, "cannot create request - %s", err) + + now := time.Now().UTC().Format(http.TimeFormat) + req.Header.Set("Authorization", cosmosAuthToken("post", resourceType, parentLink, now, key)) + req.Header.Set("x-ms-date", now) + req.Header.Set("x-ms-version", "2018-12-31") + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + require.NoErrorf(t, err, "cannot send request - %s", err) + respBody, _ := io.ReadAll(resp.Body) + resp.Body.Close() + + if resp.StatusCode == http.StatusConflict { + t.Logf("Resource already exists (409), skipping: %s/%s", parentLink, resourceType) + return + } + + require.Truef(t, resp.StatusCode == http.StatusCreated || resp.StatusCode == http.StatusOK, + "unexpected status %d creating resource: %s", resp.StatusCode, string(respBody)) + t.Logf("Created resource: %s/%s", parentLink, resourceType) +} From 370699920c57883251a8b2e509a2ac061cd8afca Mon Sep 17 00:00:00 2001 From: Yash Trivedi Date: Tue, 31 Mar 2026 11:33:14 -0700 Subject: [PATCH 09/12] fix: address PR review comments for CosmosDB scaler - generateCosmosDBAuthToken returns (string, error) instead of silently returning empty string on base64 decode failure - Add url.QueryEscape to AAD bearer token Authorization header - Filter lease documents by processorName using parameterized STARTSWITH query to prevent over-counting when multiple processors share a lease container - Add processorName field to cosmosDBClient struct Signed-off-by: Yash Trivedi Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Yash Trivedi --- pkg/scalers/azure_cosmosdb_scaler.go | 17 +++++--- pkg/scalers/azure_cosmosdb_scaler_test.go | 50 ++++++++++++++++++++++- 2 files changed, 60 insertions(+), 7 deletions(-) diff --git a/pkg/scalers/azure_cosmosdb_scaler.go b/pkg/scalers/azure_cosmosdb_scaler.go index 8e1e240e0e6..4e805bea0a3 100644 --- a/pkg/scalers/azure_cosmosdb_scaler.go +++ b/pkg/scalers/azure_cosmosdb_scaler.go @@ -69,6 +69,7 @@ type cosmosDBClient struct { leaseContainerID string databaseID string containerID string + processorName string credential azcore.TokenCredential logger logr.Logger } @@ -158,6 +159,7 @@ func newCosmosDBClient(meta *azureCosmosDBMetadata, podIdentity kedav1alpha1.Aut leaseContainerID: meta.LeaseContainerID, databaseID: meta.DatabaseID, containerID: meta.ContainerID, + processorName: meta.ProcessorName, logger: logger, } @@ -226,7 +228,10 @@ func parseCosmosDBConnectionString(connectionString string) (string, string, err // setAuthHeader sets the Authorization header using either master key HMAC-SHA256 or bearer token. func (c *cosmosDBClient) setAuthHeader(req *http.Request, verb, resourceType, resourceLink, date, key string) error { if key != "" { - token := generateCosmosDBAuthToken(verb, resourceType, resourceLink, date, key) + token, err := generateCosmosDBAuthToken(verb, resourceType, resourceLink, date, key) + if err != nil { + return fmt.Errorf("error generating auth token: %w", err) + } req.Header.Set("Authorization", token) return nil } @@ -238,7 +243,7 @@ func (c *cosmosDBClient) setAuthHeader(req *http.Request, verb, resourceType, re if err != nil { return fmt.Errorf("error acquiring bearer token: %w", err) } - req.Header.Set("Authorization", "type=aad&ver=1.0&sig="+tk.Token) + req.Header.Set("Authorization", url.QueryEscape(fmt.Sprintf("type=aad&ver=1.0&sig=%s", tk.Token))) return nil } @@ -248,10 +253,10 @@ func (c *cosmosDBClient) setAuthHeader(req *http.Request, verb, resourceType, re // generateCosmosDBAuthToken generates an HMAC-SHA256 auth token for Cosmos DB REST API. // Format: type=master&ver=1.0&sig={hashsignature} // Signature input: {verb}\n{resourceType}\n{resourceLink}\n{date}\n\n -func generateCosmosDBAuthToken(verb, resourceType, resourceLink, date, key string) string { +func generateCosmosDBAuthToken(verb, resourceType, resourceLink, date, key string) (string, error) { keyBytes, err := base64.StdEncoding.DecodeString(key) if err != nil { - return "" + return "", fmt.Errorf("error decoding cosmos db key: %w", err) } text := fmt.Sprintf("%s\n%s\n%s\n%s\n\n", @@ -264,14 +269,14 @@ func generateCosmosDBAuthToken(verb, resourceType, resourceLink, date, key strin h.Write([]byte(text)) signature := base64.StdEncoding.EncodeToString(h.Sum(nil)) - return url.QueryEscape(fmt.Sprintf("type=master&ver=1.0&sig=%s", signature)) + return url.QueryEscape(fmt.Sprintf("type=master&ver=1.0&sig=%s", signature)), nil } func (c *cosmosDBClient) queryLeases(ctx context.Context) ([]leaseDocument, error) { resourceLink := fmt.Sprintf("dbs/%s/colls/%s", c.leaseDatabaseID, c.leaseContainerID) reqURL := fmt.Sprintf("%s/%s/docs", strings.TrimRight(c.leaseEndpoint, "/"), resourceLink) - body := `{"query":"SELECT * FROM c"}` + body := fmt.Sprintf(`{"query":"SELECT * FROM c WHERE STARTSWITH(c.id, @prefix)","parameters":[{"name":"@prefix","value":"%s"}]}`, c.processorName) req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, strings.NewReader(body)) if err != nil { return nil, fmt.Errorf("error creating request: %w", err) diff --git a/pkg/scalers/azure_cosmosdb_scaler_test.go b/pkg/scalers/azure_cosmosdb_scaler_test.go index d57b6845892..e3bf6b9c306 100644 --- a/pkg/scalers/azure_cosmosdb_scaler_test.go +++ b/pkg/scalers/azure_cosmosdb_scaler_test.go @@ -3,6 +3,7 @@ package scalers import ( "context" "encoding/json" + "io" "net/http" "net/http/httptest" "testing" @@ -452,10 +453,16 @@ func TestExtractItemLSN(t *testing.T) { } func TestCosmosDBAuthTokenGeneration(t *testing.T) { - token := generateCosmosDBAuthToken("get", "docs", "dbs/testdb/colls/testcol", "thu, 01 jan 2024 00:00:00 gmt", "dGVzdGtleQ==") + token, err := generateCosmosDBAuthToken("get", "docs", "dbs/testdb/colls/testcol", "thu, 01 jan 2024 00:00:00 gmt", "dGVzdGtleQ==") + assert.NoError(t, err) assert.Contains(t, token, "type%3Dmaster%26ver%3D1.0%26sig%3D") } +func TestCosmosDBAuthTokenGenerationInvalidKey(t *testing.T) { + _, err := generateCosmosDBAuthToken("get", "docs", "dbs/testdb/colls/testcol", "date", "not-valid-base64!!!") + assert.Error(t, err) +} + func TestCosmosDBLeaseParsingDotNetFormat(t *testing.T) { // Realistic .NET SDK lease documents have: version=0, FeedRange, Mode, properties fields. // The scaler must parse LeaseToken and ContinuationToken and ignore the extra fields. @@ -524,6 +531,7 @@ func TestCosmosDBLeaseParsingDotNetFormat(t *testing.T) { containerID: "data", leaseDatabaseID: "testdb", leaseContainerID: "leases", + processorName: "testprocessor", } totalLag, _, err := client.estimateLag(context.Background()) @@ -590,6 +598,7 @@ func TestCosmosDBLeaseParsingJavaFormat(t *testing.T) { containerID: "data", leaseDatabaseID: "testdb", leaseContainerID: "leases", + processorName: "testprocessor", } totalLag, _, err := client.estimateLag(context.Background()) @@ -641,6 +650,7 @@ func TestCosmosDBLeaseParsingMixedFormats(t *testing.T) { containerID: "data", leaseDatabaseID: "testdb", leaseContainerID: "leases", + processorName: "testprocessor", } totalLag, _, err := client.estimateLag(context.Background()) @@ -700,6 +710,7 @@ func TestCosmosDBLeaseParsingEPKBasedDotNet(t *testing.T) { containerID: "data", leaseDatabaseID: "testdb", leaseContainerID: "leases", + processorName: "testprocessor", } totalLag, _, err := client.estimateLag(context.Background()) @@ -753,6 +764,7 @@ func TestCosmosDBLeaseParsingEPKBasedJava(t *testing.T) { containerID: "data", leaseDatabaseID: "testdb", leaseContainerID: "leases", + processorName: "testprocessor", } totalLag, _, err := client.estimateLag(context.Background()) @@ -821,6 +833,7 @@ func TestCosmosDBLagEstimation(t *testing.T) { containerID: "testcontainer", leaseDatabaseID: "testdb", leaseContainerID: "leases", + processorName: "testprocessor", } totalLag, _, err := client.estimateLag(context.Background()) @@ -844,6 +857,7 @@ func TestCosmosDBLagEstimationEmptyLeases(t *testing.T) { leaseKey: "dGVzdGtleQ==", leaseDatabaseID: "testdb", leaseContainerID: "leases", + processorName: "testprocessor", } totalLag, _, err := client.estimateLag(context.Background()) @@ -887,6 +901,7 @@ func TestCosmosDBLagEstimationAllPartitionsLagging(t *testing.T) { containerID: "testcontainer", leaseDatabaseID: "testdb", leaseContainerID: "leases", + processorName: "testprocessor", } totalLag, _, err := client.estimateLag(context.Background()) @@ -930,6 +945,7 @@ func TestCosmosDBLagEstimationPartitionSplit(t *testing.T) { containerID: "testcontainer", leaseDatabaseID: "testdb", leaseContainerID: "leases", + processorName: "testprocessor", } totalLag, _, err := client.estimateLag(context.Background()) @@ -984,6 +1000,7 @@ func TestCosmosDBGetMetricsAndActivity(t *testing.T) { containerID: "testcontainer", leaseDatabaseID: "testdb", leaseContainerID: "leases", + processorName: "testprocessor", }, logger: logr.Discard(), } @@ -1034,6 +1051,7 @@ func TestCosmosDBGetMetricsAndActivityNotActive(t *testing.T) { containerID: "testcontainer", leaseDatabaseID: "testdb", leaseContainerID: "leases", + processorName: "testprocessor", }, logger: logr.Discard(), } @@ -1073,6 +1091,7 @@ func TestCosmosDBGetMetricsAndActivityOnError(t *testing.T) { containerID: "testcontainer", leaseDatabaseID: "testdb", leaseContainerID: "leases", + processorName: "testprocessor", }, logger: logr.Discard(), lastPartitionCount: 4, // Simulate cached partition count from previous successful poll @@ -1112,6 +1131,7 @@ func TestCosmosDBGetMetricsAndActivityOnErrorNoCachedPartitions(t *testing.T) { containerID: "testcontainer", leaseDatabaseID: "testdb", leaseContainerID: "leases", + processorName: "testprocessor", }, logger: logr.Discard(), lastPartitionCount: 0, // No cached partition count @@ -1121,3 +1141,31 @@ func TestCosmosDBGetMetricsAndActivityOnErrorNoCachedPartitions(t *testing.T) { _, _, err := scaler.GetMetricsAndActivity(context.Background(), "test-metric") assert.Error(t, err) } + +func TestCosmosDBLeaseQueryFiltersByProcessorName(t *testing.T) { + var capturedBody string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + bodyBytes, _ := io.ReadAll(r.Body) + capturedBody = string(bodyBytes) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[]}`)) + } + })) + defer server.Close() + + client := &cosmosDBClient{ + httpClient: &http.Client{}, + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + processorName: "myprocessor", + } + + _, _ = client.queryLeases(context.Background()) + assert.Contains(t, capturedBody, "STARTSWITH") + assert.Contains(t, capturedBody, "@prefix") + assert.Contains(t, capturedBody, "myprocessor") +} From c2a6a749aeb92f6ee3d4e657a6019616ef5b9cd1 Mon Sep 17 00:00:00 2001 From: Yash Trivedi Date: Tue, 28 Apr 2026 11:18:52 -0700 Subject: [PATCH 10/12] fix: resolve CosmosDB resource URL via cloud environment instead of hardcoding Replace hardcoded azure.PublicCloud.ResourceIdentifiers.CosmosDB with cloud-aware resolution using azure.ParseEnvironmentProperty, consistent with azure_eventhub_scaler.go and azure_servicebus_scaler.go. This ensures workload identity authentication uses the correct token scope for sovereign clouds (Azure China, US Gov, German) and supports Private cloud configurations via the 'cosmosDBResourceURL' metadata key. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Yash Trivedi --- pkg/scalers/azure_cosmosdb_scaler.go | 40 +++++++++++++++++----------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/pkg/scalers/azure_cosmosdb_scaler.go b/pkg/scalers/azure_cosmosdb_scaler.go index 4e805bea0a3..c878d93d99e 100644 --- a/pkg/scalers/azure_cosmosdb_scaler.go +++ b/pkg/scalers/azure_cosmosdb_scaler.go @@ -60,18 +60,19 @@ type azureCosmosDBMetadata struct { // cosmosDBClient provides low-level access to Cosmos DB via the REST API // for querying lease documents and reading the change feed. type cosmosDBClient struct { - httpClient *http.Client - dataEndpoint string - dataKey string - leaseEndpoint string - leaseKey string - leaseDatabaseID string - leaseContainerID string - databaseID string - containerID string - processorName string - credential azcore.TokenCredential - logger logr.Logger + httpClient *http.Client + dataEndpoint string + dataKey string + leaseEndpoint string + leaseKey string + leaseDatabaseID string + leaseContainerID string + databaseID string + containerID string + processorName string + cosmosDBResourceURL string + credential azcore.TokenCredential + logger logr.Logger } type leaseDocument struct { @@ -101,7 +102,7 @@ func NewAzureCosmosDBScaler(config *scalersconfig.ScalerConfig) (Scaler, error) return nil, fmt.Errorf("error parsing azure cosmos db metadata: %w", err) } - cosmosClient, err := newCosmosDBClient(meta, config.PodIdentity, logger, config.GlobalHTTPTimeout) + cosmosClient, err := newCosmosDBClient(meta, config.TriggerMetadata, config.PodIdentity, logger, config.GlobalHTTPTimeout) if err != nil { return nil, fmt.Errorf("error creating cosmos db client: %w", err) } @@ -148,7 +149,7 @@ func parseAzureCosmosDBMetadata(config *scalersconfig.ScalerConfig) (*azureCosmo return meta, nil } -func newCosmosDBClient(meta *azureCosmosDBMetadata, podIdentity kedav1alpha1.AuthPodIdentity, logger logr.Logger, httpTimeout time.Duration) (*cosmosDBClient, error) { +func newCosmosDBClient(meta *azureCosmosDBMetadata, triggerMetadata map[string]string, podIdentity kedav1alpha1.AuthPodIdentity, logger logr.Logger, httpTimeout time.Duration) (*cosmosDBClient, error) { if httpTimeout == 0 { httpTimeout = 30 * time.Second } @@ -195,6 +196,15 @@ func newCosmosDBClient(meta *azureCosmosDBMetadata, podIdentity kedav1alpha1.Aut // Set up workload identity credential for bearer token auth if podIdentity.Provider == kedav1alpha1.PodIdentityProviderAzureWorkload && client.dataKey == "" { + cosmosDBResourceURLProvider := func(env azure.AzEnvironment) (string, error) { + return env.ResourceIdentifiers.CosmosDB, nil + } + cosmosDBResourceURL, err := azure.ParseEnvironmentProperty(triggerMetadata, "cosmosDBResourceURL", cosmosDBResourceURLProvider) + if err != nil { + return nil, fmt.Errorf("error resolving cosmos db resource URL: %w", err) + } + client.cosmosDBResourceURL = cosmosDBResourceURL + cred, err := azure.NewChainedCredential(logger, podIdentity) if err != nil { return nil, fmt.Errorf("error creating azure credential for workload identity: %w", err) @@ -238,7 +248,7 @@ func (c *cosmosDBClient) setAuthHeader(req *http.Request, verb, resourceType, re if c.credential != nil { tk, err := c.credential.GetToken(req.Context(), policy.TokenRequestOptions{ - Scopes: []string{azure.PublicCloud.ResourceIdentifiers.CosmosDB + "/.default"}, + Scopes: []string{c.cosmosDBResourceURL + "/.default"}, }) if err != nil { return fmt.Errorf("error acquiring bearer token: %w", err) From 1df6bac3522e092a38512e2f7bac81ab2bc641e4 Mon Sep 17 00:00:00 2001 From: Yash Trivedi Date: Tue, 28 Apr 2026 11:29:07 -0700 Subject: [PATCH 11/12] refactor: move lease defaults to Validate() on metadata struct Follow the convention used by azureBlobMetadata.Validate() and azureServiceBusMetadata.Validate() cross-field normalization lives in Validate() which is called automatically by TypedConfig via the CustomValidator interface. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Yash Trivedi --- pkg/scalers/azure_cosmosdb_scaler.go | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/pkg/scalers/azure_cosmosdb_scaler.go b/pkg/scalers/azure_cosmosdb_scaler.go index c878d93d99e..f53fbcf62ed 100644 --- a/pkg/scalers/azure_cosmosdb_scaler.go +++ b/pkg/scalers/azure_cosmosdb_scaler.go @@ -57,6 +57,20 @@ type azureCosmosDBMetadata struct { TriggerIndex int } +func (m *azureCosmosDBMetadata) Validate() error { + // Default lease settings to data settings if not specified + if m.LeaseConnection == "" { + m.LeaseConnection = m.Connection + } + if m.LeaseEndpoint == "" { + m.LeaseEndpoint = m.Endpoint + } + if m.LeaseCosmosDBKey == "" { + m.LeaseCosmosDBKey = m.CosmosDBKey + } + return nil +} + // cosmosDBClient provides low-level access to Cosmos DB via the REST API // for querying lease documents and reading the change feed. type cosmosDBClient struct { @@ -134,17 +148,6 @@ func parseAzureCosmosDBMetadata(config *scalersconfig.ScalerConfig) (*azureCosmo return nil, fmt.Errorf("pod identity %s not supported for azure cosmos db", config.PodIdentity.Provider) } - // Default lease settings to data settings if not specified - if meta.LeaseConnection == "" { - meta.LeaseConnection = meta.Connection - } - if meta.LeaseEndpoint == "" { - meta.LeaseEndpoint = meta.Endpoint - } - if meta.LeaseCosmosDBKey == "" { - meta.LeaseCosmosDBKey = meta.CosmosDBKey - } - meta.TriggerIndex = config.TriggerIndex return meta, nil } From ecccf9f8fb0e52ae2b3e0aece5a5cf718a49cc67 Mon Sep 17 00:00:00 2001 From: Yash Trivedi Date: Tue, 28 Apr 2026 14:19:28 -0700 Subject: [PATCH 12/12] fix: address PR review - json.Marshal for processorName, return 0 on 410, propagate errors - Use json.Marshal for processorName in query body to safely handle special characters (e.g. quotes) in processor names - Return 0 instead of -1 for partition lag on 410 Gone (split/merge) - Remove custom error fallback logic in GetMetricsAndActivity; just propagate errors to let KEDA's fallback spec on ScaledObject handle it, consistent with EventHub, ServiceBus, and Blob scalers - Remove lastPartitionCount field that was only used for the removed fallback logic Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Yash Trivedi --- pkg/scalers/azure_cosmosdb_scaler.go | 35 ++++++----------- pkg/scalers/azure_cosmosdb_scaler_test.go | 47 ++--------------------- 2 files changed, 14 insertions(+), 68 deletions(-) diff --git a/pkg/scalers/azure_cosmosdb_scaler.go b/pkg/scalers/azure_cosmosdb_scaler.go index f53fbcf62ed..d440b15758c 100644 --- a/pkg/scalers/azure_cosmosdb_scaler.go +++ b/pkg/scalers/azure_cosmosdb_scaler.go @@ -33,11 +33,10 @@ const ( ) type azureCosmosDBScaler struct { - metricType v2.MetricTargetType - metadata *azureCosmosDBMetadata - cosmosClient *cosmosDBClient - logger logr.Logger - lastPartitionCount int64 + metricType v2.MetricTargetType + metadata *azureCosmosDBMetadata + cosmosClient *cosmosDBClient + logger logr.Logger } type azureCosmosDBMetadata struct { @@ -289,7 +288,11 @@ func (c *cosmosDBClient) queryLeases(ctx context.Context) ([]leaseDocument, erro resourceLink := fmt.Sprintf("dbs/%s/colls/%s", c.leaseDatabaseID, c.leaseContainerID) reqURL := fmt.Sprintf("%s/%s/docs", strings.TrimRight(c.leaseEndpoint, "/"), resourceLink) - body := fmt.Sprintf(`{"query":"SELECT * FROM c WHERE STARTSWITH(c.id, @prefix)","parameters":[{"name":"@prefix","value":"%s"}]}`, c.processorName) + prefixJSON, err := json.Marshal(c.processorName) + if err != nil { + return nil, fmt.Errorf("error marshaling processor name: %w", err) + } + body := fmt.Sprintf(`{"query":"SELECT * FROM c WHERE STARTSWITH(c.id, @prefix)","parameters":[{"name":"@prefix","value":%s}]}`, string(prefixJSON)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, strings.NewReader(body)) if err != nil { return nil, fmt.Errorf("error creating request: %w", err) @@ -467,7 +470,7 @@ func (c *cosmosDBClient) estimatePartitionLag(ctx context.Context, lease leaseDo // 410 Gone indicates partition split or merge if cfResp.StatusCode == http.StatusGone { - return -1, true, nil + return 0, true, nil } // 304 Not Modified or empty results means processor is caught up @@ -565,26 +568,10 @@ func getChangeFeedTotalLagRelatedToPartitionAmount(totalLag int64, partitionCoun } // GetMetricsAndActivity returns the metric value and activity status. -// On error, returns the maximum possible metric (partitions * threshold) to scale -// to max replicas, ensuring the system is not under-provisioned during failures. func (s *azureCosmosDBScaler) GetMetricsAndActivity(ctx context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) { totalLag, partitionCount, err := s.cosmosClient.estimateLag(ctx) if err != nil { - if s.lastPartitionCount > 0 { - maxLag := s.lastPartitionCount * s.metadata.Threshold - s.logger.Error(err, fmt.Sprintf("error getting cosmos db change feed lag, using cached partition count (%d) for fallback, reporting lag=%d", - s.lastPartitionCount, maxLag)) - metric := GenerateMetricInMili(metricName, float64(maxLag)) - return []external_metrics.ExternalMetricValue{metric}, true, nil - } - // No cached partition count — propagate error to KEDA (standard behavior) - s.logger.Error(err, "error getting cosmos db change feed lag, no cached partition count available") - return []external_metrics.ExternalMetricValue{}, false, err - } - - // Cache partition count for error fallback - if partitionCount > 0 { - s.lastPartitionCount = partitionCount + return []external_metrics.ExternalMetricValue{}, false, fmt.Errorf("error getting cosmos db change feed lag: %w", err) } // Don't scale out beyond the number of partitions diff --git a/pkg/scalers/azure_cosmosdb_scaler_test.go b/pkg/scalers/azure_cosmosdb_scaler_test.go index e3bf6b9c306..0a5df65f77d 100644 --- a/pkg/scalers/azure_cosmosdb_scaler_test.go +++ b/pkg/scalers/azure_cosmosdb_scaler_test.go @@ -1091,53 +1091,12 @@ func TestCosmosDBGetMetricsAndActivityOnError(t *testing.T) { containerID: "testcontainer", leaseDatabaseID: "testdb", leaseContainerID: "leases", - processorName: "testprocessor", + processorName: "testprocessor", }, - logger: logr.Discard(), - lastPartitionCount: 4, // Simulate cached partition count from previous successful poll - } - - // On error, should return max lag (4 * 100 = 400) and active=true - metrics, isActive, err := scaler.GetMetricsAndActivity(context.Background(), "test-metric") - assert.NoError(t, err) // No error returned — we handle it internally - assert.True(t, isActive) - assert.Len(t, metrics, 1) -} - -func TestCosmosDBGetMetricsAndActivityOnErrorNoCachedPartitions(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - })) - defer server.Close() - - scaler := &azureCosmosDBScaler{ - metricType: v2.AverageValueMetricType, - metadata: &azureCosmosDBMetadata{ - DatabaseID: "testdb", - ContainerID: "testcontainer", - LeaseDatabaseID: "testdb", - LeaseContainerID: "leases", - ProcessorName: "testprocessor", - Threshold: 100, - ActivationThreshold: 0, - }, - cosmosClient: &cosmosDBClient{ - httpClient: &http.Client{}, - dataEndpoint: server.URL, - dataKey: "dGVzdGtleQ==", - leaseEndpoint: server.URL, - leaseKey: "dGVzdGtleQ==", - databaseID: "testdb", - containerID: "testcontainer", - leaseDatabaseID: "testdb", - leaseContainerID: "leases", - processorName: "testprocessor", - }, - logger: logr.Discard(), - lastPartitionCount: 0, // No cached partition count + logger: logr.Discard(), } - // On error with no cached partitions, should return error + // On error, propagate to KEDA and let fallback spec on ScaledObject handle it _, _, err := scaler.GetMetricsAndActivity(context.Background(), "test-metric") assert.Error(t, err) }