diff --git a/CHANGELOG.md b/CHANGELOG.md index b39dffd595e..647d7ac793c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,6 +68,7 @@ To learn more about active deprecations, we recommend checking [GitHub Discussio ### New +- **General**: Introduce new Azure Cosmos DB Change Feed Scaler ([#7556](https://github.com/kedacore/keda/issues/7556)) - TODO ([#XXX](https://github.com/kedacore/keda/issues/XXX)) #### Experimental diff --git a/pkg/scalers/azure_cosmosdb_scaler.go b/pkg/scalers/azure_cosmosdb_scaler.go new file mode 100644 index 00000000000..d440b15758c --- /dev/null +++ b/pkg/scalers/azure_cosmosdb_scaler.go @@ -0,0 +1,593 @@ +package scalers + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/go-logr/logr" + v2 "k8s.io/api/autoscaling/v2" + "k8s.io/metrics/pkg/apis/external_metrics" + + kedav1alpha1 "github.com/kedacore/keda/v2/apis/keda/v1alpha1" + "github.com/kedacore/keda/v2/pkg/scalers/azure" + "github.com/kedacore/keda/v2/pkg/scalers/scalersconfig" + kedautil "github.com/kedacore/keda/v2/pkg/util" +) + +const ( + cosmosDBMetricType = "External" + cosmosDBRestAPIVersion = "2018-12-31" +) + +type azureCosmosDBScaler struct { + metricType v2.MetricTargetType + metadata *azureCosmosDBMetadata + cosmosClient *cosmosDBClient + logger logr.Logger +} + +type azureCosmosDBMetadata struct { + DatabaseID string `keda:"name=databaseId, order=triggerMetadata"` + ContainerID string `keda:"name=containerId, order=triggerMetadata"` + LeaseDatabaseID string `keda:"name=leaseDatabaseId, order=triggerMetadata"` + LeaseContainerID string `keda:"name=leaseContainerId, order=triggerMetadata"` + ProcessorName string `keda:"name=processorName, order=triggerMetadata"` + Endpoint string `keda:"name=endpoint, order=authParams;triggerMetadata, optional"` + Connection string `keda:"name=connection, order=authParams;resolvedEnv;triggerMetadata, optional"` + LeaseEndpoint string `keda:"name=leaseEndpoint, order=authParams;triggerMetadata, optional"` + LeaseConnection string `keda:"name=leaseConnection, order=authParams;resolvedEnv;triggerMetadata, optional"` + CosmosDBKey string `keda:"name=cosmosDBKey, order=authParams;resolvedEnv, optional"` + LeaseCosmosDBKey string `keda:"name=leaseCosmosDBKey, order=authParams;resolvedEnv, optional"` + Threshold int64 `keda:"name=changeFeedLagThreshold, order=triggerMetadata, default=100"` + ActivationThreshold int64 `keda:"name=activationChangeFeedLagThreshold, order=triggerMetadata, default=0"` + TriggerIndex int +} + +func (m *azureCosmosDBMetadata) Validate() error { + // Default lease settings to data settings if not specified + if m.LeaseConnection == "" { + m.LeaseConnection = m.Connection + } + if m.LeaseEndpoint == "" { + m.LeaseEndpoint = m.Endpoint + } + if m.LeaseCosmosDBKey == "" { + m.LeaseCosmosDBKey = m.CosmosDBKey + } + return nil +} + +// cosmosDBClient provides low-level access to Cosmos DB via the REST API +// for querying lease documents and reading the change feed. +type cosmosDBClient struct { + httpClient *http.Client + dataEndpoint string + dataKey string + leaseEndpoint string + leaseKey string + leaseDatabaseID string + leaseContainerID string + databaseID string + containerID string + processorName string + cosmosDBResourceURL string + credential azcore.TokenCredential + logger logr.Logger +} + +type leaseDocument struct { + ID string `json:"id"` + LeaseToken string `json:"LeaseToken"` + ContinuationToken string `json:"ContinuationToken"` + Owner string `json:"Owner,omitempty"` +} + +type changeFeedResponse struct { + StatusCode int + Items []json.RawMessage + SessionToken string +} + +// NewAzureCosmosDBScaler creates a new Azure Cosmos DB change feed scaler. +func NewAzureCosmosDBScaler(config *scalersconfig.ScalerConfig) (Scaler, error) { + metricType, err := GetMetricTargetType(config) + if err != nil { + return nil, fmt.Errorf("error getting scaler metric type: %w", err) + } + + logger := InitializeLogger(config, "azure_cosmosdb_scaler") + + meta, err := parseAzureCosmosDBMetadata(config) + if err != nil { + return nil, fmt.Errorf("error parsing azure cosmos db metadata: %w", err) + } + + cosmosClient, err := newCosmosDBClient(meta, config.TriggerMetadata, config.PodIdentity, logger, config.GlobalHTTPTimeout) + if err != nil { + return nil, fmt.Errorf("error creating cosmos db client: %w", err) + } + + return &azureCosmosDBScaler{ + metricType: metricType, + metadata: meta, + cosmosClient: cosmosClient, + logger: logger, + }, nil +} + +func parseAzureCosmosDBMetadata(config *scalersconfig.ScalerConfig) (*azureCosmosDBMetadata, error) { + meta := &azureCosmosDBMetadata{} + if err := config.TypedConfig(meta); err != nil { + return nil, fmt.Errorf("error parsing metadata: %w", err) + } + + switch config.PodIdentity.Provider { + case "", kedav1alpha1.PodIdentityProviderNone: + if meta.Connection == "" && (meta.Endpoint == "" || meta.CosmosDBKey == "") { + return nil, fmt.Errorf("connection string or endpoint+cosmosDBKey is required when not using pod identity") + } + case kedav1alpha1.PodIdentityProviderAzureWorkload: + if meta.Endpoint == "" && meta.Connection == "" { + return nil, fmt.Errorf("endpoint or connection string is required when using workload identity") + } + default: + return nil, fmt.Errorf("pod identity %s not supported for azure cosmos db", config.PodIdentity.Provider) + } + + meta.TriggerIndex = config.TriggerIndex + return meta, nil +} + +func newCosmosDBClient(meta *azureCosmosDBMetadata, triggerMetadata map[string]string, podIdentity kedav1alpha1.AuthPodIdentity, logger logr.Logger, httpTimeout time.Duration) (*cosmosDBClient, error) { + if httpTimeout == 0 { + httpTimeout = 30 * time.Second + } + + client := &cosmosDBClient{ + httpClient: kedautil.CreateHTTPClient(httpTimeout, false), + leaseDatabaseID: meta.LeaseDatabaseID, + leaseContainerID: meta.LeaseContainerID, + databaseID: meta.DatabaseID, + containerID: meta.ContainerID, + processorName: meta.ProcessorName, + logger: logger, + } + + // Resolve data endpoint and key + if meta.Connection != "" { + endpoint, key, err := parseCosmosDBConnectionString(meta.Connection) + if err != nil { + return nil, fmt.Errorf("error parsing connection string: %w", err) + } + client.dataEndpoint = endpoint + client.dataKey = key + } else if meta.Endpoint != "" { + client.dataEndpoint = meta.Endpoint + client.dataKey = meta.CosmosDBKey + } + + // Resolve lease endpoint and key + if meta.LeaseConnection != "" { + endpoint, key, err := parseCosmosDBConnectionString(meta.LeaseConnection) + if err != nil { + return nil, fmt.Errorf("error parsing lease connection string: %w", err) + } + client.leaseEndpoint = endpoint + client.leaseKey = key + } else if meta.LeaseEndpoint != "" { + client.leaseEndpoint = meta.LeaseEndpoint + client.leaseKey = meta.LeaseCosmosDBKey + } + + if client.dataEndpoint == "" || client.leaseEndpoint == "" { + return nil, fmt.Errorf("failed to determine cosmos db endpoints") + } + + // Set up workload identity credential for bearer token auth + if podIdentity.Provider == kedav1alpha1.PodIdentityProviderAzureWorkload && client.dataKey == "" { + cosmosDBResourceURLProvider := func(env azure.AzEnvironment) (string, error) { + return env.ResourceIdentifiers.CosmosDB, nil + } + cosmosDBResourceURL, err := azure.ParseEnvironmentProperty(triggerMetadata, "cosmosDBResourceURL", cosmosDBResourceURLProvider) + if err != nil { + return nil, fmt.Errorf("error resolving cosmos db resource URL: %w", err) + } + client.cosmosDBResourceURL = cosmosDBResourceURL + + cred, err := azure.NewChainedCredential(logger, podIdentity) + if err != nil { + return nil, fmt.Errorf("error creating azure credential for workload identity: %w", err) + } + client.credential = cred + } + + return client, nil +} + +func parseCosmosDBConnectionString(connectionString string) (string, string, error) { + parts := strings.Split(connectionString, ";") + var endpoint, key string + + for _, part := range parts { + part = strings.TrimSpace(part) + if strings.HasPrefix(part, "AccountEndpoint=") { + endpoint = strings.TrimPrefix(part, "AccountEndpoint=") + } else if strings.HasPrefix(part, "AccountKey=") { + key = strings.TrimPrefix(part, "AccountKey=") + } + } + + if endpoint == "" || key == "" { + return "", "", fmt.Errorf("invalid connection string: missing AccountEndpoint or AccountKey") + } + + return endpoint, key, nil +} + +// setAuthHeader sets the Authorization header using either master key HMAC-SHA256 or bearer token. +func (c *cosmosDBClient) setAuthHeader(req *http.Request, verb, resourceType, resourceLink, date, key string) error { + if key != "" { + token, err := generateCosmosDBAuthToken(verb, resourceType, resourceLink, date, key) + if err != nil { + return fmt.Errorf("error generating auth token: %w", err) + } + req.Header.Set("Authorization", token) + return nil + } + + if c.credential != nil { + tk, err := c.credential.GetToken(req.Context(), policy.TokenRequestOptions{ + Scopes: []string{c.cosmosDBResourceURL + "/.default"}, + }) + if err != nil { + return fmt.Errorf("error acquiring bearer token: %w", err) + } + req.Header.Set("Authorization", url.QueryEscape(fmt.Sprintf("type=aad&ver=1.0&sig=%s", tk.Token))) + return nil + } + + return fmt.Errorf("no authentication method available: provide a key or configure workload identity") +} + +// generateCosmosDBAuthToken generates an HMAC-SHA256 auth token for Cosmos DB REST API. +// Format: type=master&ver=1.0&sig={hashsignature} +// Signature input: {verb}\n{resourceType}\n{resourceLink}\n{date}\n\n +func generateCosmosDBAuthToken(verb, resourceType, resourceLink, date, key string) (string, error) { + keyBytes, err := base64.StdEncoding.DecodeString(key) + if err != nil { + return "", fmt.Errorf("error decoding cosmos db key: %w", err) + } + + text := fmt.Sprintf("%s\n%s\n%s\n%s\n\n", + strings.ToLower(verb), + strings.ToLower(resourceType), + resourceLink, + strings.ToLower(date)) + + h := hmac.New(sha256.New, keyBytes) + h.Write([]byte(text)) + signature := base64.StdEncoding.EncodeToString(h.Sum(nil)) + + return url.QueryEscape(fmt.Sprintf("type=master&ver=1.0&sig=%s", signature)), nil +} + +func (c *cosmosDBClient) queryLeases(ctx context.Context) ([]leaseDocument, error) { + resourceLink := fmt.Sprintf("dbs/%s/colls/%s", c.leaseDatabaseID, c.leaseContainerID) + reqURL := fmt.Sprintf("%s/%s/docs", strings.TrimRight(c.leaseEndpoint, "/"), resourceLink) + + prefixJSON, err := json.Marshal(c.processorName) + if err != nil { + return nil, fmt.Errorf("error marshaling processor name: %w", err) + } + body := fmt.Sprintf(`{"query":"SELECT * FROM c WHERE STARTSWITH(c.id, @prefix)","parameters":[{"name":"@prefix","value":%s}]}`, string(prefixJSON)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, strings.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + + now := time.Now().UTC().Format(http.TimeFormat) + req.Header.Set("x-ms-date", now) + req.Header.Set("x-ms-version", cosmosDBRestAPIVersion) + req.Header.Set("Content-Type", "application/query+json") + req.Header.Set("x-ms-documentdb-isquery", "true") + req.Header.Set("x-ms-documentdb-query-enablecrosspartition", "true") + + if err := c.setAuthHeader(req, http.MethodPost, "docs", resourceLink, now, c.leaseKey); err != nil { + return nil, fmt.Errorf("error setting auth header: %w", err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("error executing request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("query failed with status %d: %s", resp.StatusCode, string(respBody)) + } + + var result struct { + Documents []json.RawMessage `json:"Documents"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("error decoding response: %w", err) + } + + // Parse and filter out metadata documents (those without LeaseToken or ContinuationToken) + var leases []leaseDocument + for _, raw := range result.Documents { + var doc leaseDocument + if err := json.Unmarshal(raw, &doc); err != nil { + continue + } + if doc.LeaseToken != "" && doc.ContinuationToken != "" { + leases = append(leases, doc) + } + } + + return leases, nil +} + +func (c *cosmosDBClient) readChangeFeed(ctx context.Context, partitionKeyRangeID, continuationToken string) (*changeFeedResponse, error) { + resourceLink := fmt.Sprintf("dbs/%s/colls/%s", c.databaseID, c.containerID) + reqURL := fmt.Sprintf("%s/%s/docs", strings.TrimRight(c.dataEndpoint, "/"), resourceLink) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + + now := time.Now().UTC().Format(http.TimeFormat) + req.Header.Set("x-ms-date", now) + req.Header.Set("x-ms-version", cosmosDBRestAPIVersion) + req.Header.Set("x-ms-documentdb-partitionkeyrangeid", partitionKeyRangeID) + req.Header.Set("A-IM", "Incremental feed") + req.Header.Set("x-ms-max-item-count", "1") + + if continuationToken != "" { + req.Header.Set("If-None-Match", continuationToken) + } + + if err := c.setAuthHeader(req, http.MethodGet, "docs", resourceLink, now, c.dataKey); err != nil { + return nil, fmt.Errorf("error setting auth header: %w", err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("error executing request: %w", err) + } + defer resp.Body.Close() + + cfResp := &changeFeedResponse{ + StatusCode: resp.StatusCode, + SessionToken: resp.Header.Get("x-ms-session-token"), + } + + if resp.StatusCode == http.StatusNotModified || resp.StatusCode == http.StatusGone { + return cfResp, nil + } + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("read change feed failed with status %d: %s", resp.StatusCode, string(respBody)) + } + + var result struct { + Documents []json.RawMessage `json:"Documents"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("error decoding response: %w", err) + } + + cfResp.Items = result.Documents + return cfResp, nil +} + +// estimateLag estimates the total change feed lag across all partitions and +// returns both the total lag and partition count. +// If a partition split (410 Gone) is detected, it retries once to get fresh lease data. +func (c *cosmosDBClient) estimateLag(ctx context.Context) (totalLag int64, partitionCount int64, err error) { + totalLag, partitionCount, splitDetected, err := c.estimateOnce(ctx) + if err != nil { + return 0, 0, err + } + if splitDetected { + c.logger.Info("Warning: partition split detected, re-reading leases") + totalLag, partitionCount, _, err = c.estimateOnce(ctx) + if err != nil { + return 0, 0, err + } + } + return totalLag, partitionCount, nil +} + +func (c *cosmosDBClient) estimateOnce(ctx context.Context) (int64, int64, bool, error) { + leases, err := c.queryLeases(ctx) + if err != nil { + return 0, 0, false, fmt.Errorf("error querying leases: %w", err) + } + + if len(leases) == 0 { + c.logger.V(1).Info("no lease documents found in lease container") + return 0, 0, false, nil + } + + c.logger.V(1).Info(fmt.Sprintf("found %d lease documents", len(leases))) + + totalLag := int64(0) + splitDetected := false + + for _, lease := range leases { + lag, isSplit, err := c.estimatePartitionLag(ctx, lease) + if err != nil { + return 0, 0, false, fmt.Errorf("error estimating lag for partition %s: %w", lease.LeaseToken, err) + } + if isSplit { + c.logger.Info(fmt.Sprintf("Warning: partition %s returned 410 Gone (split/merge detected)", lease.LeaseToken)) + splitDetected = true + continue + } + c.logger.V(1).Info(fmt.Sprintf("partition %s: estimated lag = %d, owner = %s", lease.LeaseToken, lag, lease.Owner)) + if lag > 0 { + totalLag += lag + } + } + + // Cap to prevent int64 overflow from summing across many partitions + if totalLag < 0 { + totalLag = math.MaxInt64 + } + + return totalLag, int64(len(leases)), splitDetected, nil +} + +// estimatePartitionLag calculates the lag for a single partition. +// Algorithm (matching .NET/Java SDKs): +// 1. Read change feed with maxItemCount=1 starting from the lease's continuation token +// 2. Extract latest LSN from session token +// 3. If items present: lag = sessionLSN - firstItem._lsn + 1 +// 4. If no items (304): lag = 0 (caught up) +// 5. If 410 Gone: report lag = -1 (split/merge) +func (c *cosmosDBClient) estimatePartitionLag(ctx context.Context, lease leaseDocument) (int64, bool, error) { + cfResp, err := c.readChangeFeed(ctx, lease.LeaseToken, lease.ContinuationToken) + if err != nil { + return 0, false, fmt.Errorf("error reading change feed for partition %s: %w", lease.LeaseToken, err) + } + + // 410 Gone indicates partition split or merge + if cfResp.StatusCode == http.StatusGone { + return 0, true, nil + } + + // 304 Not Modified or empty results means processor is caught up + if cfResp.StatusCode == http.StatusNotModified || len(cfResp.Items) == 0 { + return 0, false, nil + } + + // Calculate lag: sessionLSN - firstItemLSN + 1 + sessionLSN, err := parseLSNFromSessionToken(cfResp.SessionToken) + if err != nil || sessionLSN < 0 { + c.logger.V(1).Info(fmt.Sprintf("partition %s: could not parse session token LSN (token: %s), assuming no lag", lease.LeaseToken, cfResp.SessionToken)) + return 0, false, nil + } + + firstItemLSN, err := extractItemLSN(cfResp.Items[0]) + if err != nil || firstItemLSN < 0 { + c.logger.V(1).Info(fmt.Sprintf("partition %s: could not extract _lsn from first item, assuming no lag", lease.LeaseToken)) + return 0, false, nil + } + + lag := sessionLSN - firstItemLSN + 1 + if lag < 0 { + c.logger.V(1).Info(fmt.Sprintf("partition %s: negative lag (sessionLSN=%d, firstItemLSN=%d), assuming no lag", lease.LeaseToken, sessionLSN, firstItemLSN)) + return 0, false, nil + } + + return lag, false, nil +} + +// extractLSNFromSessionToken extracts the LSN from a Cosmos DB session token. +// Session token formats: +// - Simple: "{pkRangeId}:{lsn}" +// - Compound: "{pkRangeId}:{localLsn}#{globalLsn}" +// +// This matches the logic in both the .NET SDK (ChangeFeedEstimatorIterator.ExtractLsnFromSessionToken) +// and Java SDK (IncrementalChangeFeedProcessorImpl). +func extractLSNFromSessionToken(sessionToken string) string { + if sessionToken == "" { + return "" + } + + colonIdx := strings.IndexByte(sessionToken, ':') + if colonIdx < 0 { + return sessionToken + } + parsed := sessionToken[colonIdx+1:] + + segments := strings.Split(parsed, "#") + if len(segments) >= 2 { + return segments[1] // Global LSN + } + return segments[0] +} + +func parseLSNFromSessionToken(sessionToken string) (int64, error) { + lsnStr := extractLSNFromSessionToken(sessionToken) + if lsnStr == "" { + return -1, fmt.Errorf("empty session token") + } + return strconv.ParseInt(lsnStr, 10, 64) +} + +// extractItemLSN extracts the _lsn value from a Cosmos DB change feed document. +func extractItemLSN(item json.RawMessage) (int64, error) { + var doc struct { + LSN json.Number `json:"_lsn"` + } + if err := json.Unmarshal(item, &doc); err != nil { + return -1, fmt.Errorf("parsing item: %w", err) + } + return doc.LSN.Int64() +} + +// GetMetricSpecForScaling returns the metric spec for scaling. +func (s *azureCosmosDBScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpec { + metricName := kedautil.NormalizeString(fmt.Sprintf("azure-cosmosdb-%s-%s", + s.metadata.LeaseContainerID, s.metadata.ProcessorName)) + externalMetric := &v2.ExternalMetricSource{ + Metric: v2.MetricIdentifier{ + Name: GenerateMetricNameWithIndex(s.metadata.TriggerIndex, metricName), + }, + Target: GetMetricTarget(s.metricType, s.metadata.Threshold), + } + metricSpec := v2.MetricSpec{External: externalMetric, Type: cosmosDBMetricType} + return []v2.MetricSpec{metricSpec} +} + +// getChangeFeedTotalLagRelatedToPartitionAmount caps the total lag to prevent scaling beyond +// the number of partitions. This matches the EventHub scaler's approach. +func getChangeFeedTotalLagRelatedToPartitionAmount(totalLag int64, partitionCount int64, threshold int64) int64 { + if threshold > 0 && (totalLag/threshold) > partitionCount { + return partitionCount * threshold + } + return totalLag +} + +// GetMetricsAndActivity returns the metric value and activity status. +func (s *azureCosmosDBScaler) GetMetricsAndActivity(ctx context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) { + totalLag, partitionCount, err := s.cosmosClient.estimateLag(ctx) + if err != nil { + return []external_metrics.ExternalMetricValue{}, false, fmt.Errorf("error getting cosmos db change feed lag: %w", err) + } + + // Don't scale out beyond the number of partitions + lagRelatedToPartitionCount := getChangeFeedTotalLagRelatedToPartitionAmount(totalLag, partitionCount, s.metadata.Threshold) + + s.logger.V(1).Info(fmt.Sprintf("Cosmos DB change feed total lag: %d, scaling for a lag of %d related to %d partitions", + totalLag, lagRelatedToPartitionCount, partitionCount)) + + metric := GenerateMetricInMili(metricName, float64(lagRelatedToPartitionCount)) + return []external_metrics.ExternalMetricValue{metric}, totalLag > s.metadata.ActivationThreshold, nil +} + +// Close cleans up the scaler resources. +func (s *azureCosmosDBScaler) Close(context.Context) error { + if s.cosmosClient != nil && s.cosmosClient.httpClient != nil { + s.cosmosClient.httpClient.CloseIdleConnections() + } + return nil +} diff --git a/pkg/scalers/azure_cosmosdb_scaler_test.go b/pkg/scalers/azure_cosmosdb_scaler_test.go new file mode 100644 index 00000000000..0a5df65f77d --- /dev/null +++ b/pkg/scalers/azure_cosmosdb_scaler_test.go @@ -0,0 +1,1130 @@ +package scalers + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-logr/logr" + "github.com/stretchr/testify/assert" + v2 "k8s.io/api/autoscaling/v2" + + kedav1alpha1 "github.com/kedacore/keda/v2/apis/keda/v1alpha1" + "github.com/kedacore/keda/v2/pkg/scalers/scalersconfig" +) + +var testCosmosDBResolvedEnv = map[string]string{ + "COSMOS_CONNECTION": "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=dGVzdGtleQ==", +} + +type parseCosmosDBMetadataTestData struct { + name string + metadata map[string]string + isError bool + resolvedEnv map[string]string + authParams map[string]string + podIdentity kedav1alpha1.PodIdentityProvider +} + +type cosmosDBMetricIdentifier struct { + name string + metadataTestData *parseCosmosDBMetadataTestData + triggerIndex int + metricName string +} + +var testCosmosDBMetadata = []parseCosmosDBMetadataTestData{ + { + name: "nothing passed", + metadata: map[string]string{}, + isError: true, + resolvedEnv: testCosmosDBResolvedEnv, + authParams: map[string]string{}, + podIdentity: "", + }, + { + name: "properly formed with connection string", + metadata: map[string]string{ + "connectionFromEnv": "COSMOS_CONNECTION", + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + }, + isError: false, + resolvedEnv: testCosmosDBResolvedEnv, + authParams: map[string]string{}, + podIdentity: "", + }, + { + name: "missing database id", + metadata: map[string]string{ + "connectionFromEnv": "COSMOS_CONNECTION", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + }, + isError: true, + resolvedEnv: testCosmosDBResolvedEnv, + authParams: map[string]string{}, + podIdentity: "", + }, + { + name: "missing container id", + metadata: map[string]string{ + "connectionFromEnv": "COSMOS_CONNECTION", + "databaseId": "testdb", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + }, + isError: true, + resolvedEnv: testCosmosDBResolvedEnv, + authParams: map[string]string{}, + podIdentity: "", + }, + { + name: "missing lease database id", + metadata: map[string]string{ + "connectionFromEnv": "COSMOS_CONNECTION", + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseContainerId": "leases", + "processorName": "testprocessor", + }, + isError: true, + resolvedEnv: testCosmosDBResolvedEnv, + authParams: map[string]string{}, + podIdentity: "", + }, + { + name: "missing lease container id", + metadata: map[string]string{ + "connectionFromEnv": "COSMOS_CONNECTION", + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "processorName": "testprocessor", + }, + isError: true, + resolvedEnv: testCosmosDBResolvedEnv, + authParams: map[string]string{}, + podIdentity: "", + }, + { + name: "missing processor name", + metadata: map[string]string{ + "connectionFromEnv": "COSMOS_CONNECTION", + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + }, + isError: true, + resolvedEnv: testCosmosDBResolvedEnv, + authParams: map[string]string{}, + podIdentity: "", + }, + { + name: "missing connection and key", + metadata: map[string]string{ + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + }, + isError: true, + resolvedEnv: map[string]string{}, + authParams: map[string]string{}, + podIdentity: "", + }, + { + name: "connection from authParams", + metadata: map[string]string{ + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + }, + isError: false, + resolvedEnv: map[string]string{}, + authParams: map[string]string{ + "connection": "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=dGVzdGtleQ==", + }, + podIdentity: kedav1alpha1.PodIdentityProviderNone, + }, + { + name: "endpoint with key", + metadata: map[string]string{ + "endpoint": "https://test.documents.azure.com:443/", + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + }, + isError: false, + resolvedEnv: map[string]string{}, + authParams: map[string]string{ + "cosmosDBKey": "dGVzdGtleQ==", + }, + podIdentity: "", + }, + { + name: "podIdentity azure-workload with endpoint", + metadata: map[string]string{ + "endpoint": "https://test.documents.azure.com:443/", + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + }, + isError: false, + resolvedEnv: map[string]string{}, + authParams: map[string]string{ + "cosmosDBKey": "dGVzdGtleQ==", + }, + podIdentity: kedav1alpha1.PodIdentityProviderAzureWorkload, + }, + { + name: "podIdentity azure-workload without endpoint or connection", + metadata: map[string]string{ + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + }, + isError: true, + resolvedEnv: map[string]string{}, + authParams: map[string]string{}, + podIdentity: kedav1alpha1.PodIdentityProviderAzureWorkload, + }, + { + name: "separate lease connection", + metadata: map[string]string{ + "connectionFromEnv": "COSMOS_CONNECTION", + "leaseConnectionFromEnv": "COSMOS_CONNECTION", + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + }, + isError: false, + resolvedEnv: testCosmosDBResolvedEnv, + authParams: map[string]string{}, + podIdentity: "", + }, + { + name: "invalid changeFeedLagThreshold", + metadata: map[string]string{ + "connectionFromEnv": "COSMOS_CONNECTION", + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + "changeFeedLagThreshold": "invalid", + }, + isError: true, + resolvedEnv: testCosmosDBResolvedEnv, + authParams: map[string]string{}, + podIdentity: "", + }, + { + name: "invalid activationChangeFeedLagThreshold", + metadata: map[string]string{ + "connectionFromEnv": "COSMOS_CONNECTION", + "databaseId": "testdb", + "containerId": "testcontainer", + "leaseDatabaseId": "testdb", + "leaseContainerId": "leases", + "processorName": "testprocessor", + "activationChangeFeedLagThreshold": "invalid", + }, + isError: true, + resolvedEnv: testCosmosDBResolvedEnv, + authParams: map[string]string{}, + podIdentity: "", + }, +} + +var cosmosDBMetricIdentifiers = []cosmosDBMetricIdentifier{ + { + name: "properly formed metric", + metadataTestData: &testCosmosDBMetadata[1], + triggerIndex: 0, + metricName: "s0-azure-cosmosdb-leases-testprocessor", + }, + { + name: "endpoint with key metric", + metadataTestData: &testCosmosDBMetadata[9], + triggerIndex: 1, + metricName: "s1-azure-cosmosdb-leases-testprocessor", + }, +} + +func TestCosmosDBParseMetadata(t *testing.T) { + for _, testData := range testCosmosDBMetadata { + t.Run(testData.name, func(t *testing.T) { + config := &scalersconfig.ScalerConfig{ + TriggerMetadata: testData.metadata, + ResolvedEnv: testData.resolvedEnv, + AuthParams: testData.authParams, + PodIdentity: kedav1alpha1.AuthPodIdentity{Provider: testData.podIdentity}, + } + + _, err := parseAzureCosmosDBMetadata(config) + if err != nil && !testData.isError { + t.Errorf("Expected success but got error: %v", err) + } + if testData.isError && err == nil { + t.Errorf("Expected error but got success. testData: %v", testData) + } + }) + } +} + +func TestCosmosDBGetMetricSpecForScaling(t *testing.T) { + for _, testData := range cosmosDBMetricIdentifiers { + t.Run(testData.name, func(t *testing.T) { + config := &scalersconfig.ScalerConfig{ + TriggerMetadata: testData.metadataTestData.metadata, + ResolvedEnv: testData.metadataTestData.resolvedEnv, + AuthParams: testData.metadataTestData.authParams, + PodIdentity: kedav1alpha1.AuthPodIdentity{Provider: testData.metadataTestData.podIdentity}, + TriggerIndex: testData.triggerIndex, + } + + meta, err := parseAzureCosmosDBMetadata(config) + if err != nil { + t.Fatal("Could not parse metadata:", err) + } + + mockScaler := azureCosmosDBScaler{ + metadata: meta, + logger: logr.Discard(), + metricType: v2.AverageValueMetricType, + } + + metricSpec := mockScaler.GetMetricSpecForScaling(context.Background()) + metricName := metricSpec[0].External.Metric.Name + assert.Equal(t, testData.metricName, metricName) + }) + } +} + +func TestCosmosDBConnectionStringParsing(t *testing.T) { + testCases := []struct { + name string + connectionStr string + expectError bool + expectedEndpoint string + }{ + { + name: "valid connection string", + connectionStr: "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=dGVzdGtleQ==", + expectError: false, + expectedEndpoint: "https://test.documents.azure.com:443/", + }, + { + name: "missing endpoint", + connectionStr: "AccountKey=dGVzdGtleQ==", + expectError: true, + }, + { + name: "missing key", + connectionStr: "AccountEndpoint=https://test.documents.azure.com:443/", + expectError: true, + }, + { + name: "empty string", + connectionStr: "", + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + endpoint, key, err := parseCosmosDBConnectionString(tc.connectionStr) + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expectedEndpoint, endpoint) + assert.NotEmpty(t, key) + } + }) + } +} + +func TestExtractLSNFromSessionToken(t *testing.T) { + testCases := []struct { + name string + token string + expectedLSN string + }{ + { + name: "simple format", + token: "0:123", + expectedLSN: "123", + }, + { + name: "compound format with global LSN", + token: "0:1#100#2", + expectedLSN: "100", + }, + { + name: "two segments", + token: "5:42#999", + expectedLSN: "999", + }, + { + name: "empty token", + token: "", + expectedLSN: "", + }, + { + name: "no colon", + token: "justanumber", + expectedLSN: "justanumber", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + lsn := extractLSNFromSessionToken(tc.token) + assert.Equal(t, tc.expectedLSN, lsn) + }) + } +} + +func TestExtractItemLSN(t *testing.T) { + testCases := []struct { + name string + item string + expectedLSN int64 + expectError bool + }{ + { + name: "numeric LSN", + item: `{"_lsn": 1234}`, + expectedLSN: 1234, + }, + { + name: "string LSN", + item: `{"_lsn": "5678"}`, + expectedLSN: 5678, + }, + { + name: "missing LSN", + item: `{"id": "doc1"}`, + expectedLSN: 0, + expectError: true, + }, + { + name: "invalid JSON", + item: `not json`, + expectedLSN: -1, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + lsn, err := extractItemLSN(json.RawMessage(tc.item)) + if tc.expectError { + assert.True(t, err != nil || lsn <= 0) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expectedLSN, lsn) + } + }) + } +} + +func TestCosmosDBAuthTokenGeneration(t *testing.T) { + token, err := generateCosmosDBAuthToken("get", "docs", "dbs/testdb/colls/testcol", "thu, 01 jan 2024 00:00:00 gmt", "dGVzdGtleQ==") + assert.NoError(t, err) + assert.Contains(t, token, "type%3Dmaster%26ver%3D1.0%26sig%3D") +} + +func TestCosmosDBAuthTokenGenerationInvalidKey(t *testing.T) { + _, err := generateCosmosDBAuthToken("get", "docs", "dbs/testdb/colls/testcol", "date", "not-valid-base64!!!") + assert.Error(t, err) +} + +func TestCosmosDBLeaseParsingDotNetFormat(t *testing.T) { + // Realistic .NET SDK lease documents have: version=0, FeedRange, Mode, properties fields. + // The scaler must parse LeaseToken and ContinuationToken and ignore the extra fields. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + // Return raw JSON matching actual .NET SDK lease format + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[ + { + "id": "host1.documents.azure.com_abc==_def=..6", + "version": 0, + "_etag": "\"08000b63-0000-0800-0000-69a8c6640000\"", + "LeaseToken": "6", + "FeedRange": {"Range": {"min": "36DB6DB6DB6DB6DB6DB6DB6DB6DB6DB6", "max": "FF"}}, + "Owner": "dotnet-host1", + "ContinuationToken": "\"511\"", + "properties": {}, + "timestamp": "2026-03-04T23:55:16.5233511Z", + "Mode": "Incremental Feed", + "_rid": "abc123", + "_self": "dbs/abc/colls/def/docs/ghi", + "_ts": 1772668516 + }, + { + "id": "host1.documents.azure.com_abc==_def=..3", + "version": 0, + "LeaseToken": "3", + "FeedRange": {"Range": {"min": "0", "max": "36DB6DB6DB6DB6DB6DB6DB6DB6DB6DB6"}}, + "Owner": "dotnet-host1", + "ContinuationToken": "\"248\"", + "properties": {}, + "Mode": "Incremental Feed" + }, + { + "id": ".metadata.lease", + "version": 0, + "Owner": "", + "properties": {} + } + ]}`)) + case "/dbs/testdb/colls/data/docs": + pkRangeID := r.Header.Get("x-ms-documentdb-partitionkeyrangeid") + switch pkRangeID { + case "6": + // Partition 6 has lag: sessionLSN=600, itemLSN=512, lag=89 + w.Header().Set("x-ms-session-token", "6:0#600") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[{"id":"doc1","_lsn":512}]}`)) + case "3": + // Partition 3 is caught up + w.Header().Set("x-ms-session-token", "3:0#248") + w.WriteHeader(http.StatusNotModified) + } + } + })) + defer server.Close() + + client := &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "data", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + processorName: "testprocessor", + } + + totalLag, _, err := client.estimateLag(context.Background()) + assert.NoError(t, err) + // Only partition 6 has lag; partition 3 is caught up; metadata doc is filtered + assert.Equal(t, int64(89), totalLag) +} + +func TestCosmosDBLeaseParsingJavaFormat(t *testing.T) { + // Realistic Java SDK lease documents: no version field, no FeedRange/Mode/properties. + // The scaler must parse these identically. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[ + { + "id": "myhost.documents.azure.com_changefeed-estimator_data..2", + "_etag": "\"0100baf0-0000-0800-0000-69a8c5560000\"", + "LeaseToken": "2", + "ContinuationToken": "\"248\"", + "timestamp": "2026-03-04T23:50:46.219570110Z", + "Owner": "java-host1", + "_rid": "5jBSAKD6NqgELTEBAAAAAA==", + "_ts": 1772668246 + }, + { + "id": "myhost.documents.azure.com_changefeed-estimator_data..5", + "LeaseToken": "5", + "ContinuationToken": "\"100\"", + "Owner": "java-host2" + }, + { + "id": ".lock", + "_etag": "\"abc\"", + "Owner": "" + } + ]}`)) + case "/dbs/testdb/colls/data/docs": + pkRangeID := r.Header.Get("x-ms-documentdb-partitionkeyrangeid") + switch pkRangeID { + case "2": + // Partition 2 has lag: sessionLSN=400, itemLSN=249, lag=152 + w.Header().Set("x-ms-session-token", "2:0#400") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[{"id":"doc1","_lsn":249}]}`)) + case "5": + // Partition 5 also has lag: sessionLSN=200, itemLSN=101, lag=100 + w.Header().Set("x-ms-session-token", "5:0#200") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[{"id":"doc2","_lsn":101}]}`)) + } + } + })) + defer server.Close() + + client := &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "data", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + processorName: "testprocessor", + } + + totalLag, _, err := client.estimateLag(context.Background()) + assert.NoError(t, err) + // Both partitions 2 and 5 have lag; lock doc is filtered out + assert.Equal(t, int64(252), totalLag) +} + +func TestCosmosDBLeaseParsingMixedFormats(t *testing.T) { + // Edge case: lease container might contain docs from both SDKs (e.g. during migration). + // The scaler should handle this gracefully since it only reads common fields. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[ + { + "id": "dotnet-lease", + "version": 0, + "LeaseToken": "0", + "ContinuationToken": "\"500\"", + "Owner": "dotnet-host", + "FeedRange": {"Range": {"min": "0", "max": "80"}}, + "Mode": "Incremental Feed" + }, + { + "id": "java-lease", + "LeaseToken": "1", + "ContinuationToken": "\"300\"", + "Owner": "java-host" + } + ]}`)) + case "/dbs/testdb/colls/data/docs": + // Both partitions have lag + w.Header().Set("x-ms-session-token", "0:0#700") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[{"id":"doc1","_lsn":550}]}`)) + } + })) + defer server.Close() + + client := &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "data", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + processorName: "testprocessor", + } + + totalLag, _, err := client.estimateLag(context.Background()) + assert.NoError(t, err) + assert.Equal(t, int64(302), totalLag) +} + +func TestCosmosDBLeaseParsingEPKBasedDotNet(t *testing.T) { + // .NET SDK EPK-based leases (version=1) use FeedRange with EPK ranges. + // ContinuationToken is still a quoted LSN for incremental feed mode. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[ + { + "id": "host1..epk..0-AA", + "version": 1, + "LeaseToken": "0", + "FeedRange": {"Range": {"min": "", "max": "AA"}}, + "Owner": "dotnet-host1", + "ContinuationToken": "\"750\"", + "Mode": "LatestVersion" + }, + { + "id": "host1..epk..AA-FF", + "version": 1, + "LeaseToken": "1", + "FeedRange": {"Range": {"min": "AA", "max": "FF"}}, + "Owner": "dotnet-host1", + "ContinuationToken": "\"320\"", + "Mode": "LatestVersion" + } + ]}`)) + case "/dbs/testdb/colls/data/docs": + pkRangeID := r.Header.Get("x-ms-documentdb-partitionkeyrangeid") + switch pkRangeID { + case "0": + w.Header().Set("x-ms-session-token", "0:0#900") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[{"id":"doc1","_lsn":751}]}`)) + case "1": + w.Header().Set("x-ms-session-token", "1:0#320") + w.WriteHeader(http.StatusNotModified) + } + } + })) + defer server.Close() + + client := &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "data", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + processorName: "testprocessor", + } + + totalLag, _, err := client.estimateLag(context.Background()) + assert.NoError(t, err) + // Partition 0 has lag (900-751+1=150), partition 1 is caught up + assert.Equal(t, int64(150), totalLag) +} + +func TestCosmosDBLeaseParsingEPKBasedJava(t *testing.T) { + // Java SDK EPK-based leases (version=1) may use Base64-encoded ContinuationTokens. + // The scaler passes ContinuationToken as-is to If-None-Match, and Cosmos DB + // recognizes its own tokens regardless of encoding. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[ + { + "id": "java-epk-lease-0", + "version": 1, + "LeaseToken": "0", + "ContinuationToken": "eyJWIjoiMiIsIlJpZCI6ImFiYz0iLCJDb250aW51YXRpb24iOlt7InRva2VuIjoiXCI1MDBcIiIsInJhbmdlIjp7Im1pbiI6IiIsIm1heCI6IkZGIn19XX0=", + "Owner": "java-host1", + "feedRange": {"min": "", "max": "FF"} + }, + { + "id": "java-epk-lease-1", + "version": 1, + "LeaseToken": "1", + "ContinuationToken": "eyJWIjoiMiIsIlJpZCI6ImRlZj0iLCJDb250aW51YXRpb24iOlt7InRva2VuIjoiXCIyMDBcIiIsInJhbmdlIjp7Im1pbiI6IkZGIiwibWF4IjoiRkZGRiJ9fV19", + "Owner": "java-host2", + "feedRange": {"min": "FF", "max": "FFFF"} + } + ]}`)) + case "/dbs/testdb/colls/data/docs": + // Simulate Cosmos DB accepting Base64 continuation tokens and returning results + w.Header().Set("x-ms-session-token", "0:0#600") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[{"id":"doc1","_lsn":501}]}`)) + } + })) + defer server.Close() + + client := &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "data", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + processorName: "testprocessor", + } + + totalLag, _, err := client.estimateLag(context.Background()) + assert.NoError(t, err) + // Both partitions have lag; Base64 tokens are passed through to the server + assert.Equal(t, int64(200), totalLag) +} + +func TestCosmosDBLagEstimation(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + response := map[string]interface{}{ + "Documents": []map[string]interface{}{ + { + "id": "lease1", + "LeaseToken": "0", + "ContinuationToken": `"1000"`, + "Owner": "testowner", + }, + { + "id": "lease2", + "LeaseToken": "1", + "ContinuationToken": `"2000"`, + "Owner": "testowner", + }, + { + // Metadata doc - should be filtered out + "id": "metadata", + "Owner": "metadata", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + case "/dbs/testdb/colls/testcontainer/docs": + pkRangeID := r.Header.Get("x-ms-documentdb-partitionkeyrangeid") + + switch pkRangeID { + case "0": + // Partition with lag: sessionLSN=1100, itemLSN=1050, lag=51 + w.Header().Set("x-ms-session-token", "0:0#1100") + w.Header().Set("Content-Type", "application/json") + response := map[string]interface{}{ + "Documents": []map[string]interface{}{ + {"id": "item1", "_lsn": 1050}, + }, + } + _ = json.NewEncoder(w).Encode(response) + default: + // Partition without lag (304 Not Modified) + w.Header().Set("x-ms-session-token", "1:0#2000") + w.WriteHeader(http.StatusNotModified) + } + } + })) + defer server.Close() + + client := &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "testcontainer", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + processorName: "testprocessor", + } + + totalLag, _, err := client.estimateLag(context.Background()) + assert.NoError(t, err) + assert.Equal(t, int64(51), totalLag) +} + +func TestCosmosDBLagEstimationEmptyLeases(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := map[string]interface{}{ + "Documents": []map[string]interface{}{}, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client := &cosmosDBClient{ + httpClient: &http.Client{}, + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + processorName: "testprocessor", + } + + totalLag, _, err := client.estimateLag(context.Background()) + assert.NoError(t, err) + assert.Equal(t, int64(0), totalLag) +} + +func TestCosmosDBLagEstimationAllPartitionsLagging(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + response := map[string]interface{}{ + "Documents": []map[string]interface{}{ + {"id": "lease1", "LeaseToken": "0", "ContinuationToken": `"100"`, "Owner": "owner1"}, + {"id": "lease2", "LeaseToken": "1", "ContinuationToken": `"200"`, "Owner": "owner2"}, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + case "/dbs/testdb/colls/testcontainer/docs": + // Both partitions have lag + w.Header().Set("x-ms-session-token", "0:0#500") + w.Header().Set("Content-Type", "application/json") + response := map[string]interface{}{ + "Documents": []map[string]interface{}{ + {"id": "item1", "_lsn": 400}, + }, + } + _ = json.NewEncoder(w).Encode(response) + } + })) + defer server.Close() + + client := &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "testcontainer", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + processorName: "testprocessor", + } + + totalLag, _, err := client.estimateLag(context.Background()) + assert.NoError(t, err) + assert.Equal(t, int64(202), totalLag) +} + +func TestCosmosDBLagEstimationPartitionSplit(t *testing.T) { + changeFeedCallCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + response := map[string]interface{}{ + "Documents": []map[string]interface{}{ + {"id": "lease1", "LeaseToken": "0", "ContinuationToken": `"100"`, "Owner": "owner1"}, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + case "/dbs/testdb/colls/testcontainer/docs": + changeFeedCallCount++ + if changeFeedCallCount <= 1 { + // First call returns 410 Gone (partition split) + w.WriteHeader(http.StatusGone) + } else { + // Retry returns caught up + w.Header().Set("x-ms-session-token", "0:0#100") + w.WriteHeader(http.StatusNotModified) + } + } + })) + defer server.Close() + + client := &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "testcontainer", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + processorName: "testprocessor", + } + + totalLag, _, err := client.estimateLag(context.Background()) + assert.NoError(t, err) + assert.Equal(t, int64(0), totalLag) + // Should have retried: lease query + change feed (410) + lease query (retry) + change feed (304) + assert.GreaterOrEqual(t, changeFeedCallCount, 2) +} + +func TestCosmosDBGetMetricsAndActivity(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + response := map[string]interface{}{ + "Documents": []map[string]interface{}{ + {"id": "lease1", "LeaseToken": "0", "ContinuationToken": `"100"`, "Owner": "owner1"}, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + case "/dbs/testdb/colls/testcontainer/docs": + w.Header().Set("x-ms-session-token", "0:0#200") + w.Header().Set("Content-Type", "application/json") + response := map[string]interface{}{ + "Documents": []map[string]interface{}{ + {"id": "item1", "_lsn": 150}, + }, + } + _ = json.NewEncoder(w).Encode(response) + } + })) + defer server.Close() + + scaler := &azureCosmosDBScaler{ + metricType: v2.AverageValueMetricType, + metadata: &azureCosmosDBMetadata{ + DatabaseID: "testdb", + ContainerID: "testcontainer", + LeaseDatabaseID: "testdb", + LeaseContainerID: "leases", + ProcessorName: "testprocessor", + Threshold: 1, + ActivationThreshold: 0, + }, + cosmosClient: &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "testcontainer", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + processorName: "testprocessor", + }, + logger: logr.Discard(), + } + + metrics, isActive, err := scaler.GetMetricsAndActivity(context.Background(), "test-metric") + assert.NoError(t, err) + assert.True(t, isActive) + assert.Len(t, metrics, 1) +} + +func TestCosmosDBGetMetricsAndActivityNotActive(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + response := map[string]interface{}{ + "Documents": []map[string]interface{}{ + {"id": "lease1", "LeaseToken": "0", "ContinuationToken": `"100"`, "Owner": "owner1"}, + }, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + case "/dbs/testdb/colls/testcontainer/docs": + // Caught up + w.Header().Set("x-ms-session-token", "0:0#100") + w.WriteHeader(http.StatusNotModified) + } + })) + defer server.Close() + + scaler := &azureCosmosDBScaler{ + metricType: v2.AverageValueMetricType, + metadata: &azureCosmosDBMetadata{ + DatabaseID: "testdb", + ContainerID: "testcontainer", + LeaseDatabaseID: "testdb", + LeaseContainerID: "leases", + ProcessorName: "testprocessor", + Threshold: 1, + ActivationThreshold: 0, + }, + cosmosClient: &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "testcontainer", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + processorName: "testprocessor", + }, + logger: logr.Discard(), + } + + metrics, isActive, err := scaler.GetMetricsAndActivity(context.Background(), "test-metric") + assert.NoError(t, err) + assert.False(t, isActive) + assert.Len(t, metrics, 1) +} + +func TestCosmosDBGetMetricsAndActivityOnError(t *testing.T) { + // Server that returns 500 for all requests + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":"internal server error"}`)) + })) + defer server.Close() + + scaler := &azureCosmosDBScaler{ + metricType: v2.AverageValueMetricType, + metadata: &azureCosmosDBMetadata{ + DatabaseID: "testdb", + ContainerID: "testcontainer", + LeaseDatabaseID: "testdb", + LeaseContainerID: "leases", + ProcessorName: "testprocessor", + Threshold: 100, + ActivationThreshold: 0, + }, + cosmosClient: &cosmosDBClient{ + httpClient: &http.Client{}, + dataEndpoint: server.URL, + dataKey: "dGVzdGtleQ==", + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + databaseID: "testdb", + containerID: "testcontainer", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + processorName: "testprocessor", + }, + logger: logr.Discard(), + } + + // On error, propagate to KEDA and let fallback spec on ScaledObject handle it + _, _, err := scaler.GetMetricsAndActivity(context.Background(), "test-metric") + assert.Error(t, err) +} + +func TestCosmosDBLeaseQueryFiltersByProcessorName(t *testing.T) { + var capturedBody string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/dbs/testdb/colls/leases/docs": + bodyBytes, _ := io.ReadAll(r.Body) + capturedBody = string(bodyBytes) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Documents":[]}`)) + } + })) + defer server.Close() + + client := &cosmosDBClient{ + httpClient: &http.Client{}, + leaseEndpoint: server.URL, + leaseKey: "dGVzdGtleQ==", + leaseDatabaseID: "testdb", + leaseContainerID: "leases", + processorName: "myprocessor", + } + + _, _ = client.queryLeases(context.Background()) + assert.Contains(t, capturedBody, "STARTSWITH") + assert.Contains(t, capturedBody, "@prefix") + assert.Contains(t, capturedBody, "myprocessor") +} diff --git a/pkg/scaling/scalers_builder.go b/pkg/scaling/scalers_builder.go index e1146caa3d6..232de267cf8 100644 --- a/pkg/scaling/scalers_builder.go +++ b/pkg/scaling/scalers_builder.go @@ -145,6 +145,8 @@ func buildScaler(ctx context.Context, client client.Client, triggerType string, return scalers.NewAzureAppInsightsScaler(config) case "azure-blob": return scalers.NewAzureBlobScaler(config) + case "azure-cosmosdb": + return scalers.NewAzureCosmosDBScaler(config) case "azure-data-explorer": return scalers.NewAzureDataExplorerScaler(config) case "azure-eventhub": diff --git a/schema/generated/scalers-schema.json b/schema/generated/scalers-schema.json index 6441ad1f147..f28c2f0582b 100644 --- a/schema/generated/scalers-schema.json +++ b/schema/generated/scalers-schema.json @@ -841,6 +841,92 @@ } ] }, + { + "type": "azure-cosmosdb", + "parameters": [ + { + "name": "databaseId", + "type": "string", + "metadataVariableReadable": true + }, + { + "name": "containerId", + "type": "string", + "metadataVariableReadable": true + }, + { + "name": "leaseDatabaseId", + "type": "string", + "metadataVariableReadable": true + }, + { + "name": "leaseContainerId", + "type": "string", + "metadataVariableReadable": true + }, + { + "name": "processorName", + "type": "string", + "metadataVariableReadable": true + }, + { + "name": "endpoint", + "type": "string", + "optional": true, + "metadataVariableReadable": true, + "triggerAuthenticationVariableReadable": true + }, + { + "name": "connection", + "type": "string", + "optional": true, + "metadataVariableReadable": true, + "envVariableReadable": true, + "triggerAuthenticationVariableReadable": true + }, + { + "name": "leaseEndpoint", + "type": "string", + "optional": true, + "metadataVariableReadable": true, + "triggerAuthenticationVariableReadable": true + }, + { + "name": "leaseConnection", + "type": "string", + "optional": true, + "metadataVariableReadable": true, + "envVariableReadable": true, + "triggerAuthenticationVariableReadable": true + }, + { + "name": "cosmosDBKey", + "type": "string", + "optional": true, + "envVariableReadable": true, + "triggerAuthenticationVariableReadable": true + }, + { + "name": "leaseCosmosDBKey", + "type": "string", + "optional": true, + "envVariableReadable": true, + "triggerAuthenticationVariableReadable": true + }, + { + "name": "changeFeedLagThreshold", + "type": "string", + "default": "100", + "metadataVariableReadable": true + }, + { + "name": "activationChangeFeedLagThreshold", + "type": "string", + "default": "0", + "metadataVariableReadable": true + } + ] + }, { "type": "azure-data-explorer", "parameters": [ diff --git a/schema/generated/scalers-schema.yaml b/schema/generated/scalers-schema.yaml index b99c434a54e..58b13fce967 100644 --- a/schema/generated/scalers-schema.yaml +++ b/schema/generated/scalers-schema.yaml @@ -550,6 +550,63 @@ scalers: type: string optional: true metadataVariableReadable: true + - type: azure-cosmosdb + parameters: + - name: databaseId + type: string + metadataVariableReadable: true + - name: containerId + type: string + metadataVariableReadable: true + - name: leaseDatabaseId + type: string + metadataVariableReadable: true + - name: leaseContainerId + type: string + metadataVariableReadable: true + - name: processorName + type: string + metadataVariableReadable: true + - name: endpoint + type: string + optional: true + metadataVariableReadable: true + triggerAuthenticationVariableReadable: true + - name: connection + type: string + optional: true + metadataVariableReadable: true + envVariableReadable: true + triggerAuthenticationVariableReadable: true + - name: leaseEndpoint + type: string + optional: true + metadataVariableReadable: true + triggerAuthenticationVariableReadable: true + - name: leaseConnection + type: string + optional: true + metadataVariableReadable: true + envVariableReadable: true + triggerAuthenticationVariableReadable: true + - name: cosmosDBKey + type: string + optional: true + envVariableReadable: true + triggerAuthenticationVariableReadable: true + - name: leaseCosmosDBKey + type: string + optional: true + envVariableReadable: true + triggerAuthenticationVariableReadable: true + - name: changeFeedLagThreshold + type: string + default: "100" + metadataVariableReadable: true + - name: activationChangeFeedLagThreshold + type: string + default: "0" + metadataVariableReadable: true - type: azure-data-explorer parameters: - name: databaseName diff --git a/tests/.env b/tests/.env index efab23b86c5..d6f117770d7 100644 --- a/tests/.env +++ b/tests/.env @@ -9,6 +9,7 @@ TF_AZURE_APP_INSIGHTS_INSTRUMENTATION_KEY= TF_AZURE_APP_INSIGHTS_NAME= TF_AZURE_DATA_EXPLORER_DB= TF_AZURE_DATA_EXPLORER_ENDPOINT= +TF_AZURE_COSMOSDB_CONNECTION_STRING= AZURE_DEVOPS_BUILD_DEFINITION_ID= AZURE_DEVOPS_DEMAND_PARENT_BUILD_DEFINITION_ID= AZURE_DEVOPS_ORGANIZATION_URL= diff --git a/tests/scalers/azure/azure_cosmosdb/azure_cosmosdb_test.go b/tests/scalers/azure/azure_cosmosdb/azure_cosmosdb_test.go new file mode 100644 index 00000000000..5621f30942a --- /dev/null +++ b/tests/scalers/azure/azure_cosmosdb/azure_cosmosdb_test.go @@ -0,0 +1,332 @@ +//go:build e2e +// +build e2e + +package azure_cosmosdb_test + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "testing" + "time" + + "github.com/joho/godotenv" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/client-go/kubernetes" + + . "github.com/kedacore/keda/v2/tests/helper" +) + +// Load environment variables from .env file +var _ = godotenv.Load("../../../.env") + +const ( + testName = "azure-cosmosdb-test" +) + +var ( + connectionString = os.Getenv("TF_AZURE_COSMOSDB_CONNECTION_STRING") + testNamespace = fmt.Sprintf("%s-ns", testName) + secretName = fmt.Sprintf("%s-secret", testName) + deploymentName = fmt.Sprintf("%s-deployment", testName) + scaledObjectName = fmt.Sprintf("%s-so", testName) + databaseID = "keda-test-db" + containerID = "keda-test-container" + leaseDatabaseID = "keda-test-db" + leaseContainerID = "keda-test-leases" + processorName = "keda-test-processor" +) + +type templateData struct { + TestNamespace string + SecretName string + Connection string + DeploymentName string + ScaledObjectName string + DatabaseID string + ContainerID string + LeaseDatabaseID string + LeaseContainerID string + ProcessorName string +} + +const ( + secretTemplate = ` +apiVersion: v1 +kind: Secret +metadata: + name: {{.SecretName}} + namespace: {{.TestNamespace}} +data: + connection: {{.Connection}} +` + + triggerAuthTemplate = ` +apiVersion: keda.sh/v1alpha1 +kind: TriggerAuthentication +metadata: + name: {{.SecretName}}-trigger-auth + namespace: {{.TestNamespace}} +spec: + secretTargetRef: + - parameter: connection + name: {{.SecretName}} + key: connection +` + + deploymentTemplate = ` +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{.DeploymentName}} + namespace: {{.TestNamespace}} + labels: + app: {{.DeploymentName}} +spec: + replicas: 0 + selector: + matchLabels: + app: {{.DeploymentName}} + template: + metadata: + labels: + app: {{.DeploymentName}} + spec: + containers: + - name: {{.DeploymentName}} + image: ghcr.io/kedacore/tests-azure-cosmosdb + env: + - name: COSMOS_CONNECTION + valueFrom: + secretKeyRef: + name: {{.SecretName}} + key: connection +` + + scaledObjectTemplate = ` +apiVersion: keda.sh/v1alpha1 +kind: ScaledObject +metadata: + name: {{.ScaledObjectName}} + namespace: {{.TestNamespace}} +spec: + scaleTargetRef: + name: {{.DeploymentName}} + pollingInterval: 5 + minReplicaCount: 0 + maxReplicaCount: 1 + cooldownPeriod: 10 + triggers: + - type: azure-cosmosdb + metadata: + databaseId: {{.DatabaseID}} + containerId: {{.ContainerID}} + leaseDatabaseId: {{.LeaseDatabaseID}} + leaseContainerId: {{.LeaseContainerID}} + processorName: {{.ProcessorName}} + connectionFromEnv: COSMOS_CONNECTION + activationChangeFeedLagThreshold: "0" + authenticationRef: + name: {{.SecretName}}-trigger-auth +` +) + +func TestScaler(t *testing.T) { + // setup + ctx := context.Background() + t.Log("--- setting up ---") + require.NotEmpty(t, connectionString, "TF_AZURE_COSMOSDB_CONNECTION_STRING env variable is required for azure cosmosdb test") + + // Create Cosmos DB resources (database + containers) + setupCosmosDB(ctx, t) + + // Create kubernetes resources + kc := GetKubernetesClient(t) + data, templates := getTemplateData() + + CreateKubernetesResources(t, kc, testNamespace, data, templates) + + assert.True(t, WaitForDeploymentReplicaReadyCount(t, kc, deploymentName, testNamespace, 0, 60, 1), + "replica count should be 0 after 1 minute") + + // test scaling + testActivation(t, kc) + testScaleOut(ctx, t, kc) + testScaleIn(t, kc) + + // cleanup + DeleteKubernetesResources(t, testNamespace, data, templates) +} + +func getTemplateData() (templateData, []Template) { + base64ConnectionString := base64.StdEncoding.EncodeToString([]byte(connectionString)) + + return templateData{ + TestNamespace: testNamespace, + SecretName: secretName, + Connection: base64ConnectionString, + DeploymentName: deploymentName, + ScaledObjectName: scaledObjectName, + DatabaseID: databaseID, + ContainerID: containerID, + LeaseDatabaseID: leaseDatabaseID, + LeaseContainerID: leaseContainerID, + ProcessorName: processorName, + }, []Template{ + {Name: "secretTemplate", Config: secretTemplate}, + {Name: "triggerAuthTemplate", Config: triggerAuthTemplate}, + {Name: "deploymentTemplate", Config: deploymentTemplate}, + {Name: "scaledObjectTemplate", Config: scaledObjectTemplate}, + } +} + +func testActivation(t *testing.T, kc *kubernetes.Clientset) { + t.Log("--- testing activation ---") + AssertReplicaCountNotChangeDuringTimePeriod(t, kc, deploymentName, testNamespace, 0, 60) +} + +func testScaleOut(ctx context.Context, t *testing.T, kc *kubernetes.Clientset) { + t.Log("--- testing scale out ---") + addDocuments(ctx, t, 10) + + assert.True(t, WaitForDeploymentReplicaReadyCount(t, kc, deploymentName, testNamespace, 1, 60, 1), + "replica count should be 1 after 1 minute") +} + +func testScaleIn(t *testing.T, kc *kubernetes.Clientset) { + t.Log("--- testing scale in ---") + assert.True(t, WaitForDeploymentReplicaReadyCount(t, kc, deploymentName, testNamespace, 0, 60, 1), + "replica count should be 0 after 1 minute") +} + +// addDocuments inserts documents into the Cosmos DB data container via the REST API +// to generate change feed lag for the scaler to detect. +func addDocuments(ctx context.Context, t *testing.T, count int) { + t.Helper() + + endpoint, key, err := parseConnString(connectionString) + require.NoErrorf(t, err, "cannot parse connection string - %s", err) + + for i := 0; i < count; i++ { + docID := fmt.Sprintf("test-doc-%d-%d", GetRandomNumber(), i) + body := fmt.Sprintf(`{"id":"%s","partitionKey":"%s","message":"Test document %d"}`, docID, docID, i) + + resourceLink := fmt.Sprintf("dbs/%s/colls/%s", databaseID, containerID) + reqURL := fmt.Sprintf("%s/%s/docs", strings.TrimRight(endpoint, "/"), resourceLink) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, strings.NewReader(body)) + require.NoErrorf(t, err, "cannot create request - %s", err) + + now := time.Now().UTC().Format(http.TimeFormat) + req.Header.Set("Authorization", cosmosAuthToken("post", "docs", resourceLink, now, key)) + req.Header.Set("x-ms-date", now) + req.Header.Set("x-ms-version", "2018-12-31") + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-ms-documentdb-partitionkey", fmt.Sprintf(`["%s"]`, docID)) + + resp, err := http.DefaultClient.Do(req) + require.NoErrorf(t, err, "cannot send request - %s", err) + + respBody, _ := io.ReadAll(resp.Body) + resp.Body.Close() + + require.Truef(t, resp.StatusCode == http.StatusCreated || resp.StatusCode == http.StatusOK, + "unexpected status %d creating document: %s", resp.StatusCode, string(respBody)) + + t.Logf("Document created: %s", docID) + } +} + +func parseConnString(conn string) (string, string, error) { + var endpoint, key string + for _, part := range strings.Split(conn, ";") { + part = strings.TrimSpace(part) + switch { + case strings.HasPrefix(part, "AccountEndpoint="): + endpoint = strings.TrimPrefix(part, "AccountEndpoint=") + case strings.HasPrefix(part, "AccountKey="): + key = strings.TrimPrefix(part, "AccountKey=") + } + } + if endpoint == "" || key == "" { + return "", "", fmt.Errorf("invalid connection string: missing AccountEndpoint or AccountKey") + } + return endpoint, key, nil +} + +func cosmosAuthToken(verb, resourceType, resourceLink, date, key string) string { + keyBytes, err := base64.StdEncoding.DecodeString(key) + if err != nil { + return "" + } + text := fmt.Sprintf("%s\n%s\n%s\n%s\n\n", + strings.ToLower(verb), + strings.ToLower(resourceType), + resourceLink, + strings.ToLower(date)) + h := hmac.New(sha256.New, keyBytes) + h.Write([]byte(text)) + sig := base64.StdEncoding.EncodeToString(h.Sum(nil)) + return url.QueryEscape(fmt.Sprintf("type=master&ver=1.0&sig=%s", sig)) +} + +// setupCosmosDB creates the database, data container, and lease container if they don't exist. +func setupCosmosDB(ctx context.Context, t *testing.T) { + t.Helper() + + endpoint, key, err := parseConnString(connectionString) + require.NoErrorf(t, err, "cannot parse connection string - %s", err) + + // Create database + cosmosCreateResource(ctx, t, endpoint, key, "", "dbs", fmt.Sprintf(`{"id":"%s"}`, databaseID)) + + // Create data container with /id as partition key + dbLink := fmt.Sprintf("dbs/%s", databaseID) + cosmosCreateResource(ctx, t, endpoint, key, dbLink, "colls", + fmt.Sprintf(`{"id":"%s","partitionKey":{"paths":["/id"],"kind":"Hash"}}`, containerID)) + + // Create lease container with /id as partition key + cosmosCreateResource(ctx, t, endpoint, key, dbLink, "colls", + fmt.Sprintf(`{"id":"%s","partitionKey":{"paths":["/id"],"kind":"Hash"}}`, leaseContainerID)) +} + +// cosmosCreateResource creates a Cosmos DB resource via REST API, ignoring 409 Conflict (already exists). +func cosmosCreateResource(ctx context.Context, t *testing.T, endpoint, key, parentLink, resourceType, body string) { + t.Helper() + + reqURL := fmt.Sprintf("%s/%s/%s", strings.TrimRight(endpoint, "/"), parentLink, resourceType) + if parentLink == "" { + reqURL = fmt.Sprintf("%s/%s", strings.TrimRight(endpoint, "/"), resourceType) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, strings.NewReader(body)) + require.NoErrorf(t, err, "cannot create request - %s", err) + + now := time.Now().UTC().Format(http.TimeFormat) + req.Header.Set("Authorization", cosmosAuthToken("post", resourceType, parentLink, now, key)) + req.Header.Set("x-ms-date", now) + req.Header.Set("x-ms-version", "2018-12-31") + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + require.NoErrorf(t, err, "cannot send request - %s", err) + respBody, _ := io.ReadAll(resp.Body) + resp.Body.Close() + + if resp.StatusCode == http.StatusConflict { + t.Logf("Resource already exists (409), skipping: %s/%s", parentLink, resourceType) + return + } + + require.Truef(t, resp.StatusCode == http.StatusCreated || resp.StatusCode == http.StatusOK, + "unexpected status %d creating resource: %s", resp.StatusCode, string(respBody)) + t.Logf("Created resource: %s/%s", parentLink, resourceType) +}