diff --git a/broker/Makefile b/broker/Makefile index 2d2bb792..59fc967d 100644 --- a/broker/Makefile +++ b/broker/Makefile @@ -30,7 +30,8 @@ SQL_GEN_OUT_ILL_DB = ill_db/ill_db_gen.go ill_db/ill_models_gen.go ill_db/ill_qu SQL_GEN_OUT_EVENT = events/event_db_gen.go events/event_models_gen.go events/event_query.sql_gen.go SQL_GEN_OUT_PR = patron_request/db/pr_db_gen.go patron_request/db/pr_models_gen.go patron_request/db/pr_query.sql_gen.go SQL_GEN_OUT_PS = pullslip/db/ps_db_gen.go pullslip/db/ps_models_gen.go pullslip/db/ps_query.sql_gen.go -SQL_GEN_OUT = $(SQL_GEN_OUT_ILL_DB) $(SQL_GEN_OUT_EVENT) $(SQL_GEN_OUT_PR) $(SQL_GEN_OUT_PS) +SQL_GEN_OUT_SCHED = scheduler/db/sched_db_gen.go scheduler/db/sched_models_gen.go scheduler/db/sched_query.sql_gen.go +SQL_GEN_OUT = $(SQL_GEN_OUT_ILL_DB) $(SQL_GEN_OUT_EVENT) $(SQL_GEN_OUT_PR) $(SQL_GEN_OUT_PS) $(SQL_GEN_OUT_SCHED) SQL_GEN_IN = sqlc/*.sql # OpenAPI diff --git a/broker/README.md b/broker/README.md index 16d954bc..eabb70b2 100644 --- a/broker/README.md +++ b/broker/README.md @@ -119,6 +119,7 @@ Configuration is provided via environment variables: | | the `{tenant}` token is replaced by the `X-Okapi-Tenant` header value | | | `SUPPLIER_PATRON_PATTERN` | Pattern used to create patron ID when receiving Request on supplier side | `%v_user` | | `LANGUAGE` | Language parameter used for ts_vector search in DB | `english` | +| `SCHEDULER_RETRY_DELAY` | Delay for rescheduling failed scheduled task | `5m` | # Build diff --git a/broker/app/app.go b/broker/app/app.go index b454e4d8..f54ccd7d 100644 --- a/broker/app/app.go +++ b/broker/app/app.go @@ -21,6 +21,8 @@ import ( psapi "github.com/indexdata/crosslink/broker/pullslip/api" ps_db "github.com/indexdata/crosslink/broker/pullslip/db" psoapi "github.com/indexdata/crosslink/broker/pullslip/oapi" + sched_db "github.com/indexdata/crosslink/broker/scheduler/db" + sched_service "github.com/indexdata/crosslink/broker/scheduler/service" "github.com/indexdata/crosslink/broker/tenant" "github.com/dustin/go-humanize" @@ -196,6 +198,12 @@ func Init(ctx context.Context) (Context, error) { if err != nil { return Context{}, err } + + skdRepo := sched_db.CreateSkdRepo(pool) + if err = StartScheduler(ctx, skdRepo, eventBus); err != nil { + return Context{}, err + } + return Context{ EventBus: eventBus, IllRepo: illRepo, @@ -357,6 +365,18 @@ func StartEventBus(ctx context.Context, eventBus events.EventBus) error { return nil } +// StartScheduler creates the scheduler service, begins listening on +// sched_db.SchedulerChannel, and launches the scheduling loop in a background goroutine. +func StartScheduler(ctx context.Context, skdRepo sched_db.SchedRepo, eventBus events.EventBus) error { + extCtx := common.CreateExtCtxWithArgs(ctx, nil) + svc := sched_service.NewSchedulerService(skdRepo, eventBus, ConnectionString) + if err := svc.Listen(extCtx); err != nil { + return fmt.Errorf("starting scheduler listener failed: %w", err) + } + go svc.Run(extCtx) + return nil +} + func HandleHealthz(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK")) } diff --git a/broker/events/eventmodels.go b/broker/events/eventmodels.go index 9dcf274f..0c037e27 100644 --- a/broker/events/eventmodels.go +++ b/broker/events/eventmodels.go @@ -28,6 +28,7 @@ type EventDomain string const ( EventDomainPatronRequest EventDomain = "PATRON_REQUEST" EventDomainIllTransaction EventDomain = "ILL_TRANSACTION" + EventDomainScheduler EventDomain = "SCHEDULER" ) type EventName string diff --git a/broker/go.mod b/broker/go.mod index 7f0728ed..1bfe0b2a 100644 --- a/broker/go.mod +++ b/broker/go.mod @@ -39,6 +39,7 @@ require ( github.com/jackc/pgx/v5 v5.9.2 github.com/lib/pq v1.12.3 github.com/oapi-codegen/runtime v1.4.0 + github.com/robfig/cron/v3 v3.0.1 github.com/stretchr/testify v1.11.1 github.com/testcontainers/testcontainers-go v0.42.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.42.0 diff --git a/broker/go.sum b/broker/go.sum index 65a349fd..ccbc91c1 100644 --- a/broker/go.sum +++ b/broker/go.sum @@ -225,6 +225,8 @@ github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/riza-io/grpc-go v0.2.0 h1:2HxQKFVE7VuYstcJ8zqpN84VnAoJ4dCL6YFhJewNcHQ= github.com/riza-io/grpc-go v0.2.0/go.mod h1:2bDvR9KkKC3KhtlSHfR3dAXjUMT86kg4UfWFyVGWqi8= +github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= diff --git a/broker/migrations/038_add_scheduler.down.sql b/broker/migrations/038_add_scheduler.down.sql new file mode 100644 index 00000000..491868b2 --- /dev/null +++ b/broker/migrations/038_add_scheduler.down.sql @@ -0,0 +1,2 @@ +DROP INDEX IF EXISTS idx_scheduled_task_run_at; +DROP TABLE IF EXISTS scheduled_task; diff --git a/broker/migrations/038_add_scheduler.up.sql b/broker/migrations/038_add_scheduler.up.sql new file mode 100644 index 00000000..a24d8dba --- /dev/null +++ b/broker/migrations/038_add_scheduler.up.sql @@ -0,0 +1,14 @@ +CREATE TABLE scheduled_task +( + id TEXT PRIMARY KEY, + event_name TEXT NOT NULL, + cron_expr TEXT NOT NULL, + payload JSONB, + run_at TIMESTAMPTZ, + status TEXT NOT NULL DEFAULT 'pending', + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ, + FOREIGN KEY (event_name) REFERENCES event_config (event_name) +); + +CREATE INDEX idx_scheduled_task_run_at ON scheduled_task (run_at) WHERE status = 'pending' AND run_at IS NOT NULL; diff --git a/broker/scheduler/db/models.go b/broker/scheduler/db/models.go new file mode 100644 index 00000000..46efd103 --- /dev/null +++ b/broker/scheduler/db/models.go @@ -0,0 +1,9 @@ +package sched_db + +type ScheduledTaskStatus string + +const ( + ScheduledTaskStatusPending ScheduledTaskStatus = "pending" + ScheduledTaskStatusRunning ScheduledTaskStatus = "running" + ScheduledTaskStatusStopped ScheduledTaskStatus = "stopped" +) diff --git a/broker/scheduler/db/repo.go b/broker/scheduler/db/repo.go new file mode 100644 index 00000000..c4397dfd --- /dev/null +++ b/broker/scheduler/db/repo.go @@ -0,0 +1,86 @@ +package sched_db + +import ( + "time" + + "github.com/indexdata/crosslink/broker/common" + "github.com/indexdata/crosslink/broker/repo" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" +) + +const SchedulerChannel = "crosslink_sched_channel" + +type SchedRepo interface { + repo.Transactional[SchedRepo] + SaveScheduledTask(ctx common.ExtendedContext, params SaveScheduledTaskParams) (ScheduledTask, error) + ClaimNextScheduledTask(ctx common.ExtendedContext) (ScheduledTask, error) + GetNextRunAt(ctx common.ExtendedContext) (pgtype.Timestamptz, error) + GetStuckRunningTasks(ctx common.ExtendedContext, stuckAfter time.Duration) ([]ScheduledTask, error) +} + +type PgSchedRepo struct { + repo.PgBaseRepo[SchedRepo] + queries Queries +} + +// WithTxFunc delegates transaction handling to PgBaseRepo. +func (r *PgSchedRepo) WithTxFunc(ctx common.ExtendedContext, fn func(SchedRepo) error) error { + return r.PgBaseRepo.WithTxFunc(ctx, r, fn) +} + +// CreateWithPgBaseRepo creates a derived repo bound to the provided tx-aware base. +func (r *PgSchedRepo) CreateWithPgBaseRepo(base *repo.PgBaseRepo[SchedRepo]) SchedRepo { + derived := new(PgSchedRepo) + derived.PgBaseRepo = *base + return derived +} + +// CreateSkdRepo creates a new SchedRepo backed by the given connection pool. +func CreateSkdRepo(dbPool *pgxpool.Pool) SchedRepo { + r := new(PgSchedRepo) + r.Pool = dbPool + return r +} + +func (r *PgSchedRepo) SaveScheduledTask(ctx common.ExtendedContext, params SaveScheduledTaskParams) (ScheduledTask, error) { + row, err := r.queries.SaveScheduledTask(ctx, r.GetConnOrTx(), params) + if err == nil { + r.notify(ctx) + } + return row.ScheduledTask, err +} + +func (r *PgSchedRepo) ClaimNextScheduledTask(ctx common.ExtendedContext) (ScheduledTask, error) { + row, err := r.queries.ClaimNextScheduledTask(ctx, r.GetConnOrTx()) + return row.ScheduledTask, err +} + +func (r *PgSchedRepo) GetNextRunAt(ctx common.ExtendedContext) (pgtype.Timestamptz, error) { + return r.queries.GetNextRunAt(ctx, r.GetConnOrTx()) +} + +// GetStuckRunningTasks returns tasks that have been in 'running' state for +// longer than stuckAfter, indicating they were claimed but never completed. +func (r *PgSchedRepo) GetStuckRunningTasks(ctx common.ExtendedContext, stuckAfter time.Duration) ([]ScheduledTask, error) { + rows, err := r.queries.GetStuckRunningTasks(ctx, r.GetConnOrTx(), pgDuration(stuckAfter)) + if err != nil { + return nil, err + } + tasks := make([]ScheduledTask, 0, len(rows)) + for _, row := range rows { + tasks = append(tasks, row.ScheduledTask) + } + return tasks, nil +} + +func pgDuration(d time.Duration) pgtype.Interval { + return pgtype.Interval{Microseconds: d.Microseconds(), Valid: true} +} + +func (r *PgSchedRepo) notify(ctx common.ExtendedContext) { + _, err := r.GetConnOrTx().Exec(ctx, "NOTIFY "+SchedulerChannel) + if err != nil { + ctx.Logger().Error("failed to notify scheduler channel", "channel", SchedulerChannel, "error", err) + } +} diff --git a/broker/scheduler/service/scheduler.go b/broker/scheduler/service/scheduler.go new file mode 100644 index 00000000..2441e6d4 --- /dev/null +++ b/broker/scheduler/service/scheduler.go @@ -0,0 +1,290 @@ +package sched_service + +import ( + "errors" + "fmt" + "strings" + "time" + + "github.com/indexdata/crosslink/broker/common" + "github.com/indexdata/crosslink/broker/events" + sched_db "github.com/indexdata/crosslink/broker/scheduler/db" + "github.com/indexdata/go-utils/utils" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/robfig/cron/v3" +) + +var SCHEDULER_RETRY_DELAY, _ = utils.GetEnvAny("SCHEDULER_RETRY_DELAY", time.Duration(5*float64(time.Minute)), func(val string) (time.Duration, error) { + d, err := time.ParseDuration(val) + if err != nil { + return 0, fmt.Errorf("invalid SCHEDULER_RETRY_DELAY value: %s", val) + } + return d, nil +}) + +type SchedulerService struct { + skdRepo sched_db.SchedRepo + eventBus events.EventBus + connString string + // notifyCh is written by Listen and read by schedulerLoop via waitUntil. + notifyCh chan struct{} + notify <-chan struct{} +} + +// NewSchedulerService creates a SchedulerService wired to the given repo, +// event bus, and Postgres connection string (used for the LISTEN connection). +func NewSchedulerService(skdRepo sched_db.SchedRepo, eventBus events.EventBus, connString string) *SchedulerService { + ch := make(chan struct{}, 1) + return &SchedulerService{ + skdRepo: skdRepo, + eventBus: eventBus, + connString: connString, + notifyCh: ch, + notify: ch, + } +} + +// Listen opens a dedicated Postgres connection and listens on sched_db.SchedulerChannel. +// Each notification wakes the scheduler loop. Reconnects with exponential +// backoff on connection loss. Blocks until ctx is cancelled. +func (s *SchedulerService) Listen(ctx common.ExtendedContext) error { + // openConn establishes a fresh connection and registers the LISTEN. + // The caller is responsible for closing the returned connection. + openConn := func() (*pgx.Conn, error) { + c, err := pgx.Connect(ctx, s.connString) + if err != nil { + ctx.Logger().Error("scheduler: unable to connect to database", "error", err) + return nil, err + } + if _, err = c.Exec(ctx, "LISTEN "+sched_db.SchedulerChannel); err != nil { + ctx.Logger().Error("scheduler: unable to listen to channel", "channel", sched_db.SchedulerChannel, "error", err) + _ = c.Close(ctx) + return nil, err + } + ctx.Logger().Info("scheduler: listening on channel", "channel", sched_db.SchedulerChannel) + return c, nil + } + + // Verify we can connect before spawning the goroutine. + conn, err := openConn() + if err != nil { + return err + } + + go func() { + // conn is fully local to this goroutine; always close before returning. + defer func() { _ = conn.Close(ctx) }() + + for { + _, er := conn.WaitForNotification(ctx) + if er != nil { + if strings.Contains(er.Error(), "context canceled") { + ctx.Logger().Info("scheduler: context cancelled, stopping listener") + return + } + + ctx.Logger().Warn("scheduler: notification error, reconnecting", "error", er) + + // Close the broken connection before attempting to reconnect + // so we don't leak the old socket or its LISTEN registration. + _ = conn.Close(ctx) + conn = nil + + baseDelay := 1 * time.Second + maxDelay := 30 * time.Second + delay := baseDelay + + for { + select { + case <-ctx.Done(): + return + case <-time.After(delay): + } + newConn, connErr := openConn() + if connErr == nil { + conn = newConn + break + } + ctx.Logger().Error("scheduler: reconnect failed", "error", connErr, "next_try_in", delay) + delay = time.Duration(float64(delay) * 1.5) + if delay > maxDelay { + delay = maxDelay + } + } + continue + } + select { + case s.notifyCh <- struct{}{}: + default: + } + } + }() + + return nil +} + +// Run starts the scheduler loop, blocking until ctx is cancelled. +// Call Listen before Run to enable early wake-up via Postgres notifications. +func (s *SchedulerService) Run(ctx common.ExtendedContext) { + s.schedulerLoop(ctx) +} + +func (s *SchedulerService) schedulerLoop(ctx common.ExtendedContext) { + for { + s.rescheduleLongRunningTasks(ctx) + madeProgress := s.runDueTasks(ctx) + + nextRunAt := s.getNextRunAt(ctx) + if !waitUntil(ctx, nextRunAt, s.notify, SCHEDULER_RETRY_DELAY, madeProgress) { + return // context cancelled + } + } +} + +// runDueTasks processes all currently claimable tasks. +// Returns true if at least one task was successfully claimed and dispatched. +func (s *SchedulerService) runDueTasks(ctx common.ExtendedContext) bool { + madeProgress := false + for { + task, err := s.skdRepo.ClaimNextScheduledTask(ctx) + if err != nil { + if !errors.Is(err, pgx.ErrNoRows) { + ctx.Logger().Error("failed to claim next scheduled task", "error", err) + } + return madeProgress + } + madeProgress = true + + _, err = s.eventBus.CreateTask(events.DEFAULT_ILL_TRANSACTION_ID, task.EventName, task.Payload, events.EventDomainScheduler, nil, events.SignalConsumers) + + if err != nil { + task.RunAt = pgtype.Timestamptz{Time: time.Now().Add(SCHEDULER_RETRY_DELAY), Valid: true} + s.unlockAndReschedule(ctx, task) + continue + } + + if task.CronExpr != "" { + next, err := nextCronTime(task.CronExpr) + if err != nil { + ctx.Logger().Error("invalid cron expression, disabling task", "error", err, "taskId", task.ID) + s.disableTask(ctx, task) + continue + } + task.RunAt = next + s.unlockAndReschedule(ctx, task) + } else { + s.disableTask(ctx, task) + } + } +} + +func (s *SchedulerService) disableTask(ctx common.ExtendedContext, task sched_db.ScheduledTask) { + task.Status = sched_db.ScheduledTaskStatusStopped + task.RunAt = pgtype.Timestamptz{Valid: false} + _, err := s.skdRepo.SaveScheduledTask(ctx, sched_db.SaveScheduledTaskParams(task)) + if err != nil { + ctx.Logger().Error("failed to update scheduled task", "error", err, "taskId", task.ID) + } +} + +func (s *SchedulerService) unlockAndReschedule(ctx common.ExtendedContext, task sched_db.ScheduledTask) { + task.Status = sched_db.ScheduledTaskStatusPending + _, err := s.skdRepo.SaveScheduledTask(ctx, sched_db.SaveScheduledTaskParams(task)) + if err != nil { + ctx.Logger().Error("failed to reschedule scheduled task", "error", err, "taskId", task.ID) + } +} + +// getNextRunAt returns the run_at timestamp of the earliest pending scheduled +// task, or a zero Timestamptz if no pending tasks exist. +func (s *SchedulerService) getNextRunAt(ctx common.ExtendedContext) pgtype.Timestamptz { + next, err := s.skdRepo.GetNextRunAt(ctx) + if err != nil { + if !errors.Is(err, pgx.ErrNoRows) { + ctx.Logger().Error("failed to get next run", "error", err) + } + // No pending tasks or query error — return invalid (zero) value. + return pgtype.Timestamptz{} + } + return next +} + +// waitUntil blocks until one of three conditions is met: +// 1. nextRunAt is reached (next scheduled task is due). An overdue nextRunAt +// only causes an immediate return when madeProgress is true — i.e. the +// previous runDueTasks call actually claimed a task. This prevents a tight +// spin loop when ClaimNextScheduledTask keeps returning a persistent error +// while GetNextRunAt still reports an overdue timestamp. +// 2. a signal arrives on notifyChanged (schedule table was modified) +// 3. the fallback duration elapses (safety poll) +// +// Returns false if the context was cancelled (caller should stop the loop). +func waitUntil(ctx common.ExtendedContext, nextRunAt pgtype.Timestamptz, notifyChanged <-chan struct{}, fallback time.Duration, madeProgress bool) bool { + sleepDuration := fallback + if nextRunAt.Valid { + until := time.Until(nextRunAt.Time) + if until <= 0 && madeProgress { + // Overdue and we just successfully processed tasks — more may be ready. + return true + } else if until > 0 && until < fallback { + sleepDuration = until + } + // If overdue but no progress was made (persistent claim errors), fall + // through to sleep the full fallback to avoid a CPU-spinning tight loop. + } + + timer := time.NewTimer(sleepDuration) + defer timer.Stop() + + select { + case <-ctx.Done(): + return false + case <-timer.C: + return true + case <-notifyChanged: + return true + } +} + +// nextCronTime parses a standard 5-field cron expression and returns the next +// scheduled execution time after now as a pgtype.Timestamptz. +// Returns an error if the expression is invalid. +func nextCronTime(cronExpr string) (pgtype.Timestamptz, error) { + parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow) + schedule, err := parser.Parse(cronExpr) + if err != nil { + return pgtype.Timestamptz{}, fmt.Errorf("invalid cron expression %q: %w", cronExpr, err) + } + next := schedule.Next(time.Now()) + return pgtype.Timestamptz{ + Time: next, + Valid: true, + }, nil +} + +// rescheduleLongRunningTasks finds tasks that have been in 'running' state for +// longer than hour (indicating a crashed or lost worker) and +// resets them to 'pending' so they are picked up again on the next loop tick. +func (s *SchedulerService) rescheduleLongRunningTasks(ctx common.ExtendedContext) { + tasks, err := s.skdRepo.GetStuckRunningTasks(ctx, time.Hour) + if err != nil { + ctx.Logger().Error("failed to query stuck running tasks", "error", err) + return + } + for _, task := range tasks { + ctx.Logger().Info("rescheduling stuck task", "taskId", task.ID, "eventName", task.EventName) + if task.CronExpr != "" { + next, err := nextCronTime(task.CronExpr) + if err != nil { + ctx.Logger().Error("invalid cron expression, disabling task", "error", err, "taskId", task.ID) + s.disableTask(ctx, task) + continue + } + task.RunAt = next + } else { + task.RunAt = pgtype.Timestamptz{Time: time.Now().Add(SCHEDULER_RETRY_DELAY), Valid: true} + } + s.unlockAndReschedule(ctx, task) + } +} diff --git a/broker/scheduler/service/scheduler_test.go b/broker/scheduler/service/scheduler_test.go new file mode 100644 index 00000000..3e7a391f --- /dev/null +++ b/broker/scheduler/service/scheduler_test.go @@ -0,0 +1,452 @@ +package sched_service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/indexdata/crosslink/broker/common" + "github.com/indexdata/crosslink/broker/events" + sched_db "github.com/indexdata/crosslink/broker/scheduler/db" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/assert" +) + +// --------------------------------------------------------------------------- +// Test helpers +// --------------------------------------------------------------------------- + +var testCtx = common.CreateExtCtxWithArgs(context.Background(), nil) + +func tstz(t time.Time) pgtype.Timestamptz { + return pgtype.Timestamptz{Time: t, Valid: true} +} + +func invalidTstz() pgtype.Timestamptz { + return pgtype.Timestamptz{Valid: false} +} + +// --------------------------------------------------------------------------- +// Mock SkdRepo +// --------------------------------------------------------------------------- + +type mockSkdRepo struct { + claimResults []sched_db.ScheduledTask + claimErrors []error + claimIndex int + savedTasks []sched_db.SaveScheduledTaskParams + saveError error + nextRunAt pgtype.Timestamptz + nextRunAtErr error + stuckTasks []sched_db.ScheduledTask + stuckTasksErr error +} + +func (m *mockSkdRepo) WithTxFunc(ctx common.ExtendedContext, fn func(sched_db.SchedRepo) error) error { + return fn(m) +} + +func (m *mockSkdRepo) ClaimNextScheduledTask(_ common.ExtendedContext) (sched_db.ScheduledTask, error) { + if m.claimIndex >= len(m.claimResults) { + return sched_db.ScheduledTask{}, pgx.ErrNoRows + } + task := m.claimResults[m.claimIndex] + var err error + if m.claimIndex < len(m.claimErrors) { + err = m.claimErrors[m.claimIndex] + } + m.claimIndex++ + return task, err +} + +func (m *mockSkdRepo) SaveScheduledTask(_ common.ExtendedContext, p sched_db.SaveScheduledTaskParams) (sched_db.ScheduledTask, error) { + m.savedTasks = append(m.savedTasks, p) + return sched_db.ScheduledTask(p), m.saveError +} + +func (m *mockSkdRepo) GetNextRunAt(_ common.ExtendedContext) (pgtype.Timestamptz, error) { + return m.nextRunAt, m.nextRunAtErr +} + +func (m *mockSkdRepo) GetStuckRunningTasks(_ common.ExtendedContext, _ time.Duration) ([]sched_db.ScheduledTask, error) { + return m.stuckTasks, m.stuckTasksErr +} + +// --------------------------------------------------------------------------- +// Mock EventBus — only CreateTask is exercised by the scheduler +// --------------------------------------------------------------------------- + +type mockEventBus struct { + events.EventBus + createTaskErr error + createdTaskNames []events.EventName +} + +func (m *mockEventBus) CreateTask(_ string, name events.EventName, _ events.EventData, _ events.EventDomain, _ *string, _ events.SignalTarget) (string, error) { + m.createdTaskNames = append(m.createdTaskNames, name) + return "task-id", m.createTaskErr +} + +// --------------------------------------------------------------------------- +// nextCronTime +// --------------------------------------------------------------------------- + +func TestNextCronTime_ValidExpr(t *testing.T) { + ts, err := nextCronTime("* * * * *") // every minute + assert.NoError(t, err) + assert.True(t, ts.Valid) + assert.True(t, ts.Time.After(time.Now())) +} + +func TestNextCronTime_InvalidExpr(t *testing.T) { + _, err := nextCronTime("not-a-cron") + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid cron expression") +} + +func TestNextCronTime_SpecificSchedule(t *testing.T) { + // "0 9 * * 1" = every Monday at 09:00 — just verify it's in the future + ts, err := nextCronTime("0 9 * * 1") + assert.NoError(t, err) + assert.True(t, ts.Valid) + assert.True(t, ts.Time.After(time.Now())) +} + +// --------------------------------------------------------------------------- +// waitUntil +// --------------------------------------------------------------------------- + +func TestWaitUntil_ContextCancelled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + extCtx := common.CreateExtCtxWithArgs(ctx, nil) + cancel() + + result := waitUntil(extCtx, invalidTstz(), make(chan struct{}), 10*time.Second, false) + assert.False(t, result, "should return false when context is cancelled") +} + +func TestWaitUntil_NotifyWakes(t *testing.T) { + extCtx := common.CreateExtCtxWithArgs(context.Background(), nil) + ch := make(chan struct{}, 1) + ch <- struct{}{} // pre-signal + + result := waitUntil(extCtx, invalidTstz(), ch, 10*time.Second, false) + assert.True(t, result) +} + +func TestWaitUntil_FallbackElapsed(t *testing.T) { + extCtx := common.CreateExtCtxWithArgs(context.Background(), nil) + start := time.Now() + result := waitUntil(extCtx, invalidTstz(), make(chan struct{}), 20*time.Millisecond, false) + assert.True(t, result) + assert.WithinDuration(t, start.Add(20*time.Millisecond), time.Now(), 50*time.Millisecond) +} + +func TestWaitUntil_NextRunAtSooner(t *testing.T) { + extCtx := common.CreateExtCtxWithArgs(context.Background(), nil) + nextRunAt := tstz(time.Now().Add(20 * time.Millisecond)) + start := time.Now() + result := waitUntil(extCtx, nextRunAt, make(chan struct{}), 10*time.Second, false) + assert.True(t, result) + assert.WithinDuration(t, start.Add(20*time.Millisecond), time.Now(), 50*time.Millisecond) +} + +// TestWaitUntil_NextRunAtAlreadyOverdue_MadeProgress verifies that an overdue +// timestamp returns immediately when the caller made progress (claimed a task). +func TestWaitUntil_NextRunAtAlreadyOverdue_MadeProgress(t *testing.T) { + extCtx := common.CreateExtCtxWithArgs(context.Background(), nil) + nextRunAt := tstz(time.Now().Add(-1 * time.Second)) + result := waitUntil(extCtx, nextRunAt, make(chan struct{}), 10*time.Second, true) + assert.True(t, result) // returns immediately +} + +// TestWaitUntil_NextRunAtAlreadyOverdue_NoProgress verifies that an overdue +// timestamp does NOT return immediately when no progress was made, to prevent +// a tight spin loop on persistent claim errors. +func TestWaitUntil_NextRunAtAlreadyOverdue_NoProgress(t *testing.T) { + extCtx := common.CreateExtCtxWithArgs(context.Background(), nil) + nextRunAt := tstz(time.Now().Add(-1 * time.Second)) + start := time.Now() + result := waitUntil(extCtx, nextRunAt, make(chan struct{}), 20*time.Millisecond, false) + assert.True(t, result) + // Must have slept the fallback, not returned immediately. + assert.WithinDuration(t, start.Add(20*time.Millisecond), time.Now(), 50*time.Millisecond) +} + +func TestWaitUntil_NextRunAtLongerThanFallback(t *testing.T) { + extCtx := common.CreateExtCtxWithArgs(context.Background(), nil) + nextRunAt := tstz(time.Now().Add(10 * time.Second)) + start := time.Now() + result := waitUntil(extCtx, nextRunAt, make(chan struct{}), 20*time.Millisecond, false) + assert.True(t, result) + assert.WithinDuration(t, start.Add(20*time.Millisecond), time.Now(), 50*time.Millisecond) +} + +// --------------------------------------------------------------------------- +// getNextRunAt +// --------------------------------------------------------------------------- + +func TestGetNextRunAt_ReturnsValue(t *testing.T) { + expected := tstz(time.Now().Add(5 * time.Minute)) + svc := &SchedulerService{skdRepo: &mockSkdRepo{nextRunAt: expected}} + + got := svc.getNextRunAt(testCtx) + assert.Equal(t, expected, got) +} + +func TestGetNextRunAt_ErrorReturnsInvalid(t *testing.T) { + svc := &SchedulerService{skdRepo: &mockSkdRepo{nextRunAtErr: errors.New("no rows")}} + + got := svc.getNextRunAt(testCtx) + assert.False(t, got.Valid) +} + +// --------------------------------------------------------------------------- +// runDueTasks +// --------------------------------------------------------------------------- + +func TestRunDueTasks_NoTasks(t *testing.T) { + repo := &mockSkdRepo{} + bus := &mockEventBus{} + svc := &SchedulerService{skdRepo: repo, eventBus: bus} + + progress := svc.runDueTasks(testCtx) + assert.False(t, progress) + assert.Empty(t, bus.createdTaskNames) + assert.Empty(t, repo.savedTasks) +} + +func TestRunDueTasks_ClaimError_NonNoRows(t *testing.T) { + repo := &mockSkdRepo{ + claimResults: []sched_db.ScheduledTask{{}}, + claimErrors: []error{errors.New("db error")}, + } + svc := &SchedulerService{skdRepo: repo, eventBus: &mockEventBus{}} + + progress := svc.runDueTasks(testCtx) + assert.False(t, progress) + assert.Empty(t, repo.savedTasks) +} + +func TestRunDueTasks_OneShot_DisablesAfterFiring(t *testing.T) { + task := sched_db.ScheduledTask{ID: "t1", EventName: "my-event", CronExpr: "", RunAt: pgtype.Timestamptz{Time: time.Now(), Valid: true}} + repo := &mockSkdRepo{claimResults: []sched_db.ScheduledTask{task}} + bus := &mockEventBus{} + svc := &SchedulerService{skdRepo: repo, eventBus: bus} + + progress := svc.runDueTasks(testCtx) + + assert.True(t, progress) + assert.Equal(t, []events.EventName{"my-event"}, bus.createdTaskNames) + assert.Len(t, repo.savedTasks, 1) + assert.False(t, repo.savedTasks[0].RunAt.Valid, "one-shot task should be disabled") +} + +func TestRunDueTasks_Recurring_ReschedulesWithNextCronTime(t *testing.T) { + task := sched_db.ScheduledTask{ID: "t2", EventName: "cron-ev", CronExpr: "* * * * *"} + repo := &mockSkdRepo{claimResults: []sched_db.ScheduledTask{task}} + bus := &mockEventBus{} + svc := &SchedulerService{skdRepo: repo, eventBus: bus} + + progress := svc.runDueTasks(testCtx) + + assert.True(t, progress) + assert.Equal(t, []events.EventName{"cron-ev"}, bus.createdTaskNames) + assert.Len(t, repo.savedTasks, 1) + saved := repo.savedTasks[0] + assert.True(t, saved.RunAt.Valid) + assert.True(t, saved.RunAt.Time.After(time.Now())) + assert.Equal(t, sched_db.ScheduledTaskStatusPending, saved.Status) +} + +func TestRunDueTasks_Recurring_InvalidCronExpr_DisablesTask(t *testing.T) { + task := sched_db.ScheduledTask{ID: "t3", EventName: "bad", CronExpr: "not-valid"} + repo := &mockSkdRepo{claimResults: []sched_db.ScheduledTask{task}} + bus := &mockEventBus{} + svc := &SchedulerService{skdRepo: repo, eventBus: bus} + + progress := svc.runDueTasks(testCtx) + + assert.True(t, progress) + assert.Len(t, repo.savedTasks, 1) + assert.False(t, repo.savedTasks[0].RunAt.Valid) +} + +func TestRunDueTasks_CreateTaskError_ReschedulesWithRetryDelay(t *testing.T) { + task := sched_db.ScheduledTask{ID: "t4", EventName: "fail-ev"} + repo := &mockSkdRepo{claimResults: []sched_db.ScheduledTask{task}} + bus := &mockEventBus{createTaskErr: errors.New("bus down")} + svc := &SchedulerService{skdRepo: repo, eventBus: bus} + + progress := svc.runDueTasks(testCtx) + + assert.True(t, progress) + assert.Len(t, repo.savedTasks, 1) + saved := repo.savedTasks[0] + assert.True(t, saved.RunAt.Valid) + assert.True(t, saved.RunAt.Time.After(time.Now())) + assert.Equal(t, sched_db.ScheduledTaskStatusPending, saved.Status) +} + +func TestRunDueTasks_MultipleTasks_ProcessedInOrder(t *testing.T) { + tasks := []sched_db.ScheduledTask{ + {ID: "t1", EventName: "event-a"}, + {ID: "t2", EventName: "event-b"}, + } + repo := &mockSkdRepo{claimResults: tasks} + bus := &mockEventBus{} + svc := &SchedulerService{skdRepo: repo, eventBus: bus} + + progress := svc.runDueTasks(testCtx) + + assert.True(t, progress) + assert.Equal(t, []events.EventName{"event-a", "event-b"}, bus.createdTaskNames) + assert.Len(t, repo.savedTasks, 2) +} + +func TestRunDueTasks_ValidJsonPayload_Dispatched(t *testing.T) { + task := sched_db.ScheduledTask{ + ID: "t5", + EventName: "payload-ev", + Payload: events.EventData{}, + } + repo := &mockSkdRepo{claimResults: []sched_db.ScheduledTask{task}} + bus := &mockEventBus{} + svc := &SchedulerService{skdRepo: repo, eventBus: bus} + + progress := svc.runDueTasks(testCtx) + + assert.True(t, progress) + assert.Equal(t, []events.EventName{"payload-ev"}, bus.createdTaskNames) +} + +// --------------------------------------------------------------------------- +// rescheduleLongRunningTasks +// --------------------------------------------------------------------------- + +// TestRescheduleLongRunning_NoStuckTasks verifies that nothing is saved when +// there are no stuck tasks. +func TestRescheduleLongRunning_NoStuckTasks(t *testing.T) { + repo := &mockSkdRepo{stuckTasks: nil} + svc := &SchedulerService{skdRepo: repo, eventBus: &mockEventBus{}} + + svc.rescheduleLongRunningTasks(testCtx) + + assert.Empty(t, repo.savedTasks) +} + +// TestRescheduleLongRunning_RepoError_DoesNotSave verifies that a repo error +// is handled gracefully without saving anything. +func TestRescheduleLongRunning_RepoError_DoesNotSave(t *testing.T) { + repo := &mockSkdRepo{stuckTasksErr: errors.New("db error")} + svc := &SchedulerService{skdRepo: repo, eventBus: &mockEventBus{}} + + svc.rescheduleLongRunningTasks(testCtx) + + assert.Empty(t, repo.savedTasks) +} + +// TestRescheduleLongRunning_OneShot_ReschedulesWithRetryDelay verifies that a +// stuck one-shot task (no cron) is reset to pending with run_at = now + retry. +func TestRescheduleLongRunning_OneShot_ReschedulesWithRetryDelay(t *testing.T) { + stuck := sched_db.ScheduledTask{ + ID: "stuck-1", + EventName: "one-shot", + CronExpr: "", + Status: sched_db.ScheduledTaskStatusRunning, + } + repo := &mockSkdRepo{stuckTasks: []sched_db.ScheduledTask{stuck}} + svc := &SchedulerService{skdRepo: repo, eventBus: &mockEventBus{}} + + before := time.Now() + svc.rescheduleLongRunningTasks(testCtx) + after := time.Now() + + assert.Len(t, repo.savedTasks, 1) + saved := repo.savedTasks[0] + assert.Equal(t, sched_db.ScheduledTaskStatusPending, saved.Status) + assert.True(t, saved.RunAt.Valid) + assert.True(t, saved.RunAt.Time.After(before)) + assert.True(t, saved.RunAt.Time.After(after)) // run_at is in the future +} + +// TestRescheduleLongRunning_Recurring_ReschedulesWithNextCronTime verifies that +// a stuck recurring task is reset to pending with the next cron-computed run_at. +func TestRescheduleLongRunning_Recurring_ReschedulesWithNextCronTime(t *testing.T) { + stuck := sched_db.ScheduledTask{ + ID: "stuck-2", + EventName: "cron-ev", + CronExpr: "* * * * *", // every minute + Status: sched_db.ScheduledTaskStatusRunning, + } + repo := &mockSkdRepo{stuckTasks: []sched_db.ScheduledTask{stuck}} + svc := &SchedulerService{skdRepo: repo, eventBus: &mockEventBus{}} + + svc.rescheduleLongRunningTasks(testCtx) + + assert.Len(t, repo.savedTasks, 1) + saved := repo.savedTasks[0] + assert.Equal(t, sched_db.ScheduledTaskStatusPending, saved.Status) + assert.True(t, saved.RunAt.Valid) + assert.True(t, saved.RunAt.Time.After(time.Now())) +} + +// TestRescheduleLongRunning_InvalidCron_DisablesTask verifies that a stuck task +// with an invalid cron expression is disabled rather than rescheduled. +func TestRescheduleLongRunning_InvalidCron_DisablesTask(t *testing.T) { + stuck := sched_db.ScheduledTask{ + ID: "stuck-3", + EventName: "bad-cron", + CronExpr: "not-a-cron", + Status: sched_db.ScheduledTaskStatusRunning, + } + repo := &mockSkdRepo{stuckTasks: []sched_db.ScheduledTask{stuck}} + svc := &SchedulerService{skdRepo: repo, eventBus: &mockEventBus{}} + + svc.rescheduleLongRunningTasks(testCtx) + + assert.Len(t, repo.savedTasks, 1) + saved := repo.savedTasks[0] + assert.Equal(t, sched_db.ScheduledTaskStatusStopped, saved.Status) + assert.False(t, saved.RunAt.Valid) +} + +// TestRescheduleLongRunning_MultipleStuck_AllRescheduled verifies that all +// stuck tasks in the result set are processed. +func TestRescheduleLongRunning_MultipleStuck_AllRescheduled(t *testing.T) { + stuckTasks := []sched_db.ScheduledTask{ + {ID: "s1", EventName: "ev-a", CronExpr: "", Status: sched_db.ScheduledTaskStatusRunning}, + {ID: "s2", EventName: "ev-b", CronExpr: "* * * * *", Status: sched_db.ScheduledTaskStatusRunning}, + } + repo := &mockSkdRepo{stuckTasks: stuckTasks} + svc := &SchedulerService{skdRepo: repo, eventBus: &mockEventBus{}} + + svc.rescheduleLongRunningTasks(testCtx) + + assert.Len(t, repo.savedTasks, 2) + for _, saved := range repo.savedTasks { + assert.Equal(t, sched_db.ScheduledTaskStatusPending, saved.Status) + assert.True(t, saved.RunAt.Valid) + } +} + +// --------------------------------------------------------------------------- +// NewSchedulerService — channel wiring +// --------------------------------------------------------------------------- + +func TestNewSchedulerService_ChannelWired(t *testing.T) { + svc := NewSchedulerService(&mockSkdRepo{}, &mockEventBus{}, "") + + assert.NotNil(t, svc.notifyCh) + assert.NotNil(t, svc.notify) + + svc.notifyCh <- struct{}{} + select { + case <-svc.notify: + // OK — same underlying channel + default: + t.Fatal("notify channel is not wired to notifyCh") + } +} diff --git a/broker/sqlc/sched_query.sql b/broker/sqlc/sched_query.sql new file mode 100644 index 00000000..eacce93e --- /dev/null +++ b/broker/sqlc/sched_query.sql @@ -0,0 +1,41 @@ +-- name: SaveScheduledTask :one +INSERT INTO scheduled_task (id, event_name, cron_expr, payload, run_at, status, created_at, updated_at) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8) +ON CONFLICT (id) DO UPDATE + SET event_name = EXCLUDED.event_name, + cron_expr = EXCLUDED.cron_expr, + payload = EXCLUDED.payload, + run_at = EXCLUDED.run_at, + status = EXCLUDED.status, + updated_at = now() +RETURNING sqlc.embed(scheduled_task); + +-- name: GetNextRunAt :one +SELECT run_at +FROM scheduled_task +WHERE status = 'pending' + AND run_at IS NOT NULL +ORDER BY run_at +LIMIT 1; + +-- name: GetStuckRunningTasks :many +SELECT sqlc.embed(scheduled_task) +FROM scheduled_task +WHERE status = 'running' + AND updated_at <= now() - $1::interval; + +-- name: ClaimNextScheduledTask :one +UPDATE scheduled_task +SET status = 'running', + updated_at = now() +WHERE id = (SELECT id + FROM scheduled_task + WHERE status = 'pending' + AND run_at <= now() AND run_at IS NOT NULL + ORDER BY run_at + LIMIT 1 + FOR +UPDATE SKIP LOCKED + ) + RETURNING sqlc.embed(scheduled_task); + diff --git a/broker/sqlc/sched_schema.sql b/broker/sqlc/sched_schema.sql new file mode 100644 index 00000000..59600175 --- /dev/null +++ b/broker/sqlc/sched_schema.sql @@ -0,0 +1,15 @@ +CREATE TABLE scheduled_task +( + id TEXT PRIMARY KEY, + event_name TEXT NOT NULL, + cron_expr TEXT NOT NULL, + payload JSONB, + run_at TIMESTAMPTZ, + status TEXT NOT NULL DEFAULT 'pending', + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ, + FOREIGN KEY (event_name) REFERENCES event_config (event_name) +); + +CREATE INDEX idx_scheduled_task_run_at ON scheduled_task (run_at) WHERE status = 'pending' AND run_at IS NOT NULL; + diff --git a/broker/sqlc/sqlc.yaml b/broker/sqlc/sqlc.yaml index 47533043..5f2b7264 100644 --- a/broker/sqlc/sqlc.yaml +++ b/broker/sqlc/sqlc.yaml @@ -132,3 +132,27 @@ sql: - column: "pull_slip.type" go_type: type: "PullSlipType" + - engine: "postgresql" + queries: "sched_query.sql" + schema: "sched_schema.sql" + gen: + go: + package: "sched_db" + out: "../scheduler/db" + output_db_file_name: "sched_db_gen.go" + output_models_file_name: "sched_models_gen.go" + output_files_suffix: "_gen" + sql_package: "pgx/v5" + emit_methods_with_db_argument: true + overrides: + - column: "scheduled_task.status" + go_type: + type: "ScheduledTaskStatus" + - column: "scheduled_task.event_name" + go_type: + import: "github.com/indexdata/crosslink/broker/events" + type: "EventName" + - column: "scheduled_task.payload" + go_type: + import: "github.com/indexdata/crosslink/broker/events" + type: "EventData" diff --git a/broker/test/scheduler/db/skdrepo_test.go b/broker/test/scheduler/db/skdrepo_test.go new file mode 100644 index 00000000..6b3e0053 --- /dev/null +++ b/broker/test/scheduler/db/skdrepo_test.go @@ -0,0 +1,349 @@ +package sched_db + +import ( + "context" + "os" + "testing" + "time" + + "github.com/google/uuid" + "github.com/indexdata/crosslink/broker/app" + "github.com/indexdata/crosslink/broker/common" + "github.com/indexdata/crosslink/broker/events" + sched_db "github.com/indexdata/crosslink/broker/scheduler/db" + test "github.com/indexdata/crosslink/broker/test/utils" + "github.com/indexdata/go-utils/utils" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/assert" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/postgres" + "github.com/testcontainers/testcontainers-go/wait" +) + +var skdRepo sched_db.SchedRepo +var appCtx = common.CreateExtCtxWithArgs(context.Background(), nil) + +func TestMain(m *testing.M) { + ctx := context.Background() + + pgContainer, err := postgres.Run(ctx, "postgres", + postgres.WithDatabase("crosslink"), + postgres.WithUsername("crosslink"), + postgres.WithPassword("crosslink"), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2).WithStartupTimeout(30*time.Second)), + ) + test.Expect(err, "failed to start db container") + + connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable") + test.Expect(err, "failed to get conn string") + + app.ConnectionString = connStr + app.MigrationsFolder = "file://../../../migrations" + app.HTTP_PORT = utils.Must(test.GetFreePort()) + app.DB_PROVISION = true + + test.Expect(app.RunDbUp(), "failed to run db migrations") + + pool, err := app.InitDbPool() + test.Expect(err, "failed to init db pool") + + skdRepo = sched_db.CreateSkdRepo(pool) + + code := m.Run() + + test.Expect(pgContainer.Terminate(ctx), "failed to stop db container") + os.Exit(code) +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +func newTask(cronExpr string, runAt pgtype.Timestamptz) sched_db.SaveScheduledTaskParams { + return sched_db.SaveScheduledTaskParams{ + ID: uuid.NewString(), + EventName: events.EventNameSendNotification, + CronExpr: cronExpr, + RunAt: runAt, + Status: sched_db.ScheduledTaskStatusPending, + CreatedAt: pgtype.Timestamptz{Time: time.Now(), Valid: true}, + } +} + +func tstz(t time.Time) pgtype.Timestamptz { + return pgtype.Timestamptz{Time: t, Valid: true} +} + +func stopTask(t *testing.T, task sched_db.ScheduledTask) { + task.Status = sched_db.ScheduledTaskStatusStopped + _, err := skdRepo.SaveScheduledTask(appCtx, sched_db.SaveScheduledTaskParams(task)) + assert.NoError(t, err) +} + +// --------------------------------------------------------------------------- +// SaveScheduledTask +// --------------------------------------------------------------------------- + +func TestSaveScheduledTask_Insert(t *testing.T) { + params := newTask("* * * * *", tstz(time.Now().Add(1*time.Minute))) + + saved, err := skdRepo.SaveScheduledTask(appCtx, params) + + assert.NoError(t, err) + assert.Equal(t, params.ID, saved.ID) + assert.Equal(t, params.EventName, saved.EventName) + assert.Equal(t, params.CronExpr, saved.CronExpr) + assert.Equal(t, sched_db.ScheduledTaskStatusPending, saved.Status) + assert.True(t, saved.CreatedAt.Valid) + + stopTask(t, saved) +} + +func TestSaveScheduledTask_Upsert_UpdatesFields(t *testing.T) { + params := newTask("0 * * * *", tstz(time.Now().Add(1*time.Hour))) + _, err := skdRepo.SaveScheduledTask(appCtx, params) + assert.NoError(t, err) + + params.CronExpr = "0 9 * * 1" + params.RunAt = tstz(time.Now().Add(2 * time.Hour)) + + updated, err := skdRepo.SaveScheduledTask(appCtx, params) + + assert.NoError(t, err) + assert.Equal(t, params.ID, updated.ID) + assert.Equal(t, "0 9 * * 1", updated.CronExpr) + + stopTask(t, updated) +} + +func TestSaveScheduledTask_WithPayload(t *testing.T) { + params := newTask("", tstz(time.Now().Add(1*time.Minute))) + params.Payload = events.EventData{ + CommonEventData: events.CommonEventData{Note: "hello from scheduler"}, + } + + saved, err := skdRepo.SaveScheduledTask(appCtx, params) + + assert.NoError(t, err) + assert.Equal(t, "hello from scheduler", saved.Payload.Note) + + stopTask(t, saved) +} + +// --------------------------------------------------------------------------- +// GetNextRunAt +// --------------------------------------------------------------------------- + +func TestGetNextRunAt_ReturnsPendingTask(t *testing.T) { + params := newTask("* * * * *", tstz(time.Now().Add(5*time.Minute))) + saved, err := skdRepo.SaveScheduledTask(appCtx, params) + assert.NoError(t, err) + + next, err := skdRepo.GetNextRunAt(appCtx) + + assert.NoError(t, err) + assert.True(t, next.Valid) + assert.True(t, next.Time.After(time.Now())) + + stopTask(t, saved) +} + +func TestGetNextRunAt_ReturnsEarliestPendingTask(t *testing.T) { + earlier := tstz(time.Now().Add(2 * time.Minute)) + later := tstz(time.Now().Add(4 * time.Hour)) + + p1 := newTask("", earlier) + p2 := newTask("", later) + + s1, err := skdRepo.SaveScheduledTask(appCtx, p1) + assert.NoError(t, err) + s2, err := skdRepo.SaveScheduledTask(appCtx, p2) + assert.NoError(t, err) + + next, err := skdRepo.GetNextRunAt(appCtx) + + assert.NoError(t, err) + assert.True(t, next.Valid) + assert.WithinDuration(t, earlier.Time, next.Time, time.Second) + assert.True(t, next.Time.Before(later.Time), "returned run_at should be the earlier of the two tasks") + + stopTask(t, s1) + stopTask(t, s2) +} + +// --------------------------------------------------------------------------- +// ClaimNextScheduledTask +// --------------------------------------------------------------------------- + +func TestClaimNextScheduledTask_OverdueTask_ClaimedAndSetToRunning(t *testing.T) { + overdue := newTask("", tstz(time.Now().Add(-1*time.Second))) + _, err := skdRepo.SaveScheduledTask(appCtx, overdue) + assert.NoError(t, err) + + claimed, err := skdRepo.ClaimNextScheduledTask(appCtx) + + assert.NoError(t, err) + assert.Equal(t, sched_db.ScheduledTaskStatusRunning, claimed.Status) + assert.True(t, claimed.UpdatedAt.Valid) + + stopTask(t, claimed) +} + +func TestClaimNextScheduledTask_SetsStatusToRunning(t *testing.T) { + params := newTask("* * * * *", tstz(time.Now().Add(-30*time.Second))) + _, err := skdRepo.SaveScheduledTask(appCtx, params) + assert.NoError(t, err) + + claimed, err := skdRepo.ClaimNextScheduledTask(appCtx) + + assert.NoError(t, err) + assert.Equal(t, sched_db.ScheduledTaskStatusRunning, claimed.Status) + + stopTask(t, claimed) +} + +func TestClaimNextScheduledTask_FutureTask_NotClaimed(t *testing.T) { + i := 0 + for { + i++ + _, err := skdRepo.ClaimNextScheduledTask(appCtx) + if err != nil || i > 100 { + break + } + } + assert.True(t, i < 100, "too many claimed tasks") + + params := newTask("", tstz(time.Now().Add(24*time.Hour))) + saved, err := skdRepo.SaveScheduledTask(appCtx, params) + assert.NoError(t, err) + + _, err = skdRepo.ClaimNextScheduledTask(appCtx) + assert.ErrorIs(t, err, pgx.ErrNoRows) + + stopTask(t, saved) +} + +// --------------------------------------------------------------------------- +// Reschedule flow (claim → save pending with updated run_at) +// --------------------------------------------------------------------------- + +func TestRescheduleAfterClaim(t *testing.T) { + params := newTask("* * * * *", tstz(time.Now().Add(-1*time.Second))) + _, err := skdRepo.SaveScheduledTask(appCtx, params) + assert.NoError(t, err) + + claimed, err := skdRepo.ClaimNextScheduledTask(appCtx) + assert.NoError(t, err) + assert.Equal(t, sched_db.ScheduledTaskStatusRunning, claimed.Status) + + claimed.Status = sched_db.ScheduledTaskStatusPending + claimed.RunAt = tstz(time.Now().Add(5 * time.Minute)) + rescheduled, err := skdRepo.SaveScheduledTask(appCtx, sched_db.SaveScheduledTaskParams(claimed)) + + assert.NoError(t, err) + assert.Equal(t, sched_db.ScheduledTaskStatusPending, rescheduled.Status) + assert.True(t, rescheduled.RunAt.Time.After(time.Now())) + + stopTask(t, rescheduled) +} + +// --------------------------------------------------------------------------- +// GetStuckRunningTasks +// --------------------------------------------------------------------------- + +// insertRunning inserts a task directly in 'running' status with the given +// updated_at so we can simulate a task that has been stuck for a known duration. +func insertRunning(t *testing.T, updatedAt time.Time) sched_db.ScheduledTask { + t.Helper() + params := newTask("", tstz(time.Now().Add(-10*time.Second))) + params.Status = sched_db.ScheduledTaskStatusRunning + params.UpdatedAt = pgtype.Timestamptz{Time: updatedAt, Valid: true} + saved, err := skdRepo.SaveScheduledTask(appCtx, params) + assert.NoError(t, err) + return saved +} + +func TestGetStuckRunningTasks_ReturnsTaskStuckLongerThanThreshold(t *testing.T) { + // Insert a task that has been running for 2 hours. + stuck := insertRunning(t, time.Now().Add(-2*time.Hour)) + + tasks, err := skdRepo.GetStuckRunningTasks(appCtx, 1*time.Hour) + + assert.NoError(t, err) + ids := make([]string, len(tasks)) + for i, tk := range tasks { + ids[i] = tk.ID + } + assert.Contains(t, ids, stuck.ID) + + stopTask(t, stuck) +} + +func TestGetStuckRunningTasks_DoesNotReturnRecentTask(t *testing.T) { + // Insert a task that has been running for only 10 seconds — well within threshold. + recent := insertRunning(t, time.Now().Add(-10*time.Second)) + + tasks, err := skdRepo.GetStuckRunningTasks(appCtx, 1*time.Hour) + + assert.NoError(t, err) + for _, tk := range tasks { + assert.NotEqual(t, recent.ID, tk.ID, "recently started task should not be returned as stuck") + } + + stopTask(t, recent) +} + +func TestGetStuckRunningTasks_DoesNotReturnPendingTask(t *testing.T) { + // A pending task (not running) should never appear. + params := newTask("", tstz(time.Now().Add(-2*time.Hour))) + saved, err := skdRepo.SaveScheduledTask(appCtx, params) + assert.NoError(t, err) + + tasks, err := skdRepo.GetStuckRunningTasks(appCtx, 1*time.Hour) + + assert.NoError(t, err) + for _, tk := range tasks { + assert.NotEqual(t, saved.ID, tk.ID, "pending task should not appear in stuck results") + } + + stopTask(t, saved) +} + +func TestGetStuckRunningTasks_MultipleStuckTasks_AllReturned(t *testing.T) { + stuck1 := insertRunning(t, time.Now().Add(-3*time.Hour)) + stuck2 := insertRunning(t, time.Now().Add(-2*time.Hour)) + + tasks, err := skdRepo.GetStuckRunningTasks(appCtx, 1*time.Hour) + + assert.NoError(t, err) + ids := make(map[string]bool, len(tasks)) + for _, tk := range tasks { + ids[tk.ID] = true + } + assert.True(t, ids[stuck1.ID], "stuck1 should be returned") + assert.True(t, ids[stuck2.ID], "stuck2 should be returned") + + stopTask(t, stuck1) + stopTask(t, stuck2) +} + +// --------------------------------------------------------------------------- +// Disable flow (save with invalid RunAt) +// --------------------------------------------------------------------------- + +func TestDisableTask_InvalidRunAt(t *testing.T) { + params := newTask("", tstz(time.Now().Add(1*time.Minute))) + saved, err := skdRepo.SaveScheduledTask(appCtx, params) + assert.NoError(t, err) + + saved.RunAt = pgtype.Timestamptz{Valid: false} + disabled, err := skdRepo.SaveScheduledTask(appCtx, sched_db.SaveScheduledTaskParams(saved)) + + assert.NoError(t, err) + assert.False(t, disabled.RunAt.Valid) + + stopTask(t, disabled) +}