diff --git a/common/dynamicconfig/constants.go b/common/dynamicconfig/constants.go index 81b276f1189..d2b291689c3 100644 --- a/common/dynamicconfig/constants.go +++ b/common/dynamicconfig/constants.go @@ -475,6 +475,14 @@ Deleted Redirect Rules will be kept in the DB (with DeleteTimestamp). After this `PollerHistoryTTL is the time to live for poller histories in the pollerHistory cache of a physical task queue. Poller histories are fetched when requiring a list of pollers that polled a given task queue.`, ) + ShutdownWorkerCacheTTL = NewGlobalDurationSetting( + "matching.ShutdownWorkerCacheTTL", + 70*time.Second, + `ShutdownWorkerCacheTTL is the time to live for entries in the shutdown worker cache. When a worker calls + ShutdownWorker, its WorkerInstanceKey is cached for this duration. Any poll arriving with a cached + WorkerInstanceKey returns empty immediately, preventing task dispatch to a shutting-down worker. + This should be longer than MatchingLongPollExpirationInterval (1 min default) to catch in-flight polls.`, + ) ReachabilityBuildIdVisibilityGracePeriod = NewNamespaceDurationSetting( "matching.wv.ReachabilityBuildIdVisibilityGracePeriod", 3*time.Minute, diff --git a/service/matching/config.go b/service/matching/config.go index 0e400dd8f2a..5c4104fd3da 100644 --- a/service/matching/config.go +++ b/service/matching/config.go @@ -72,6 +72,7 @@ type ( RedirectRuleMaxUpstreamBuildIDsPerQueue dynamicconfig.IntPropertyFnWithNamespaceFilter DeletedRuleRetentionTime dynamicconfig.DurationPropertyFnWithNamespaceFilter PollerHistoryTTL dynamicconfig.DurationPropertyFnWithNamespaceFilter + ShutdownWorkerCacheTTL dynamicconfig.DurationPropertyFn ReachabilityBuildIdVisibilityGracePeriod dynamicconfig.DurationPropertyFnWithNamespaceFilter ReachabilityCacheOpenWFsTTL dynamicconfig.DurationPropertyFn ReachabilityCacheClosedWFsTTL dynamicconfig.DurationPropertyFn @@ -310,6 +311,7 @@ func NewConfig( RedirectRuleMaxUpstreamBuildIDsPerQueue: dynamicconfig.RedirectRuleMaxUpstreamBuildIDsPerQueue.Get(dc), DeletedRuleRetentionTime: dynamicconfig.MatchingDeletedRuleRetentionTime.Get(dc), PollerHistoryTTL: dynamicconfig.PollerHistoryTTL.Get(dc), + ShutdownWorkerCacheTTL: dynamicconfig.ShutdownWorkerCacheTTL.Get(dc), ReachabilityBuildIdVisibilityGracePeriod: dynamicconfig.ReachabilityBuildIdVisibilityGracePeriod.Get(dc), ReachabilityCacheOpenWFsTTL: dynamicconfig.ReachabilityCacheOpenWFsTTL.Get(dc), ReachabilityCacheClosedWFsTTL: dynamicconfig.ReachabilityCacheClosedWFsTTL.Get(dc), diff --git a/service/matching/matching_engine.go b/service/matching/matching_engine.go index b09c10338c0..4f97cd755ff 100644 --- a/service/matching/matching_engine.go +++ b/service/matching/matching_engine.go @@ -33,6 +33,7 @@ import ( "go.temporal.io/server/client/matching" "go.temporal.io/server/common" "go.temporal.io/server/common/backoff" + "go.temporal.io/server/common/cache" "go.temporal.io/server/common/clock" hlc "go.temporal.io/server/common/clock/hybrid_logical_clock" "go.temporal.io/server/common/cluster" @@ -165,6 +166,10 @@ type ( outstandingPollers collection.SyncMap[string, context.CancelFunc] // workerInstancePollers tracks pollers by worker instance key for bulk cancellation during shutdown. workerInstancePollers workerPollerTracker + // shutdownWorkers caches WorkerInstanceKeys of workers that have initiated shutdown. + // Polls arriving with a cached key return empty immediately to prevent task dispatch + // to a shutting-down worker (handles race where poll arrives after cancellation completes). + shutdownWorkers cache.Cache // Only set if global namespaces are enabled on the cluster. namespaceReplicationQueue persistence.NamespaceReplicationQueue // Lock to serialize replication queue updates. @@ -293,6 +298,8 @@ func NewEngine( nexusResults: collection.NewSyncMap[string, chan *nexusResult](), outstandingPollers: collection.NewSyncMap[string, context.CancelFunc](), workerInstancePollers: workerPollerTracker{pollers: make(map[string]map[string]context.CancelFunc)}, + // 50000 entries ≈ 10MB (each entry ~200 bytes: UUID key + cache overhead) + shutdownWorkers: cache.New(50000, &cache.Options{TTL: config.ShutdownWorkerCacheTTL()}), namespaceReplicationQueue: namespaceReplicationQueue, userDataUpdateBatchers: collection.NewSyncMap[namespace.ID, *stream_batcher.Batcher[*userDataUpdate, error]](), rateLimiter: rateLimiter, @@ -650,6 +657,14 @@ func (e *matchingEngineImpl) PollWorkflowTaskQueue( request := req.PollRequest taskQueueName := request.TaskQueue.GetName() + // Return empty immediately if this worker has already initiated shutdown. + // This guards against polls that arrive after CancelOutstandingWorkerPolls completed. + if workerInstanceKey := request.GetWorkerInstanceKey(); workerInstanceKey != "" { + if e.shutdownWorkers.Get(workerInstanceKey) != nil { + return emptyPollWorkflowTaskQueueResponse, nil + } + } + // Namespace field is not populated for forwarded requests. if len(request.Namespace) == 0 { ns, err := e.namespaceRegistry.GetNamespaceName(namespace.ID(req.GetNamespaceId())) @@ -952,6 +967,14 @@ func (e *matchingEngineImpl) PollActivityTaskQueue( request := req.PollRequest taskQueueName := request.TaskQueue.GetName() + // Return empty immediately if this worker has already initiated shutdown. + // This guards against polls that arrive after CancelOutstandingWorkerPolls completed. + if workerInstanceKey := request.GetWorkerInstanceKey(); workerInstanceKey != "" { + if e.shutdownWorkers.Get(workerInstanceKey) != nil { + return emptyPollActivityTaskQueueResponse, nil + } + } + // Namespace field is not populated for forwarded requests. if len(request.Namespace) == 0 { ns, err := e.namespaceRegistry.GetNamespaceName(namespace.ID(req.GetNamespaceId())) @@ -1212,7 +1235,16 @@ func (e *matchingEngineImpl) CancelOutstandingWorkerPolls( ctx context.Context, request *matchingservice.CancelOutstandingWorkerPollsRequest, ) (*matchingservice.CancelOutstandingWorkerPollsResponse, error) { - cancelledCount := e.workerInstancePollers.CancelAll(request.WorkerInstanceKey) + workerInstanceKey := request.WorkerInstanceKey + cancelledCount := e.workerInstancePollers.CancelAll(workerInstanceKey) + + // Cache the WorkerInstanceKey to guard against polls that arrive after this + // cancellation completes (edge case: poll was already in-flight when shutdown started). + // Any new poll with this key will return empty immediately. + if workerInstanceKey != "" { + e.shutdownWorkers.Put(workerInstanceKey, struct{}{}) + } + e.removePollerFromHistory(ctx, request) return &matchingservice.CancelOutstandingWorkerPollsResponse{CancelledCount: cancelledCount}, nil } diff --git a/service/matching/matching_engine_test.go b/service/matching/matching_engine_test.go index 688987a7365..2b52672fb7b 100644 --- a/service/matching/matching_engine_test.go +++ b/service/matching/matching_engine_test.go @@ -42,6 +42,7 @@ import ( taskqueuespb "go.temporal.io/server/api/taskqueue/v1" tokenspb "go.temporal.io/server/api/token/v1" "go.temporal.io/server/common" + "go.temporal.io/server/common/cache" "go.temporal.io/server/common/clock" hlc "go.temporal.io/server/common/clock/hybrid_logical_clock" "go.temporal.io/server/common/cluster" @@ -5700,6 +5701,7 @@ func TestCancelOutstandingWorkerPolls(t *testing.T) { t.Parallel() engine := &matchingEngineImpl{ workerInstancePollers: workerPollerTracker{pollers: make(map[string]map[string]context.CancelFunc)}, + shutdownWorkers: cache.New(100, &cache.Options{TTL: time.Minute}), } resp, err := engine.CancelOutstandingWorkerPolls(context.Background(), @@ -5715,6 +5717,7 @@ func TestCancelOutstandingWorkerPolls(t *testing.T) { t.Parallel() engine := &matchingEngineImpl{ workerInstancePollers: workerPollerTracker{pollers: make(map[string]map[string]context.CancelFunc)}, + shutdownWorkers: cache.New(100, &cache.Options{TTL: time.Minute}), } workerKey := "test-worker" @@ -5741,6 +5744,7 @@ func TestCancelOutstandingWorkerPolls(t *testing.T) { worker2Cancelled := false engine := &matchingEngineImpl{ workerInstancePollers: workerPollerTracker{pollers: make(map[string]map[string]context.CancelFunc)}, + shutdownWorkers: cache.New(100, &cache.Options{TTL: time.Minute}), } // Set up pollers for worker1 and worker2 @@ -5763,6 +5767,7 @@ func TestCancelOutstandingWorkerPolls(t *testing.T) { t.Parallel() engine := &matchingEngineImpl{ workerInstancePollers: workerPollerTracker{pollers: make(map[string]map[string]context.CancelFunc)}, + shutdownWorkers: cache.New(100, &cache.Options{TTL: time.Minute}), } workerKey := "test-worker" @@ -5785,4 +5790,94 @@ func TestCancelOutstandingWorkerPolls(t *testing.T) { require.True(t, childCancelled, "child partition poll should be cancelled") require.True(t, parentCancelled, "parent partition poll should be cancelled") }) + + t.Run("adds worker to shutdown cache", func(t *testing.T) { + t.Parallel() + shutdownCache := cache.New(100, &cache.Options{TTL: time.Minute}) + engine := &matchingEngineImpl{ + workerInstancePollers: workerPollerTracker{pollers: make(map[string]map[string]context.CancelFunc)}, + shutdownWorkers: shutdownCache, + } + + workerKey := "test-worker" + + // Verify worker is not in cache initially + require.Nil(t, shutdownCache.Get(workerKey)) + + _, err := engine.CancelOutstandingWorkerPolls(context.Background(), + &matchingservice.CancelOutstandingWorkerPollsRequest{ + WorkerInstanceKey: workerKey, + }) + + require.NoError(t, err) + // Verify worker is now in the shutdown cache + require.NotNil(t, shutdownCache.Get(workerKey), "worker should be added to shutdown cache") + }) + + t.Run("empty worker key not added to shutdown cache", func(t *testing.T) { + t.Parallel() + shutdownCache := cache.New(100, &cache.Options{TTL: time.Minute}) + engine := &matchingEngineImpl{ + workerInstancePollers: workerPollerTracker{pollers: make(map[string]map[string]context.CancelFunc)}, + shutdownWorkers: shutdownCache, + } + + _, err := engine.CancelOutstandingWorkerPolls(context.Background(), + &matchingservice.CancelOutstandingWorkerPollsRequest{ + WorkerInstanceKey: "", // empty + }) + + require.NoError(t, err) + // Verify empty key was not added + require.Nil(t, shutdownCache.Get("")) + }) +} + +func TestPollReturnsEmptyAfterWorkerShutdown(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockNamespaceCache := namespace.NewMockRegistry(ctrl) + mockNamespaceCache.EXPECT().GetNamespaceName(gomock.Any()).Return(namespace.Name("test-namespace"), nil).AnyTimes() + + shutdownCache := cache.New(100, &cache.Options{TTL: time.Minute}) + workerKey := "shutdown-worker-key" + + // Pre-populate shutdown cache (simulating CancelOutstandingWorkerPolls was called) + shutdownCache.Put(workerKey, struct{}{}) + + engine := &matchingEngineImpl{ + shutdownWorkers: shutdownCache, + namespaceRegistry: mockNamespaceCache, + } + + t.Run("PollWorkflowTaskQueue returns empty for shutdown worker", func(t *testing.T) { + resp, err := engine.PollWorkflowTaskQueue(context.Background(), &matchingservice.PollWorkflowTaskQueueRequest{ + NamespaceId: "test-namespace-id", + PollerId: "poller-1", + PollRequest: &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: "test-namespace", + WorkerInstanceKey: workerKey, + TaskQueue: &taskqueuepb.TaskQueue{Name: "test-queue"}, + }, + }, metrics.NoopMetricsHandler) + + require.NoError(t, err) + require.Equal(t, emptyPollWorkflowTaskQueueResponse, resp) + }) + + t.Run("PollActivityTaskQueue returns empty for shutdown worker", func(t *testing.T) { + resp, err := engine.PollActivityTaskQueue(context.Background(), &matchingservice.PollActivityTaskQueueRequest{ + NamespaceId: "test-namespace-id", + PollerId: "poller-1", + PollRequest: &workflowservice.PollActivityTaskQueueRequest{ + Namespace: "test-namespace", + WorkerInstanceKey: workerKey, + TaskQueue: &taskqueuepb.TaskQueue{Name: "test-queue"}, + }, + }, metrics.NoopMetricsHandler) + + require.NoError(t, err) + require.Equal(t, emptyPollActivityTaskQueueResponse, resp) + }) }