diff --git a/go.mod b/go.mod index 2b47cf567..dcdd8614a 100644 --- a/go.mod +++ b/go.mod @@ -33,3 +33,5 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20240827150818-7e3bb234dfed // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace go.temporal.io/api => ../temporal-api-go diff --git a/go.sum b/go.sum index fd9193a8e..eccdaab65 100644 --- a/go.sum +++ b/go.sum @@ -35,8 +35,6 @@ github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -go.temporal.io/api v1.62.5 h1:9R/9CeyM7xqHSlsNt+QIvapQLcRxCZ38bnXQx4mCN6I= -go.temporal.io/api v1.62.5/go.mod h1:iaxoP/9OXMJcQkETTECfwYq4cw/bj4nwov8b3ZLVnXM= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= diff --git a/internal/internal_nexus_task_handler.go b/internal/internal_nexus_task_handler.go index 0d1ff8ae6..3fbfd4b42 100644 --- a/internal/internal_nexus_task_handler.go +++ b/internal/internal_nexus_task_handler.go @@ -94,9 +94,10 @@ func newNexusTaskHandler( func (h *nexusTaskHandler) Execute(task *workflowservice.PollNexusTaskQueueResponse) (*workflowservice.RespondNexusTaskCompletedRequest, *workflowservice.RespondNexusTaskFailedRequest, error) { failureReasonSupport := getEffectiveTemporalFailureResponses(task.GetRequest().GetCapabilities().GetTemporalFailureResponses()) + pollerGroupId := task.GetPollerGroupId() nctx, handlerErr := h.newNexusOperationContext(task) if handlerErr != nil { - failureRequest, err := h.fillInFailure(task.TaskToken, handlerErr, failureReasonSupport) + failureRequest, err := h.fillInFailure(task.TaskToken, handlerErr, failureReasonSupport, pollerGroupId) if err != nil { return nil, nil, err } @@ -107,13 +108,13 @@ func (h *nexusTaskHandler) Execute(task *workflowservice.PollNexusTaskQueueRespo return nil, nil, err } if handlerErr != nil { - failureRequest, err := h.fillInFailure(task.TaskToken, handlerErr, failureReasonSupport) + failureRequest, err := h.fillInFailure(task.TaskToken, handlerErr, failureReasonSupport, pollerGroupId) if err != nil { return nil, nil, err } return nil, failureRequest, nil } - completedRequest, err := h.fillInCompletion(task.TaskToken, res, failureReasonSupport) + completedRequest, err := h.fillInCompletion(task.TaskToken, res, failureReasonSupport, pollerGroupId) if err != nil { return nil, nil, err } @@ -122,18 +123,19 @@ func (h *nexusTaskHandler) Execute(task *workflowservice.PollNexusTaskQueueRespo func (h *nexusTaskHandler) ExecuteContext(nctx *NexusOperationContext, task *workflowservice.PollNexusTaskQueueResponse) (*workflowservice.RespondNexusTaskCompletedRequest, *workflowservice.RespondNexusTaskFailedRequest, error) { failureReasonSupport := getEffectiveTemporalFailureResponses(task.GetRequest().GetCapabilities().GetTemporalFailureResponses()) + pollerGroupId := task.GetPollerGroupId() res, handlerErr, err := h.execute(nctx, task) if err != nil { return nil, nil, err } if handlerErr != nil { - failureRequest, err := h.fillInFailure(task.TaskToken, handlerErr, failureReasonSupport) + failureRequest, err := h.fillInFailure(task.TaskToken, handlerErr, failureReasonSupport, pollerGroupId) if err != nil { return nil, nil, err } return nil, failureRequest, nil } - completedRequest, err := h.fillInCompletion(task.TaskToken, res, failureReasonSupport) + completedRequest, err := h.fillInCompletion(task.TaskToken, res, failureReasonSupport, pollerGroupId) if err != nil { return nil, nil, err } @@ -458,7 +460,7 @@ func (h *nexusTaskHandler) newNexusOperationContext(response *workflowservice.Po }, nil } -func (h *nexusTaskHandler) fillInCompletion(taskToken []byte, res *nexuspb.Response, failureReasonSupport bool) (*workflowservice.RespondNexusTaskCompletedRequest, error) { +func (h *nexusTaskHandler) fillInCompletion(taskToken []byte, res *nexuspb.Response, failureReasonSupport bool, pollerGroupId string) (*workflowservice.RespondNexusTaskCompletedRequest, error) { // Handle conversion of Failure to OperationError for backwards compatibility with old servers. if res.GetStartOperation().GetFailure() != nil && !failureReasonSupport { // Convert to operation error for backwards compatibility. @@ -487,18 +489,20 @@ func (h *nexusTaskHandler) fillInCompletion(taskToken []byte, res *nexuspb.Respo } } return &workflowservice.RespondNexusTaskCompletedRequest{ - Identity: h.identity, - Namespace: h.namespace, - TaskToken: taskToken, - Response: res, + Identity: h.identity, + Namespace: h.namespace, + TaskToken: taskToken, + Response: res, + PollerGroupId: pollerGroupId, }, nil } -func (h *nexusTaskHandler) fillInFailure(taskToken []byte, handlerError *nexus.HandlerError, failureReasonSupport bool) (*workflowservice.RespondNexusTaskFailedRequest, error) { +func (h *nexusTaskHandler) fillInFailure(taskToken []byte, handlerError *nexus.HandlerError, failureReasonSupport bool, pollerGroupId string) (*workflowservice.RespondNexusTaskFailedRequest, error) { r := &workflowservice.RespondNexusTaskFailedRequest{ - Identity: h.identity, - Namespace: h.namespace, - TaskToken: taskToken, + Identity: h.identity, + Namespace: h.namespace, + TaskToken: taskToken, + PollerGroupId: pollerGroupId, } if failureReasonSupport { r.Failure = h.failureConverter.ErrorToFailure(handlerError) diff --git a/internal/internal_nexus_task_poller.go b/internal/internal_nexus_task_poller.go index 0f1fc6922..6b6be850d 100644 --- a/internal/internal_nexus_task_poller.go +++ b/internal/internal_nexus_task_poller.go @@ -15,13 +15,14 @@ import ( type nexusTaskPoller struct { basePoller - namespace string - taskQueueName string - identity string - service workflowservice.WorkflowServiceClient - taskHandler *nexusTaskHandler - logger log.Logger - numPollerMetric *numPollerMetric + namespace string + taskQueueName string + identity string + service workflowservice.WorkflowServiceClient + taskHandler *nexusTaskHandler + logger log.Logger + numPollerMetric *numPollerMetric + pollerGroupTracker *pollerGroupTracker } type nexusTask struct { @@ -53,7 +54,8 @@ func newNexusTaskPoller( taskQueueName: params.TaskQueue, identity: params.Identity, logger: params.Logger, - numPollerMetric: newNumPollerMetric(params.MetricsHandler, metrics.PollerTypeNexusTask), + numPollerMetric: newNumPollerMetric(params.MetricsHandler, metrics.PollerTypeNexusTask), + pollerGroupTracker: newPollerGroupTracker(), } } @@ -69,6 +71,10 @@ func (ntp *nexusTaskPoller) poll(ctx context.Context) (taskForWorker, error) { traceLog(func() { ntp.logger.Debug("nexusTaskPoller::Poll") }) + + groupId := ntp.pollerGroupTracker.getNextGroupId() + defer ntp.pollerGroupTracker.release(groupId) + request := &workflowservice.PollNexusTaskQueueRequest{ Namespace: ntp.namespace, TaskQueue: &taskqueuepb.TaskQueue{Name: ntp.taskQueueName, Kind: enumspb.TASK_QUEUE_KIND_NORMAL}, @@ -83,12 +89,14 @@ func (ntp *nexusTaskPoller) poll(ctx context.Context) (taskForWorker, error) { ntp.workerDeploymentVersion, ), WorkerInstanceKey: ntp.workerInstanceKey, + PollerGroupId: groupId, } response, err := ntp.pollNexusTaskQueue(ctx, request) if err != nil { return nil, err } + ntp.pollerGroupTracker.updateGroups(response.GetPollerGroupInfos()) if response == nil || len(response.TaskToken) == 0 { // No operation info is available on empty poll. Emit using base scope. ntp.metricsHandler.Counter(metrics.NexusPollNoTaskCounter).Inc(1) @@ -131,7 +139,7 @@ func (ntp *nexusTaskPoller) ProcessTask(task interface{}) error { nctx, handlerErr := ntp.taskHandler.newNexusOperationContext(response) if handlerErr != nil { // context wasn't propagated to us, use a background context. - failedRequest, err := ntp.taskHandler.fillInFailure(response.TaskToken, handlerErr, getEffectiveTemporalFailureResponses(response.GetRequest().GetCapabilities().GetTemporalFailureResponses())) + failedRequest, err := ntp.taskHandler.fillInFailure(response.TaskToken, handlerErr, getEffectiveTemporalFailureResponses(response.GetRequest().GetCapabilities().GetTemporalFailureResponses()), response.GetPollerGroupId()) if err != nil { return err } diff --git a/internal/internal_task_handlers.go b/internal/internal_task_handlers.go index 5979ad95a..a9558c8cf 100644 --- a/internal/internal_task_handlers.go +++ b/internal/internal_task_handlers.go @@ -1834,8 +1834,9 @@ func (wth *workflowTaskHandlerImpl) completeWorkflow( // for query task if task.Query != nil { queryCompletedRequest := &workflowservice.RespondQueryTaskCompletedRequest{ - TaskToken: task.TaskToken, - Namespace: wth.namespace, + TaskToken: task.TaskToken, + Namespace: wth.namespace, + PollerGroupId: task.GetPollerGroupId(), } var panicErr *PanicError if errors.As(workflowContext.err, &panicErr) { diff --git a/internal/internal_task_pollers.go b/internal/internal_task_pollers.go index bc148edf9..7388c1406 100644 --- a/internal/internal_task_pollers.go +++ b/internal/internal_task_pollers.go @@ -122,6 +122,8 @@ type ( numNormalPollerMetric *numPollerMetric numStickyPollerMetric *numPollerMetric + pollerGroupTracker *pollerGroupTracker + inboundPayloadVisitor PayloadVisitor } @@ -155,7 +157,7 @@ type ( outboundPayloadVisitor PayloadVisitor } - // activityTaskPoller implements polling/processing a workflow task + // activityTaskPoller implements polling/processing an activity task activityTaskPoller struct { basePoller namespace string @@ -166,6 +168,7 @@ type ( logger log.Logger activitiesPerSecond float64 numPollerMetric *numPollerMetric + pollerGroupTracker *pollerGroupTracker inboundPayloadVisitor PayloadVisitor outboundPayloadVisitor PayloadVisitor } @@ -424,6 +427,7 @@ func (wtp *workflowTaskProcessor) createPoller(mode workflowTaskPollerMode) task eagerActivityExecutor: wtp.eagerActivityExecutor, numNormalPollerMetric: wtp.numNormalPollerMetric, numStickyPollerMetric: wtp.numStickyPollerMetric, + pollerGroupTracker: newPollerGroupTracker(), inboundPayloadVisitor: wtp.inboundPayloadVisitor, } } @@ -741,6 +745,7 @@ func (wtp *workflowTaskProcessor) reportGrpcMessageTooLarge( Namespace: wtp.namespace, Failure: wtp.failureConverter.ErrorToFailure(sendErr), Cause: enumspb.WORKFLOW_TASK_FAILED_CAUSE_GRPC_MESSAGE_TOO_LARGE, + PollerGroupId: task.GetPollerGroupId(), } if err = visitProtoPayloads(ctx, wtp.outboundPayloadVisitor, request); err != nil { wtp.logger.Error("Failed to visit payloads for GRPC message too large query failure response.", tagError, err) @@ -1105,7 +1110,11 @@ func (wtp *workflowTaskPoller) poll(ctx context.Context) (taskForWorker, error) wtp.logger.Debug("workflowTaskPoller::Poll") }) + groupId := wtp.pollerGroupTracker.getNextGroupId() + defer wtp.pollerGroupTracker.release(groupId) + request := wtp.getNextPollRequest() + request.PollerGroupId = groupId defer wtp.release(request.TaskQueue.GetKind()) response, err := wtp.pollWorkflowTaskQueue(ctx, request) @@ -1113,6 +1122,7 @@ func (wtp *workflowTaskPoller) poll(ctx context.Context) (taskForWorker, error) wtp.updateBacklog(request.TaskQueue.GetKind(), 0) return nil, err } + wtp.pollerGroupTracker.updateGroups(response.GetPollerGroupInfos()) if response == nil || len(response.TaskToken) == 0 { // Emit using base scope as no workflow type information is available in the case of empty poll @@ -1302,6 +1312,7 @@ func newActivityTaskPoller(taskHandler ActivityTaskHandler, service workflowserv logger: params.Logger, activitiesPerSecond: params.TaskQueueActivitiesPerSecond, numPollerMetric: newNumPollerMetric(params.MetricsHandler, metrics.PollerTypeActivityTask), + pollerGroupTracker: newPollerGroupTracker(), inboundPayloadVisitor: params.inboundPayloadVisitor, outboundPayloadVisitor: params.outboundPayloadVisitor, } @@ -1320,6 +1331,10 @@ func (atp *activityTaskPoller) poll(ctx context.Context) (taskForWorker, error) traceLog(func() { atp.logger.Debug("activityTaskPoller::Poll") }) + + groupId := atp.pollerGroupTracker.getNextGroupId() + defer atp.pollerGroupTracker.release(groupId) + request := &workflowservice.PollActivityTaskQueueRequest{ Namespace: atp.namespace, TaskQueue: &taskqueuepb.TaskQueue{Name: atp.taskQueueName, Kind: enumspb.TASK_QUEUE_KIND_NORMAL}, @@ -1335,12 +1350,14 @@ func (atp *activityTaskPoller) poll(ctx context.Context) (taskForWorker, error) atp.workerDeploymentVersion, ), WorkerInstanceKey: atp.workerInstanceKey, + PollerGroupId: groupId, } response, err := atp.pollActivityTaskQueue(ctx, request) if err != nil { return nil, err } + atp.pollerGroupTracker.updateGroups(response.GetPollerGroupInfos()) if response == nil || len(response.TaskToken) == 0 { // No activity info is available on empty poll. Emit using base scope. atp.metricsHandler.Counter(metrics.ActivityPollNoTaskCounter).Inc(1) diff --git a/internal/internal_workflow_testsuite.go b/internal/internal_workflow_testsuite.go index 032cc17aa..224becbb5 100644 --- a/internal/internal_workflow_testsuite.go +++ b/internal/internal_workflow_testsuite.go @@ -2731,7 +2731,7 @@ func (env *testWorkflowEnvironmentImpl) ExecuteNexusOperation( response, failure, err := taskHandler.Execute(task) if err != nil { // No retries for operations, fail the operation immediately. - failure, err = taskHandler.fillInFailure(task.TaskToken, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, err.Error()), false) + failure, err = taskHandler.fillInFailure(task.TaskToken, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, err.Error()), false, "") } if failure != nil { // Convert to a nexus HandlerError first to simulate the flow in the server. diff --git a/internal/poller_group_id_test.go b/internal/poller_group_id_test.go new file mode 100644 index 000000000..476107169 --- /dev/null +++ b/internal/poller_group_id_test.go @@ -0,0 +1,397 @@ +package internal + +import ( + "context" + "fmt" + "testing" + + "github.com/golang/mock/gomock" + "github.com/nexus-rpc/sdk-go/nexus" + "github.com/stretchr/testify/require" + commonpb "go.temporal.io/api/common/v1" + enumspb "go.temporal.io/api/enums/v1" + historypb "go.temporal.io/api/history/v1" + nexuspb "go.temporal.io/api/nexus/v1" + taskqueuepb "go.temporal.io/api/taskqueue/v1" + "go.temporal.io/api/workflowservice/v1" + "go.temporal.io/api/workflowservicemock/v1" + "google.golang.org/grpc" +) + +func TestActivityPoll_SetsPollerGroupIdAndUpdatesTracker(t *testing.T) { + ctrl := gomock.NewController(t) + service := workflowservicemock.NewMockWorkflowServiceClient(ctrl) + + params := workerExecutionParameters{ + Namespace: "test-ns", + TaskQueue: "test-tq", + cache: NewWorkerCache(), + } + ensureRequiredParams(¶ms) + + poller := newActivityTaskPoller(&noopActivityTaskHandler{}, service, params) + + // First poll: no groups known yet, so PollerGroupId should be empty. + // Server returns groups in the response. + service.EXPECT().PollActivityTaskQueue(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func( + _ context.Context, + req *workflowservice.PollActivityTaskQueueRequest, + _ ...grpc.CallOption, + ) (*workflowservice.PollActivityTaskQueueResponse, error) { + require.Empty(t, req.PollerGroupId, "first poll should have empty poller group id") + return &workflowservice.PollActivityTaskQueueResponse{ + TaskToken: []byte("token"), + PollerGroupInfos: []*taskqueuepb.PollerGroupInfo{ + {Id: "group-1", Weight: 1.0}, + {Id: "group-2", Weight: 1.0}, + }, + }, nil + }) + + ctx := context.Background() + _, err := poller.poll(ctx) + require.NoError(t, err) + + // Second poll: tracker now has groups, so PollerGroupId should be set. + service.EXPECT().PollActivityTaskQueue(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func( + _ context.Context, + req *workflowservice.PollActivityTaskQueueRequest, + _ ...grpc.CallOption, + ) (*workflowservice.PollActivityTaskQueueResponse, error) { + require.NotEmpty(t, req.PollerGroupId, "second poll should have a poller group id") + require.Contains(t, []string{"group-1", "group-2"}, req.PollerGroupId) + return &workflowservice.PollActivityTaskQueueResponse{}, nil + }) + + _, err = poller.poll(ctx) + require.NoError(t, err) +} + +type noopActivityTaskHandler struct{} + +func (h *noopActivityTaskHandler) Execute(string, *workflowservice.PollActivityTaskQueueResponse) (interface{}, error) { + return nil, nil +} + +func TestWorkflowPoll_SetsPollerGroupIdAndUpdatesTracker(t *testing.T) { + ctrl := gomock.NewController(t) + service := workflowservicemock.NewMockWorkflowServiceClient(ctrl) + + params := workerExecutionParameters{ + Namespace: "test-ns", + TaskQueue: "test-tq", + cache: NewWorkerCache(), + } + ensureRequiredParams(¶ms) + + processor := newWorkflowTaskProcessor( + newWorkflowTaskHandler(params, nil, newRegistry()), + nil, + service, + params, + "sticky-uuid", + ) + poller := processor.createPoller(NonSticky).(*workflowTaskPoller) + + // First poll: no groups known yet. + service.EXPECT().PollWorkflowTaskQueue(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func( + _ context.Context, + req *workflowservice.PollWorkflowTaskQueueRequest, + _ ...grpc.CallOption, + ) (*workflowservice.PollWorkflowTaskQueueResponse, error) { + require.Empty(t, req.PollerGroupId, "first poll should have empty poller group id") + return &workflowservice.PollWorkflowTaskQueueResponse{ + PollerGroupInfos: []*taskqueuepb.PollerGroupInfo{ + {Id: "wf-group-a", Weight: 1.0}, + }, + }, nil + }) + + ctx := context.Background() + _, err := poller.poll(ctx) + require.NoError(t, err) + + // Second poll: tracker has groups. + service.EXPECT().PollWorkflowTaskQueue(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func( + _ context.Context, + req *workflowservice.PollWorkflowTaskQueueRequest, + _ ...grpc.CallOption, + ) (*workflowservice.PollWorkflowTaskQueueResponse, error) { + require.Equal(t, "wf-group-a", req.PollerGroupId) + return &workflowservice.PollWorkflowTaskQueueResponse{}, nil + }) + + _, err = poller.poll(ctx) + require.NoError(t, err) +} + +func TestNexusPoll_SetsPollerGroupIdAndUpdatesTracker(t *testing.T) { + ctrl := gomock.NewController(t) + service := workflowservicemock.NewMockWorkflowServiceClient(ctrl) + + params := workerExecutionParameters{ + Namespace: "test-ns", + TaskQueue: "test-tq", + cache: NewWorkerCache(), + } + ensureRequiredParams(¶ms) + + poller := newNexusTaskPoller(nil, service, params) + + // First poll: no groups known yet. + service.EXPECT().PollNexusTaskQueue(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func( + _ context.Context, + req *workflowservice.PollNexusTaskQueueRequest, + _ ...grpc.CallOption, + ) (*workflowservice.PollNexusTaskQueueResponse, error) { + require.Empty(t, req.PollerGroupId, "first poll should have empty poller group id") + return &workflowservice.PollNexusTaskQueueResponse{ + PollerGroupInfos: []*taskqueuepb.PollerGroupInfo{ + {Id: "nexus-group-x", Weight: 1.0}, + }, + }, nil + }) + + ctx := context.Background() + _, err := poller.poll(ctx) + require.NoError(t, err) + + // Second poll: tracker has groups. + service.EXPECT().PollNexusTaskQueue(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func( + _ context.Context, + req *workflowservice.PollNexusTaskQueueRequest, + _ ...grpc.CallOption, + ) (*workflowservice.PollNexusTaskQueueResponse, error) { + require.Equal(t, "nexus-group-x", req.PollerGroupId) + return &workflowservice.PollNexusTaskQueueResponse{}, nil + }) + + _, err = poller.poll(ctx) + require.NoError(t, err) +} + +func TestQueryResponse_ForwardsPollerGroupId(t *testing.T) { + taskQueue := "tq1" + testEvents := []*historypb.HistoryEvent{ + createTestEventWorkflowExecutionStarted(1, &historypb.WorkflowExecutionStartedEventAttributes{ + TaskQueue: &taskqueuepb.TaskQueue{Name: taskQueue}, + }), + createTestEventWorkflowTaskScheduled(2, &historypb.WorkflowTaskScheduledEventAttributes{ + TaskQueue: &taskqueuepb.TaskQueue{Name: taskQueue}, + }), + createTestEventWorkflowTaskStarted(3), + } + + task := createQueryTask(testEvents, 3, "HelloWorld_Workflow", queryType) + task.PollerGroupId = "test-poller-group-42" + + params := workerExecutionParameters{ + Namespace: "test-ns", + TaskQueue: taskQueue, + cache: NewWorkerCache(), + } + ensureRequiredParams(¶ms) + + reg := newRegistry() + reg.RegisterWorkflowWithOptions(helloWorldWorkflowFunc, RegisterWorkflowOptions{Name: "HelloWorld_Workflow"}) + + taskHandler := newWorkflowTaskHandler(params, nil, reg) + wftask := workflowTask{task: task} + + wfctx, err := taskHandler.GetOrCreateWorkflowContext(task, wftask.historyIterator) + require.NoError(t, err) + + response, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock(err) + require.NoError(t, err) + require.NotNil(t, response) + + queryResp, ok := response.rawRequest.(*workflowservice.RespondQueryTaskCompletedRequest) + require.True(t, ok) + require.Equal(t, "test-poller-group-42", queryResp.PollerGroupId) +} + +func TestQueryResponse_ForwardsPollerGroupIdOnError(t *testing.T) { + taskQueue := "tq1" + testEvents := []*historypb.HistoryEvent{ + createTestEventWorkflowExecutionStarted(1, &historypb.WorkflowExecutionStartedEventAttributes{ + TaskQueue: &taskqueuepb.TaskQueue{Name: taskQueue}, + }), + createTestEventWorkflowTaskScheduled(2, &historypb.WorkflowTaskScheduledEventAttributes{ + TaskQueue: &taskqueuepb.TaskQueue{Name: taskQueue}, + }), + createTestEventWorkflowTaskStarted(3), + } + + task := createQueryTask(testEvents, 3, "HelloWorld_Workflow", "nonexistent-query-type") + task.PollerGroupId = "test-poller-group-err" + + params := workerExecutionParameters{ + Namespace: "test-ns", + TaskQueue: taskQueue, + cache: NewWorkerCache(), + } + ensureRequiredParams(¶ms) + + reg := newRegistry() + reg.RegisterWorkflowWithOptions(helloWorldWorkflowFunc, RegisterWorkflowOptions{Name: "HelloWorld_Workflow"}) + + taskHandler := newWorkflowTaskHandler(params, nil, reg) + wftask := workflowTask{task: task} + + wfctx, err := taskHandler.GetOrCreateWorkflowContext(task, wftask.historyIterator) + require.NoError(t, err) + + response, err := taskHandler.ProcessWorkflowTask(&wftask, wfctx, nil) + wfctx.Unlock(err) + require.NoError(t, err) + require.NotNil(t, response) + + queryResp, ok := response.rawRequest.(*workflowservice.RespondQueryTaskCompletedRequest) + require.True(t, ok) + require.Equal(t, enumspb.QUERY_RESULT_TYPE_FAILED, queryResp.CompletedType) + require.Equal(t, "test-poller-group-err", queryResp.PollerGroupId) +} + +// noopNexusHandler is a minimal nexus.Handler that succeeds on CancelOperation. +type noopNexusHandler struct { + nexus.UnimplementedHandler +} + +func (h *noopNexusHandler) CancelOperation(_ context.Context, service, operation, token string, _ nexus.CancelOperationOptions) error { + return nil +} + +func TestNexusCompletion_ForwardsPollerGroupId(t *testing.T) { + params := workerExecutionParameters{ + Namespace: "test-ns", + TaskQueue: "test-tq", + cache: NewWorkerCache(), + } + ensureRequiredParams(¶ms) + + handler := newNexusTaskHandler( + &noopNexusHandler{}, + params.Identity, + params.Namespace, + params.TaskQueue, + nil, + params.DataConverter, + params.FailureConverter, + params.Logger, + params.MetricsHandler, + nil, + ) + + // Call Execute with a CancelOperation request. The noopNexusHandler succeeds, + // so we get a completion. Verify PollerGroupId is forwarded from the response. + task := &workflowservice.PollNexusTaskQueueResponse{ + TaskToken: []byte("token"), + PollerGroupId: "nexus-pg-complete", + Request: &nexuspb.Request{ + Variant: &nexuspb.Request_CancelOperation{ + CancelOperation: &nexuspb.CancelOperationRequest{ + Service: "test-service", + Operation: "test-op", + }, + }, + }, + } + + completedReq, failedReq, err := handler.Execute(task) + require.NoError(t, err) + require.Nil(t, failedReq) + require.NotNil(t, completedReq) + require.Equal(t, "nexus-pg-complete", completedReq.PollerGroupId) +} + +func TestNexusFailure_ForwardsPollerGroupId(t *testing.T) { + params := workerExecutionParameters{ + Namespace: "test-ns", + TaskQueue: "test-tq", + cache: NewWorkerCache(), + } + ensureRequiredParams(¶ms) + + handler := newNexusTaskHandler( + nil, // no nexus handler needed — nil request variant triggers failure + params.Identity, + params.Namespace, + params.TaskQueue, + nil, + params.DataConverter, + params.FailureConverter, + params.Logger, + params.MetricsHandler, + nil, + ) + + // Call Execute with no valid request variant. newNexusOperationContext returns + // a handler error, which goes through fillInFailure. Verify PollerGroupId is + // forwarded from the poll response. + task := &workflowservice.PollNexusTaskQueueResponse{ + TaskToken: []byte("token"), + PollerGroupId: "nexus-pg-fail", + } + + completedReq, failedReq, err := handler.Execute(task) + require.NoError(t, err) + require.Nil(t, completedReq) + require.NotNil(t, failedReq) + require.Equal(t, "nexus-pg-fail", failedReq.PollerGroupId) +} + +func TestGrpcTooLargeQueryResponse_ForwardsPollerGroupId(t *testing.T) { + ctrl := gomock.NewController(t) + service := workflowservicemock.NewMockWorkflowServiceClient(ctrl) + + params := workerExecutionParameters{ + Namespace: "test-ns", + TaskQueue: "test-tq", + cache: NewWorkerCache(), + } + ensureRequiredParams(¶ms) + + processor := newWorkflowTaskProcessor( + newWorkflowTaskHandler(params, nil, newRegistry()), + nil, + service, + params, + "sticky-uuid", + ) + + task := &workflowservice.PollWorkflowTaskQueueResponse{ + TaskToken: []byte("token"), + PollerGroupId: "grpc-too-large-pg", + WorkflowType: &commonpb.WorkflowType{Name: "test-wf"}, + } + + // Capture the RespondQueryTaskCompleted request. + service.EXPECT().RespondQueryTaskCompleted(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func( + _ context.Context, + req *workflowservice.RespondQueryTaskCompletedRequest, + _ ...grpc.CallOption, + ) (*workflowservice.RespondQueryTaskCompletedResponse, error) { + require.Equal(t, "grpc-too-large-pg", req.PollerGroupId) + require.Equal(t, enumspb.QUERY_RESULT_TYPE_FAILED, req.CompletedType) + return &workflowservice.RespondQueryTaskCompletedResponse{}, nil + }) + + // Simulate the GRPC too large fallback path for a query task. + queryCompletion := &workflowTaskCompletion{ + rawRequest: &workflowservice.RespondQueryTaskCompletedRequest{}, + } + _, _ = processor.reportGrpcMessageTooLarge( + context.Background(), + queryCompletion, + task, + fmt.Errorf("grpc message too large"), + ) +} diff --git a/internal/poller_group_tracker.go b/internal/poller_group_tracker.go new file mode 100644 index 000000000..263c4b248 --- /dev/null +++ b/internal/poller_group_tracker.go @@ -0,0 +1,112 @@ +package internal + +import ( + "math/rand" + "sync" + + taskqueuepb "go.temporal.io/api/taskqueue/v1" +) + +// pollerGroupTracker distributes pollers across server-provided poller groups +// based on weights. Each call to getNextGroupId returns a group ID such that +// every group has at least one pending (unreleased) request, and beyond that +// minimum the distribution follows group weights. +type pollerGroupTracker struct { + mu sync.Mutex + groups []*taskqueuepb.PollerGroupInfo + pending map[string]int // number of unreleased requests per group ID +} + +func newPollerGroupTracker() *pollerGroupTracker { + return &pollerGroupTracker{ + pending: make(map[string]int), + } +} + +// updateGroups updates the available groups from a server response and +// cleans up pending entries for groups that no longer exist. +func (t *pollerGroupTracker) updateGroups(groups []*taskqueuepb.PollerGroupInfo) { + if len(groups) == 0 { + return + } + t.mu.Lock() + defer t.mu.Unlock() + t.groups = groups + // Remove pending entries for groups that no longer exist. + valid := make(map[string]bool, len(groups)) + for _, g := range groups { + valid[g.GetId()] = true + } + for id := range t.pending { + if !valid[id] { + delete(t.pending, id) + } + } +} + +// getNextGroupId returns the group ID that should be used for the next poll +// request. Groups with zero pending polls are prioritized as candidates. If all +// groups have pending polls, all groups become candidates. Among candidates, one +// is selected randomly weighted by group weights. +func (t *pollerGroupTracker) getNextGroupId() string { + t.mu.Lock() + defer t.mu.Unlock() + + if len(t.groups) == 0 { + return "" + } + + // Candidate set: groups with zero pending polls. + var candidates []*taskqueuepb.PollerGroupInfo + for _, g := range t.groups { + if t.pending[g.GetId()] == 0 { + candidates = append(candidates, g) + } + } + // If all groups have pending polls, all groups are candidates. + if len(candidates) == 0 { + candidates = t.groups + } + + chosen := weightedRandom(candidates) + t.pending[chosen]++ + return chosen +} + +// weightedRandom selects a group ID from candidates randomly based on weights. +// candidates must be non-empty. +func weightedRandom(candidates []*taskqueuepb.PollerGroupInfo) string { + if len(candidates) == 1 { + return candidates[0].GetId() + } + + var totalWeight float64 + for _, g := range candidates { + totalWeight += float64(g.GetWeight()) + } + + // If all weights are zero, pick uniformly at random. + if totalWeight <= 0 { + return candidates[rand.Intn(len(candidates))].GetId() + } + + r := rand.Float64() * totalWeight + for _, g := range candidates { + r -= float64(g.GetWeight()) + if r <= 0 { + return g.GetId() + } + } + // Floating-point rounding fallback. + return candidates[len(candidates)-1].GetId() +} + +// release marks one pending request for the given group as completed. +func (t *pollerGroupTracker) release(groupId string) { + t.mu.Lock() + defer t.mu.Unlock() + + if t.pending[groupId] > 0 { + t.pending[groupId]-- + } +} diff --git a/internal/poller_group_tracker_test.go b/internal/poller_group_tracker_test.go new file mode 100644 index 000000000..30b7c92ff --- /dev/null +++ b/internal/poller_group_tracker_test.go @@ -0,0 +1,201 @@ +package internal + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + taskqueuepb "go.temporal.io/api/taskqueue/v1" +) + +func makeGroups(ids ...string) []*taskqueuepb.PollerGroupInfo { + groups := make([]*taskqueuepb.PollerGroupInfo, len(ids)) + for i, id := range ids { + groups[i] = &taskqueuepb.PollerGroupInfo{Id: id, Weight: 1.0} + } + return groups +} + +func makeWeightedGroups(pairs ...interface{}) []*taskqueuepb.PollerGroupInfo { + var groups []*taskqueuepb.PollerGroupInfo + for i := 0; i < len(pairs); i += 2 { + groups = append(groups, &taskqueuepb.PollerGroupInfo{ + Id: pairs[i].(string), + Weight: pairs[i+1].(float32), + }) + } + return groups +} + +func TestPollerGroupTracker_NoGroups(t *testing.T) { + tracker := newPollerGroupTracker() + // With no groups, should return empty string. + assert.Equal(t, "", tracker.getNextGroupId()) +} + +func TestPollerGroupTracker_SingleGroup(t *testing.T) { + tracker := newPollerGroupTracker() + tracker.updateGroups(makeGroups("group-a")) + + id := tracker.getNextGroupId() + assert.Equal(t, "group-a", id) + + // Second call should also return group-a (only option). + id2 := tracker.getNextGroupId() + assert.Equal(t, "group-a", id2) + + tracker.release(id) + tracker.release(id2) +} + +func TestPollerGroupTracker_ZeroPendingThenFallback(t *testing.T) { + tracker := newPollerGroupTracker() + tracker.updateGroups(makeGroups("a", "b", "c")) + + // Each call must pick a different group while zero-pending groups remain. + id1 := tracker.getNextGroupId() + id2 := tracker.getNextGroupId() + id3 := tracker.getNextGroupId() + require.NotEqual(t, id1, id2) + require.NotEqual(t, id1, id3) + require.NotEqual(t, id2, id3) + + // All groups now have pending=1. Next call falls back to any group. + id4 := tracker.getNextGroupId() + assert.Contains(t, []string{"a", "b", "c"}, id4) + + tracker.release(id1) + tracker.release(id2) + tracker.release(id3) + tracker.release(id4) +} + +func TestPollerGroupTracker_Release(t *testing.T) { + tracker := newPollerGroupTracker() + tracker.updateGroups(makeGroups("a", "b")) + + id1 := tracker.getNextGroupId() + id2 := tracker.getNextGroupId() + // Both groups have pending=1. + + // Release one and get next: should pick the released one (zero pending). + tracker.release(id1) + id3 := tracker.getNextGroupId() + assert.Equal(t, id1, id3, "should prefer the group with zero pending after release") + + tracker.release(id2) + tracker.release(id3) +} + +func TestPollerGroupTracker_ReleaseFloorAtZero(t *testing.T) { + tracker := newPollerGroupTracker() + tracker.updateGroups(makeGroups("a")) + + // Release without ever getting should not panic or go negative. + tracker.release("a") + tracker.release("nonexistent") + + id := tracker.getNextGroupId() + assert.Equal(t, "a", id) + tracker.release(id) +} + +func TestPollerGroupTracker_UpdateGroupsCleansStale(t *testing.T) { + tracker := newPollerGroupTracker() + tracker.updateGroups(makeGroups("a", "b")) + + // Create pending for both. + id1 := tracker.getNextGroupId() + id2 := tracker.getNextGroupId() + + // Update to only have "b" and "c". + tracker.updateGroups(makeGroups("b", "c")) + + // Pending for "a" should be cleaned up. + tracker.mu.Lock() + _, hasPendingA := tracker.pending["a"] + tracker.mu.Unlock() + assert.False(t, hasPendingA, "pending for removed group should be cleaned up") + + tracker.release(id1) + tracker.release(id2) +} + +func TestPollerGroupTracker_UpdateGroupsEmptyNoOp(t *testing.T) { + tracker := newPollerGroupTracker() + tracker.updateGroups(makeGroups("a")) + + // Empty update should not clear groups. + tracker.updateGroups(nil) + + id := tracker.getNextGroupId() + assert.Equal(t, "a", id) + tracker.release(id) +} + +func TestWeightedRandom_SingleCandidate(t *testing.T) { + groups := makeGroups("only") + assert.Equal(t, "only", weightedRandom(groups)) +} + +func TestWeightedRandom_ZeroWeights(t *testing.T) { + groups := makeWeightedGroups("a", float32(0), "b", float32(0)) + // Should still return a valid group. + id := weightedRandom(groups) + assert.Contains(t, []string{"a", "b"}, id) +} + +func TestPollerGroupTracker_UpdateGroupsPreservesSurvivingPending(t *testing.T) { + tracker := newPollerGroupTracker() + tracker.updateGroups(makeGroups("a", "b", "c")) + + // Create pending for all three. + id1 := tracker.getNextGroupId() + id2 := tracker.getNextGroupId() + id3 := tracker.getNextGroupId() + + // Update groups: drop "a", keep "b" and "c", add "d". + tracker.updateGroups(makeGroups("b", "c", "d")) + + // "d" has zero pending, so it must be picked next. + id4 := tracker.getNextGroupId() + assert.Equal(t, "d", id4) + + // "b" and "c" still have pending=1 from before the update. + // Release one of them and verify it becomes preferred over the other. + // Find which of id1/id2/id3 was "b". + var bId string + for _, id := range []string{id1, id2, id3} { + if id == "b" { + bId = id + break + } + } + require.NotEmpty(t, bId, "b should have been selected in initial round") + tracker.release(bId) + tracker.release(id4) + + // Now "b" and "d" have zero pending, "c" has pending=1. Next pick must not be "c". + id5 := tracker.getNextGroupId() + assert.NotEqual(t, "c", id5, "should prefer zero-pending groups b or d over c") + + tracker.release(id1) + tracker.release(id2) + tracker.release(id3) + tracker.release(id5) +} + +func TestWeightedRandom_DistributionConverges(t *testing.T) { + groups := makeWeightedGroups("a", float32(3.0), "b", float32(1.0)) + + counts := map[string]int{} + iterations := 10000 + for i := 0; i < iterations; i++ { + counts[weightedRandom(groups)]++ + } + + // With weights 3:1, "a" should get ~75% of selections. + ratioA := float64(counts["a"]) / float64(iterations) + assert.InDelta(t, 0.75, ratioA, 0.05, "expected ~75%% for weight-3 group, got %.2f%%", ratioA*100) +} +