diff --git a/api/matchingservice/v1/request_response.pb.go b/api/matchingservice/v1/request_response.pb.go index 924fc90ee2d..d3ab037ea67 100644 --- a/api/matchingservice/v1/request_response.pb.go +++ b/api/matchingservice/v1/request_response.pb.go @@ -145,9 +145,12 @@ type PollWorkflowTaskQueueResponse struct { PollerScalingDecision *v14.PollerScalingDecision `protobuf:"bytes,21,opt,name=poller_scaling_decision,json=pollerScalingDecision,proto3" json:"poller_scaling_decision,omitempty"` // Raw history bytes sent from matching service when history.sendRawHistoryBetweenInternalServices is enabled. // Matching client will deserialize this to History when it receives the response. - RawHistory *v16.History `protobuf:"bytes,22,opt,name=raw_history,json=rawHistory,proto3" json:"raw_history,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + RawHistory *v16.History `protobuf:"bytes,22,opt,name=raw_history,json=rawHistory,proto3" json:"raw_history,omitempty"` + // When true, this empty response was caused by the server completing the poll + // because the worker has been shut down via the ShutdownWorker API. + CompletedByWorkerShutdown bool `protobuf:"varint,23,opt,name=completed_by_worker_shutdown,json=completedByWorkerShutdown,proto3" json:"completed_by_worker_shutdown,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *PollWorkflowTaskQueueResponse) Reset() { @@ -327,6 +330,13 @@ func (x *PollWorkflowTaskQueueResponse) GetRawHistory() *v16.History { return nil } +func (x *PollWorkflowTaskQueueResponse) GetCompletedByWorkerShutdown() bool { + if x != nil { + return x.CompletedByWorkerShutdown + } + return false +} + // PollWorkflowTaskQueueResponseWithRawHistory is wire-compatible with PollWorkflowTaskQueueResponse. // // WIRE COMPATIBILITY PATTERN: @@ -374,9 +384,12 @@ type PollWorkflowTaskQueueResponseWithRawHistory struct { // Raw history bytes. Each element is a proto-encoded batch of history events. // When matching client deserializes this to PollWorkflowTaskQueueResponse, this field // will be automatically deserialized to the raw_history field as History. - RawHistory [][]byte `protobuf:"bytes,22,rep,name=raw_history,json=rawHistory,proto3" json:"raw_history,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + RawHistory [][]byte `protobuf:"bytes,22,rep,name=raw_history,json=rawHistory,proto3" json:"raw_history,omitempty"` + // When true, this empty response was caused by the server completing the poll + // because the worker has been shut down via the ShutdownWorker API. + CompletedByWorkerShutdown bool `protobuf:"varint,23,opt,name=completed_by_worker_shutdown,json=completedByWorkerShutdown,proto3" json:"completed_by_worker_shutdown,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *PollWorkflowTaskQueueResponseWithRawHistory) Reset() { @@ -556,6 +569,13 @@ func (x *PollWorkflowTaskQueueResponseWithRawHistory) GetRawHistory() [][]byte { return nil } +func (x *PollWorkflowTaskQueueResponseWithRawHistory) GetCompletedByWorkerShutdown() bool { + if x != nil { + return x.CompletedByWorkerShutdown + } + return false +} + type PollActivityTaskQueueRequest struct { state protoimpl.MessageState `protogen:"open.v1"` NamespaceId string `protobuf:"bytes,1,opt,name=namespace_id,json=namespaceId,proto3" json:"namespace_id,omitempty"` @@ -662,8 +682,11 @@ type PollActivityTaskQueueResponse struct { RetryPolicy *v11.RetryPolicy `protobuf:"bytes,19,opt,name=retry_policy,json=retryPolicy,proto3" json:"retry_policy,omitempty"` // ID of the activity run (applicable for standalone activities only) ActivityRunId string `protobuf:"bytes,20,opt,name=activity_run_id,json=activityRunId,proto3" json:"activity_run_id,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // When true, this empty response was caused by the server completing the poll + // because the worker has been shut down via the ShutdownWorker API. + CompletedByWorkerShutdown bool `protobuf:"varint,21,opt,name=completed_by_worker_shutdown,json=completedByWorkerShutdown,proto3" json:"completed_by_worker_shutdown,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *PollActivityTaskQueueResponse) Reset() { @@ -836,6 +859,13 @@ func (x *PollActivityTaskQueueResponse) GetActivityRunId() string { return "" } +func (x *PollActivityTaskQueueResponse) GetCompletedByWorkerShutdown() bool { + if x != nil { + return x.CompletedByWorkerShutdown + } + return false +} + type AddWorkflowTaskRequest struct { state protoimpl.MessageState `protogen:"open.v1"` NamespaceId string `protobuf:"bytes,1,opt,name=namespace_id,json=namespaceId,proto3" json:"namespace_id,omitempty"` @@ -4087,9 +4117,12 @@ func (x *PollNexusTaskQueueRequest) GetConditions() *PollConditions { type PollNexusTaskQueueResponse struct { state protoimpl.MessageState `protogen:"open.v1"` // Response that should be delivered to the worker containing a request from DispatchNexusTaskRequest. - Response *v1.PollNexusTaskQueueResponse `protobuf:"bytes,1,opt,name=response,proto3" json:"response,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + Response *v1.PollNexusTaskQueueResponse `protobuf:"bytes,1,opt,name=response,proto3" json:"response,omitempty"` + // When true, this empty response was caused by the server completing the poll + // because the worker has been shut down via the ShutdownWorker API. + CompletedByWorkerShutdown bool `protobuf:"varint,2,opt,name=completed_by_worker_shutdown,json=completedByWorkerShutdown,proto3" json:"completed_by_worker_shutdown,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *PollNexusTaskQueueResponse) Reset() { @@ -4129,6 +4162,13 @@ func (x *PollNexusTaskQueueResponse) GetResponse() *v1.PollNexusTaskQueueRespons return nil } +func (x *PollNexusTaskQueueResponse) GetCompletedByWorkerShutdown() bool { + if x != nil { + return x.CompletedByWorkerShutdown + } + return false +} + type RespondNexusTaskCompletedRequest struct { state protoimpl.MessageState `protogen:"open.v1"` NamespaceId string `protobuf:"bytes,1,opt,name=namespace_id,json=namespaceId,proto3" json:"namespace_id,omitempty"` @@ -5749,7 +5789,7 @@ const file_temporal_server_api_matchingservice_v1_request_response_proto_rawDesc "\x10forwarded_source\x18\x04 \x01(\tR\x0fforwardedSource\x12V\n" + "\n" + "conditions\x18\x05 \x01(\v26.temporal.server.api.matchingservice.v1.PollConditionsR\n" + - "conditions\"\xd1\v\n" + + "conditions\"\x92\f\n" + "\x1dPollWorkflowTaskQueueResponse\x12\x1d\n" + "\n" + "task_token\x18\x01 \x01(\fR\ttaskToken\x12X\n" + @@ -5774,10 +5814,11 @@ const file_temporal_server_api_matchingservice_v1_request_response_proto_rawDesc "\x0fnext_page_token\x18\x14 \x01(\fR\rnextPageToken\x12h\n" + "\x17poller_scaling_decision\x18\x15 \x01(\v20.temporal.api.taskqueue.v1.PollerScalingDecisionR\x15pollerScalingDecision\x12A\n" + "\vraw_history\x18\x16 \x01(\v2 .temporal.api.history.v1.HistoryR\n" + - "rawHistory\x1a`\n" + + "rawHistory\x12?\n" + + "\x1ccompleted_by_worker_shutdown\x18\x17 \x01(\bR\x19completedByWorkerShutdown\x1a`\n" + "\fQueriesEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12:\n" + - "\x05value\x18\x02 \x01(\v2$.temporal.api.query.v1.WorkflowQueryR\x05value:\x028\x01J\x04\b\r\x10\x0e\"\xcb\v\n" + + "\x05value\x18\x02 \x01(\v2$.temporal.api.query.v1.WorkflowQueryR\x05value:\x028\x01J\x04\b\r\x10\x0e\"\x8c\f\n" + "+PollWorkflowTaskQueueResponseWithRawHistory\x12\x1d\n" + "\n" + "task_token\x18\x01 \x01(\fR\ttaskToken\x12X\n" + @@ -5802,7 +5843,8 @@ const file_temporal_server_api_matchingservice_v1_request_response_proto_rawDesc "\x0fnext_page_token\x18\x14 \x01(\fR\rnextPageToken\x12h\n" + "\x17poller_scaling_decision\x18\x15 \x01(\v20.temporal.api.taskqueue.v1.PollerScalingDecisionR\x15pollerScalingDecision\x12\x1f\n" + "\vraw_history\x18\x16 \x03(\fR\n" + - "rawHistory\x1a`\n" + + "rawHistory\x12?\n" + + "\x1ccompleted_by_worker_shutdown\x18\x17 \x01(\bR\x19completedByWorkerShutdown\x1a`\n" + "\fQueriesEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12:\n" + "\x05value\x18\x02 \x01(\v2$.temporal.api.query.v1.WorkflowQueryR\x05value:\x028\x01J\x04\b\r\x10\x0e\"\xc3\x02\n" + @@ -5813,8 +5855,7 @@ const file_temporal_server_api_matchingservice_v1_request_response_proto_rawDesc "\x10forwarded_source\x18\x04 \x01(\tR\x0fforwardedSource\x12V\n" + "\n" + "conditions\x18\x05 \x01(\v26.temporal.server.api.matchingservice.v1.PollConditionsR\n" + - "conditions\"\xc0\n" + - "\n" + + "conditions\"\x81\v\n" + "\x1dPollActivityTaskQueueResponse\x12\x1d\n" + "\n" + "task_token\x18\x01 \x01(\fR\ttaskToken\x12X\n" + @@ -5838,7 +5879,8 @@ const file_temporal_server_api_matchingservice_v1_request_response_proto_rawDesc "\x17poller_scaling_decision\x18\x11 \x01(\v20.temporal.api.taskqueue.v1.PollerScalingDecisionR\x15pollerScalingDecision\x12<\n" + "\bpriority\x18\x12 \x01(\v2 .temporal.api.common.v1.PriorityR\bpriority\x12F\n" + "\fretry_policy\x18\x13 \x01(\v2#.temporal.api.common.v1.RetryPolicyR\vretryPolicy\x12&\n" + - "\x0factivity_run_id\x18\x14 \x01(\tR\ractivityRunId\"\x9d\x05\n" + + "\x0factivity_run_id\x18\x14 \x01(\tR\ractivityRunId\x12?\n" + + "\x1ccompleted_by_worker_shutdown\x18\x15 \x01(\bR\x19completedByWorkerShutdown\"\x9d\x05\n" + "\x16AddWorkflowTaskRequest\x12!\n" + "\fnamespace_id\x18\x01 \x01(\tR\vnamespaceId\x12G\n" + "\texecution\x18\x02 \x01(\v2).temporal.api.common.v1.WorkflowExecutionR\texecution\x12C\n" + @@ -6087,9 +6129,10 @@ const file_temporal_server_api_matchingservice_v1_request_response_proto_rawDesc "\x10forwarded_source\x18\x04 \x01(\tR\x0fforwardedSource\x12V\n" + "\n" + "conditions\x18\x05 \x01(\v26.temporal.server.api.matchingservice.v1.PollConditionsR\n" + - "conditions\"u\n" + + "conditions\"\xb6\x01\n" + "\x1aPollNexusTaskQueueResponse\x12W\n" + - "\bresponse\x18\x01 \x01(\v2;.temporal.api.workflowservice.v1.PollNexusTaskQueueResponseR\bresponse\"\x80\x02\n" + + "\bresponse\x18\x01 \x01(\v2;.temporal.api.workflowservice.v1.PollNexusTaskQueueResponseR\bresponse\x12?\n" + + "\x1ccompleted_by_worker_shutdown\x18\x02 \x01(\bR\x19completedByWorkerShutdown\"\x80\x02\n" + " RespondNexusTaskCompletedRequest\x12!\n" + "\fnamespace_id\x18\x01 \x01(\tR\vnamespaceId\x12C\n" + "\n" + diff --git a/go.mod b/go.mod index a8031c011ee..e5f84a93d23 100644 --- a/go.mod +++ b/go.mod @@ -216,3 +216,5 @@ require ( modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.11.0 // indirect ) + +replace go.temporal.io/api => /Users/krajah/api-go diff --git a/go.sum b/go.sum index fa7633c649a..65e7b38fd77 100644 --- a/go.sum +++ b/go.sum @@ -440,8 +440,6 @@ go.opentelemetry.io/otel/trace v1.40.0 h1:WA4etStDttCSYuhwvEa8OP8I5EWu24lkOzp+ZY go.opentelemetry.io/otel/trace v1.40.0/go.mod h1:zeAhriXecNGP/s2SEG3+Y8X9ujcJOTqQ5RgdEJcawiA= go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOVAtj4= go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE= -go.temporal.io/api v1.62.8 h1:g8RAZmdebYODoNa2GLA4M4TsXNe1096WV3n26C4+fdw= -go.temporal.io/api v1.62.8/go.mod h1:iaxoP/9OXMJcQkETTECfwYq4cw/bj4nwov8b3ZLVnXM= go.temporal.io/auto-scaled-workers v0.0.0-20260407181057-edd947d743d2 h1:1hKeH3GyR6YD6LKMHGCZ76t6h1Sgha0hXVQBxWi3dlQ= go.temporal.io/auto-scaled-workers v0.0.0-20260407181057-edd947d743d2/go.mod h1:T8dnzVPeO+gaUTj9eDgm/lT2lZH4+JXNvrGaQGyVi50= go.temporal.io/sdk v1.41.1 h1:yOpvsHyDD1lNuwlGBv/SUodCPhjv9nDeC9lLHW/fJUA= diff --git a/proto/internal/temporal/server/api/matchingservice/v1/request_response.proto b/proto/internal/temporal/server/api/matchingservice/v1/request_response.proto index 28bf1593a73..5b38770cb40 100644 --- a/proto/internal/temporal/server/api/matchingservice/v1/request_response.proto +++ b/proto/internal/temporal/server/api/matchingservice/v1/request_response.proto @@ -62,6 +62,9 @@ message PollWorkflowTaskQueueResponse { // Raw history bytes sent from matching service when history.sendRawHistoryBetweenInternalServices is enabled. // Matching client will deserialize this to History when it receives the response. temporal.api.history.v1.History raw_history = 22; + // When true, this empty response was caused by the server completing the poll + // because the worker has been shut down via the ShutdownWorker API. + bool completed_by_worker_shutdown = 23; } // PollWorkflowTaskQueueResponseWithRawHistory is wire-compatible with PollWorkflowTaskQueueResponse. @@ -112,6 +115,9 @@ message PollWorkflowTaskQueueResponseWithRawHistory { // When matching client deserializes this to PollWorkflowTaskQueueResponse, this field // will be automatically deserialized to the raw_history field as History. repeated bytes raw_history = 22; + // When true, this empty response was caused by the server completing the poll + // because the worker has been shut down via the ShutdownWorker API. + bool completed_by_worker_shutdown = 23; } message PollActivityTaskQueueRequest { @@ -149,6 +155,9 @@ message PollActivityTaskQueueResponse { temporal.api.common.v1.RetryPolicy retry_policy = 19; // ID of the activity run (applicable for standalone activities only) string activity_run_id = 20; + // When true, this empty response was caused by the server completing the poll + // because the worker has been shut down via the ShutdownWorker API. + bool completed_by_worker_shutdown = 21; } message AddWorkflowTaskRequest { @@ -579,6 +588,9 @@ message PollNexusTaskQueueRequest { message PollNexusTaskQueueResponse { // Response that should be delivered to the worker containing a request from DispatchNexusTaskRequest. temporal.api.workflowservice.v1.PollNexusTaskQueueResponse response = 1; + // When true, this empty response was caused by the server completing the poll + // because the worker has been shut down via the ShutdownWorker API. + bool completed_by_worker_shutdown = 2; } message RespondNexusTaskCompletedRequest { diff --git a/service/frontend/workflow_handler.go b/service/frontend/workflow_handler.go index 0f7cac9dcfa..1d0475d7305 100644 --- a/service/frontend/workflow_handler.go +++ b/service/frontend/workflow_handler.go @@ -1134,7 +1134,8 @@ func (wh *WorkflowHandler) PollWorkflowTaskQueue(ctx context.Context, request *w StartedTime: matchingResp.StartedTime, Queries: matchingResp.Queries, Messages: matchingResp.Messages, - PollerScalingDecision: matchingResp.PollerScalingDecision, + PollerScalingDecision: matchingResp.PollerScalingDecision, + CompletedByWorkerShutdown: matchingResp.CompletedByWorkerShutdown, }, nil } @@ -1372,6 +1373,7 @@ func (wh *WorkflowHandler) PollActivityTaskQueue(ctx context.Context, request *w PollerScalingDecision: matchingResponse.PollerScalingDecision, Priority: matchingResponse.Priority, RetryPolicy: matchingResponse.RetryPolicy, + CompletedByWorkerShutdown: matchingResponse.CompletedByWorkerShutdown, }, nil } @@ -6025,7 +6027,12 @@ func (wh *WorkflowHandler) PollNexusTaskQueue(ctx context.Context, request *work return nil, err } - return matchingResponse.GetResponse(), nil + resp := matchingResponse.GetResponse() + if resp == nil { + resp = &workflowservice.PollNexusTaskQueueResponse{} + } + resp.CompletedByWorkerShutdown = matchingResponse.CompletedByWorkerShutdown + return resp, nil } func (wh *WorkflowHandler) RespondNexusTaskCompleted(ctx context.Context, request *workflowservice.RespondNexusTaskCompletedRequest) (_ *workflowservice.RespondNexusTaskCompletedResponse, retError error) { diff --git a/service/frontend/workflow_handler_test.go b/service/frontend/workflow_handler_test.go index 8726d6bf071..56c7b6b095e 100644 --- a/service/frontend/workflow_handler_test.go +++ b/service/frontend/workflow_handler_test.go @@ -289,6 +289,74 @@ func (s *WorkflowHandlerSuite) TestPollForTask_Failed_ContextTimeoutTooShort() { s.Equal(common.ErrContextTimeoutTooShort, err) } +func (s *WorkflowHandlerSuite) TestPollWorkflowTaskQueue_CompletedByWorkerShutdown() { + config := s.newConfig() + wh := s.getWorkflowHandler(config) + + namespaceEntry := namespace.NewLocalNamespaceForTest( + &persistencespb.NamespaceInfo{Id: testNamespaceID, Name: "test-namespace"}, + nil, + "", + ) + s.mockNamespaceCache.EXPECT().GetNamespace(namespace.Name("test-namespace")).Return(namespaceEntry, nil) + s.mockMatchingClient.EXPECT().PollWorkflowTaskQueue(gomock.Any(), gomock.Any()).Return( + &matchingservice.PollWorkflowTaskQueueResponse{ + CompletedByWorkerShutdown: true, + }, nil, + ) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + resp, err := wh.PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: "test-namespace", + TaskQueue: &taskqueuepb.TaskQueue{Name: "test-tq"}, + }) + s.NoError(err) + s.True(resp.GetCompletedByWorkerShutdown()) +} + +func (s *WorkflowHandlerSuite) TestPollActivityTaskQueue_CompletedByWorkerShutdown() { + config := s.newConfig() + wh := s.getWorkflowHandler(config) + + s.mockNamespaceCache.EXPECT().GetNamespaceID(namespace.Name("test-namespace")).Return(namespace.ID(testNamespaceID), nil) + s.mockMatchingClient.EXPECT().PollActivityTaskQueue(gomock.Any(), gomock.Any()).Return( + &matchingservice.PollActivityTaskQueueResponse{ + CompletedByWorkerShutdown: true, + }, nil, + ) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + resp, err := wh.PollActivityTaskQueue(ctx, &workflowservice.PollActivityTaskQueueRequest{ + Namespace: "test-namespace", + TaskQueue: &taskqueuepb.TaskQueue{Name: "test-tq"}, + }) + s.NoError(err) + s.True(resp.GetCompletedByWorkerShutdown()) +} + +func (s *WorkflowHandlerSuite) TestPollNexusTaskQueue_CompletedByWorkerShutdown() { + config := s.newConfig() + wh := s.getWorkflowHandler(config) + + s.mockNamespaceCache.EXPECT().GetNamespaceID(namespace.Name("test-namespace")).Return(namespace.ID(testNamespaceID), nil) + s.mockMatchingClient.EXPECT().PollNexusTaskQueue(gomock.Any(), gomock.Any()).Return( + &matchingservice.PollNexusTaskQueueResponse{ + CompletedByWorkerShutdown: true, + }, nil, + ) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + resp, err := wh.PollNexusTaskQueue(ctx, &workflowservice.PollNexusTaskQueueRequest{ + Namespace: "test-namespace", + TaskQueue: &taskqueuepb.TaskQueue{Name: "test-tq"}, + }) + s.NoError(err) + s.True(resp.GetCompletedByWorkerShutdown()) +} + func (s *WorkflowHandlerSuite) TestStartWorkflowExecution_Failed_StartRequestNotSet() { config := s.newConfig() config.RPS = dc.GetIntPropertyFn(10) diff --git a/service/matching/matching_engine.go b/service/matching/matching_engine.go index 6eab8ff798c..3caa9681f25 100644 --- a/service/matching/matching_engine.go +++ b/service/matching/matching_engine.go @@ -107,6 +107,12 @@ type ( workerInstanceKey string } + pollResult struct { + task *internalTask + versionSetUsed bool + completedByShutdown bool + } + userDataUpdate struct { taskQueue string update persistence.SingleTaskQueueUserDataUpdate @@ -687,13 +693,19 @@ pollLoop: conditions: req.Conditions, workerInstanceKey: request.WorkerInstanceKey, } - task, versionSetUsed, err := e.pollTask(pollerCtx, partition, pollMetadata) + result, err := e.pollTask(pollerCtx, partition, pollMetadata) if err != nil { if errors.Is(err, errNoTasks) { + if result.completedByShutdown { + return &matchingservice.PollWorkflowTaskQueueResponseWithRawHistory{ + CompletedByWorkerShutdown: true, + }, nil + } return emptyPollWorkflowTaskQueueResponse, nil } return nil, err } + task := result.task if task.isStarted() { // tasks received from remote are already started. So, simply forward the response // no need to emit task dispatch latency metric because the parent partition already did it. @@ -750,7 +762,7 @@ pollLoop: } requestClone := request - if versionSetUsed { + if result.versionSetUsed { // We remove build ID from workerVersionCapabilities so History can differentiate between // old and new versioning in Record*TaskStart. // TODO: remove this block after old versioning cleanup. [cleanup-old-wv] @@ -959,20 +971,26 @@ pollLoop: conditions: req.Conditions, workerInstanceKey: request.WorkerInstanceKey, } - task, versionSetUsed, err := e.pollTask(pollerCtx, partition, pollMetadata) + result, err := e.pollTask(pollerCtx, partition, pollMetadata) if err != nil { if errors.Is(err, errNoTasks) { + if result.completedByShutdown { + return &matchingservice.PollActivityTaskQueueResponse{ + CompletedByWorkerShutdown: true, + }, nil + } return emptyPollActivityTaskQueueResponse, nil } return nil, err } + task := result.task if task.isStarted() { // tasks received from remote are already started. So, simply forward the response return task.pollActivityTaskQueueResponse(), nil } requestClone := request - if versionSetUsed { + if result.versionSetUsed { // We remove build ID from workerVersionCapabilities so History can differentiate between // old and new versioning in Record*TaskStart. // TODO: remove this block after old versioning cleanup. [cleanup-old-wv] @@ -2550,14 +2568,20 @@ pollLoop: conditions: req.Conditions, workerInstanceKey: request.WorkerInstanceKey, } - task, _, err := e.pollTask(pollerCtx, partition, pollMetadata) + result, err := e.pollTask(pollerCtx, partition, pollMetadata) if err != nil { if errors.Is(err, errNoTasks) { + if result.completedByShutdown { + return &matchingservice.PollNexusTaskQueueResponse{ + CompletedByWorkerShutdown: true, + }, nil + } return &matchingservice.PollNexusTaskQueueResponse{}, nil } return nil, err } + task := result.task if task.isStarted() { // tasks received from remote are already started. So, simply forward the response return task.pollNexusTaskQueueResponse(), nil @@ -2807,10 +2831,10 @@ func (e *matchingEngineImpl) pollTask( ctx context.Context, partition tqid.Partition, pollMetadata *pollMetadata, -) (*internalTask, bool, error) { +) (pollResult, error) { pm, _, err := e.getTaskQueuePartitionManager(ctx, partition, true, loadCausePoll) if err != nil { - return nil, false, err + return pollResult{}, err } pollMetadata.localPollStartTime = e.timeSource.Now() @@ -2827,7 +2851,7 @@ func (e *matchingEngineImpl) pollTask( tag.WorkflowTaskQueueType(partition.TaskType()), tag.NewStringTag("worker-instance-key", workerInstanceKey), ) - return nil, false, errNoTasks + return pollResult{completedByShutdown: true}, errNoTasks } ctx, cancel := contextutil.WithDeadlineBuffer(ctx, pm.LongPollExpirationInterval(), returnEmptyTaskTimeBudget) @@ -2840,6 +2864,9 @@ func (e *matchingEngineImpl) pollTask( // Use UUID (not pollerID) because pollerID is reused when forwarded. pollerTrackerKey := uuid.NewString() if workerInstanceKey != "" { + if e.shutdownWorkers.Get(workerInstanceKey) != nil { + return pollResult{completedByShutdown: true}, errNoTasks + } e.workerInstancePollers.Add(workerInstanceKey, pollerTrackerKey, cancel) } @@ -2850,7 +2877,8 @@ func (e *matchingEngineImpl) pollTask( } }() } - return pm.PollTask(ctx, pollMetadata) + task, versionSetUsed, err := pm.PollTask(ctx, pollMetadata) + return pollResult{task: task, versionSetUsed: versionSetUsed}, err } // emitTaskDispatchLatency emits latency metrics for a task dispatched to a worker. diff --git a/service/matching/matching_engine_test.go b/service/matching/matching_engine_test.go index 612797e1c0c..6759c5f10fc 100644 --- a/service/matching/matching_engine_test.go +++ b/service/matching/matching_engine_test.go @@ -2154,14 +2154,16 @@ func (s *matchingEngineSuite) TestAddTaskAfterStartFailure() { s.NoError(err) s.EqualValues(1, s.taskManager.getTaskCount(dbq)) - task1, _, err := s.matchingEngine.pollTask(context.Background(), dbq.partition, &pollMetadata{}) + result1, err := s.matchingEngine.pollTask(context.Background(), dbq.partition, &pollMetadata{}) s.NoError(err) + task1 := result1.task task1.finish(serviceerror.NewInternal("test error"), true) s.EqualValues(1, s.taskManager.getTaskCount(dbq)) - task2, _, err := s.matchingEngine.pollTask(context.Background(), dbq.partition, &pollMetadata{}) + result2, err := s.matchingEngine.pollTask(context.Background(), dbq.partition, &pollMetadata{}) s.NoError(err) + task2 := result2.task protoassert.ProtoEqual(s.T(), task1.event.Data, task2.event.Data) s.NotEqual(task1.event.GetTaskId(), task2.event.GetTaskId(), "IDs should not match") @@ -2800,14 +2802,14 @@ func (s *matchingEngineSuite) TestUnknownBuildId_Match() { go func() { prtn := newRootPartition(namespaceID, tq, enumspb.TASK_QUEUE_TYPE_WORKFLOW) - task, _, err := s.matchingEngine.pollTask(ctx, prtn, &pollMetadata{ + result, err := s.matchingEngine.pollTask(ctx, prtn, &pollMetadata{ workerVersionCapabilities: &commonpb.WorkerVersionCapabilities{ BuildId: "unknown", UseVersioning: true, }, }) s.NoError(err) - s.Equal("wf", task.event.Data.WorkflowId) + s.Equal("wf", result.task.event.Data.WorkflowId) s.Equal(int64(123), task.event.Data.ScheduledEventId) task.finish(nil, true) wg.Done() @@ -2909,16 +2911,16 @@ func (s *matchingEngineSuite) TestDemotedMatch() { s.NoError(err) // now poll for the task - task, _, err := s.matchingEngine.pollTask(ctx, prtn, &pollMetadata{ + result, err := s.matchingEngine.pollTask(ctx, prtn, &pollMetadata{ workerVersionCapabilities: &commonpb.WorkerVersionCapabilities{ BuildId: build1, UseVersioning: true, }, }) s.Require().NoError(err) - s.Equal("wf", task.event.Data.WorkflowId) - s.Equal(int64(123), task.event.Data.ScheduledEventId) - task.finish(nil, true) + s.Equal("wf", result.task.event.Data.WorkflowId) + s.Equal(int64(123), result.task.event.Data.ScheduledEventId) + result.task.finish(nil, true) } type mockRoutingMatchingClient struct { @@ -5988,3 +5990,4 @@ func TestAutoEnableV2ConfigChange_NoUnloadWhenEffectiveConfigUnchanged(t *testin } }, 100*time.Millisecond, 10*time.Millisecond, "physical queue should NOT be stopped when effective config does not change") } + diff --git a/tests/task_queue_test.go b/tests/task_queue_test.go index 628d2f4f83b..335a33d93f7 100644 --- a/tests/task_queue_test.go +++ b/tests/task_queue_test.go @@ -1500,10 +1500,7 @@ func (s *TaskQueueSuite) TestShutdownWorkerCancelsOutstandingPolls() { s.NoError(err) s.NotNil(rePollResp) s.Empty(rePollResp.GetTaskToken(), "re-poll from shutdown worker should return empty response") - // TODO: Replace timing assertion with an explicit poll response field indicating - // shutdown rejection, so we don't rely on timing to distinguish cache rejection - // from natural poll timeout. Requires adding a field to PollWorkflowTaskQueueResponse - // and PollActivityTaskQueueResponse in the public API proto. + s.True(rePollResp.GetCompletedByWorkerShutdown(), "workflow re-poll should have CompletedByWorkerShutdown set") s.Less(time.Since(wfStart), 2*time.Minute, "workflow re-poll should be rejected quickly, not wait for timeout") // Activity poll should also be rejected immediately. @@ -1519,5 +1516,6 @@ func (s *TaskQueueSuite) TestShutdownWorkerCancelsOutstandingPolls() { s.NoError(err) s.NotNil(actResp) s.Empty(actResp.GetTaskToken(), "activity re-poll from shutdown worker should return empty response") + s.True(actResp.GetCompletedByWorkerShutdown(), "activity re-poll should have CompletedByWorkerShutdown set") s.Less(time.Since(actStart), 2*time.Minute, "activity re-poll should be rejected quickly, not wait for timeout") }