diff --git a/balance_checker.go b/balance_checker.go index 446fbfe2..66230290 100644 --- a/balance_checker.go +++ b/balance_checker.go @@ -21,7 +21,6 @@ import ( "pkg.akt.dev/go/util/pubsub" netutil "pkg.akt.dev/node/v2/util/network" - "pkg.akt.dev/node/v2/util/runner" "pkg.akt.dev/node/v2/x/escrow/client/util" "github.com/akash-network/provider/event" @@ -43,6 +42,7 @@ const ( type BalanceCheckerConfig struct { WithdrawalPeriod time.Duration LeaseFundsCheckInterval time.Duration + WithdrawalBatchMaxMsgs int } type leaseState struct { @@ -60,6 +60,7 @@ type balanceChecker struct { aqc aclient.QueryClient leases map[mtypes.LeaseID]*leaseState cfg BalanceCheckerConfig + batcher *withdrawBatcher } type leaseCheckResponse struct { @@ -77,16 +78,18 @@ func newBalanceChecker( bus pubsub.Bus, cfg BalanceCheckerConfig, ) (*balanceChecker, error) { + bcLog := clientSession.Log().With("cmp", "balance-checker") bc := &balanceChecker{ ctx: ctx, session: clientSession, - log: clientSession.Log().With("cmp", "balance-checker"), + log: bcLog, bus: bus, lc: lifecycle.New(), ownAddr: accAddr, aqc: aqc, leases: make(map[mtypes.LeaseID]*leaseState), cfg: cfg, + batcher: newWithdrawBatcher(clientSession.Client().Tx(), bcLog.With("cmp", "withdraw-batcher"), withdrawTimeout, cfg.WithdrawalBatchMaxMsgs), } startCh := make(chan error, 1) @@ -199,18 +202,6 @@ func (bc *balanceChecker) doEscrowCheck(ctx context.Context, lid mtypes.LeaseID, return resp } -func (bc *balanceChecker) startWithdraw(ctx context.Context, lid mtypes.LeaseID) error { - ctx, cancel := context.WithTimeout(ctx, withdrawTimeout) - defer cancel() - - msg := &mvbeta.MsgWithdrawLease{ - ID: lid, - } - - _, err := bc.session.Client().Tx().BroadcastMsgs(ctx, []sdk.Msg{msg}, aclient.WithResultCodeAsError()) - return err -} - func (bc *balanceChecker) run(startCh chan<- error) { ctx, cancel := context.WithCancel(bc.ctx) @@ -228,7 +219,6 @@ func (bc *balanceChecker) run(startCh chan<- error) { }() leaseCheckCh := make(chan leaseCheckResponse, 1) - var resultch chan runner.Result subscriber, err := bc.bus.Subscribe() startCh <- err @@ -236,8 +226,6 @@ func (bc *balanceChecker) run(startCh chan<- error) { return } - resultch = make(chan runner.Result, 1) - loop: for { select { @@ -282,6 +270,7 @@ loop: } delete(bc.leases, ev.LeaseID) + bc.batcher.Remove(ev.LeaseID) } case res := <-leaseCheckCh: // we may have timer fired just a heart beat ahead of lease remove event. @@ -325,17 +314,20 @@ loop: } if withdraw { - go func() { - select { - case <-ctx.Done(): - case resultch <- runner.NewResult(res.lid, bc.startWithdraw(ctx, res.lid)): - } - }() + bc.batcher.Enqueue(res.lid) + bc.batcher.Flush(ctx) } - case res := <-resultch: - if err := res.Error(); err != nil { - bc.log.Error("failed to do lease withdrawal", "err", err, "LeaseID", res.Value().(mtypes.LeaseID)) + case err := <-bc.batcher.Done(): + bc.batcher.MarkDone() + if err != nil { + // Skip immediate re-flush on failure: let the next per-lease + // timer trigger Enqueue+Flush, which gives natural backoff. + // Pending ids stay queued; Enqueue dedupes so a lease that + // re-triggers while still pending won't duplicate the msg. + bc.log.Error("failed to do lease withdrawal", "err", err, "pending", bc.batcher.Pending()) + continue loop } + bc.batcher.Flush(ctx) } } } diff --git a/bidengine/order.go b/bidengine/order.go index 4896042e..25ba1406 100644 --- a/bidengine/order.go +++ b/bidengine/order.go @@ -436,7 +436,7 @@ loop: Sources: deposit.Sources{deposit.SourceBalance}, }, offer) bidch = runner.Do(func() runner.Result { - return runner.NewResult(o.session.Client().Tx().BroadcastMsgs(ctx, []sdk.Msg{msg}, aclient.WithResultCodeAsError())) + return runner.NewResult(o.session.Client().Tx().BroadcastMsgs(ctx, []sdk.Msg{msg}, aclient.WithResultCodeAsError(), aclient.WithPriority())) }) case result := <-bidch: @@ -493,7 +493,7 @@ loop: Reason: mtypes.LeaseClosedReasonUnspecified, } - _, err := o.session.Client().Tx().BroadcastMsgs(ctx, []sdk.Msg{msg}, aclient.WithResultCodeAsError()) + _, err := o.session.Client().Tx().BroadcastMsgs(ctx, []sdk.Msg{msg}, aclient.WithResultCodeAsError(), aclient.WithPriority()) if err != nil { o.log.Error("closing bid", "err", err) bidCounter.WithLabelValues("close", metricsutils.FailLabel).Inc() diff --git a/bidengine/order_test.go b/bidengine/order_test.go index f3ddc374..8db72839 100644 --- a/bidengine/order_test.go +++ b/bidengine/order_test.go @@ -121,7 +121,7 @@ func makeMocks(s *orderTestScaffold) { txMocks := &clientmocks.TxClient{} s.broadcasts = make(chan []sdk.Msg, 1) - txMocks.On("BroadcastMsgs", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + txMocks.On("BroadcastMsgs", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { s.broadcasts <- args.Get(1).([]sdk.Msg) }).Return(&sdk.Result{}, nil) @@ -553,7 +553,7 @@ func Test_ShouldCloseBidWhenAlreadySetAndOld(t *testing.T) { Reason: mtypes.LeaseClosedReasonUnspecified, }} - scaffold.txClient.AssertCalled(t, "BroadcastMsgs", mock.Anything, expMsgs, mock.Anything) + scaffold.txClient.AssertCalled(t, "BroadcastMsgs", mock.Anything, expMsgs, mock.Anything, mock.Anything) } func Test_ShouldExitWhenAlreadySetAndLost(t *testing.T) { @@ -578,7 +578,7 @@ func Test_ShouldExitWhenAlreadySetAndLost(t *testing.T) { ID: mtypes.MakeBidID(order.orderID, scaffold.testAddr), } - scaffold.txClient.AssertNotCalled(t, "BroadcastMsgs", mock.Anything, expMsgs, mock.Anything) + scaffold.txClient.AssertNotCalled(t, "BroadcastMsgs", mock.Anything, expMsgs, mock.Anything, mock.Anything) } func Test_ShouldCloseBidWhenAlreadySetAndThenTimeout(t *testing.T) { @@ -605,7 +605,7 @@ func Test_ShouldCloseBidWhenAlreadySetAndThenTimeout(t *testing.T) { Reason: mtypes.LeaseClosedReasonUnspecified, }, } - scaffold.txClient.AssertCalled(t, "BroadcastMsgs", mock.Anything, expMsgs, mock.Anything) + scaffold.txClient.AssertCalled(t, "BroadcastMsgs", mock.Anything, expMsgs, mock.Anything, mock.Anything) // Should have called unreserve scaffold.cluster.AssertCalled(t, "Unreserve", scaffold.orderID) diff --git a/cluster/inventory.go b/cluster/inventory.go index 71df2d30..88733e7c 100644 --- a/cluster/inventory.go +++ b/cluster/inventory.go @@ -419,6 +419,11 @@ func countPendingIPs(state *inventoryServiceState) uint { } func (is *inventoryService) handleRequest(req inventoryRequest, state *inventoryServiceState) { + if state.inventory == nil { + req.ch <- inventoryResponse{err: errInventoryNotAvailableYet} + return + } + // convert the resources to the committed amount resourcesToCommit := is.resourcesToCommit(req.resources) // create new registration if capacity available diff --git a/cluster/monitor.go b/cluster/monitor.go index 261d6666..e6a82395 100644 --- a/cluster/monitor.go +++ b/cluster/monitor.go @@ -199,7 +199,7 @@ func (m *deploymentMonitor) runCloseLease(ctx context.Context) <-chan runner.Res ID: m.deployment.LeaseID().BidID(), Reason: mv1.LeaseClosedReasonUnstable, } - res, err := m.session.Client().Tx().BroadcastMsgs(ctx, []sdk.Msg{msg}, aclient.WithResultCodeAsError()) + res, err := m.session.Client().Tx().BroadcastMsgs(ctx, []sdk.Msg{msg}, aclient.WithResultCodeAsError(), aclient.WithPriority()) if err != nil { m.log.Error("closing deployment", "err", err) } else { diff --git a/cmd/provider-services/cmd/flags.go b/cmd/provider-services/cmd/flags.go index 8fff9a03..57a58757 100644 --- a/cmd/provider-services/cmd/flags.go +++ b/cmd/provider-services/cmd/flags.go @@ -1,6 +1,7 @@ package cmd import ( + "fmt" "time" "github.com/go-acme/lego/v4/lego" @@ -196,6 +197,11 @@ func addRunFlags(cmd *cobra.Command) error { return err } + cmd.Flags().Int(FlagWithdrawalBatchMaxMsgs, 50, fmt.Sprintf("max number of MsgWithdrawLease messages coalesced into a single broadcast. valid range [%d, %d]", withdrawalBatchMaxMsgsMin, withdrawalBatchMaxMsgsMax)) + if err := viper.BindPFlag(FlagWithdrawalBatchMaxMsgs, cmd.Flags().Lookup(FlagWithdrawalBatchMaxMsgs)); err != nil { + return err + } + cmd.Flags().Duration(FlagLeaseFundsMonitorInterval, time.Minute*10, "interval at which lease is checked for funds available on the escrow accounts. >= 1m") if err := viper.BindPFlag(FlagLeaseFundsMonitorInterval, cmd.Flags().Lookup(FlagLeaseFundsMonitorInterval)); err != nil { return err diff --git a/cmd/provider-services/cmd/run.go b/cmd/provider-services/cmd/run.go index 519e2752..2a35988e 100644 --- a/cmd/provider-services/cmd/run.go +++ b/cmd/provider-services/cmd/run.go @@ -96,6 +96,7 @@ const ( FlagManifestTimeout = "manifest-timeout" FlagMetricsListener = "metrics-listener" FlagWithdrawalPeriod = "withdrawal-period" + FlagWithdrawalBatchMaxMsgs = "withdrawal-batch-max-msgs" FlagLeaseFundsMonitorInterval = "lease-funds-monitor-interval" FlagMinimumBalance = "minimum-balance" FlagProviderConfig = "provider-config" @@ -132,8 +133,10 @@ const ( ) const ( - serviceIPOperator = "ip-operator" - serviceHostnameOperator = "hostname-operator" + serviceIPOperator = "ip-operator" + serviceHostnameOperator = "hostname-operator" + withdrawalBatchMaxMsgsMin = 10 + withdrawalBatchMaxMsgsMax = 100 ) var ( @@ -189,6 +192,10 @@ func RunCmd() *cobra.Command { Short: "run akash provider", SilenceUsage: true, PreRunE: func(cmd *cobra.Command, args []string) error { + // Store logger in context before TxPersistentPreRunE so that the + // serialBroadcaster (created during DiscoverClient) picks it up via ctxlog.Logger(ctx). + fromctx.CmdSetContextValue(cmd, fromctx.CtxKeyLogc, log.NewLogger(os.Stderr)) + err := TxPersistentPreRunE(cmd, args) if err != nil { return err @@ -205,6 +212,10 @@ func RunCmd() *cobra.Command { return fmt.Errorf(`flag "%s" value must be > "%s"`, FlagWithdrawalPeriod, FlagLeaseFundsMonitorInterval) // nolint: err113 } + if maxMsgs := viper.GetInt(FlagWithdrawalBatchMaxMsgs); maxMsgs < withdrawalBatchMaxMsgsMin || maxMsgs > withdrawalBatchMaxMsgsMax { + return fmt.Errorf(`flag "%s" contains invalid value %d. expected range [%d, %d]`, FlagWithdrawalBatchMaxMsgs, maxMsgs, withdrawalBatchMaxMsgsMin, withdrawalBatchMaxMsgsMax) // nolint: err113 + } + if viper.GetDuration(FlagMonitorRetryPeriod) < 4*time.Second { return fmt.Errorf(`flag "%s" value must be > "%s"`, FlagMonitorRetryPeriod, 4*time.Second) // nolint: err113 } @@ -265,7 +276,7 @@ func RunCmd() *cobra.Command { return err } - logger := log.NewLogger(os.Stderr) + logger := ctxlog.LogcFromCtx(cmd.Context()) kubeLog := logger.With("component", "k8s") @@ -675,6 +686,7 @@ func doRunCmd(ctx context.Context, cmd *cobra.Command, _ []string) error { config.BalanceCheckerCfg = provider.BalanceCheckerConfig{ WithdrawalPeriod: viper.GetDuration(FlagWithdrawalPeriod), LeaseFundsCheckInterval: viper.GetDuration(FlagLeaseFundsMonitorInterval), + WithdrawalBatchMaxMsgs: viper.GetInt(FlagWithdrawalBatchMaxMsgs), } config.BidPricingStrategy = pricing diff --git a/config.go b/config.go index e2248d27..9c17e5a4 100644 --- a/config.go +++ b/config.go @@ -36,6 +36,7 @@ func NewDefaultConfig() Config { BalanceCheckerCfg: BalanceCheckerConfig{ LeaseFundsCheckInterval: 1 * time.Minute, WithdrawalPeriod: 24 * time.Hour, + WithdrawalBatchMaxMsgs: 50, }, MaxGroupVolumes: constants.DefaultMaxGroupVolumes, Config: cluster.NewDefaultConfig(), diff --git a/manifest/watchdog.go b/manifest/watchdog.go index f961bab5..82a00109 100644 --- a/manifest/watchdog.go +++ b/manifest/watchdog.go @@ -76,7 +76,7 @@ func (wd *watchdog) run() { Reason: mtypes.LeaseClosedReasonManifestTimeout, } - return runner.NewResult(wd.sess.Client().Tx().BroadcastMsgs(wd.ctx, []sdk.Msg{msg}, aclient.WithResultCodeAsError())) + return runner.NewResult(wd.sess.Client().Tx().BroadcastMsgs(wd.ctx, []sdk.Msg{msg}, aclient.WithResultCodeAsError(), aclient.WithPriority())) // bidClose should be priority to release the reserved resources ASAP }) case err = <-wd.lc.ShutdownRequest(): } diff --git a/manifest/watchdog_test.go b/manifest/watchdog_test.go index 34191f18..1e8ed88f 100644 --- a/manifest/watchdog_test.go +++ b/manifest/watchdog_test.go @@ -38,7 +38,7 @@ func makeWatchdogTestScaffold(t *testing.T, timeout time.Duration) (*watchdog, * scaffold.broadcasts = make(chan []sdk.Msg, 1) txClientMock := &clientmocks.TxClient{} - txClientMock.On("BroadcastMsgs", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + txClientMock.On("BroadcastMsgs", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { scaffold.broadcasts <- args.Get(1).([]sdk.Msg) }).Return(&sdk.Result{}, nil) diff --git a/withdraw_batcher.go b/withdraw_batcher.go new file mode 100644 index 00000000..639d2f33 --- /dev/null +++ b/withdraw_batcher.go @@ -0,0 +1,166 @@ +package provider + +import ( + "context" + "fmt" + "slices" + "sync/atomic" + "time" + + "cosmossdk.io/log" + sdk "github.com/cosmos/cosmos-sdk/types" + + aclient "pkg.akt.dev/go/node/client/v1beta3" + mtypes "pkg.akt.dev/go/node/market/v1" + mvbeta "pkg.akt.dev/go/node/market/v1beta5" +) + +// withdrawBatcher coalesces MsgWithdrawLease requests into single multi-msg +// transactions using opportunistic in-flight batching: +// +// - Idle: Flush fires a 1-msg TX immediately. +// - In-flight: subsequent Enqueue calls accumulate in pending. +// - On MarkDone: callers invoke Flush which drains up to maxMsgs from pending. +// +// Not safe for concurrent use. All public methods must be called from a single +// goroutine. Concurrent calls panic to surface developer mistakes early. +type withdrawBatcher struct { + tx aclient.TxClient + log log.Logger + timeout time.Duration + maxMsgs int + + inUse atomic.Bool + + pending []mtypes.LeaseID + inFlight bool + doneCh chan error +} + +func (b *withdrawBatcher) enter() { + if b.inUse.Swap(true) { // Swap returns the previous value + panic("withdrawBatcher: concurrent use detected") + } +} + +func (b *withdrawBatcher) exit() { + b.inUse.Store(false) +} + +func newWithdrawBatcher(tx aclient.TxClient, logger log.Logger, timeout time.Duration, maxMsgs int) *withdrawBatcher { + if maxMsgs < 1 { + panic(fmt.Sprintf("withdrawBatcher: maxMsgs must be >= 1, got %d", maxMsgs)) + } + return &withdrawBatcher{ + tx: tx, + log: logger, + timeout: timeout, + maxMsgs: maxMsgs, + doneCh: make(chan error, 1), + } +} + +// After an in-flight broadcast fails, items coalesced during the in-flight +// window remain in pending (run-loop skips re-flush on error for natural +// backoff). If the same lease re-triggers before pending drains, Enqueue must +// dedupe so the next batch doesn't carry a duplicate MsgWithdrawLease, which +// would risk failing the entire atomic tx on the second message. +func (b *withdrawBatcher) Enqueue(lid mtypes.LeaseID) { + b.enter() + defer b.exit() + if slices.Contains(b.pending, lid) { + b.log.Debug("batcher: enqueue dedup", "lease", lid, "pending", len(b.pending), "inFlight", b.inFlight) + return + } + b.pending = append(b.pending, lid) + b.log.Debug("batcher: enqueue", "lease", lid, "pending", len(b.pending), "inFlight", b.inFlight) +} + +// Remove drops a lease id from the pending batch. +// Does not affect an in-flight broadcast. +func (b *withdrawBatcher) Remove(lid mtypes.LeaseID) { + b.enter() + defer b.exit() + b.pending = slices.DeleteFunc(b.pending, func(p mtypes.LeaseID) bool { + return p == lid + }) +} + +// InFlight reports whether a broadcast is currently running. +func (b *withdrawBatcher) InFlight() bool { + b.enter() + defer b.exit() + return b.inFlight +} + +// Pending reports the number of queued lease ids not yet broadcast. +func (b *withdrawBatcher) Pending() int { + b.enter() + defer b.exit() + return len(b.pending) +} + +// Flush starts a broadcast with up to maxMsgs pending lease ids when idle. +// Returns true if a broadcast was started, false if nothing to do or already in-flight. +func (b *withdrawBatcher) Flush(ctx context.Context) bool { + b.enter() + defer b.exit() + if b.inFlight { + b.log.Debug("batcher: flush skipped (in-flight)", "pending", len(b.pending)) + return false + } + if len(b.pending) == 0 { + return false + } + + n := min(len(b.pending), b.maxMsgs) + + batch := make([]mtypes.LeaseID, n) + copy(batch, b.pending[:n]) + b.pending = b.pending[n:] + b.inFlight = true + + b.log.Info("batcher: flush", "batch", n, "remaining", len(b.pending), "maxMsgs", b.maxMsgs) + + go func() { + start := time.Now() + err := b.broadcast(ctx, batch) + b.log.Info("batcher: broadcast done", "batch", n, "duration", time.Since(start), "err", err) + select { + case <-ctx.Done(): + case b.doneCh <- err: + } + }() + + return true +} + +// Done returns a channel that delivers the broadcast result of each completed batch. +// Callers must invoke MarkDone after reading to unblock the next Flush. +func (b *withdrawBatcher) Done() <-chan error { + return b.doneCh +} + +// MarkDone clears the in-flight flag. Must be called after reading Done(). +func (b *withdrawBatcher) MarkDone() { + b.enter() + defer b.exit() + b.inFlight = false +} + +func (b *withdrawBatcher) broadcast(ctx context.Context, lids []mtypes.LeaseID) error { + if len(lids) == 0 { + return nil + } + + ctx, cancel := context.WithTimeout(ctx, b.timeout) + defer cancel() + + msgs := make([]sdk.Msg, 0, len(lids)) + for _, lid := range lids { + msgs = append(msgs, &mvbeta.MsgWithdrawLease{ID: lid}) + } + + _, err := b.tx.BroadcastMsgs(ctx, msgs, aclient.WithResultCodeAsError()) + return err +} diff --git a/withdraw_batcher_test.go b/withdraw_batcher_test.go new file mode 100644 index 00000000..c62d92f1 --- /dev/null +++ b/withdraw_batcher_test.go @@ -0,0 +1,317 @@ +package provider + +import ( + "context" + "errors" + "io" + "testing" + "time" + + "cosmossdk.io/log" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + sdk "github.com/cosmos/cosmos-sdk/types" + + clientmocks "pkg.akt.dev/go/mocks/node/client" + mtypes "pkg.akt.dev/go/node/market/v1" + mvbeta "pkg.akt.dev/go/node/market/v1beta5" + "pkg.akt.dev/go/testutil" +) + +const twoSecondsTestTimeout = 2 * time.Second + +func testLogger() log.Logger { return log.NewLogger(io.Discard) } + +// batcherFixture wires a withdrawBatcher to a mock TxClient whose broadcast +// can be blocked until released, enabling deterministic in-flight tests. +type batcherFixture struct { + t *testing.T + tx *clientmocks.TxClient + batcher *withdrawBatcher + captured chan []sdk.Msg + release chan struct{} +} + +func newBatcherFixture(t *testing.T, maxMsgs int, block bool) *batcherFixture { + t.Helper() + + f := &batcherFixture{ + t: t, + tx: &clientmocks.TxClient{}, + captured: make(chan []sdk.Msg, 16), + release: make(chan struct{}), + } + + f.tx.On("BroadcastMsgs", mock.Anything, mock.Anything, mock.Anything). + Run(func(args mock.Arguments) { + msgs := args.Get(1).([]sdk.Msg) + f.captured <- msgs + if block { + <-f.release + } + }). + Return(&sdk.TxResponse{}, nil) + + f.batcher = newWithdrawBatcher(f.tx, testLogger(), time.Second, maxMsgs) + + return f +} + +func (f *batcherFixture) waitCaptured() []sdk.Msg { + f.t.Helper() + select { + case msgs := <-f.captured: + return msgs + case <-time.After(twoSecondsTestTimeout): + f.t.Fatal("timed out waiting for broadcast") + return nil + } +} + +func (f *batcherFixture) waitDone() error { + f.t.Helper() + select { + case err := <-f.batcher.Done(): + return err + case <-time.After(twoSecondsTestTimeout): + f.t.Fatal("timed out waiting for batch completion") + return nil + } +} + +func lidsOf(t *testing.T, n int) []mtypes.LeaseID { + t.Helper() + out := make([]mtypes.LeaseID, n) + for i := 0; i < n; i++ { + out[i] = testutil.LeaseID(t) + } + return out +} + +// happy path: enqueue one withdrawal and flush it immediately +func TestWithdrawBatcher_IdleFlushImmediate(t *testing.T) { + f := newBatcherFixture(t, 50, false) + lid := testutil.LeaseID(t) + + f.batcher.Enqueue(lid) + require.True(t, f.batcher.Flush(context.Background())) + + msgs := f.waitCaptured() + require.Len(t, msgs, 1) + require.Equal(t, lid, msgs[0].(*mvbeta.MsgWithdrawLease).ID) + + require.NoError(t, f.waitDone()) + f.batcher.MarkDone() + require.False(t, f.batcher.Flush(context.Background())) +} + +// send the first withdrawal tx to force the batcher start batching the remaining withdrawals +func TestWithdrawBatcher_BurstCoalescesWhileInFlight(t *testing.T) { + f := newBatcherFixture(t, 50, true) + lids := lidsOf(t, 5) + + f.batcher.Enqueue(lids[0]) + require.True(t, f.batcher.Flush(context.Background())) + + first := f.waitCaptured() + require.Len(t, first, 1) + + for _, l := range lids[1:] { + f.batcher.Enqueue(l) + } + require.False(t, f.batcher.Flush(context.Background())) + require.Equal(t, 4, f.batcher.Pending()) + + close(f.release) + require.NoError(t, f.waitDone()) + f.batcher.MarkDone() + require.True(t, f.batcher.Flush(context.Background())) + + second := f.waitCaptured() + require.Len(t, second, 4) + for i, m := range second { + require.Equal(t, lids[i+1], m.(*mvbeta.MsgWithdrawLease).ID) + } +} + +// set the limit to 2 and send 5 withdrawals, the batcher should batch the first 2, second 2, and the last 1 +func TestWithdrawBatcher_RespectsMaxMsgs(t *testing.T) { + f := newBatcherFixture(t, 2, false) + lids := lidsOf(t, 5) + + for _, l := range lids { + f.batcher.Enqueue(l) + } + + require.True(t, f.batcher.Flush(context.Background())) + first := f.waitCaptured() + require.Len(t, first, 2) + require.NoError(t, f.waitDone()) + f.batcher.MarkDone() + + require.True(t, f.batcher.Flush(context.Background())) + second := f.waitCaptured() + require.Len(t, second, 2) + require.NoError(t, f.waitDone()) + f.batcher.MarkDone() + + require.True(t, f.batcher.Flush(context.Background())) + third := f.waitCaptured() + require.Len(t, third, 1) + require.NoError(t, f.waitDone()) + f.batcher.MarkDone() + + require.False(t, f.batcher.Flush(context.Background())) +} + +// enqueue several withdrawals and remove one. +// on flush, the batcher should batch the remaining withdrawals and the removed one should be ignored +func TestWithdrawBatcher_RemoveFiltersPending(t *testing.T) { + f := newBatcherFixture(t, 50, false) + lids := lidsOf(t, 3) + + for _, l := range lids { + f.batcher.Enqueue(l) + } + f.batcher.Remove(lids[1]) + require.Equal(t, 2, f.batcher.Pending()) + + require.True(t, f.batcher.Flush(context.Background())) + msgs := f.waitCaptured() + require.Len(t, msgs, 2) + require.Equal(t, lids[0], msgs[0].(*mvbeta.MsgWithdrawLease).ID) + require.Equal(t, lids[2], msgs[1].(*mvbeta.MsgWithdrawLease).ID) +} + +// Remove on an in-flight id is a no-op; Remove on a still-pending id drops it +// from the next batch. +func TestWithdrawBatcher_RemoveAfterFlushDoesNotAffectBatch(t *testing.T) { + f := newBatcherFixture(t, 2, true) + lids := lidsOf(t, 3) + + for _, l := range lids { + f.batcher.Enqueue(l) + } + require.True(t, f.batcher.Flush(context.Background())) + + first := f.waitCaptured() + require.Len(t, first, 2) + require.Equal(t, lids[0], first[0].(*mvbeta.MsgWithdrawLease).ID) + require.Equal(t, lids[1], first[1].(*mvbeta.MsgWithdrawLease).ID) + + f.batcher.Remove(lids[0]) + f.batcher.Remove(lids[2]) + + close(f.release) + require.NoError(t, f.waitDone()) + f.batcher.MarkDone() + + require.False(t, f.batcher.Flush(context.Background())) + require.Equal(t, 0, f.batcher.Pending()) +} + +// Re-enqueueing a still-pending lease is a no-op, so a lease that fires its +// timer again while a prior trigger is still queued does not duplicate the msg. +func TestWithdrawBatcher_EnqueueDedupes(t *testing.T) { + f := newBatcherFixture(t, 50, false) + lid := testutil.LeaseID(t) + + f.batcher.Enqueue(lid) + f.batcher.Enqueue(lid) + f.batcher.Enqueue(lid) + require.Equal(t, 1, f.batcher.Pending()) + + require.True(t, f.batcher.Flush(context.Background())) + msgs := f.waitCaptured() + require.Len(t, msgs, 1) +} + +func TestWithdrawBatcher_EnqueueDedupesAfterFailure(t *testing.T) { + wantErr := errors.New("broadcast failed") + + tx := &clientmocks.TxClient{} + captured := make(chan []sdk.Msg, 2) + release := make(chan struct{}) + + tx.On("BroadcastMsgs", mock.Anything, mock.Anything, mock.Anything). + Run(func(args mock.Arguments) { + captured <- args.Get(1).([]sdk.Msg) + <-release + }). + Return(nil, wantErr).Once() + + tx.On("BroadcastMsgs", mock.Anything, mock.Anything, mock.Anything). + Run(func(args mock.Arguments) { + captured <- args.Get(1).([]sdk.Msg) + }). + Return(&sdk.TxResponse{}, nil).Once() + + b := newWithdrawBatcher(tx, testLogger(), time.Second, 50) + lidA := testutil.LeaseID(t) + lidB := testutil.LeaseID(t) + + b.Enqueue(lidA) + require.True(t, b.Flush(context.Background())) + + first := <-captured + require.Len(t, first, 1) + require.Equal(t, lidA, first[0].(*mvbeta.MsgWithdrawLease).ID) + + b.Enqueue(lidB) + require.Equal(t, 1, b.Pending()) + + close(release) + require.ErrorIs(t, <-b.Done(), wantErr) + b.MarkDone() + + b.Enqueue(lidB) + require.Equal(t, 1, b.Pending(), "dedup must prevent stale pending entry from duplicating") + + require.True(t, b.Flush(context.Background())) + second := <-captured + require.Len(t, second, 1) + require.Equal(t, lidB, second[0].(*mvbeta.MsgWithdrawLease).ID) + + require.NoError(t, <-b.Done()) +} + +func TestWithdrawBatcher_ErrorPropagatesToDone(t *testing.T) { + wantErr := errors.New("broadcast failed") + + tx := &clientmocks.TxClient{} + tx.On("BroadcastMsgs", mock.Anything, mock.Anything, mock.Anything). + Return(nil, wantErr) + + b := newWithdrawBatcher(tx, testLogger(), time.Second, 50) + b.Enqueue(testutil.LeaseID(t)) + require.True(t, b.Flush(context.Background())) + + err := <-b.Done() + require.ErrorIs(t, err, wantErr) +} + +func TestWithdrawBatcher_FlushEmptyNoop(t *testing.T) { + f := newBatcherFixture(t, 50, false) + require.False(t, f.batcher.Flush(context.Background())) + f.tx.AssertNotCalled(t, "BroadcastMsgs") +} + +func TestWithdrawBatcher_TimeoutAppliedPerCall(t *testing.T) { + tx := &clientmocks.TxClient{} + tx.On("BroadcastMsgs", mock.Anything, mock.Anything, mock.Anything). + Run(func(args mock.Arguments) { + ctx := args.Get(0).(context.Context) + deadline, ok := ctx.Deadline() + require.True(t, ok) + require.WithinDuration(t, time.Now().Add(50*time.Millisecond), deadline, 50*time.Millisecond) + <-ctx.Done() + }). + Return(nil, context.DeadlineExceeded) + + b := newWithdrawBatcher(tx, testLogger(), 50*time.Millisecond, 50) + b.Enqueue(testutil.LeaseID(t)) + require.True(t, b.Flush(context.Background())) + err := <-b.Done() + require.ErrorIs(t, err, context.DeadlineExceeded) +}