diff --git a/gen/iwfidl/model_inter_state_channel_result.go b/gen/iwfidl/model_inter_state_channel_result.go index 1e6ff3a3..45598748 100644 --- a/gen/iwfidl/model_inter_state_channel_result.go +++ b/gen/iwfidl/model_inter_state_channel_result.go @@ -23,6 +23,10 @@ type InterStateChannelResult struct { RequestStatus ChannelRequestStatus `json:"requestStatus"` ChannelName string `json:"channelName"` Value *EncodedObject `json:"value,omitempty"` + // Values contains all consumed messages when AtLeast/AtMost is used. + // For single-message commands, this contains the same single value. + // NOTE: This field requires a corresponding update to the iwf-idl OpenAPI spec. + Values []EncodedObject `json:"values,omitempty"` } // NewInterStateChannelResult instantiates a new InterStateChannelResult object @@ -165,6 +169,9 @@ func (o InterStateChannelResult) ToMap() (map[string]interface{}, error) { if !IsNil(o.Value) { toSerialize["value"] = o.Value } + if len(o.Values) > 0 { + toSerialize["values"] = o.Values + } return toSerialize, nil } diff --git a/integ/interstate_consume_n_test.go b/integ/interstate_consume_n_test.go new file mode 100644 index 00000000..0133a8fb --- /dev/null +++ b/integ/interstate_consume_n_test.go @@ -0,0 +1,129 @@ +package integ + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/indeedeng/iwf/gen/iwfidl" + "github.com/indeedeng/iwf/integ/workflow/interstate_consume_n" + "github.com/indeedeng/iwf/service" + "github.com/indeedeng/iwf/service/common/ptr" + "github.com/stretchr/testify/assert" +) + +func TestInterStateConsumeNWorkflowTemporal(t *testing.T) { + if !*temporalIntegTest { + t.Skip() + } + for i := 0; i < *repeatIntegTest; i++ { + doTestInterStateConsumeNWorkflow(t, service.BackendTypeTemporal, nil) + smallWaitForFastTest() + doTestInterStateConsumeNWorkflow(t, service.BackendTypeTemporal, minimumContinueAsNewConfig(true)) + smallWaitForFastTest() + } +} + +func TestInterStateConsumeNWorkflowCadence(t *testing.T) { + if !*cadenceIntegTest { + t.Skip() + } + for i := 0; i < *repeatIntegTest; i++ { + doTestInterStateConsumeNWorkflow(t, service.BackendTypeCadence, nil) + smallWaitForFastTest() + doTestInterStateConsumeNWorkflow(t, service.BackendTypeCadence, minimumContinueAsNewConfig(false)) + smallWaitForFastTest() + } +} + +func doTestInterStateConsumeNWorkflow(t *testing.T, backendType service.BackendType, config *iwfidl.WorkflowConfig) { + wfHandler := interstate_consume_n.NewHandler() + closeFunc1 := startWorkflowWorker(wfHandler, t) + defer closeFunc1() + + closeFunc2 := startIwfService(backendType) + defer closeFunc2() + + apiClient := iwfidl.NewAPIClient(&iwfidl.Configuration{ + Servers: []iwfidl.ServerConfiguration{ + { + URL: "http://localhost:" + testIwfServerPort, + }, + }, + }) + + wfId := interstate_consume_n.WorkflowType + strconv.Itoa(int(time.Now().UnixNano())) + req := apiClient.DefaultApi.ApiV1WorkflowStartPost(context.Background()) + _, httpResp, err := req.WorkflowStartRequest(iwfidl.WorkflowStartRequest{ + WorkflowId: wfId, + IwfWorkflowType: interstate_consume_n.WorkflowType, + WorkflowTimeoutSeconds: 20, + IwfWorkerUrl: "http://localhost:" + testWorkflowServerPort, + StartStateId: ptr.Any(interstate_consume_n.State1), + WorkflowStartOptions: &iwfidl.WorkflowStartOptions{ + WorkflowConfigOverride: config, + }, + }).Execute() + failTestAtHttpError(err, httpResp, t) + + req2 := apiClient.DefaultApi.ApiV1WorkflowGetWithWaitPost(context.Background()) + resp2, httpResp, err := req2.WorkflowGetRequest(iwfidl.WorkflowGetRequest{ + WorkflowId: wfId, + }).Execute() + failTestAtHttpError(err, httpResp, t) + + assertions := assert.New(t) + assertions.Equal(iwfidl.COMPLETED, resp2.GetWorkflowStatus()) + + history, data := wfHandler.GetTestResult() + assertions.Equalf(map[string]int64{ + "S1_start": 1, + "S1_decide": 1, + "S2_start": 1, + "S2_decide": 1, + "S3_start": 1, + "S3_decide": 1, + "S4_start": 1, + "S4_decide": 1, + "S5_start": 1, + "S5_decide": 1, + "S6_start": 1, + "S6_decide": 1, + }, history, "consume N test fail, %v", history) + + // ExactN (AtLeast=3, AtMost=3): should consume exactly 3 of the 5 published messages + exactNValues := data["exactN_values"].([]iwfidl.EncodedObject) + assertions.Equal(3, len(exactNValues), "ExactN should consume exactly 3 messages") + assertions.Equal(*interstate_consume_n.TestValues[0].Data, *exactNValues[0].Data) + assertions.Equal(*interstate_consume_n.TestValues[1].Data, *exactNValues[1].Data) + assertions.Equal(*interstate_consume_n.TestValues[2].Data, *exactNValues[2].Data) + // Value field backward compat: first message + exactNValue := data["exactN_value"].(iwfidl.EncodedObject) + assertions.Equal(*interstate_consume_n.TestValues[0].Data, *exactNValue.Data) + + // OneToAll (AtLeast=1, no AtMost): should consume all remaining 2 messages + oneToAllValues := data["oneToAll_values"].([]iwfidl.EncodedObject) + assertions.Equal(2, len(oneToAllValues), "OneToAll should consume all remaining messages") + assertions.Equal(*interstate_consume_n.TestValues[3].Data, *oneToAllValues[0].Data) + assertions.Equal(*interstate_consume_n.TestValues[4].Data, *oneToAllValues[1].Data) + oneToAllValue := data["oneToAll_value"].(iwfidl.EncodedObject) + assertions.Equal(*interstate_consume_n.TestValues[3].Data, *oneToAllValue.Data) + + // ZeroToAll (AtLeast=0, no AtMost): channel empty, should consume 0 messages + zeroToAllValues := data["zeroToAll_values"] + if zeroToAllValues == nil { + assertions.Nil(zeroToAllValues) + } else { + vals := zeroToAllValues.([]iwfidl.EncodedObject) + assertions.Equal(0, len(vals), "ZeroToAll should consume 0 messages from empty channel") + } + + // AtMostOnly (AtMost=2, no AtLeast): waits for late messages from S6, should consume 2 of 3 + atMostOnlyValues := data["atMostOnly_values"].([]iwfidl.EncodedObject) + assertions.Equal(2, len(atMostOnlyValues), "AtMostOnly should consume exactly 2 messages") + assertions.Equal(*interstate_consume_n.TestValuesCh2[0].Data, *atMostOnlyValues[0].Data) + assertions.Equal(*interstate_consume_n.TestValuesCh2[1].Data, *atMostOnlyValues[1].Data) + atMostOnlyValue := data["atMostOnly_value"].(iwfidl.EncodedObject) + assertions.Equal(*interstate_consume_n.TestValuesCh2[0].Data, *atMostOnlyValue.Data) +} diff --git a/integ/workflow/interstate_consume_n/routers.go b/integ/workflow/interstate_consume_n/routers.go new file mode 100644 index 00000000..72152b24 --- /dev/null +++ b/integ/workflow/interstate_consume_n/routers.go @@ -0,0 +1,320 @@ +package interstate_consume_n + +import ( + "log" + "net/http" + "sync" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/indeedeng/iwf/gen/iwfidl" + "github.com/indeedeng/iwf/service" + "github.com/indeedeng/iwf/service/common/ptr" +) + +/** + * This test workflow verifies consuming N messages from an inter-state channel using AtLeast/AtMost. + * + * S1 (Setup): + * - Start: no commands + * - Decide: publishes 5 messages to "ch", moves to S2 + * + * S2 (ExactN — AtLeast=3, AtMost=3): + * - Start: channel command on "ch" with AtLeast=3, AtMost=3 + * - Decide: verifies 3 values consumed, moves to S3 + * + * S3 (OneToAll — AtLeast=1, no AtMost): + * - Start: channel command on "ch" with AtLeast=1 + * - Decide: verifies remaining 2 values consumed, moves to S4 + * + * S4 (ZeroToAll — AtLeast=0, no AtMost on empty channel): + * - Start: channel command on "ch" with AtLeast=0 + * - Decide: verifies 0 values consumed, moves to S5 + S6 + * + * S5 (AtMostOnly — AtMost=2, no AtLeast; also tests late message arrival): + * - Start: channel command on "ch2" with AtMost=2 only + * - Decide: verifies 2 values consumed (out of 3 published by S6), completes workflow + * + * S6 (Delayed publisher): + * - Start: delays 2s, publishes 3 messages to "ch2", no commands + * - Decide: dead-end + */ +const ( + WorkflowType = "interstate_consume_n" + State1 = "S1" + State2 = "S2" + State3 = "S3" + State4 = "S4" + State5 = "S5" + State6 = "S6" + + channel1 = "ch" + channel2 = "ch2" +) + +var TestValues = []iwfidl.EncodedObject{ + {Encoding: iwfidl.PtrString("json"), Data: iwfidl.PtrString("val-0")}, + {Encoding: iwfidl.PtrString("json"), Data: iwfidl.PtrString("val-1")}, + {Encoding: iwfidl.PtrString("json"), Data: iwfidl.PtrString("val-2")}, + {Encoding: iwfidl.PtrString("json"), Data: iwfidl.PtrString("val-3")}, + {Encoding: iwfidl.PtrString("json"), Data: iwfidl.PtrString("val-4")}, +} + +var TestValuesCh2 = []iwfidl.EncodedObject{ + {Encoding: iwfidl.PtrString("json"), Data: iwfidl.PtrString("ch2-val-0")}, + {Encoding: iwfidl.PtrString("json"), Data: iwfidl.PtrString("ch2-val-1")}, + {Encoding: iwfidl.PtrString("json"), Data: iwfidl.PtrString("ch2-val-2")}, +} + +type handler struct { + invokeHistory sync.Map + invokeData sync.Map +} + +func NewHandler() *handler { + return &handler{ + invokeHistory: sync.Map{}, + invokeData: sync.Map{}, + } +} + +func (h *handler) ApiV1WorkflowStateStart(c *gin.Context, t *testing.T) { + var req iwfidl.WorkflowStateStartRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + log.Println("received state start request, ", req) + + if req.GetWorkflowType() == WorkflowType { + h.recordInvoke(req.GetWorkflowStateId() + "_start") + + switch req.GetWorkflowStateId() { + case State1: + c.JSON(http.StatusOK, iwfidl.WorkflowStateStartResponse{ + CommandRequest: &iwfidl.CommandRequest{ + DeciderTriggerType: iwfidl.ALL_COMMAND_COMPLETED.Ptr(), + }, + }) + return + + case State2: + // ExactN: wait for exactly 3, consume exactly 3 + c.JSON(http.StatusOK, iwfidl.WorkflowStateStartResponse{ + CommandRequest: &iwfidl.CommandRequest{ + DeciderTriggerType: iwfidl.ALL_COMMAND_COMPLETED.Ptr(), + InterStateChannelCommands: []iwfidl.InterStateChannelCommand{ + { + CommandId: ptr.Any("cmd-1"), + ChannelName: channel1, + AtLeast: iwfidl.PtrInt32(3), + AtMost: iwfidl.PtrInt32(3), + }, + }, + }, + }) + return + + case State3: + // OneToAll: wait for at least 1, consume all available + c.JSON(http.StatusOK, iwfidl.WorkflowStateStartResponse{ + CommandRequest: &iwfidl.CommandRequest{ + DeciderTriggerType: iwfidl.ALL_COMMAND_COMPLETED.Ptr(), + InterStateChannelCommands: []iwfidl.InterStateChannelCommand{ + { + CommandId: ptr.Any("cmd-2"), + ChannelName: channel1, + AtLeast: iwfidl.PtrInt32(1), + }, + }, + }, + }) + return + + case State4: + // ZeroToAll: don't wait, consume all available (channel is empty) + c.JSON(http.StatusOK, iwfidl.WorkflowStateStartResponse{ + CommandRequest: &iwfidl.CommandRequest{ + DeciderTriggerType: iwfidl.ALL_COMMAND_COMPLETED.Ptr(), + InterStateChannelCommands: []iwfidl.InterStateChannelCommand{ + { + CommandId: ptr.Any("cmd-3"), + ChannelName: channel1, + AtLeast: iwfidl.PtrInt32(0), + }, + }, + }, + }) + return + + case State5: + // AtMostOnly: only AtMost set (no AtLeast), waits for late messages from S6 + c.JSON(http.StatusOK, iwfidl.WorkflowStateStartResponse{ + CommandRequest: &iwfidl.CommandRequest{ + DeciderTriggerType: iwfidl.ALL_COMMAND_COMPLETED.Ptr(), + InterStateChannelCommands: []iwfidl.InterStateChannelCommand{ + { + CommandId: ptr.Any("cmd-4"), + ChannelName: channel2, + AtMost: iwfidl.PtrInt32(2), + }, + }, + }, + }) + return + + case State6: + // Delayed publisher: wait 2s then publish 3 messages to ch2 + time.Sleep(time.Second * 2) + publishes := make([]iwfidl.InterStateChannelPublishing, len(TestValuesCh2)) + for i := range TestValuesCh2 { + v := TestValuesCh2[i] + publishes[i] = iwfidl.InterStateChannelPublishing{ + ChannelName: channel2, + Value: &v, + } + } + c.JSON(http.StatusOK, iwfidl.WorkflowStateStartResponse{ + CommandRequest: &iwfidl.CommandRequest{ + DeciderTriggerType: iwfidl.ALL_COMMAND_COMPLETED.Ptr(), + }, + PublishToInterStateChannel: publishes, + }) + return + } + } + + c.JSON(http.StatusBadRequest, struct{}{}) +} + +func (h *handler) ApiV1WorkflowStateDecide(c *gin.Context, t *testing.T) { + var req iwfidl.WorkflowStateDecideRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + log.Println("received state decide request, ", req) + + if req.GetWorkflowType() == WorkflowType { + h.recordInvoke(req.GetWorkflowStateId() + "_decide") + + switch req.GetWorkflowStateId() { + case State1: + // Publish 5 messages to ch and move to S2 + publishes := make([]iwfidl.InterStateChannelPublishing, len(TestValues)) + for i := range TestValues { + v := TestValues[i] + publishes[i] = iwfidl.InterStateChannelPublishing{ + ChannelName: channel1, + Value: &v, + } + } + c.JSON(http.StatusOK, iwfidl.WorkflowStateDecideResponse{ + StateDecision: &iwfidl.StateDecision{ + NextStates: []iwfidl.StateMovement{ + {StateId: State2}, + }, + }, + PublishToInterStateChannel: publishes, + }) + return + + case State2: + results := req.GetCommandResults() + channelResult := results.GetInterStateChannelResults()[0] + h.invokeData.Store("exactN_values", channelResult.Values) + h.invokeData.Store("exactN_value", channelResult.GetValue()) + + c.JSON(http.StatusOK, iwfidl.WorkflowStateDecideResponse{ + StateDecision: &iwfidl.StateDecision{ + NextStates: []iwfidl.StateMovement{ + {StateId: State3}, + }, + }, + }) + return + + case State3: + results := req.GetCommandResults() + channelResult := results.GetInterStateChannelResults()[0] + h.invokeData.Store("oneToAll_values", channelResult.Values) + h.invokeData.Store("oneToAll_value", channelResult.GetValue()) + + c.JSON(http.StatusOK, iwfidl.WorkflowStateDecideResponse{ + StateDecision: &iwfidl.StateDecision{ + NextStates: []iwfidl.StateMovement{ + {StateId: State4}, + }, + }, + }) + return + + case State4: + results := req.GetCommandResults() + channelResult := results.GetInterStateChannelResults()[0] + h.invokeData.Store("zeroToAll_values", channelResult.Values) + + // Move to S5 (waiter) and S6 (delayed publisher) concurrently + c.JSON(http.StatusOK, iwfidl.WorkflowStateDecideResponse{ + StateDecision: &iwfidl.StateDecision{ + NextStates: []iwfidl.StateMovement{ + {StateId: State5}, + {StateId: State6}, + }, + }, + }) + return + + case State5: + results := req.GetCommandResults() + channelResult := results.GetInterStateChannelResults()[0] + h.invokeData.Store("atMostOnly_values", channelResult.Values) + h.invokeData.Store("atMostOnly_value", channelResult.GetValue()) + + c.JSON(http.StatusOK, iwfidl.WorkflowStateDecideResponse{ + StateDecision: &iwfidl.StateDecision{ + NextStates: []iwfidl.StateMovement{ + {StateId: service.GracefulCompletingWorkflowStateId}, + }, + }, + }) + return + + case State6: + // Dead-end after publishing + c.JSON(http.StatusOK, iwfidl.WorkflowStateDecideResponse{ + StateDecision: &iwfidl.StateDecision{ + NextStates: []iwfidl.StateMovement{ + {StateId: service.DeadEndWorkflowStateId}, + }, + }, + }) + return + } + } + + c.JSON(http.StatusBadRequest, struct{}{}) +} + +func (h *handler) GetTestResult() (map[string]int64, map[string]interface{}) { + invokeHistory := make(map[string]int64) + h.invokeHistory.Range(func(key, value interface{}) bool { + invokeHistory[key.(string)] = value.(int64) + return true + }) + invokeData := make(map[string]interface{}) + h.invokeData.Range(func(key, value interface{}) bool { + invokeData[key.(string)] = value + return true + }) + return invokeHistory, invokeData +} + +func (h *handler) recordInvoke(key string) { + if value, ok := h.invokeHistory.Load(key); ok { + h.invokeHistory.Store(key, value.(int64)+1) + } else { + h.invokeHistory.Store(key, int64(1)) + } +} diff --git a/service/interfaces.go b/service/interfaces.go index 7fc15548..461dde2e 100644 --- a/service/interfaces.go +++ b/service/interfaces.go @@ -160,6 +160,9 @@ type ( CompletedTimerCommands map[int]InternalTimerStatus `json:"completedTimerCommands"` CompletedSignalCommands map[int]*iwfidl.EncodedObject `json:"completedSignalCommands"` CompletedInterStateChannelCommands map[int]*iwfidl.EncodedObject `json:"completedInterStateChannelCommands"` + // CompletedInterStateChannelMultiCmds stores multiple messages consumed by a single command + // when AtLeast/AtMost are set. Keyed by command index. + CompletedInterStateChannelMultiCmds map[int][]*iwfidl.EncodedObject `json:"completedInterStateChannelMultiCmds,omitempty"` } StaleSkipTimerSignal struct { diff --git a/service/interpreter/InternalChannel.go b/service/interpreter/InternalChannel.go index 77788d4c..4789d51c 100644 --- a/service/interpreter/InternalChannel.go +++ b/service/interpreter/InternalChannel.go @@ -41,6 +41,15 @@ func (i *InternalChannel) HasData(channelName string) bool { return len(l) > 0 } +func (i *InternalChannel) HasAtLeastN(channelName string, n int) bool { + l := i.receivedData[channelName] + return len(l) >= n +} + +func (i *InternalChannel) Size(channelName string) int { + return len(i.receivedData[channelName]) +} + func (i *InternalChannel) ProcessPublishing(publishes []iwfidl.InterStateChannelPublishing) { for _, pub := range publishes { i.receive(pub.ChannelName, pub.Value) @@ -68,3 +77,25 @@ func (i *InternalChannel) Retrieve(channelName string) *iwfidl.EncodedObject { return data } + +// RetrieveUpToN atomically retrieves up to n messages from the channel. +// It consumes min(n, available) messages. +func (i *InternalChannel) RetrieveUpToN(channelName string, n int) []*iwfidl.EncodedObject { + l := i.receivedData[channelName] + if len(l) == 0 { + return []*iwfidl.EncodedObject{} + } + count := n + if count > len(l) { + count = len(l) + } + data := make([]*iwfidl.EncodedObject, count) + copy(data, l[:count]) + l = l[count:] + if len(l) == 0 { + delete(i.receivedData, channelName) + } else { + i.receivedData[channelName] = l + } + return data +} diff --git a/service/interpreter/continueAsNewer.go b/service/interpreter/continueAsNewer.go index a33fd39d..640d339a 100644 --- a/service/interpreter/continueAsNewer.go +++ b/service/interpreter/continueAsNewer.go @@ -172,6 +172,7 @@ func (c *ContinueAsNewer) AddPotentialStateExecutionToResume( commandRequest iwfidl.CommandRequest, completedTimerCommands map[int]service.InternalTimerStatus, completedSignalCommands, completedInterStateChannelCommands map[int]*iwfidl.EncodedObject, + completedInterStateChannelMultiCmds map[int][]*iwfidl.EncodedObject, ) { c.StateExecutionToResumeMap[stateExecutionId] = service.StateExecutionResumeInfo{ StateExecutionId: stateExecutionId, @@ -179,9 +180,10 @@ func (c *ContinueAsNewer) AddPotentialStateExecutionToResume( StateExecutionLocals: stateExecLocals, CommandRequest: commandRequest, StateExecutionCompletedCommands: service.StateExecutionCompletedCommands{ - CompletedTimerCommands: completedTimerCommands, - CompletedSignalCommands: completedSignalCommands, - CompletedInterStateChannelCommands: completedInterStateChannelCommands, + CompletedTimerCommands: completedTimerCommands, + CompletedSignalCommands: completedSignalCommands, + CompletedInterStateChannelCommands: completedInterStateChannelCommands, + CompletedInterStateChannelMultiCmds: completedInterStateChannelMultiCmds, }, } } diff --git a/service/interpreter/deciderTriggerer.go b/service/interpreter/deciderTriggerer.go index 7fd743b5..a0b524b1 100644 --- a/service/interpreter/deciderTriggerer.go +++ b/service/interpreter/deciderTriggerer.go @@ -12,17 +12,20 @@ func IsDeciderTriggerConditionMet( completedTimerCmds map[int]service.InternalTimerStatus, completedSignalCmds map[int]*iwfidl.EncodedObject, completedInterStateChannelCmds map[int]*iwfidl.EncodedObject, + completedInterStateChannelMultiCmds map[int][]*iwfidl.EncodedObject, ) bool { + completedInterStateChannelCount := countCompletedInterStateChannelCmds(completedInterStateChannelCmds, completedInterStateChannelMultiCmds) + if len(commandReq.GetTimerCommands())+len(commandReq.GetSignalCommands())+len(commandReq.GetInterStateChannelCommands()) > 0 { triggerType := compatibility.GetDeciderTriggerType(commandReq) if triggerType == iwfidl.ALL_COMMAND_COMPLETED { return len(completedTimerCmds) == len(commandReq.GetTimerCommands()) && len(completedSignalCmds) == len(commandReq.GetSignalCommands()) && - len(completedInterStateChannelCmds) == len(commandReq.GetInterStateChannelCommands()) + completedInterStateChannelCount == len(commandReq.GetInterStateChannelCommands()) } else if triggerType == iwfidl.ANY_COMMAND_COMPLETED { return len(completedTimerCmds)+ len(completedSignalCmds)+ - len(completedInterStateChannelCmds) > 0 + completedInterStateChannelCount > 0 } else if triggerType == iwfidl.ANY_COMMAND_COMBINATION_COMPLETED { var completedCmdIds []string for _, idx := range DeterministicKeys(completedTimerCmds) { @@ -33,7 +36,7 @@ func IsDeciderTriggerConditionMet( cmdId := commandReq.GetSignalCommands()[idx].CommandId completedCmdIds = append(completedCmdIds, *cmdId) } - for _, idx := range DeterministicKeys(completedInterStateChannelCmds) { + for _, idx := range getCompletedInterStateChannelIndices(completedInterStateChannelCmds, completedInterStateChannelMultiCmds) { cmdId := commandReq.GetInterStateChannelCommands()[idx].CommandId completedCmdIds = append(completedCmdIds, *cmdId) } @@ -63,3 +66,35 @@ func IsDeciderTriggerConditionMet( } return true } + +// countCompletedInterStateChannelCmds returns the total number of unique completed +// inter-state channel commands across both single and multi maps. +func countCompletedInterStateChannelCmds( + single map[int]*iwfidl.EncodedObject, + multi map[int][]*iwfidl.EncodedObject, +) int { + seen := make(map[int]bool) + for idx := range single { + seen[idx] = true + } + for idx := range multi { + seen[idx] = true + } + return len(seen) +} + +// getCompletedInterStateChannelIndices returns sorted unique indices of completed +// inter-state channel commands from both single and multi maps. +func getCompletedInterStateChannelIndices( + single map[int]*iwfidl.EncodedObject, + multi map[int][]*iwfidl.EncodedObject, +) []int { + merged := make(map[int]bool) + for idx := range single { + merged[idx] = true + } + for idx := range multi { + merged[idx] = true + } + return DeterministicKeys(merged) +} diff --git a/service/interpreter/globalVersioner.go b/service/interpreter/globalVersioner.go index efc32269..3c57a03e 100644 --- a/service/interpreter/globalVersioner.go +++ b/service/interpreter/globalVersioner.go @@ -40,7 +40,12 @@ const SyncUpdateRPCUseLocalActivity = 9 // This ensures that commands don't get lost during continueAsNew operations. const StartingVersionWaitingCommandThreads = 10 -const MaxOfAllVersions = StartingVersionWaitingCommandThreads +// StartingVersionChannelConsumeN supports consuming N messages from a channel in one command +// via AtLeast/AtMost fields on InterStateChannelCommand. +// See: https://github.com/indeedeng/iwf/issues/301 +const StartingVersionChannelConsumeN = 11 + +const MaxOfAllVersions = StartingVersionChannelConsumeN // GlobalVersioner see https://stackoverflow.com/questions/73941723/what-is-a-good-way-pattern-to-use-temporal-cadence-versioning-api type GlobalVersioner struct { @@ -103,6 +108,10 @@ func (p *GlobalVersioner) IsAfterVersionOfWaitingCommandThreads() bool { return p.version >= StartingVersionWaitingCommandThreads } +func (p *GlobalVersioner) IsAfterVersionOfChannelConsumeN() bool { + return p.version >= StartingVersionChannelConsumeN +} + // methods checking feature/functionality availability func (p *GlobalVersioner) IsUsingGlobalVersionSearchAttribute() bool { diff --git a/service/interpreter/workflowImpl.go b/service/interpreter/workflowImpl.go index f424d161..34f56977 100644 --- a/service/interpreter/workflowImpl.go +++ b/service/interpreter/workflowImpl.go @@ -575,6 +575,7 @@ func processStateExecution( completedTimerCmds := map[int]service.InternalTimerStatus{} completedSignalCmds := map[int]*iwfidl.EncodedObject{} completedInterStateChannelCmds := map[int]*iwfidl.EncodedObject{} + completedInterStateChannelMultiCmds := map[int][]*iwfidl.EncodedObject{} state := stateReq.GetStateMovement() isResumeFromContinueAsNew := stateReq.IsResumeRequest() @@ -592,6 +593,9 @@ func processStateExecution( commandReq = resumeStateRequest.CommandRequest completedCmds := resumeStateRequest.StateExecutionCompletedCommands completedTimerCmds, completedSignalCmds, completedInterStateChannelCmds = completedCmds.CompletedTimerCommands, completedCmds.CompletedSignalCommands, completedCmds.CompletedInterStateChannelCommands + if completedCmds.CompletedInterStateChannelMultiCmds != nil { + completedInterStateChannelMultiCmds = completedCmds.CompletedInterStateChannelMultiCmds + } } else { if state.StateOptions != nil { startApiTimeout := compatibility.GetStartApiTimeoutSeconds(state.StateOptions) @@ -758,6 +762,10 @@ func processStateExecution( // skip completed interStateChannelCommand(from continueAsNew) continue } + if _, ok := completedInterStateChannelMultiCmds[idx]; ok { + // skip completed multi-message interStateChannelCommand(from continueAsNew) + continue + } cmdCtx := provider.ExtendContextWithValue(ctx, "cmd", cmd) cmdCtx = provider.ExtendContextWithValue(cmdCtx, "idx", idx) //Process interstate channel command in a new thread. @@ -773,9 +781,11 @@ func processStateExecution( panic("critical code bug") } + atLeast, atMost := getChannelCommandLimits(cmd, globalVersioner) + received := false _ = provider.Await(ctx, func() bool { - received = interStateChannel.HasData(cmd.ChannelName) + received = interStateChannel.HasAtLeastN(cmd.ChannelName, atLeast) // Note that commandReqDoneOrCanceled is needed for two cases: // 1. will be true when trigger type of the commandReq is completed(e.g. AnyCommandCompleted) so we don't need to wait for all commands. Returning the thread to avoid thread leakage. // 2. will be true to cancel the wait for unblocking continueAsNew(continueAsNew will wait for all threads to complete) @@ -783,7 +793,12 @@ func processStateExecution( }) if received { - completedInterStateChannelCmds[idx] = interStateChannel.Retrieve(cmd.ChannelName) + if atMost > 1 || atLeast == 0 { + values := interStateChannel.RetrieveUpToN(cmd.ChannelName, atMost) + completedInterStateChannelMultiCmds[idx] = values + } else { + completedInterStateChannelCmds[idx] = interStateChannel.Retrieve(cmd.ChannelName) + } } waitForThreads[threadName] = true }) @@ -796,11 +811,12 @@ func processStateExecution( continueAsNewer.AddPotentialStateExecutionToResume( stateExeId, state, stateExecutionLocal, commandReq, completedTimerCmds, completedSignalCmds, completedInterStateChannelCmds, + completedInterStateChannelMultiCmds, ) // Wait for decider trigger (ANY/ALL command completed) OR continue-as-new threshold _ = provider.Await(ctx, func() bool { - return IsDeciderTriggerConditionMet(commandReq, completedTimerCmds, completedSignalCmds, completedInterStateChannelCmds) || continueAsNewCounter.IsThresholdMet() + return IsDeciderTriggerConditionMet(commandReq, completedTimerCmds, completedSignalCmds, completedInterStateChannelCmds, completedInterStateChannelMultiCmds) || continueAsNewCounter.IsThresholdMet() }) //This variable tells all command threads to stop waiting and exit, even if their specific command has not been completed. @@ -826,7 +842,7 @@ func processStateExecution( } } - if !IsDeciderTriggerConditionMet(commandReq, completedTimerCmds, completedSignalCmds, completedInterStateChannelCmds) { + if !IsDeciderTriggerConditionMet(commandReq, completedTimerCmds, completedSignalCmds, completedInterStateChannelCmds, completedInterStateChannelMultiCmds) { // this means continueAsNewCounter.IsThresholdMet == true // not using continueAsNewCounter.IsThresholdMet because deciderTrigger is higher prioritized // it won't continueAsNew in those cases 1. start Api fail with proceed policy, 2. empty commands, 3. both commands and continueAsNew are met @@ -878,8 +894,27 @@ func processStateExecution( var interStateChannelResults []iwfidl.InterStateChannelResult for idx, cmd := range commandReq.GetInterStateChannelCommands() { status := iwfidl.RECEIVED - result, completed := completedInterStateChannelCmds[idx] - if !completed { + var firstValue *iwfidl.EncodedObject + var allValues []iwfidl.EncodedObject + + multiResult, multiCompleted := completedInterStateChannelMultiCmds[idx] + singleResult, singleCompleted := completedInterStateChannelCmds[idx] + + if multiCompleted { + if len(multiResult) > 0 { + firstValue = multiResult[0] + } + for _, v := range multiResult { + if v != nil { + allValues = append(allValues, *v) + } + } + } else if singleCompleted { + firstValue = singleResult + if singleResult != nil { + allValues = append(allValues, *singleResult) + } + } else { status = iwfidl.WAITING } @@ -887,7 +922,8 @@ func processStateExecution( CommandId: cmd.GetCommandId(), ChannelName: cmd.ChannelName, RequestStatus: status, - Value: result, + Value: firstValue, + Values: allValues, }) } commandRes.SetInterStateChannelResults(interStateChannelResults) @@ -1092,6 +1128,42 @@ func getCommandThreadName(prefix string, stateExecId, cmdId string, idx int) str return fmt.Sprintf("%v-%v-%v-%v", prefix, stateExecId, cmdId, idx) } +// getChannelCommandLimits returns the effective AtLeast and AtMost values for a channel command. +// For old workflow versions or commands without AtLeast/AtMost, defaults to (1, 1) for backward compat. +func getChannelCommandLimits(cmd iwfidl.InterStateChannelCommand, globalVersioner *GlobalVersioner) (atLeast int, atMost int) { + if !globalVersioner.IsAfterVersionOfChannelConsumeN() { + return 1, 1 + } + + atLeast = 1 + atMost = 1 + + if cmd.HasAtLeast() { + atLeast = int(cmd.GetAtLeast()) + } + if cmd.HasAtMost() { + atMost = int(cmd.GetAtMost()) + } + + // If only AtLeast is set, AtMost defaults to max (consume up to all) + if cmd.HasAtLeast() && !cmd.HasAtMost() { + atMost = int(^uint(0) >> 1) // max int + } + // If only AtMost is set, AtLeast defaults to AtMost (exact count) + if !cmd.HasAtLeast() && cmd.HasAtMost() { + atLeast = atMost + } + + if atLeast < 0 { + atLeast = 0 + } + if atMost < atLeast { + atMost = atLeast + } + + return atLeast, atMost +} + func createUserWorkflowError(provider interfaces.WorkflowProvider, message string) error { return provider.NewApplicationError( string(iwfidl.INVALID_USER_WORKFLOW_CODE_ERROR_TYPE),