Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@ require (
google.golang.org/genproto/googleapis/rpc v0.0.0-20240827150818-7e3bb234dfed // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

replace go.temporal.io/api => ../temporal-api-go
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
go.temporal.io/api v1.62.5 h1:9R/9CeyM7xqHSlsNt+QIvapQLcRxCZ38bnXQx4mCN6I=
go.temporal.io/api v1.62.5/go.mod h1:iaxoP/9OXMJcQkETTECfwYq4cw/bj4nwov8b3ZLVnXM=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
Expand Down
32 changes: 18 additions & 14 deletions internal/internal_nexus_task_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,10 @@ func newNexusTaskHandler(

func (h *nexusTaskHandler) Execute(task *workflowservice.PollNexusTaskQueueResponse) (*workflowservice.RespondNexusTaskCompletedRequest, *workflowservice.RespondNexusTaskFailedRequest, error) {
failureReasonSupport := getEffectiveTemporalFailureResponses(task.GetRequest().GetCapabilities().GetTemporalFailureResponses())
pollerGroupId := task.GetPollerGroupId()
nctx, handlerErr := h.newNexusOperationContext(task)
if handlerErr != nil {
failureRequest, err := h.fillInFailure(task.TaskToken, handlerErr, failureReasonSupport)
failureRequest, err := h.fillInFailure(task.TaskToken, handlerErr, failureReasonSupport, pollerGroupId)
if err != nil {
return nil, nil, err
}
Expand All @@ -107,13 +108,13 @@ func (h *nexusTaskHandler) Execute(task *workflowservice.PollNexusTaskQueueRespo
return nil, nil, err
}
if handlerErr != nil {
failureRequest, err := h.fillInFailure(task.TaskToken, handlerErr, failureReasonSupport)
failureRequest, err := h.fillInFailure(task.TaskToken, handlerErr, failureReasonSupport, pollerGroupId)
if err != nil {
return nil, nil, err
}
return nil, failureRequest, nil
}
completedRequest, err := h.fillInCompletion(task.TaskToken, res, failureReasonSupport)
completedRequest, err := h.fillInCompletion(task.TaskToken, res, failureReasonSupport, pollerGroupId)
if err != nil {
return nil, nil, err
}
Expand All @@ -122,18 +123,19 @@ func (h *nexusTaskHandler) Execute(task *workflowservice.PollNexusTaskQueueRespo

func (h *nexusTaskHandler) ExecuteContext(nctx *NexusOperationContext, task *workflowservice.PollNexusTaskQueueResponse) (*workflowservice.RespondNexusTaskCompletedRequest, *workflowservice.RespondNexusTaskFailedRequest, error) {
failureReasonSupport := getEffectiveTemporalFailureResponses(task.GetRequest().GetCapabilities().GetTemporalFailureResponses())
pollerGroupId := task.GetPollerGroupId()
res, handlerErr, err := h.execute(nctx, task)
if err != nil {
return nil, nil, err
}
if handlerErr != nil {
failureRequest, err := h.fillInFailure(task.TaskToken, handlerErr, failureReasonSupport)
failureRequest, err := h.fillInFailure(task.TaskToken, handlerErr, failureReasonSupport, pollerGroupId)
if err != nil {
return nil, nil, err
}
return nil, failureRequest, nil
}
completedRequest, err := h.fillInCompletion(task.TaskToken, res, failureReasonSupport)
completedRequest, err := h.fillInCompletion(task.TaskToken, res, failureReasonSupport, pollerGroupId)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -458,7 +460,7 @@ func (h *nexusTaskHandler) newNexusOperationContext(response *workflowservice.Po
}, nil
}

func (h *nexusTaskHandler) fillInCompletion(taskToken []byte, res *nexuspb.Response, failureReasonSupport bool) (*workflowservice.RespondNexusTaskCompletedRequest, error) {
func (h *nexusTaskHandler) fillInCompletion(taskToken []byte, res *nexuspb.Response, failureReasonSupport bool, pollerGroupId string) (*workflowservice.RespondNexusTaskCompletedRequest, error) {
// Handle conversion of Failure to OperationError for backwards compatibility with old servers.
if res.GetStartOperation().GetFailure() != nil && !failureReasonSupport {
// Convert to operation error for backwards compatibility.
Expand Down Expand Up @@ -487,18 +489,20 @@ func (h *nexusTaskHandler) fillInCompletion(taskToken []byte, res *nexuspb.Respo
}
}
return &workflowservice.RespondNexusTaskCompletedRequest{
Identity: h.identity,
Namespace: h.namespace,
TaskToken: taskToken,
Response: res,
Identity: h.identity,
Namespace: h.namespace,
TaskToken: taskToken,
Response: res,
PollerGroupId: pollerGroupId,
}, nil
}

func (h *nexusTaskHandler) fillInFailure(taskToken []byte, handlerError *nexus.HandlerError, failureReasonSupport bool) (*workflowservice.RespondNexusTaskFailedRequest, error) {
func (h *nexusTaskHandler) fillInFailure(taskToken []byte, handlerError *nexus.HandlerError, failureReasonSupport bool, pollerGroupId string) (*workflowservice.RespondNexusTaskFailedRequest, error) {
r := &workflowservice.RespondNexusTaskFailedRequest{
Identity: h.identity,
Namespace: h.namespace,
TaskToken: taskToken,
Identity: h.identity,
Namespace: h.namespace,
TaskToken: taskToken,
PollerGroupId: pollerGroupId,
}
if failureReasonSupport {
r.Failure = h.failureConverter.ErrorToFailure(handlerError)
Expand Down
26 changes: 17 additions & 9 deletions internal/internal_nexus_task_poller.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ import (

type nexusTaskPoller struct {
basePoller
namespace string
taskQueueName string
identity string
service workflowservice.WorkflowServiceClient
taskHandler *nexusTaskHandler
logger log.Logger
numPollerMetric *numPollerMetric
namespace string
taskQueueName string
identity string
service workflowservice.WorkflowServiceClient
taskHandler *nexusTaskHandler
logger log.Logger
numPollerMetric *numPollerMetric
pollerGroupTracker *pollerGroupTracker
}

type nexusTask struct {
Expand Down Expand Up @@ -53,7 +54,8 @@ func newNexusTaskPoller(
taskQueueName: params.TaskQueue,
identity: params.Identity,
logger: params.Logger,
numPollerMetric: newNumPollerMetric(params.MetricsHandler, metrics.PollerTypeNexusTask),
numPollerMetric: newNumPollerMetric(params.MetricsHandler, metrics.PollerTypeNexusTask),
pollerGroupTracker: newPollerGroupTracker(),
}
}

Expand All @@ -69,6 +71,10 @@ func (ntp *nexusTaskPoller) poll(ctx context.Context) (taskForWorker, error) {
traceLog(func() {
ntp.logger.Debug("nexusTaskPoller::Poll")
})

groupId := ntp.pollerGroupTracker.getNextGroupId()
defer ntp.pollerGroupTracker.release(groupId)

request := &workflowservice.PollNexusTaskQueueRequest{
Namespace: ntp.namespace,
TaskQueue: &taskqueuepb.TaskQueue{Name: ntp.taskQueueName, Kind: enumspb.TASK_QUEUE_KIND_NORMAL},
Expand All @@ -83,12 +89,14 @@ func (ntp *nexusTaskPoller) poll(ctx context.Context) (taskForWorker, error) {
ntp.workerDeploymentVersion,
),
WorkerInstanceKey: ntp.workerInstanceKey,
PollerGroupId: groupId,
}

response, err := ntp.pollNexusTaskQueue(ctx, request)
if err != nil {
return nil, err
}
ntp.pollerGroupTracker.updateGroups(response.GetPollerGroupInfos())
if response == nil || len(response.TaskToken) == 0 {
// No operation info is available on empty poll. Emit using base scope.
ntp.metricsHandler.Counter(metrics.NexusPollNoTaskCounter).Inc(1)
Expand Down Expand Up @@ -131,7 +139,7 @@ func (ntp *nexusTaskPoller) ProcessTask(task interface{}) error {
nctx, handlerErr := ntp.taskHandler.newNexusOperationContext(response)
if handlerErr != nil {
// context wasn't propagated to us, use a background context.
failedRequest, err := ntp.taskHandler.fillInFailure(response.TaskToken, handlerErr, getEffectiveTemporalFailureResponses(response.GetRequest().GetCapabilities().GetTemporalFailureResponses()))
failedRequest, err := ntp.taskHandler.fillInFailure(response.TaskToken, handlerErr, getEffectiveTemporalFailureResponses(response.GetRequest().GetCapabilities().GetTemporalFailureResponses()), response.GetPollerGroupId())
if err != nil {
return err
}
Expand Down
5 changes: 3 additions & 2 deletions internal/internal_task_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -1834,8 +1834,9 @@ func (wth *workflowTaskHandlerImpl) completeWorkflow(
// for query task
if task.Query != nil {
queryCompletedRequest := &workflowservice.RespondQueryTaskCompletedRequest{
TaskToken: task.TaskToken,
Namespace: wth.namespace,
TaskToken: task.TaskToken,
Namespace: wth.namespace,
PollerGroupId: task.GetPollerGroupId(),
}
var panicErr *PanicError
if errors.As(workflowContext.err, &panicErr) {
Expand Down
19 changes: 18 additions & 1 deletion internal/internal_task_pollers.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ type (
numNormalPollerMetric *numPollerMetric
numStickyPollerMetric *numPollerMetric

pollerGroupTracker *pollerGroupTracker

inboundPayloadVisitor PayloadVisitor
}

Expand Down Expand Up @@ -155,7 +157,7 @@ type (
outboundPayloadVisitor PayloadVisitor
}

// activityTaskPoller implements polling/processing a workflow task
// activityTaskPoller implements polling/processing an activity task
activityTaskPoller struct {
basePoller
namespace string
Expand All @@ -166,6 +168,7 @@ type (
logger log.Logger
activitiesPerSecond float64
numPollerMetric *numPollerMetric
pollerGroupTracker *pollerGroupTracker
inboundPayloadVisitor PayloadVisitor
outboundPayloadVisitor PayloadVisitor
}
Expand Down Expand Up @@ -424,6 +427,7 @@ func (wtp *workflowTaskProcessor) createPoller(mode workflowTaskPollerMode) task
eagerActivityExecutor: wtp.eagerActivityExecutor,
numNormalPollerMetric: wtp.numNormalPollerMetric,
numStickyPollerMetric: wtp.numStickyPollerMetric,
pollerGroupTracker: newPollerGroupTracker(),
inboundPayloadVisitor: wtp.inboundPayloadVisitor,
}
}
Expand Down Expand Up @@ -741,6 +745,7 @@ func (wtp *workflowTaskProcessor) reportGrpcMessageTooLarge(
Namespace: wtp.namespace,
Failure: wtp.failureConverter.ErrorToFailure(sendErr),
Cause: enumspb.WORKFLOW_TASK_FAILED_CAUSE_GRPC_MESSAGE_TOO_LARGE,
PollerGroupId: task.GetPollerGroupId(),
}
if err = visitProtoPayloads(ctx, wtp.outboundPayloadVisitor, request); err != nil {
wtp.logger.Error("Failed to visit payloads for GRPC message too large query failure response.", tagError, err)
Expand Down Expand Up @@ -1105,14 +1110,19 @@ func (wtp *workflowTaskPoller) poll(ctx context.Context) (taskForWorker, error)
wtp.logger.Debug("workflowTaskPoller::Poll")
})

groupId := wtp.pollerGroupTracker.getNextGroupId()
defer wtp.pollerGroupTracker.release(groupId)

request := wtp.getNextPollRequest()
request.PollerGroupId = groupId
defer wtp.release(request.TaskQueue.GetKind())

response, err := wtp.pollWorkflowTaskQueue(ctx, request)
if err != nil {
wtp.updateBacklog(request.TaskQueue.GetKind(), 0)
return nil, err
}
wtp.pollerGroupTracker.updateGroups(response.GetPollerGroupInfos())

if response == nil || len(response.TaskToken) == 0 {
// Emit using base scope as no workflow type information is available in the case of empty poll
Expand Down Expand Up @@ -1302,6 +1312,7 @@ func newActivityTaskPoller(taskHandler ActivityTaskHandler, service workflowserv
logger: params.Logger,
activitiesPerSecond: params.TaskQueueActivitiesPerSecond,
numPollerMetric: newNumPollerMetric(params.MetricsHandler, metrics.PollerTypeActivityTask),
pollerGroupTracker: newPollerGroupTracker(),
inboundPayloadVisitor: params.inboundPayloadVisitor,
outboundPayloadVisitor: params.outboundPayloadVisitor,
}
Expand All @@ -1320,6 +1331,10 @@ func (atp *activityTaskPoller) poll(ctx context.Context) (taskForWorker, error)
traceLog(func() {
atp.logger.Debug("activityTaskPoller::Poll")
})

groupId := atp.pollerGroupTracker.getNextGroupId()
defer atp.pollerGroupTracker.release(groupId)

request := &workflowservice.PollActivityTaskQueueRequest{
Namespace: atp.namespace,
TaskQueue: &taskqueuepb.TaskQueue{Name: atp.taskQueueName, Kind: enumspb.TASK_QUEUE_KIND_NORMAL},
Expand All @@ -1335,12 +1350,14 @@ func (atp *activityTaskPoller) poll(ctx context.Context) (taskForWorker, error)
atp.workerDeploymentVersion,
),
WorkerInstanceKey: atp.workerInstanceKey,
PollerGroupId: groupId,
}

response, err := atp.pollActivityTaskQueue(ctx, request)
if err != nil {
return nil, err
}
atp.pollerGroupTracker.updateGroups(response.GetPollerGroupInfos())
if response == nil || len(response.TaskToken) == 0 {
// No activity info is available on empty poll. Emit using base scope.
atp.metricsHandler.Counter(metrics.ActivityPollNoTaskCounter).Inc(1)
Expand Down
2 changes: 1 addition & 1 deletion internal/internal_workflow_testsuite.go
Original file line number Diff line number Diff line change
Expand Up @@ -2731,7 +2731,7 @@ func (env *testWorkflowEnvironmentImpl) ExecuteNexusOperation(
response, failure, err := taskHandler.Execute(task)
if err != nil {
// No retries for operations, fail the operation immediately.
failure, err = taskHandler.fillInFailure(task.TaskToken, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, err.Error()), false)
failure, err = taskHandler.fillInFailure(task.TaskToken, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, err.Error()), false, "")
}
if failure != nil {
// Convert to a nexus HandlerError first to simulate the flow in the server.
Expand Down
Loading
Loading