From 1ca6cb30455ab87b248d49d8a768f979d8b47b6b Mon Sep 17 00:00:00 2001 From: Vinod Morya Date: Tue, 7 Apr 2026 18:52:49 +0530 Subject: [PATCH] feat: Add SSE (Server-Sent Events) middleware Add production-grade Server-Sent Events middleware built natively for Fiber's fasthttp architecture with proper client disconnect detection. Features: Hub-based broker, 3 priority lanes, NATS-style topic wildcards, adaptive throttling, connection groups, JWT/ticket auth, cache invalidation helpers, Prometheus metrics, Last-Event-ID replay, Redis/NATS fan-out, and graceful Kubernetes-style drain. 91% test coverage, golangci-lint clean, go test -race clean. Resolves #4194 --- docs/middleware/sse.md | 119 ++ docs/whats_new.md | 18 + middleware/sse/auth.go | 187 +++ middleware/sse/coalescer.go | 89 ++ middleware/sse/config.go | 109 ++ middleware/sse/connection.go | 132 ++ middleware/sse/domain_event.go | 135 ++ middleware/sse/event.go | 176 +++ middleware/sse/example_test.go | 101 ++ middleware/sse/fanout.go | 144 +++ middleware/sse/invalidation.go | 118 ++ middleware/sse/metrics.go | 195 +++ middleware/sse/replayer.go | 148 +++ middleware/sse/sse.go | 645 ++++++++++ middleware/sse/sse_test.go | 2109 ++++++++++++++++++++++++++++++++ middleware/sse/stats.go | 74 ++ middleware/sse/throttle.go | 80 ++ middleware/sse/topic.go | 61 + 18 files changed, 4640 insertions(+) create mode 100644 docs/middleware/sse.md create mode 100644 middleware/sse/auth.go create mode 100644 middleware/sse/coalescer.go create mode 100644 middleware/sse/config.go create mode 100644 middleware/sse/connection.go create mode 100644 middleware/sse/domain_event.go create mode 100644 middleware/sse/event.go create mode 100644 middleware/sse/example_test.go create mode 100644 middleware/sse/fanout.go create mode 100644 middleware/sse/invalidation.go create mode 100644 middleware/sse/metrics.go create mode 100644 middleware/sse/replayer.go create mode 100644 middleware/sse/sse.go create mode 100644 middleware/sse/sse_test.go create mode 100644 middleware/sse/stats.go create mode 100644 middleware/sse/throttle.go create mode 100644 middleware/sse/topic.go diff --git a/docs/middleware/sse.md b/docs/middleware/sse.md new file mode 100644 index 00000000000..09f119cd86f --- /dev/null +++ b/docs/middleware/sse.md @@ -0,0 +1,119 @@ +--- +id: sse +--- + +# SSE + +Server-Sent Events middleware for [Fiber](https://github.com/gofiber/fiber) that provides a production-grade SSE broker built natively on Fiber's fasthttp architecture. It includes a Hub-based event broker with topic routing, event coalescing (last-writer-wins), three priority lanes (instant/batched/coalesced), NATS-style topic wildcards, adaptive per-connection throttling, connection groups, built-in JWT and ticket auth helpers, Prometheus metrics, graceful Kubernetes-style drain, auto fan-out from Redis/NATS, and pluggable Last-Event-ID replay. + +## Signatures + +```go +func New(config ...Config) fiber.Handler +func NewWithHub(config ...Config) (fiber.Handler, *Hub) +``` + +## Examples + +Import the middleware package: + +```go +import ( + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/middleware/sse" +) +``` + +Once your Fiber app is initialized, create an SSE handler and hub: + +```go +// Basic usage — subscribe all clients to "notifications" +handler, hub := sse.NewWithHub(sse.Config{ + OnConnect: func(c fiber.Ctx, conn *sse.Connection) error { + conn.Topics = []string{"notifications"} + return nil + }, +}) +app.Get("/events", handler) + +// Publish an event from any goroutine +hub.Publish(sse.Event{ + Type: "update", + Data: "hello", + Topics: []string{"notifications"}, +}) +``` + +Use JWT authentication and metadata-based groups for multi-tenant isolation: + +```go +handler, hub := sse.NewWithHub(sse.Config{ + OnConnect: sse.JWTAuth(func(token string) (map[string]string, error) { + claims, err := validateJWT(token) + if err != nil { + return nil, err + } + return map[string]string{ + "user_id": claims.UserID, + "tenant_id": claims.TenantID, + }, nil + }), +}) +app.Get("/events", handler) + +// Publish only to a specific tenant +hub.DomainEvent("orders", "created", orderID, tenantID, nil) +``` + +Use event coalescing to reduce traffic for high-frequency updates: + +```go +// Progress events use PriorityCoalesced — if progress goes 5%→8% +// in one flush window, only 8% is sent to the client. +hub.Progress("import", importID, tenantID, current, total, nil) + +// Completion events use PriorityInstant — always delivered immediately. +hub.Complete("import", importID, tenantID, true, map[string]any{ + "rows_imported": 1500, +}) +``` + +Use fan-out to bridge an external pub/sub system into the SSE hub: + +```go +cancel := hub.FanOut(sse.FanOutConfig{ + Subscriber: redisSubscriber, + Channel: "events:orders", + EventType: "order-update", + Topic: "orders", +}) +defer cancel() +``` + +## Config + +| Property | Type | Description | Default | +| :---------------- | :------------------------------------------------ | :------------------------------------------------------------------------------------------------------------------- | :------------- | +| Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` | +| OnConnect | `func(fiber.Ctx, *Connection) error` | Called when a new client connects. Set `conn.Topics` and `conn.Metadata` here. Return error to reject (sends 403). | `nil` | +| OnDisconnect | `func(*Connection)` | Called after a client disconnects. | `nil` | +| OnPause | `func(*Connection)` | Called when a connection is paused (browser tab hidden). | `nil` | +| OnResume | `func(*Connection)` | Called when a connection is resumed (browser tab visible). | `nil` | +| Replayer | `Replayer` | Enables Last-Event-ID replay. If nil, replay is disabled. | `nil` | +| FlushInterval | `time.Duration` | How often batched (P1) and coalesced (P2) events are flushed to clients. Instant (P0) events bypass this. | `2s` | +| HeartbeatInterval | `time.Duration` | How often a comment is sent to idle connections to detect disconnects and prevent proxy timeouts. | `30s` | +| MaxLifetime | `time.Duration` | Maximum duration a single SSE connection can stay open. Set to -1 for unlimited. | `30m` | +| SendBufferSize | `int` | Per-connection channel buffer. If full, events are dropped. | `256` | +| RetryMS | `int` | Reconnection interval hint sent to clients via the `retry:` directive on connect. | `3000` | + +## Default Config + +```go +var ConfigDefault = Config{ + FlushInterval: 2 * time.Second, + SendBufferSize: 256, + HeartbeatInterval: 30 * time.Second, + MaxLifetime: 30 * time.Minute, + RetryMS: 3000, +} +``` diff --git a/docs/whats_new.md b/docs/whats_new.md index 20c7781d092..8409e1bb36e 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -57,6 +57,7 @@ Here's a quick overview of the changes in Fiber `v3`: - [Proxy](#proxy) - [Recover](#recover) - [Session](#session) + - [SSE](#sse) - [🔌 Addons](#-addons) - [📋 Migration guide](#-migration-guide) @@ -3138,3 +3139,20 @@ app.Use(session.New(session.Config{ See the [Session Middleware Migration Guide](./middleware/session.md#migration-guide) for complete details. + +#### SSE + +The new SSE middleware provides production-grade Server-Sent Events for Fiber. It includes a Hub-based broker with topic routing, event coalescing, NATS-style wildcards, JWT/ticket auth, and Prometheus metrics. + +```go +handler, hub := sse.NewWithHub(sse.Config{ + OnConnect: func(c fiber.Ctx, conn *sse.Connection) error { + conn.Topics = []string{"notifications"} + return nil + }, +}) +app.Get("/events", handler) + +// Replace polling with real-time push +hub.Invalidate("orders", order.ID, "created") +``` diff --git a/middleware/sse/auth.go b/middleware/sse/auth.go new file mode 100644 index 00000000000..811479e31b0 --- /dev/null +++ b/middleware/sse/auth.go @@ -0,0 +1,187 @@ +package sse + +import ( + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "maps" + "runtime" + "strings" + "sync" + "time" + + "github.com/gofiber/fiber/v3" +) + +// JWTAuth returns an OnConnect handler that validates a JWT Bearer token +// from the Authorization header or a token query parameter. +// +// The validateFunc receives the raw token string and should return the +// claims as a map. Return an error to reject the connection. +func JWTAuth(validateFunc func(token string) (map[string]string, error)) func(fiber.Ctx, *Connection) error { + return func(c fiber.Ctx, conn *Connection) error { + token := "" + + const bearerPrefix = "Bearer " + auth := c.Get("Authorization") + if len(auth) > len(bearerPrefix) && strings.EqualFold(auth[:len(bearerPrefix)], bearerPrefix) { + token = auth[len(bearerPrefix):] + } + + if token == "" { + token = c.Query("token") + } + + if token == "" { + return errors.New("missing authentication token") + } + + claims, err := validateFunc(token) + if err != nil { + return fmt.Errorf("authentication failed: %w", err) + } + + maps.Copy(conn.Metadata, claims) + + return nil + } +} + +// TicketStore is the interface for ticket-based SSE authentication. +// Implement this with Redis, in-memory, or any key-value store. +type TicketStore interface { + // Set stores a ticket with the given value and TTL. + Set(ticket, value string, ttl time.Duration) error + + // GetDel atomically retrieves and deletes a ticket (one-time use). + // Returns empty string and nil error if not found. + GetDel(ticket string) (string, error) +} + +// MemoryTicketStore is an in-memory TicketStore for development and testing. +// Call Close to stop the background cleanup goroutine. +type MemoryTicketStore struct { + tickets map[string]memTicket + done chan struct{} + mu sync.Mutex + closeOnce sync.Once +} + +type memTicket struct { + expires time.Time + value string +} + +// NewMemoryTicketStore creates an in-memory ticket store with a background +// cleanup goroutine that evicts expired tickets every 30 seconds. +func NewMemoryTicketStore() *MemoryTicketStore { + s := &MemoryTicketStore{ + tickets: make(map[string]memTicket), + done: make(chan struct{}), + } + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + s.mu.Lock() + now := time.Now() + for k, v := range s.tickets { + if now.After(v.expires) { + delete(s.tickets, k) + } + } + s.mu.Unlock() + case <-s.done: + return + } + } + }() + + // Prevent goroutine leak if caller forgets to call Close. + runtime.SetFinalizer(s, func(s *MemoryTicketStore) { + s.Close() + }) + + return s +} + +// Close stops the background cleanup goroutine. Safe to call multiple times. +func (s *MemoryTicketStore) Close() { + s.closeOnce.Do(func() { + close(s.done) + }) +} + +// Set stores a ticket with the given value and TTL. +func (s *MemoryTicketStore) Set(ticket, value string, ttl time.Duration) error { + s.mu.Lock() + defer s.mu.Unlock() + s.tickets[ticket] = memTicket{value: value, expires: time.Now().Add(ttl)} + return nil +} + +// GetDel atomically retrieves and deletes a ticket (one-time use). +func (s *MemoryTicketStore) GetDel(ticket string) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + t, ok := s.tickets[ticket] + if !ok { + return "", nil + } + delete(s.tickets, ticket) + if time.Now().After(t.expires) { + return "", nil + } + return t.value, nil +} + +// TicketAuth returns an OnConnect handler that validates a one-time ticket +// from the ticket query parameter. +func TicketAuth( + store TicketStore, + parseValue func(value string) (metadata map[string]string, topics []string, err error), +) func(fiber.Ctx, *Connection) error { + return func(c fiber.Ctx, conn *Connection) error { + ticket := c.Query("ticket") + if ticket == "" { + return errors.New("missing ticket parameter") + } + + value, err := store.GetDel(ticket) + if err != nil { + return fmt.Errorf("ticket validation error: %w", err) + } + if value == "" { + return errors.New("invalid or expired ticket") + } + + metadata, topics, err := parseValue(value) + if err != nil { + return fmt.Errorf("ticket parse error: %w", err) + } + + maps.Copy(conn.Metadata, metadata) + if len(topics) > 0 { + conn.Topics = topics + } + + return nil + } +} + +// IssueTicket creates a one-time ticket and stores it. Returns the +// ticket string that the client should pass as ?ticket=. +func IssueTicket(store TicketStore, value string, ttl time.Duration) (string, error) { + b := make([]byte, 24) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("failed to generate ticket: %w", err) + } + ticket := hex.EncodeToString(b) + if err := store.Set(ticket, value, ttl); err != nil { + return "", err + } + return ticket, nil +} diff --git a/middleware/sse/coalescer.go b/middleware/sse/coalescer.go new file mode 100644 index 00000000000..811c27686ef --- /dev/null +++ b/middleware/sse/coalescer.go @@ -0,0 +1,89 @@ +package sse + +import ( + "sync" + "time" +) + +// coalescer buffers P1 (batched) and P2 (coalesced) events per connection. +// The hub's flush ticker drains these buffers periodically. +type coalescer struct { + // coalesced holds P2 events keyed by CoalesceKey — only the latest per key survives. + coalesced map[string]MarshaledEvent + + // batched holds P1 events in insertion order — all are sent on flush. + batched []MarshaledEvent + + // coalescedOrder preserves first-seen order of coalesce keys for deterministic output. + coalescedOrder []string + + mu sync.Mutex + + // flushInterval is the target flush cadence (informational). + flushInterval time.Duration +} + +// newCoalescer creates a coalescer with the given flush interval hint. +func newCoalescer(flushInterval time.Duration) *coalescer { + return &coalescer{ + coalesced: make(map[string]MarshaledEvent), + batched: make([]MarshaledEvent, 0, 16), + flushInterval: flushInterval, + } +} + +// addBatched appends a P1 event to the batch buffer. +func (c *coalescer) addBatched(me MarshaledEvent) { //nolint:gocritic // hugeParam: value semantics match flush() return type + c.mu.Lock() + c.batched = append(c.batched, me) + c.mu.Unlock() +} + +// addCoalesced upserts a P2 event by its coalesce key. If the key already +// exists, the previous event is overwritten (last-writer-wins). +func (c *coalescer) addCoalesced(key string, me MarshaledEvent) { //nolint:gocritic // hugeParam: value semantics match flush() return type + c.mu.Lock() + if _, exists := c.coalesced[key]; !exists { + c.coalescedOrder = append(c.coalescedOrder, key) + } + c.coalesced[key] = me + c.mu.Unlock() +} + +// flush drains both buffers and returns the events to send. +func (c *coalescer) flush() []MarshaledEvent { + c.mu.Lock() + defer c.mu.Unlock() + + batchLen := len(c.batched) + coalLen := len(c.coalescedOrder) + + if batchLen == 0 && coalLen == 0 { + return nil + } + + result := make([]MarshaledEvent, 0, batchLen+coalLen) + + if batchLen > 0 { + result = append(result, c.batched...) + c.batched = c.batched[:0] + } + + if coalLen > 0 { + for _, key := range c.coalescedOrder { + result = append(result, c.coalesced[key]) + } + c.coalesced = make(map[string]MarshaledEvent, coalLen) + c.coalescedOrder = c.coalescedOrder[:0] + } + + return result +} + +// pending returns the total number of buffered events. +func (c *coalescer) pending() int { + c.mu.Lock() + n := len(c.batched) + len(c.coalescedOrder) + c.mu.Unlock() + return n +} diff --git a/middleware/sse/config.go b/middleware/sse/config.go new file mode 100644 index 00000000000..582bfa6c189 --- /dev/null +++ b/middleware/sse/config.go @@ -0,0 +1,109 @@ +package sse + +import ( + "time" + + "github.com/gofiber/fiber/v3" +) + +// Config defines the configuration for the SSE middleware. +type Config struct { + // Next defines a function to skip this middleware when returned true. + // + // Optional. Default: nil + Next func(c fiber.Ctx) bool + + // OnConnect is called when a new client connects, before the SSE + // stream begins. Use it for authentication, topic selection, and + // connection limits. Set conn.Topics and conn.Metadata here. + // Return a non-nil error to reject the connection (sends 403). + // + // Optional. Default: nil + OnConnect func(c fiber.Ctx, conn *Connection) error + + // OnDisconnect is called after a client disconnects. + // + // Optional. Default: nil + OnDisconnect func(conn *Connection) + + // OnPause is called when a connection is paused (browser tab hidden). + // + // Optional. Default: nil + OnPause func(conn *Connection) + + // OnResume is called when a connection is resumed (browser tab visible). + // + // Optional. Default: nil + OnResume func(conn *Connection) + + // Replayer enables Last-Event-ID replay. If nil, replay is disabled. + // + // Optional. Default: nil + Replayer Replayer + + // FlushInterval is how often batched (P1) and coalesced (P2) events + // are flushed to clients. Instant (P0) events bypass this. + // + // Optional. Default: 2s + FlushInterval time.Duration + + // HeartbeatInterval is how often a comment is sent to idle connections + // to detect disconnects and prevent proxy timeouts. + // + // Optional. Default: 30s + HeartbeatInterval time.Duration + + // MaxLifetime is the maximum duration a single SSE connection can + // stay open. After this, the connection is closed gracefully. + // Set to -1 for unlimited. + // + // Optional. Default: 30m + MaxLifetime time.Duration + + // SendBufferSize is the per-connection channel buffer. If full, + // events are dropped and the client should reconnect. + // + // Optional. Default: 256 + SendBufferSize int + + // RetryMS is the reconnection interval hint sent to clients via the + // retry: directive on connect. + // + // Optional. Default: 3000 + RetryMS int +} + +// ConfigDefault is the default config. +var ConfigDefault = Config{ + FlushInterval: 2 * time.Second, + SendBufferSize: 256, + HeartbeatInterval: 30 * time.Second, + MaxLifetime: 30 * time.Minute, + RetryMS: 3000, +} + +func configDefault(config ...Config) Config { + if len(config) < 1 { + return ConfigDefault + } + + cfg := config[0] + + if cfg.FlushInterval <= 0 { + cfg.FlushInterval = ConfigDefault.FlushInterval + } + if cfg.SendBufferSize <= 0 { + cfg.SendBufferSize = ConfigDefault.SendBufferSize + } + if cfg.HeartbeatInterval <= 0 { + cfg.HeartbeatInterval = ConfigDefault.HeartbeatInterval + } + if cfg.MaxLifetime == 0 { + cfg.MaxLifetime = ConfigDefault.MaxLifetime + } + if cfg.RetryMS <= 0 { + cfg.RetryMS = ConfigDefault.RetryMS + } + + return cfg +} diff --git a/middleware/sse/connection.go b/middleware/sse/connection.go new file mode 100644 index 00000000000..e500507a584 --- /dev/null +++ b/middleware/sse/connection.go @@ -0,0 +1,132 @@ +package sse + +import ( + "bufio" + "sync" + "sync/atomic" + "time" +) + +// Connection represents a single SSE client connection managed by the hub. +type Connection struct { + CreatedAt time.Time + LastEventID atomic.Value + lastWrite atomic.Value + send chan MarshaledEvent + heartbeat chan struct{} + done chan struct{} + coalescer *coalescer + // Metadata holds connection metadata set during OnConnect. + // It is frozen (defensive-copied) after OnConnect returns -- do not + // mutate it from other goroutines after the connection is registered. + Metadata map[string]string + ID string + Topics []string + MessagesSent atomic.Int64 + MessagesDropped atomic.Int64 + once sync.Once + paused atomic.Bool +} + +// newConnection creates a Connection with the given buffer size. +func newConnection(id string, topics []string, bufferSize int, flushInterval time.Duration) *Connection { + c := &Connection{ + ID: id, + Topics: topics, + Metadata: make(map[string]string), + CreatedAt: time.Now(), + send: make(chan MarshaledEvent, bufferSize), + heartbeat: make(chan struct{}, 1), + done: make(chan struct{}), + } + c.lastWrite.Store(time.Now()) + c.LastEventID.Store("") + c.coalescer = newCoalescer(flushInterval) + return c +} + +// Close terminates the connection. Safe to call multiple times. +func (c *Connection) Close() { + c.once.Do(func() { + close(c.done) + }) +} + +// IsClosed returns true if the connection has been terminated. +func (c *Connection) IsClosed() bool { + select { + case <-c.done: + return true + default: + return false + } +} + +// trySend attempts to deliver an event to the connection's send channel. +// Returns false if the buffer is full (backpressure). +func (c *Connection) trySend(me MarshaledEvent) bool { //nolint:gocritic // hugeParam: value semantics for channel send + select { + case c.send <- me: + return true + default: + c.MessagesDropped.Add(1) + return false + } +} + +// sendHeartbeat sends a heartbeat signal to the connection. +// Non-blocking — if a heartbeat is already pending it is silently dropped. +func (c *Connection) sendHeartbeat() { + select { + case c.heartbeat <- struct{}{}: + default: + } +} + +// writeLoop runs inside Fiber's SendStreamWriter. It reads from the send +// and heartbeat channels, writing SSE-formatted events to the bufio.Writer. +func (c *Connection) writeLoop(w *bufio.Writer) { + for { + select { + case <-c.done: + return + case <-c.heartbeat: + if err := writeComment(w, "heartbeat"); err != nil { + c.Close() + return + } + if err := w.Flush(); err != nil { + c.Close() + return + } + case me, ok := <-c.send: + if !ok { + return + } + if _, err := me.WriteTo(w); err != nil { + c.Close() + return + } + if err := w.Flush(); err != nil { + c.Close() + return + } + c.MessagesSent.Add(1) + c.lastWrite.Store(time.Now()) + if me.ID != "" { + c.LastEventID.Store(me.ID) + } + } + } +} + +// connMatchesGroup returns true if ALL key-value pairs in the group +// match the connection's metadata. +func connMatchesGroup(conn *Connection, group map[string]string) bool { + for k, v := range group { + if conn.Metadata[k] != v { + return false + } + } + return true +} diff --git a/middleware/sse/domain_event.go b/middleware/sse/domain_event.go new file mode 100644 index 00000000000..2c2b2db01e3 --- /dev/null +++ b/middleware/sse/domain_event.go @@ -0,0 +1,135 @@ +package sse + +import ( + "maps" +) + +// DomainEvent publishes a domain event to the hub. This is the primary +// method for triggering real-time UI updates from your backend code. +// +// Parameters: +// - resource: what changed ("orders", "products", "customers") +// - action: what happened ("created", "updated", "deleted", "refresh") +// - resourceID: specific item ID (empty for collection-level events) +// - tenantID: tenant scope (empty for global events) +// - hint: optional small payload (nil if not needed) +func (h *Hub) DomainEvent(resource, action, resourceID, tenantID string, hint map[string]any) { + evt := InvalidationEvent{ + Resource: resource, + Action: action, + ResourceID: resourceID, + Hint: hint, + } + + event := Event{ + Type: "invalidate", + Topics: []string{resource}, + Data: evt, + Priority: PriorityInstant, + } + + if tenantID != "" { + event.Group = map[string]string{"tenant_id": tenantID} + } + + h.Publish(event) +} + +// Progress publishes a progress update for a long-running operation. +// Uses PriorityCoalesced — if progress goes 5%→8% in one flush +// window, only 8% is sent to the client. +func (h *Hub) Progress(topic, resourceID, tenantID string, current, total int, hint ...map[string]any) { + pct := 0 + if total > 0 { + pct = (current * 100) / total + } + + data := map[string]any{ + "resource_id": resourceID, + "current": current, + "total": total, + "pct": pct, + } + if len(hint) > 0 && hint[0] != nil { + maps.Copy(data, hint[0]) + } + + event := Event{ + Type: "progress", + Topics: []string{topic}, + Data: data, + Priority: PriorityCoalesced, + CoalesceKey: "progress:" + topic + ":" + resourceID, + } + + if tenantID != "" { + event.Group = map[string]string{"tenant_id": tenantID} + } + + h.Publish(event) +} + +// Complete publishes a completion signal for a long-running operation. +// Uses PriorityInstant — completion always delivers immediately. +func (h *Hub) Complete(topic, resourceID, tenantID string, success bool, hint map[string]any) { //nolint:revive // flag-parameter: public API toggle + action := "completed" + if !success { + action = "failed" + } + + data := map[string]any{ + "resource_id": resourceID, + "status": action, + } + maps.Copy(data, hint) + + event := Event{ + Type: "complete", + Topics: []string{topic}, + Data: data, + Priority: PriorityInstant, + } + + if tenantID != "" { + event.Group = map[string]string{"tenant_id": tenantID} + } + + h.Publish(event) +} + +// DomainEventSpec describes a single domain event within a batch. +type DomainEventSpec struct { + Hint map[string]any `json:"hint,omitempty"` + Resource string `json:"resource"` + Action string `json:"action"` + ResourceID string `json:"resource_id,omitempty"` +} + +// BatchDomainEvents publishes multiple domain events as a single SSE frame. +// The event is delivered to any connection subscribed to ANY of the resources +// in the batch. This is by design — batches target clients subscribed to +// multiple topics (e.g., a dashboard). Clients should filter the specs array +// locally by resource if they only care about a subset. +func (h *Hub) BatchDomainEvents(tenantID string, specs []DomainEventSpec) { + if len(specs) == 0 { + return + } + topicSet := make(map[string]struct{}) + for _, s := range specs { + topicSet[s.Resource] = struct{}{} + } + topics := make([]string, 0, len(topicSet)) + for t := range topicSet { + topics = append(topics, t) + } + batchEvt := Event{ + Type: "batch", + Topics: topics, + Data: specs, + Priority: PriorityInstant, + } + if tenantID != "" { + batchEvt.Group = map[string]string{"tenant_id": tenantID} + } + h.Publish(batchEvt) +} diff --git a/middleware/sse/event.go b/middleware/sse/event.go new file mode 100644 index 00000000000..c26cbd71812 --- /dev/null +++ b/middleware/sse/event.go @@ -0,0 +1,176 @@ +package sse + +import ( + "encoding/json" + "fmt" + "io" + "strings" + "sync/atomic" + "time" +) + +// Priority controls how an event is delivered to clients. +type Priority int + +const ( + // PriorityInstant bypasses all buffering — the event is written to the + // client connection immediately. Use for errors, auth revocations, + // force-refresh commands, and chat messages. + PriorityInstant Priority = 0 + + // PriorityBatched collects events in a time window (FlushInterval) and + // sends them all at once. Use for status changes, media updates. + PriorityBatched Priority = 1 + + // PriorityCoalesced uses last-writer-wins per CoalesceKey. Multiple + // events with the same key within a flush window are merged — only the + // latest is sent. Use for progress bars, live counters, typing indicators. + PriorityCoalesced Priority = 2 +) + +// Event represents a single SSE event to be published through the hub. +type Event struct { + CreatedAt time.Time + Data any + Group map[string]string + Type string + ID string + CoalesceKey string + Topics []string + TTL time.Duration + Priority Priority +} + +// globalEventID is an auto-incrementing counter for event IDs. +var globalEventID atomic.Uint64 + +// nextEventID returns a monotonically increasing event ID string. +func nextEventID() string { + return fmt.Sprintf("evt_%d", globalEventID.Add(1)) +} + +// MarshaledEvent is the wire-ready representation of an SSE event. +// External Replayer implementations receive and return this type. +type MarshaledEvent struct { + // CreatedAt is the timestamp of the source Event (zero if unset). + CreatedAt time.Time + ID string + Type string + Data string + // TTL is the maximum age for this event. Zero means no expiry. + TTL time.Duration + Retry int // -1 means omit +} + +// sanitizeSSEField strips carriage returns and newlines from SSE control +// fields (id, event) to prevent SSE injection attacks. An attacker-controlled +// value containing \r or \n could break SSE framing and inject fake events. +func sanitizeSSEField(s string) string { + return strings.NewReplacer("\r\n", "", "\r", "", "\n", "").Replace(s) +} + +// marshalEvent converts an Event into wire-ready format. +func marshalEvent(e *Event) MarshaledEvent { + me := MarshaledEvent{ + ID: sanitizeSSEField(e.ID), + Type: sanitizeSSEField(e.Type), + CreatedAt: e.CreatedAt, + TTL: e.TTL, + Retry: -1, + } + + if me.ID == "" { + me.ID = nextEventID() + } + + switch v := e.Data.(type) { + case nil: + me.Data = "" + case string: + me.Data = v + case []byte: + me.Data = string(v) + case json.Marshaler: + b, err := v.MarshalJSON() + if err != nil { + errJSON, _ := json.Marshal(err.Error()) //nolint:errcheck,errchkjson // encoding a string never fails + me.Data = fmt.Sprintf(`{"error":%s}`, string(errJSON)) + } else { + me.Data = string(b) + } + default: + b, err := json.Marshal(v) + if err != nil { + errJSON, _ := json.Marshal(err.Error()) //nolint:errcheck,errchkjson // encoding a string never fails + me.Data = fmt.Sprintf(`{"error":%s}`, string(errJSON)) + } else { + me.Data = string(b) + } + } + + return me +} + +// WriteTo writes the SSE-formatted event to w following the Server-Sent +// Events specification. +func (me *MarshaledEvent) WriteTo(w io.Writer) (int64, error) { + var total int64 + + if me.ID != "" { + n, err := fmt.Fprintf(w, "id: %s\n", me.ID) + total += int64(n) + if err != nil { + return total, fmt.Errorf("sse: write id: %w", err) + } + } + + if me.Type != "" { + n, err := fmt.Fprintf(w, "event: %s\n", me.Type) + total += int64(n) + if err != nil { + return total, fmt.Errorf("sse: write event: %w", err) + } + } + + if me.Retry >= 0 { + n, err := fmt.Fprintf(w, "retry: %d\n", me.Retry) + total += int64(n) + if err != nil { + return total, fmt.Errorf("sse: write retry: %w", err) + } + } + + // strings.SplitSeq("", "\n") yields "", correctly writing "data: \n" for empty data. + for line := range strings.SplitSeq(me.Data, "\n") { + n, err := fmt.Fprintf(w, "data: %s\n", line) + total += int64(n) + if err != nil { + return total, fmt.Errorf("sse: write data: %w", err) + } + } + + n, err := fmt.Fprint(w, "\n") + total += int64(n) + if err != nil { + return total, fmt.Errorf("sse: write terminator: %w", err) + } + return total, nil +} + +// writeComment writes an SSE comment line. +func writeComment(w io.Writer, text string) error { + _, err := fmt.Fprintf(w, ": %s\n\n", text) + if err != nil { + return fmt.Errorf("sse: write comment: %w", err) + } + return nil +} + +// writeRetry writes the retry directive. +func writeRetry(w io.Writer, ms int) error { + _, err := fmt.Fprintf(w, "retry: %d\n\n", ms) + if err != nil { + return fmt.Errorf("sse: write retry: %w", err) + } + return nil +} diff --git a/middleware/sse/example_test.go b/middleware/sse/example_test.go new file mode 100644 index 00000000000..f046e2efd1d --- /dev/null +++ b/middleware/sse/example_test.go @@ -0,0 +1,101 @@ +package sse + +import ( + "context" + "fmt" + "time" + + "github.com/gofiber/fiber/v3" +) + +func Example() { + app := fiber.New() + + handler, hub := NewWithHub(Config{ + OnConnect: func(_ fiber.Ctx, conn *Connection) error { + conn.Topics = []string{"notifications"} + conn.Metadata["user_id"] = "example" + return nil + }, + }) + + app.Get("/events", handler) + + // Publish from any handler or worker + hub.Publish(Event{ + Type: "update", + Data: map[string]string{"message": "hello"}, + Topics: []string{"notifications"}, + }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + if err := hub.Shutdown(ctx); err != nil { + panic(err) + } + + fmt.Println("Hub created and shut down successfully") //nolint:errcheck // example test output + // Output: Hub created and shut down successfully +} + +func Example_invalidation() { + _, hub := NewWithHub() + + // Replace polling: instead of clients polling every 30s, + // push an invalidation signal when data changes. + hub.Invalidate("orders", "ord_123", "created") + + // Multi-tenant + hub.InvalidateForTenant("t_1", "orders", "ord_456", "updated") + + // With hints (small extra data) + hub.InvalidateWithHint("orders", "ord_789", "created", map[string]any{ + "total": 149.99, + }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + if err := hub.Shutdown(ctx); err != nil { + panic(err) + } + + fmt.Println("Invalidation events published") //nolint:errcheck // example test output + // Output: Invalidation events published +} + +func Example_progress() { + _, hub := NewWithHub() + + // Coalesced: if progress goes 1%→2%→3%→4% in one flush window, + // only 4% is sent to the client. + for i := 1; i <= 100; i++ { + hub.Progress("import", "imp_1", "t_1", i, 100) + } + hub.Complete("import", "imp_1", "t_1", true, nil) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + if err := hub.Shutdown(ctx); err != nil { + panic(err) + } + + fmt.Println("Progress tracking complete") //nolint:errcheck // example test output + // Output: Progress tracking complete +} + +func Example_ticketAuth() { + store := NewMemoryTicketStore() + defer store.Close() + + // Issue a ticket (typically in a POST handler after JWT validation) + ticket, err := IssueTicket(store, `{"tenant":"t_1","topics":"orders,products"}`, 30*time.Second) + if err != nil { + panic(err) + } + + fmt.Println("Ticket issued, length:", len(ticket)) //nolint:errcheck // example test output + // Output: Ticket issued, length: 48 +} diff --git a/middleware/sse/fanout.go b/middleware/sse/fanout.go new file mode 100644 index 00000000000..2a9a361af62 --- /dev/null +++ b/middleware/sse/fanout.go @@ -0,0 +1,144 @@ +package sse + +import ( + "context" + "time" + + "github.com/gofiber/fiber/v3/log" +) + +// PubSubSubscriber abstracts a pub/sub system (Redis, NATS, etc.) for +// auto-fan-out from an external message broker into the SSE hub. +type PubSubSubscriber interface { + // Subscribe listens on the given channel and sends received messages + // to the provided callback. It blocks until ctx is canceled. + Subscribe(ctx context.Context, channel string, onMessage func(payload string)) error +} + +// FanOutConfig configures auto-fan-out from an external pub/sub to the hub. +type FanOutConfig struct { + // Subscriber is the pub/sub implementation (Redis, NATS, etc.). + Subscriber PubSubSubscriber + + // Transform optionally transforms the raw pub/sub message before + // publishing to the hub. Return nil to skip the message. + Transform func(payload string) *Event + + // Channel is the pub/sub channel to subscribe to. + Channel string + + // Topic is the SSE topic to publish events to. If empty, Channel is used. + Topic string + + // EventType is the SSE event type. Required. + EventType string + + // CoalesceKey for PriorityCoalesced events. + CoalesceKey string + + // TTL for events. Zero means no expiration. + TTL time.Duration + + // Priority for delivered events. Note: PriorityInstant is 0 (the zero value), + // so it is always the default if not set explicitly. + Priority Priority +} + +// FanOut starts a goroutine that subscribes to an external pub/sub channel +// and automatically publishes received messages to the SSE hub. +// Returns a cancel function to stop the fan-out. +func (h *Hub) FanOut(cfg FanOutConfig) context.CancelFunc { //nolint:gocritic // hugeParam: public API, value semantics preferred + if cfg.Subscriber == nil { + panic("sse: FanOut requires a non-nil Subscriber") + } + + ctx, cancel := context.WithCancel(context.Background()) + + topic := cfg.Topic + if topic == "" { + topic = cfg.Channel + } + + go func() { + for { + select { + case <-ctx.Done(): + return + default: + } + + err := cfg.Subscriber.Subscribe(ctx, cfg.Channel, func(payload string) { + event := h.buildFanOutEvent(&cfg, topic, payload) + if event != nil { + h.Publish(*event) + } + }) + + if err != nil && ctx.Err() == nil { + h.logFanOutError(cfg.Channel, err) + select { + case <-time.After(3 * time.Second): + case <-ctx.Done(): + return + } + } + } + }() + + return cancel +} + +// buildFanOutEvent creates an Event from a raw pub/sub payload. +// When Transform is set, the transform function controls all event fields; +// only missing Topics and Type are filled in from the config defaults. +// When Transform is not set, the event is built entirely from config defaults. +func (*Hub) buildFanOutEvent(cfg *FanOutConfig, topic, payload string) *Event { + if cfg.Transform != nil { + transformed := cfg.Transform(payload) + if transformed == nil { + return nil + } + event := *transformed + // Only fill in missing Topics and Type — Transform controls everything else. + if len(event.Topics) == 0 { + event.Topics = []string{topic} + } + if event.Type == "" { + event.Type = cfg.EventType + } + return &event + } + + // Non-transform: build entirely from config defaults. + event := Event{ + Type: cfg.EventType, + Data: payload, + Topics: []string{topic}, + Priority: cfg.Priority, + CoalesceKey: cfg.CoalesceKey, + TTL: cfg.TTL, + } + + return &event +} + +// logFanOutError logs a fan-out subscriber error. +func (*Hub) logFanOutError(channel string, err error) { + log.Warnf("sse: fan-out subscriber error, retrying channel=%s error=%v", channel, err) +} + +// FanOutMulti starts multiple fan-out goroutines at once. +// Returns a single cancel function that stops all of them. +func (h *Hub) FanOutMulti(configs ...FanOutConfig) context.CancelFunc { + ctx, cancel := context.WithCancel(context.Background()) + + for _, cfg := range configs { + innerCancel := h.FanOut(cfg) + go func() { + <-ctx.Done() + innerCancel() + }() + } + + return cancel +} diff --git a/middleware/sse/invalidation.go b/middleware/sse/invalidation.go new file mode 100644 index 00000000000..5800919ca4e --- /dev/null +++ b/middleware/sse/invalidation.go @@ -0,0 +1,118 @@ +package sse + +import ( + "time" +) + +// InvalidationEvent is a lightweight signal telling the client to refetch +// a specific resource. +type InvalidationEvent struct { + // Hint is optional extra data for the client. + Hint map[string]any `json:"hint,omitempty"` + + // Resource is what changed (e.g., "orders", "products"). + Resource string `json:"resource"` + + // Action is what happened (e.g., "created", "updated", "deleted"). + Action string `json:"action"` + + // ResourceID is the specific item that changed (optional). + ResourceID string `json:"resource_id,omitempty"` +} + +// Invalidate publishes a cache invalidation signal to all connections +// subscribed to the given topic. +func (h *Hub) Invalidate(topic, resourceID, action string) { + h.Publish(Event{ + Type: "invalidate", + Topics: []string{topic}, + Data: InvalidationEvent{ + Resource: topic, + Action: action, + ResourceID: resourceID, + }, + Priority: PriorityInstant, + }) +} + +// InvalidateForTenant publishes a tenant-scoped cache invalidation signal. +func (h *Hub) InvalidateForTenant(tenantID, topic, resourceID, action string) { + h.Publish(Event{ + Type: "invalidate", + Topics: []string{topic}, + Group: map[string]string{"tenant_id": tenantID}, + Data: InvalidationEvent{ + Resource: topic, + Action: action, + ResourceID: resourceID, + }, + Priority: PriorityInstant, + }) +} + +// InvalidateWithHint publishes an invalidation signal with extra data hints. +func (h *Hub) InvalidateWithHint(topic, resourceID, action string, hint map[string]any) { + h.Publish(Event{ + Type: "invalidate", + Topics: []string{topic}, + Data: InvalidationEvent{ + Resource: topic, + Action: action, + ResourceID: resourceID, + Hint: hint, + }, + Priority: PriorityInstant, + }) +} + +// InvalidateForTenantWithHint publishes a tenant-scoped invalidation signal +// with extra data hints. +func (h *Hub) InvalidateForTenantWithHint(tenantID, topic, resourceID, action string, hint map[string]any) { + h.Publish(Event{ + Type: "invalidate", + Topics: []string{topic}, + Group: map[string]string{"tenant_id": tenantID}, + Data: InvalidationEvent{ + Resource: topic, + Action: action, + ResourceID: resourceID, + Hint: hint, + }, + Priority: PriorityInstant, + }) +} + +// Signal publishes a simple refresh signal. +func (h *Hub) Signal(topic string) { + h.Publish(Event{ + Type: "signal", + Topics: []string{topic}, + Data: map[string]string{"signal": "refresh"}, + Priority: PriorityCoalesced, + CoalesceKey: "signal:" + topic, + }) +} + +// SignalForTenant publishes a tenant-scoped refresh signal. +func (h *Hub) SignalForTenant(tenantID, topic string) { + h.Publish(Event{ + Type: "signal", + Topics: []string{topic}, + Group: map[string]string{"tenant_id": tenantID}, + Data: map[string]string{"signal": "refresh"}, + Priority: PriorityCoalesced, + CoalesceKey: "signal:" + topic + ":" + tenantID, + }) +} + +// SignalThrottled publishes a signal with a TTL. +func (h *Hub) SignalThrottled(topic string, ttl time.Duration) { + h.Publish(Event{ + Type: "signal", + Topics: []string{topic}, + Data: map[string]string{"signal": "refresh"}, + Priority: PriorityCoalesced, + CoalesceKey: "signal:" + topic, + TTL: ttl, + }) +} diff --git a/middleware/sse/metrics.go b/middleware/sse/metrics.go new file mode 100644 index 00000000000..64032366c25 --- /dev/null +++ b/middleware/sse/metrics.go @@ -0,0 +1,195 @@ +package sse + +import ( + "math" + "strconv" + "strings" + "time" + + "github.com/gofiber/fiber/v3" +) + +// MetricsSnapshot is a detailed point-in-time view of the hub for monitoring. +type MetricsSnapshot struct { + ConnectionsByTopic map[string]int `json:"connections_by_topic"` + EventsByType map[string]int64 `json:"events_by_type"` + Timestamp string `json:"timestamp"` + Connections []ConnectionInfo `json:"connections,omitempty"` + EventsPublished int64 `json:"events_published"` + EventsDropped int64 `json:"events_dropped"` + AvgBufferSaturation float64 `json:"avg_buffer_saturation"` + MaxBufferSaturation float64 `json:"max_buffer_saturation"` + ActiveConnections int `json:"active_connections"` + PausedConnections int `json:"paused_connections"` + TotalPendingEvents int `json:"total_pending_events"` +} + +// ConnectionInfo is per-connection detail for the metrics snapshot. +type ConnectionInfo struct { + Metadata map[string]string `json:"metadata"` + ID string `json:"id"` + CreatedAt string `json:"created_at"` + Uptime string `json:"uptime"` + LastEventID string `json:"last_event_id"` + Topics []string `json:"topics"` + MessagesSent int64 `json:"messages_sent"` + MessagesDropped int64 `json:"messages_dropped"` + BufferUsage int `json:"buffer_usage"` + BufferCapacity int `json:"buffer_capacity"` + Paused bool `json:"paused"` +} + +// Metrics returns a detailed snapshot of the hub for monitoring dashboards. +func (h *Hub) Metrics(includeConnections bool) MetricsSnapshot { //nolint:revive // flag-parameter: public API toggle + h.mu.RLock() + defer h.mu.RUnlock() + + now := time.Now() + snap := MetricsSnapshot{ + Timestamp: now.Format(time.RFC3339), + ActiveConnections: len(h.connections), + ConnectionsByTopic: make(map[string]int, len(h.topicIndex)), + EventsPublished: h.metrics.eventsPublished.Load(), + EventsDropped: h.metrics.eventsDropped.Load(), + } + + for topic, conns := range h.topicIndex { + snap.ConnectionsByTopic[topic] = len(conns) + } + + snap.EventsByType = h.metrics.snapshotEventsByType() + + var totalSat float64 + var maxSat float64 + for _, conn := range h.connections { + if conn.paused.Load() { + snap.PausedConnections++ + } + + pending := conn.coalescer.pending() + snap.TotalPendingEvents += pending + + bufCap := cap(conn.send) + sat := float64(0) + if bufCap > 0 { + sat = float64(len(conn.send)) / float64(bufCap) + } + totalSat += sat + if sat > maxSat { + maxSat = sat + } + + if includeConnections { + lastID, _ := conn.LastEventID.Load().(string) //nolint:errcheck // type assertion on atomic.Value + snap.Connections = append(snap.Connections, ConnectionInfo{ + ID: conn.ID, + Topics: conn.Topics, + Metadata: conn.Metadata, + CreatedAt: conn.CreatedAt.Format(time.RFC3339), + Uptime: now.Sub(conn.CreatedAt).Round(time.Second).String(), + MessagesSent: conn.MessagesSent.Load(), + MessagesDropped: conn.MessagesDropped.Load(), + LastEventID: lastID, + BufferUsage: len(conn.send), + BufferCapacity: cap(conn.send), + Paused: conn.paused.Load(), + }) + } + } + + if len(h.connections) > 0 { + snap.AvgBufferSaturation = totalSat / float64(len(h.connections)) + } + snap.MaxBufferSaturation = maxSat + + return snap +} + +// MetricsHandler returns a Fiber handler that serves the metrics snapshot +// as JSON. Mount it on an admin route: +// +// app.Get("/admin/sse/metrics", hub.MetricsHandler()) +func (h *Hub) MetricsHandler() fiber.Handler { + return func(c fiber.Ctx) error { + includeConns := c.Query("connections") == "true" + snap := h.Metrics(includeConns) + return c.JSON(snap) + } +} + +// PrometheusHandler returns a Fiber handler that serves Prometheus-formatted +// metrics. Mount on your /metrics endpoint: +// +// app.Get("/metrics/sse", hub.PrometheusHandler()) +func (h *Hub) PrometheusHandler() fiber.Handler { + return func(c fiber.Ctx) error { + snap := h.Metrics(false) + c.Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8") + + lines := []byte("") + lines = appendProm(lines, "sse_connections_active", "", float64(snap.ActiveConnections)) + lines = appendProm(lines, "sse_connections_paused", "", float64(snap.PausedConnections)) + lines = appendProm(lines, "sse_events_published_total", "", float64(snap.EventsPublished)) + lines = appendProm(lines, "sse_events_dropped_total", "", float64(snap.EventsDropped)) + lines = appendProm(lines, "sse_pending_events", "", float64(snap.TotalPendingEvents)) + lines = appendProm(lines, "sse_buffer_saturation_avg", "", snap.AvgBufferSaturation) + lines = appendProm(lines, "sse_buffer_saturation_max", "", snap.MaxBufferSaturation) + + for topic, count := range snap.ConnectionsByTopic { + lines = appendProm(lines, "sse_connections_by_topic", `topic="`+escapePromLabelValue(topic)+`"`, float64(count)) + } + + for eventType, count := range snap.EventsByType { + lines = appendProm(lines, "sse_events_by_type_total", `type="`+escapePromLabelValue(eventType)+`"`, float64(count)) + } + + return c.Send(lines) + } +} + +func appendProm(buf []byte, name, labels string, value float64) []byte { + if labels != "" { + return append(buf, []byte(name+"{"+labels+"} "+formatFloat(value)+"\n")...) + } + return append(buf, []byte(name+" "+formatFloat(value)+"\n")...) +} + +// escapePromLabelValue escapes backslashes, double quotes, and newlines in +// Prometheus label values per the exposition format spec. +func escapePromLabelValue(s string) string { + var needsEscape bool + for _, c := range s { + if c == '\\' || c == '"' || c == '\n' { + needsEscape = true + break + } + } + if !needsEscape { + return s + } + var b strings.Builder + b.Grow(len(s) + 4) + for _, c := range s { + switch c { + case '\\': + b.WriteString(`\\`) //nolint:errcheck // strings.Builder.WriteString never fails + case '"': + b.WriteString(`\"`) //nolint:errcheck // strings.Builder.WriteString never fails + case '\n': + b.WriteString(`\n`) //nolint:errcheck // strings.Builder.WriteString never fails + default: + b.WriteRune(c) //nolint:errcheck // strings.Builder.WriteRune never fails + } + } + return b.String() +} + +func formatFloat(f float64) string { + if math.IsNaN(f) || math.IsInf(f, 0) { + return "0" + } + if f == float64(int64(f)) { + return strconv.FormatInt(int64(f), 10) + } + return strconv.FormatFloat(f, 'f', 6, 64) +} diff --git a/middleware/sse/replayer.go b/middleware/sse/replayer.go new file mode 100644 index 00000000000..7d1e79f9415 --- /dev/null +++ b/middleware/sse/replayer.go @@ -0,0 +1,148 @@ +package sse + +import ( + "sync" + "time" +) + +// Replayer stores events for replay when a client reconnects with Last-Event-ID. +// Implement this interface to use Redis Streams, a database, or any durable store. +type Replayer interface { + // Store persists an event for potential future replay. + Store(event MarshaledEvent, topics []string) error + + // Replay returns all events after lastEventID that match any of the given topics. + Replay(lastEventID string, topics []string) ([]MarshaledEvent, error) +} + +// replayEntry pairs an event with its topic set for filtering. +type replayEntry struct { + timestamp time.Time + topics map[string]struct{} + event MarshaledEvent +} + +// MemoryReplayer is an in-memory Replayer backed by a fixed-size circular buffer. +// Events older than TTL or exceeding MaxEvents are evicted. Once the buffer is +// full, new events overwrite the oldest entry with zero allocations. +// +// For production deployments with high event throughput, use a persistent +// replayer backed by Redis Streams or a database. +type MemoryReplayer struct { + entries []replayEntry + mu sync.RWMutex + ttl time.Duration + head int // write position (wraps around) + count int // number of valid entries + maxEvents int +} + +// MemoryReplayerConfig configures the in-memory replayer. +type MemoryReplayerConfig struct { + // MaxEvents is the maximum number of events to retain (default: 1000). + MaxEvents int + + // TTL is how long events are kept before eviction (default: 5m). + TTL time.Duration +} + +// NewMemoryReplayer creates an in-memory replayer. +func NewMemoryReplayer(cfg ...MemoryReplayerConfig) *MemoryReplayer { + c := MemoryReplayerConfig{ + MaxEvents: 1000, + TTL: 5 * time.Minute, + } + if len(cfg) > 0 { + if cfg[0].MaxEvents > 0 { + c.MaxEvents = cfg[0].MaxEvents + } + if cfg[0].TTL > 0 { + c.TTL = cfg[0].TTL + } + } + return &MemoryReplayer{ + entries: make([]replayEntry, c.MaxEvents), + maxEvents: c.MaxEvents, + ttl: c.TTL, + } +} + +// Store adds an event to the replay buffer. Once full, overwrites the +// oldest entry (O(1), zero allocations). +func (r *MemoryReplayer) Store(event MarshaledEvent, topics []string) error { //nolint:gocritic // hugeParam: matches Replayer interface, value semantics + topicSet := make(map[string]struct{}, len(topics)) + for _, t := range topics { + topicSet[t] = struct{}{} + } + + r.mu.Lock() + defer r.mu.Unlock() + + r.entries[r.head] = replayEntry{ + event: event, + topics: topicSet, + timestamp: time.Now(), + } + r.head = (r.head + 1) % r.maxEvents + if r.count < r.maxEvents { + r.count++ + } + + return nil +} + +// Replay returns events after lastEventID matching the given topics. +func (r *MemoryReplayer) Replay(lastEventID string, topics []string) ([]MarshaledEvent, error) { + if lastEventID == "" { + return nil, nil + } + + r.mu.RLock() + defer r.mu.RUnlock() + + cutoff := time.Now().Add(-r.ttl) + + // Walk the ring buffer in chronological order to find lastEventID. + start := (r.head - r.count + r.maxEvents) % r.maxEvents + foundIdx := -1 + for i := range r.count { + idx := (start + i) % r.maxEvents + if r.entries[idx].event.ID == lastEventID { + foundIdx = i + 1 // start from the NEXT entry + break + } + } + + if foundIdx < 0 { + return nil, nil + } + + var result []MarshaledEvent + for i := foundIdx; i < r.count; i++ { + idx := (start + i) % r.maxEvents + entry := r.entries[idx] + + if entry.timestamp.Before(cutoff) { + continue + } + + if matchesAnyTopicWithWildcards(topics, entry.topics) { + result = append(result, entry.event) + } + } + + return result, nil +} + +// matchesAnyTopicWithWildcards returns true if any subscription pattern +// matches any of the stored event topics. +func matchesAnyTopicWithWildcards(subscriptionPatterns []string, eventTopics map[string]struct{}) bool { + for _, pattern := range subscriptionPatterns { + for topic := range eventTopics { + if topicMatch(pattern, topic) { + return true + } + } + } + return false +} diff --git a/middleware/sse/sse.go b/middleware/sse/sse.go new file mode 100644 index 00000000000..37abcf6a916 --- /dev/null +++ b/middleware/sse/sse.go @@ -0,0 +1,645 @@ +// Package sse provides Server-Sent Events middleware for Fiber. +// +// It is the only SSE implementation built natively for Fiber's +// fasthttp architecture — no net/http adapters, no broken disconnect +// detection. +// +// Features: event coalescing (last-writer-wins), three priority lanes +// (instant/batched/coalesced), NATS-style topic wildcards, adaptive +// per-connection throttling, connection groups (publish by metadata), +// built-in JWT and ticket auth helpers, Prometheus metrics, graceful +// Kubernetes-style drain, auto fan-out from Redis/NATS, and pluggable +// Last-Event-ID replay. +// +// Quick start: +// +// handler, hub := sse.NewWithHub(sse.Config{ +// OnConnect: func(c fiber.Ctx, conn *sse.Connection) error { +// conn.Topics = []string{"notifications"} +// return nil +// }, +// }) +// app.Get("/events", handler) +// hub.Publish(sse.Event{Type: "ping", Data: "hello", Topics: []string{"notifications"}}) +package sse + +import ( + "bufio" + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "maps" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/log" +) + +// Hub is the central SSE event broker. It manages client connections, +// event routing, coalescing, and delivery. All methods are goroutine-safe. +type Hub struct { + throttler *adaptiveThrottler + connections map[string]*Connection + topicIndex map[string]map[string]struct{} + wildcardConns map[string]struct{} + register chan *Connection + unregister chan *Connection + events chan Event + shutdown chan struct{} + stopped chan struct{} + cfg Config + metrics hubMetrics + mu sync.RWMutex + shutdownOnce sync.Once + draining atomic.Bool +} + +// New creates a new SSE middleware handler. Use this when you don't need +// direct access to the Hub (e.g., simple streaming without Publish). +// +// For most use cases, prefer [NewWithHub] instead. +func New(config ...Config) fiber.Handler { + handler, _ := NewWithHub(config...) + return handler +} + +// NewWithHub creates a new SSE middleware handler and returns it along +// with the Hub for publishing events. This is the primary entry point. +// +// handler, hub := sse.NewWithHub(sse.Config{ +// OnConnect: func(c fiber.Ctx, conn *sse.Connection) error { +// conn.Topics = []string{"notifications", "live"} +// conn.Metadata["tenant_id"] = c.Locals("tenant_id").(string) +// return nil +// }, +// }) +// app.Get("/events", handler) +// +// // From any handler or worker: +// hub.Publish(sse.Event{Type: "update", Data: "hello", Topics: []string{"live"}}) +func NewWithHub(config ...Config) (fiber.Handler, *Hub) { + cfg := configDefault(config...) + + hub := &Hub{ + cfg: cfg, + register: make(chan *Connection, 64), + unregister: make(chan *Connection, 64), + events: make(chan Event, 1024), + shutdown: make(chan struct{}), + connections: make(map[string]*Connection), + topicIndex: make(map[string]map[string]struct{}), + wildcardConns: make(map[string]struct{}), + throttler: newAdaptiveThrottler(cfg.FlushInterval), + metrics: hubMetrics{eventsByType: make(map[string]*atomic.Int64)}, + stopped: make(chan struct{}), + } + + go hub.run() + + handler := func(c fiber.Ctx) error { + // Skip middleware if Next returns true + if cfg.Next != nil && cfg.Next(c) { + return c.Next() + } + + // Reject during graceful drain + if hub.draining.Load() { + c.Set("Retry-After", "5") + return c.Status(fiber.StatusServiceUnavailable).SendString("server draining, please reconnect") + } + + conn := newConnection( + generateID(), + nil, + cfg.SendBufferSize, + cfg.FlushInterval, + ) + + // Let the application authenticate and configure the connection + if cfg.OnConnect != nil { + if err := cfg.OnConnect(c, conn); err != nil { + return c.Status(fiber.StatusForbidden).SendString(err.Error()) + } + } + + // Freeze metadata — defensive copy to prevent concurrent mutation + // after the connection is registered with the hub. + frozen := make(map[string]string, len(conn.Metadata)) + maps.Copy(frozen, conn.Metadata) + conn.Metadata = frozen + + if len(conn.Topics) == 0 { + return c.Status(fiber.StatusBadRequest).SendString("no topics subscribed") + } + + // Set SSE headers + c.Set("Content-Type", "text/event-stream") + c.Set("Cache-Control", "no-cache") + c.Set("Connection", "keep-alive") + c.Set("X-Accel-Buffering", "no") + + // Capture Last-Event-ID before entering the stream writer + lastEventID := c.Get("Last-Event-ID") + if lastEventID == "" { + lastEventID = c.Query("lastEventID") + } + + return c.SendStreamWriter(func(w *bufio.Writer) { + defer func() { + // Use select to avoid blocking forever if hub.run() has exited (CRITICAL-3). + select { + case hub.unregister <- conn: + case <-hub.shutdown: + } + conn.Close() + if cfg.OnDisconnect != nil { + cfg.OnDisconnect(conn) + } + }() + + if err := hub.initStream(w, conn, lastEventID); err != nil { + return + } + + // Register AFTER initStream to avoid duplicate events from + // replay + live delivery race (MAJOR-7). + select { + case hub.register <- conn: + case <-hub.shutdown: + return + } + + hub.watchLifetime(conn) + hub.watchShutdown(conn) + conn.writeLoop(w) + }) + } + + return handler, hub +} + +// generateID produces a random 32-character hex string for connection IDs. +func generateID() string { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + panic("sse: failed to generate connection ID: " + err.Error()) + } + return hex.EncodeToString(b) +} + +// Publish sends an event to all connections subscribed to the event's topics. +// This method is goroutine-safe and non-blocking. If the internal event buffer +// is full, the event is dropped and eventsDropped is incremented. +func (h *Hub) Publish(event Event) { //nolint:gocritic // hugeParam: public API, value semantics preferred + if event.TTL > 0 && event.CreatedAt.IsZero() { + event.CreatedAt = time.Now() + } + select { + case h.events <- event: + h.metrics.eventsPublished.Add(1) + case <-h.shutdown: + // Hub is shutting down, discard + default: + // Buffer full — drop event to avoid blocking callers (MAJOR-5). + h.metrics.eventsDropped.Add(1) + } +} + +// SetPaused pauses or resumes a connection by ID. Paused connections +// skip P1/P2 events (visibility hint for hidden browser tabs). +// P0 (instant) events are always delivered regardless. +func (h *Hub) SetPaused(connID string, paused bool) { //nolint:revive // flag-parameter: public API toggle + h.mu.RLock() + conn, ok := h.connections[connID] + h.mu.RUnlock() + if ok { + wasPaused := conn.paused.Swap(paused) + if paused && !wasPaused && h.cfg.OnPause != nil { + h.cfg.OnPause(conn) + } + if !paused && wasPaused && h.cfg.OnResume != nil { + h.cfg.OnResume(conn) + } + } +} + +// Shutdown gracefully drains all connections and stops the hub. +// It enters drain mode (rejects new connections), sends a server-shutdown +// event to all clients, then closes the hub. +// Safe to call multiple times — subsequent calls are no-ops. +// Pass context.Background() for an unbounded wait. +func (h *Hub) Shutdown(ctx context.Context) error { + h.draining.Store(true) + h.shutdownOnce.Do(func() { + close(h.shutdown) + }) + + select { + case <-h.stopped: + return nil + case <-ctx.Done(): + return fmt.Errorf("sse: shutdown: %w", ctx.Err()) + } +} + +// Stats returns a snapshot of the hub's current state. +func (h *Hub) Stats() HubStats { + h.mu.RLock() + defer h.mu.RUnlock() + + byTopic := make(map[string]int, len(h.topicIndex)) + for topic, conns := range h.topicIndex { + byTopic[topic] = len(conns) + } + + return HubStats{ + ActiveConnections: len(h.connections), + TotalTopics: len(h.topicIndex), + EventsPublished: h.metrics.eventsPublished.Load(), + EventsDropped: h.metrics.eventsDropped.Load(), + ConnectionsByTopic: byTopic, + EventsByType: h.metrics.snapshotEventsByType(), + } +} + +// initStream writes the initial SSE preamble: retry hint, replayed events, +// and the connected event. +func (h *Hub) initStream(w *bufio.Writer, conn *Connection, lastEventID string) error { + if err := writeRetry(w, h.cfg.RetryMS); err != nil { + return err + } + + if err := h.replayEvents(w, conn, lastEventID); err != nil { + return err + } + + return sendConnectedEvent(w, conn) +} + +// replayEvents replays missed events if the client sent a Last-Event-ID. +func (h *Hub) replayEvents(w *bufio.Writer, conn *Connection, lastEventID string) error { + if lastEventID == "" || h.cfg.Replayer == nil { + return nil + } + events, err := h.cfg.Replayer.Replay(lastEventID, conn.Topics) + if err != nil { + log.Warnf("sse: replay error for conn %s: %v", conn.ID, err) + return nil + } + if len(events) == 0 { + return nil + } + for _, me := range events { + if _, werr := me.WriteTo(w); werr != nil { + return werr + } + } + if err := w.Flush(); err != nil { + return fmt.Errorf("sse: flush replay: %w", err) + } + return nil +} + +// sendConnectedEvent writes the connected event with the connection ID +// and subscribed topics. +func sendConnectedEvent(w *bufio.Writer, conn *Connection) error { + topicsJSON, err := json.Marshal(conn.Topics) + if err != nil { + topicsJSON = []byte("[]") + } + connected := MarshaledEvent{ + ID: nextEventID(), + Type: "connected", + Data: fmt.Sprintf(`{"connection_id":%q,"topics":%s}`, conn.ID, string(topicsJSON)), + Retry: -1, + } + if _, err := connected.WriteTo(w); err != nil { + return err + } + if err := w.Flush(); err != nil { + return fmt.Errorf("sse: flush connected event: %w", err) + } + return nil +} + +// watchLifetime starts a goroutine that closes the connection after +// MaxLifetime has elapsed. +func (h *Hub) watchLifetime(conn *Connection) { + if h.cfg.MaxLifetime <= 0 { + return + } + go func() { + timer := time.NewTimer(h.cfg.MaxLifetime) + defer timer.Stop() + select { + case <-timer.C: + conn.Close() + case <-conn.done: + } + }() +} + +// shutdownDrainDelay is the time between sending the server-shutdown event +// and closing the connection, allowing the client to process the event. +const shutdownDrainDelay = 200 * time.Millisecond + +// watchShutdown starts a goroutine that sends a server-shutdown event +// and closes the connection when the hub begins draining. +func (h *Hub) watchShutdown(conn *Connection) { + go func() { + select { + case <-h.shutdown: + if !conn.IsClosed() { + shutdownEvt := MarshaledEvent{ + ID: nextEventID(), + Type: "server-shutdown", + Data: "{}", + Retry: -1, + } + conn.trySend(shutdownEvt) + time.Sleep(shutdownDrainDelay) + } + conn.Close() + case <-conn.done: + } + }() +} + +// run is the hub's main event loop. +func (h *Hub) run() { + defer close(h.stopped) + + flushTicker := time.NewTicker(h.cfg.FlushInterval) + defer flushTicker.Stop() + + heartbeatTicker := time.NewTicker(h.cfg.HeartbeatInterval) + defer heartbeatTicker.Stop() + + cleanupTicker := time.NewTicker(5 * time.Minute) + defer cleanupTicker.Stop() + + for { + select { + case conn := <-h.register: + h.addConnection(conn) + + case conn := <-h.unregister: + h.removeConnection(conn) + + case event := <-h.events: + h.routeEvent(&event) + + case <-flushTicker.C: + h.flushAll() + + case <-heartbeatTicker.C: + h.sendHeartbeats() + + case <-cleanupTicker.C: + h.throttler.cleanup(time.Now().Add(-10 * time.Minute)) + + case <-h.shutdown: + h.mu.Lock() + for _, conn := range h.connections { + conn.Close() + } + h.mu.Unlock() + return + } + } +} + +// addConnection registers a new connection and indexes it by topic. +func (h *Hub) addConnection(conn *Connection) { + h.mu.Lock() + defer h.mu.Unlock() + + h.connections[conn.ID] = conn + + hasWildcard := false + for _, topic := range conn.Topics { + if strings.ContainsAny(topic, "*>") { + hasWildcard = true + } else { + if h.topicIndex[topic] == nil { + h.topicIndex[topic] = make(map[string]struct{}) + } + h.topicIndex[topic][conn.ID] = struct{}{} + } + } + if hasWildcard { + h.wildcardConns[conn.ID] = struct{}{} + } + + log.Infof("sse: connection opened conn_id=%s topics=%v total=%d", conn.ID, conn.Topics, len(h.connections)) +} + +// removeConnection unregisters a connection and removes it from topic indexes. +func (h *Hub) removeConnection(conn *Connection) { + h.mu.Lock() + defer h.mu.Unlock() + + if _, exists := h.connections[conn.ID]; !exists { + return + } + + for _, topic := range conn.Topics { + if idx, ok := h.topicIndex[topic]; ok { + delete(idx, conn.ID) + if len(idx) == 0 { + delete(h.topicIndex, topic) + } + } + } + + delete(h.wildcardConns, conn.ID) + delete(h.connections, conn.ID) + h.throttler.remove(conn.ID) + + log.Infof("sse: connection closed conn_id=%s sent=%d dropped=%d total=%d", + conn.ID, conn.MessagesSent.Load(), conn.MessagesDropped.Load(), len(h.connections)) +} + +// routeEvent delivers an event to all connections subscribed to its topics. +func (h *Hub) routeEvent(event *Event) { + if event.TTL > 0 && !event.CreatedAt.IsZero() { + if time.Since(event.CreatedAt) > event.TTL { + h.metrics.eventsDropped.Add(1) + return + } + } + + me := marshalEvent(event) + h.metrics.trackEventType(event.Type) + + // Skip replay storage for group-scoped events — replaying them without + // tenant context would leak data across tenants (CRITICAL-2). + if h.cfg.Replayer != nil && len(event.Group) == 0 { + _ = h.cfg.Replayer.Store(me, event.Topics) //nolint:errcheck // best-effort replay storage + } + + h.mu.RLock() + defer h.mu.RUnlock() + + seen := h.matchConnections(event) + + for connID := range seen { + conn, ok := h.connections[connID] + if !ok || conn.IsClosed() { + continue + } + if conn.paused.Load() && event.Priority != PriorityInstant { + continue + } + h.deliverToConn(conn, event, me) + } +} + +// matchConnections collects all connection IDs that should receive the event. +// When an event has BOTH Topics AND Group set, only connections matching BOTH +// are included (intersection semantics for tenant isolation). When only one +// dimension is set, the existing OR behavior applies. +func (h *Hub) matchConnections(event *Event) map[string]struct{} { + seen := make(map[string]struct{}) + + for _, topic := range event.Topics { + if idx, ok := h.topicIndex[topic]; ok { + for connID := range idx { + seen[connID] = struct{}{} + } + } + } + + h.matchWildcardConns(event, seen) + + // When both Topics and Group are present, filter topic-matched connections + // down to those also matching the group (AND semantics). + if len(event.Group) > 0 && len(event.Topics) > 0 { + for connID := range seen { + conn, ok := h.connections[connID] + if !ok || !connMatchesGroup(conn, event.Group) { + delete(seen, connID) + } + } + } else { + h.matchGroupConns(event, seen) + } + + return seen +} + +// matchWildcardConns adds wildcard-subscribed connections that match the event topics. +func (h *Hub) matchWildcardConns(event *Event, seen map[string]struct{}) { + for connID := range h.wildcardConns { + if _, already := seen[connID]; already { + continue + } + conn, ok := h.connections[connID] + if !ok { + continue + } + for _, eventTopic := range event.Topics { + if connMatchesTopic(conn, eventTopic) { + seen[connID] = struct{}{} + break + } + } + } +} + +// matchGroupConns adds connections that match the event's group metadata. +func (h *Hub) matchGroupConns(event *Event, seen map[string]struct{}) { + if len(event.Group) == 0 { + return + } + for connID, conn := range h.connections { + if _, already := seen[connID]; already { + continue + } + if connMatchesGroup(conn, event.Group) { + seen[connID] = struct{}{} + } + } +} + +// deliverToConn routes an event to a connection based on priority. +func (h *Hub) deliverToConn(conn *Connection, event *Event, me MarshaledEvent) { //nolint:gocritic // hugeParam: value semantics preferred for event routing + switch event.Priority { + case PriorityInstant: + if !conn.trySend(me) { + h.metrics.eventsDropped.Add(1) + } + case PriorityBatched: + conn.coalescer.addBatched(me) + default: // PriorityCoalesced + key := event.CoalesceKey + if key == "" { + key = event.Type + } + conn.coalescer.addCoalesced(key, me) + } +} + +// flushAll drains each connection's coalescer and sends buffered events. +func (h *Hub) flushAll() { + h.mu.RLock() + conns := make([]*Connection, 0, len(h.connections)) + for _, conn := range h.connections { + if !conn.IsClosed() && !conn.paused.Load() { + conns = append(conns, conn) + } + } + h.mu.RUnlock() + + for _, conn := range conns { + if conn.IsClosed() { + continue + } + + bufCap := cap(conn.send) + saturation := float64(0) + if bufCap > 0 { + saturation = float64(len(conn.send)) / float64(bufCap) + } + + if !h.throttler.shouldFlush(conn.ID, saturation) { + continue + } + + events := conn.coalescer.flush() + now := time.Now() + for _, me := range events { + // Drop coalesced events that have expired while buffered (MAJOR-6). + if me.TTL > 0 && !me.CreatedAt.IsZero() && now.Sub(me.CreatedAt) > me.TTL { + h.metrics.eventsDropped.Add(1) + continue + } + if !conn.trySend(me) { + h.metrics.eventsDropped.Add(1) + } + } + } +} + +// sendHeartbeats sends a comment to connections that haven't received +// real data recently. +func (h *Hub) sendHeartbeats() { + h.mu.RLock() + defer h.mu.RUnlock() + + now := time.Now() + for _, conn := range h.connections { + if conn.IsClosed() { + continue + } + lastWrite, _ := conn.lastWrite.Load().(time.Time) //nolint:errcheck // type assertion on atomic.Value + if now.Sub(lastWrite) >= h.cfg.HeartbeatInterval { + conn.sendHeartbeat() + } + } +} diff --git a/middleware/sse/sse_test.go b/middleware/sse/sse_test.go new file mode 100644 index 00000000000..90452304bb2 --- /dev/null +++ b/middleware/sse/sse_test.go @@ -0,0 +1,2109 @@ +package sse + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "math" + "net/http" + "strings" + "testing" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/require" +) + +func Test_SSE_New(t *testing.T) { + t.Parallel() + + handler, hub := NewWithHub() + require.NotNil(t, handler) + require.NotNil(t, hub) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) +} + +func Test_SSE_New_DefaultConfig(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + require.Equal(t, 2*time.Second, hub.cfg.FlushInterval) + require.Equal(t, 256, hub.cfg.SendBufferSize) + require.Equal(t, 30*time.Second, hub.cfg.HeartbeatInterval) + require.Equal(t, 30*time.Minute, hub.cfg.MaxLifetime) + require.Equal(t, 3000, hub.cfg.RetryMS) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) +} + +func Test_SSE_New_CustomConfig(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub(Config{ + FlushInterval: 5 * time.Second, + SendBufferSize: 128, + HeartbeatInterval: 10 * time.Second, + MaxLifetime: time.Hour, + RetryMS: 5000, + }) + require.Equal(t, 5*time.Second, hub.cfg.FlushInterval) + require.Equal(t, 128, hub.cfg.SendBufferSize) + require.Equal(t, 10*time.Second, hub.cfg.HeartbeatInterval) + require.Equal(t, time.Hour, hub.cfg.MaxLifetime) + require.Equal(t, 5000, hub.cfg.RetryMS) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) +} + +func Test_SSE_Next(t *testing.T) { + t.Parallel() + + app := fiber.New() + handler, hub := NewWithHub(Config{ + Next: func(c fiber.Ctx) bool { + return c.Query("skip") == "true" + }, + OnConnect: func(_ fiber.Ctx, conn *Connection) error { + conn.Topics = []string{"test"} + return nil + }, + }) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + app.Get("/events", handler) + app.Get("/events", func(c fiber.Ctx) error { + return c.SendString("skipped") + }) + + req, err := http.NewRequest(fiber.MethodGet, "/events?skip=true", http.NoBody) + require.NoError(t, err) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "skipped", string(body)) +} + +func Test_SSE_NoTopics(t *testing.T) { + t.Parallel() + + app := fiber.New() + handler, hub := NewWithHub(Config{ + OnConnect: func(_ fiber.Ctx, _ *Connection) error { + // Don't set any topics + return nil + }, + }) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + app.Get("/events", handler) + + req, err := http.NewRequest(fiber.MethodGet, "/events", http.NoBody) + require.NoError(t, err) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusBadRequest, resp.StatusCode) +} + +func Test_SSE_OnConnectReject(t *testing.T) { + t.Parallel() + + app := fiber.New() + handler, hub := NewWithHub(Config{ + OnConnect: func(_ fiber.Ctx, _ *Connection) error { + return errors.New("unauthorized") + }, + }) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + app.Get("/events", handler) + + req, err := http.NewRequest(fiber.MethodGet, "/events", http.NoBody) + require.NoError(t, err) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusForbidden, resp.StatusCode) +} + +func Test_SSE_GenerateID(t *testing.T) { + t.Parallel() + + ids := make(map[string]struct{}) + for range 1000 { + id := generateID() + require.Len(t, id, 32) + _, exists := ids[id] + require.False(t, exists, "duplicate ID generated") + ids[id] = struct{}{} + } +} + +func Test_SSE_TopicMatch(t *testing.T) { + t.Parallel() + + tests := []struct { + pattern string + topic string + want bool + }{ + {"events", "events", true}, + {"events", "events.sub", false}, + {"notifications.*", "notifications.orders", true}, + {"notifications.*", "notifications.orders.new", false}, + {"analytics.>", "analytics.live", true}, + {"analytics.>", "analytics.live.visitors", true}, + {"analytics.>", "analytics", false}, + {"*", "anything", true}, + {">", "anything", true}, + {">", "a.b.c", true}, + // > must be last token — invalid patterns should not match + {"a.>.c", "a.b.c", false}, + {">.b", "a.b", false}, + } + + for _, tt := range tests { + got := topicMatch(tt.pattern, tt.topic) + require.Equal(t, tt.want, got, "topicMatch(%q, %q)", tt.pattern, tt.topic) + } +} + +func Test_SSE_MarshalEvent(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + data any + want string + }{ + {"nil", nil, ""}, + {"string", "hello", "hello"}, + {"bytes", []byte("world"), "world"}, + {"struct", map[string]string{"key": "val"}, `{"key":"val"}`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + me := marshalEvent(&Event{Data: tt.data}) + require.Equal(t, tt.want, me.Data) + require.NotEmpty(t, me.ID) // auto-generated + }) + } +} + +func Test_SSE_MarshaledEvent_WriteTo(t *testing.T) { + t.Parallel() + + me := MarshaledEvent{ + ID: "evt_1", + Type: "test", + Data: "hello world", + } + + var buf bytes.Buffer + n, err := me.WriteTo(&buf) + require.NoError(t, err) + require.Positive(t, n) + + output := buf.String() + require.Contains(t, output, "id: evt_1\n") + require.Contains(t, output, "event: test\n") + require.Contains(t, output, "data: hello world\n") + require.True(t, strings.HasSuffix(output, "\n\n")) +} + +func Test_SSE_MarshaledEvent_WriteTo_Multiline(t *testing.T) { + t.Parallel() + + me := MarshaledEvent{ + ID: "evt_2", + Type: "test", + Data: "line1\nline2\nline3", + } + + var buf bytes.Buffer + _, err := me.WriteTo(&buf) + require.NoError(t, err) + + output := buf.String() + require.Contains(t, output, "data: line1\n") + require.Contains(t, output, "data: line2\n") + require.Contains(t, output, "data: line3\n") +} + +func Test_SSE_MarshaledEvent_WriteTo_Retry(t *testing.T) { + t.Parallel() + + me := MarshaledEvent{ + ID: "evt_3", + Type: "test", + Data: "x", + Retry: 3000, + } + + var buf bytes.Buffer + _, err := me.WriteTo(&buf) + require.NoError(t, err) + require.Contains(t, buf.String(), "retry: 3000\n") +} + +func Test_SSE_Coalescer(t *testing.T) { + t.Parallel() + + c := newCoalescer(time.Second) + + // Add batched events + c.addBatched(MarshaledEvent{ID: "1", Data: "a"}) + c.addBatched(MarshaledEvent{ID: "2", Data: "b"}) + + // Add coalesced events (last wins) + c.addCoalesced("key1", MarshaledEvent{ID: "3", Data: "old"}) + c.addCoalesced("key1", MarshaledEvent{ID: "4", Data: "new"}) + c.addCoalesced("key2", MarshaledEvent{ID: "5", Data: "other"}) + + require.Equal(t, 4, c.pending()) + + events := c.flush() + require.Len(t, events, 4) + + // Batched first + require.Equal(t, "a", events[0].Data) + require.Equal(t, "b", events[1].Data) + + // Coalesced: key1 = "new" (last wins), key2 = "other" + require.Equal(t, "new", events[2].Data) + require.Equal(t, "other", events[3].Data) + + // Should be empty now + require.Nil(t, c.flush()) +} + +func Test_SSE_AdaptiveThrottler(t *testing.T) { + t.Parallel() + + at := newAdaptiveThrottler(2 * time.Second) + + // First flush always passes + require.True(t, at.shouldFlush("conn1", 0.0)) + + // Second flush immediately — should fail (too soon) + require.False(t, at.shouldFlush("conn1", 0.0)) + + // Clean up + at.remove("conn1") +} + +func Test_SSE_MemoryReplayer(t *testing.T) { + t.Parallel() + + replayer := NewMemoryReplayer(MemoryReplayerConfig{MaxEvents: 5}) + + for i := range 5 { + require.NoError(t, replayer.Store( + MarshaledEvent{ID: fmt.Sprintf("evt_%d", i), Data: fmt.Sprintf("data_%d", i)}, + []string{"topic1"}, + )) + } + + events, err := replayer.Replay("evt_2", []string{"topic1"}) + require.NoError(t, err) + require.Len(t, events, 2) // evt_3 and evt_4 + + // Unknown ID returns nil + events, err = replayer.Replay("unknown", []string{"topic1"}) + require.NoError(t, err) + require.Nil(t, events) +} + +func Test_SSE_MemoryReplayer_MaxEvents(t *testing.T) { + t.Parallel() + + replayer := NewMemoryReplayer(MemoryReplayerConfig{MaxEvents: 3}) + + for i := range 10 { + require.NoError(t, replayer.Store( + MarshaledEvent{ID: fmt.Sprintf("evt_%d", i)}, + []string{"t"}, + )) + } + + // Only last 3 events should be retained (ring buffer count) + replayer.mu.RLock() + require.Equal(t, 3, replayer.count) + replayer.mu.RUnlock() + + // Replay from evt_7 should return evt_8 and evt_9 + events, err := replayer.Replay("evt_7", []string{"t"}) + require.NoError(t, err) + require.Len(t, events, 2) + require.Equal(t, "evt_8", events[0].ID) + require.Equal(t, "evt_9", events[1].ID) +} + +func Test_SSE_TicketAuth(t *testing.T) { + t.Parallel() + + store := NewMemoryTicketStore() + + ticket, err := IssueTicket(store, `{"tenant":"t_1"}`, 5*time.Minute) + require.NoError(t, err) + require.Len(t, ticket, 48) // 24 bytes = 48 hex chars + + // Consume ticket + value, err := store.GetDel(ticket) + require.NoError(t, err) + require.JSONEq(t, `{"tenant":"t_1"}`, value) + + // Second use should fail + value, err = store.GetDel(ticket) + require.NoError(t, err) + require.Empty(t, value) +} + +func Test_SSE_Publish_Stats(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + hub.Publish(Event{Type: "test", Topics: []string{"t"}, Data: "hello"}) + time.Sleep(50 * time.Millisecond) // let run loop process + + stats := hub.Stats() + require.Equal(t, int64(1), stats.EventsPublished) +} + +func Test_SSE_Shutdown(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err := hub.Shutdown(ctx) + require.NoError(t, err) + require.True(t, hub.draining.Load()) +} + +func Test_SSE_Shutdown_Idempotent(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // First call shuts down + require.NoError(t, hub.Shutdown(ctx)) + + // Second call must not panic (sync.Once guards close) + require.NoError(t, hub.Shutdown(ctx)) +} + +func Test_SSE_Shutdown_Background_Context(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + + err := hub.Shutdown(context.Background()) + require.NoError(t, err) +} + +func Test_SSE_Invalidation(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + // These should not panic + hub.Invalidate("orders", "ord_123", "created") + hub.InvalidateForTenant("t_1", "orders", "ord_123", "created") + hub.InvalidateWithHint("orders", "ord_123", "created", map[string]any{"total": 99.99}) + hub.InvalidateForTenantWithHint("t_1", "orders", "ord_123", "created", nil) + hub.Signal("dashboard") + hub.SignalForTenant("t_1", "dashboard") + hub.SignalThrottled("analytics", time.Minute) + + time.Sleep(50 * time.Millisecond) + stats := hub.Stats() + require.Equal(t, int64(7), stats.EventsPublished) +} + +func Test_SSE_DomainEvent(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + hub.DomainEvent("orders", "created", "ord_1", "t_1", nil) + hub.Progress("import", "imp_1", "t_1", 50, 100) + hub.Complete("import", "imp_1", "t_1", true, nil) + hub.BatchDomainEvents("t_1", []DomainEventSpec{ + {Resource: "orders", Action: "created", ResourceID: "o1"}, + {Resource: "products", Action: "updated", ResourceID: "p1"}, + }) + + time.Sleep(50 * time.Millisecond) + stats := hub.Stats() + require.Equal(t, int64(4), stats.EventsPublished) +} + +func Test_SSE_Draining_RejectsConnection(t *testing.T) { + t.Parallel() + + app := fiber.New() + handler, hub := NewWithHub(Config{ + OnConnect: func(_ fiber.Ctx, conn *Connection) error { + conn.Topics = []string{"test"} + return nil + }, + }) + + app.Get("/events", handler) + + // Start draining + hub.draining.Store(true) + defer func() { + close(hub.shutdown) + <-hub.stopped + }() + + req, err := http.NewRequest(fiber.MethodGet, "/events", http.NoBody) + require.NoError(t, err) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusServiceUnavailable, resp.StatusCode) +} + +func Test_SSE_Connection_Lifecycle(t *testing.T) { + t.Parallel() + + conn := newConnection("test-id", []string{"t"}, 10, time.Second) + require.Equal(t, "test-id", conn.ID) + require.False(t, conn.IsClosed()) + + conn.Close() + require.True(t, conn.IsClosed()) + + // Double close should not panic + conn.Close() +} + +func Test_SSE_Connection_TrySend_Backpressure(t *testing.T) { + t.Parallel() + + conn := newConnection("test", nil, 2, time.Second) + + require.True(t, conn.trySend(MarshaledEvent{Data: "1"})) + require.True(t, conn.trySend(MarshaledEvent{Data: "2"})) + + // Buffer full + require.False(t, conn.trySend(MarshaledEvent{Data: "3"})) + require.Equal(t, int64(1), conn.MessagesDropped.Load()) +} + +func Test_SSE_WriteComment(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + err := writeComment(&buf, "heartbeat") + require.NoError(t, err) + require.Equal(t, ": heartbeat\n\n", buf.String()) +} + +func Test_SSE_WriteRetry(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + err := writeRetry(&buf, 3000) + require.NoError(t, err) + require.Equal(t, "retry: 3000\n\n", buf.String()) +} + +func Test_SSE_Metrics(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + snap := hub.Metrics(false) + require.Equal(t, 0, snap.ActiveConnections) + require.NotEmpty(t, snap.Timestamp) +} + +func Test_SSE_MaxLifetime_Unlimited(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub(Config{ + MaxLifetime: -1, // unlimited + }) + require.Equal(t, time.Duration(-1), hub.cfg.MaxLifetime) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) +} + +func Test_SSE_JWTAuth_Valid(t *testing.T) { + t.Parallel() + + app := fiber.New() + handler, hub := NewWithHub(Config{ + OnConnect: JWTAuth(func(token string) (map[string]string, error) { + if token == "valid-token" { + return map[string]string{"user_id": "u_1", "tenant_id": "t_1"}, nil + } + return nil, errors.New("invalid token") + }), + }) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + app.Get("/events", handler) + + // No token → 403 + req, err := http.NewRequest(fiber.MethodGet, "/events", http.NoBody) + require.NoError(t, err) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusForbidden, resp.StatusCode) + + // Invalid token → 403 + req, err = http.NewRequest(fiber.MethodGet, "/events?token=bad", http.NoBody) + require.NoError(t, err) + resp, err = app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusForbidden, resp.StatusCode) + + // Valid token but no topics set → 400 + req, err = http.NewRequest(fiber.MethodGet, "/events?token=valid-token", http.NoBody) + require.NoError(t, err) + resp, err = app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusBadRequest, resp.StatusCode) +} + +func Test_SSE_JWTAuth_BearerHeader(t *testing.T) { + t.Parallel() + + authHandler := JWTAuth(func(token string) (map[string]string, error) { + if token == "my-jwt" { + return map[string]string{"user": "test"}, nil + } + return nil, errors.New("bad") + }) + + app := fiber.New() + handler, hub := NewWithHub(Config{ + OnConnect: func(c fiber.Ctx, conn *Connection) error { + if err := authHandler(c, conn); err != nil { + return err + } + conn.Topics = []string{"test"} + return nil + }, + }) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + app.Get("/events", handler) + + // Bearer header should work — SSE streams never end, so use short timeout + req, err := http.NewRequest(fiber.MethodGet, "/events", http.NoBody) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer my-jwt") + resp, err := app.Test(req, fiber.TestConfig{Timeout: 500 * time.Millisecond}) + // Timeout is expected for SSE — the stream opened successfully + if err != nil { + require.ErrorContains(t, err, "timeout") + } + if resp != nil { + require.Equal(t, fiber.StatusOK, resp.StatusCode) + } +} + +func Test_SSE_TicketAuth_Full(t *testing.T) { + t.Parallel() + + store := NewMemoryTicketStore() + defer store.Close() + + ticket, err := IssueTicket(store, `test-value`, 5*time.Minute) + require.NoError(t, err) + + app := fiber.New() + handler, hub := NewWithHub(Config{ + OnConnect: TicketAuth(store, func(_ string) (map[string]string, []string, error) { + return map[string]string{"source": "ticket"}, []string{"notifications"}, nil + }), + }) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + app.Get("/events", handler) + + // Valid ticket → SSE stream starts (timeout expected) + req, err := http.NewRequest(fiber.MethodGet, "/events?ticket="+ticket, http.NoBody) + require.NoError(t, err) + resp, err := app.Test(req, fiber.TestConfig{Timeout: 500 * time.Millisecond}) + if err != nil { + require.ErrorContains(t, err, "timeout") + } + if resp != nil { + require.Equal(t, fiber.StatusOK, resp.StatusCode) + } + + // Same ticket again → 403 (one-time use, already consumed) + req, err = http.NewRequest(fiber.MethodGet, "/events?ticket="+ticket, http.NoBody) + require.NoError(t, err) + resp, err = app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusForbidden, resp.StatusCode) + + // No ticket → 403 + req, err = http.NewRequest(fiber.MethodGet, "/events", http.NoBody) + require.NoError(t, err) + resp, err = app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusForbidden, resp.StatusCode) +} + +func Test_SSE_TicketStore_Close(t *testing.T) { + t.Parallel() + + store := NewMemoryTicketStore() + require.NoError(t, store.Set("test", "value", time.Minute)) + + // Close should not panic + store.Close() + + // Double close should not panic + store.Close() + + // Operations after close still work (just no cleanup goroutine) + v, err := store.GetDel("test") + require.NoError(t, err) + require.Equal(t, "value", v) +} + +func Test_SSE_MetricsHandler(t *testing.T) { + t.Parallel() + + app := fiber.New() + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + app.Get("/metrics", hub.MetricsHandler()) + + req, err := http.NewRequest(fiber.MethodGet, "/metrics", http.NoBody) + require.NoError(t, err) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Contains(t, string(body), `"active_connections"`) + require.Contains(t, string(body), `"events_published"`) +} + +func Test_SSE_MetricsHandler_WithConnections(t *testing.T) { + t.Parallel() + + app := fiber.New() + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + app.Get("/metrics", hub.MetricsHandler()) + + req, err := http.NewRequest(fiber.MethodGet, "/metrics?connections=true", http.NoBody) + require.NoError(t, err) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) +} + +func Test_SSE_PrometheusHandler(t *testing.T) { + t.Parallel() + + app := fiber.New() + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + // Publish some events so metrics have data + hub.Publish(Event{Type: "test", Topics: []string{"t"}, Data: "x"}) + time.Sleep(50 * time.Millisecond) + + app.Get("/prom", hub.PrometheusHandler()) + + req, err := http.NewRequest(fiber.MethodGet, "/prom", http.NoBody) + require.NoError(t, err) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + output := string(body) + require.Contains(t, output, "sse_connections_active") + require.Contains(t, output, "sse_events_published_total") + require.Contains(t, output, "sse_events_dropped_total") +} + +func Test_SSE_Prometheus_LabelEscaping(t *testing.T) { + t.Parallel() + + // No special chars → pass through + require.Equal(t, "normal", escapePromLabelValue("normal")) + + // Quotes get escaped + require.Equal(t, `say \"hello\"`, escapePromLabelValue(`say "hello"`)) + + // Backslashes get escaped + require.Equal(t, `path\\to`, escapePromLabelValue(`path\to`)) + + // Newlines get escaped + require.Equal(t, `line1\nline2`, escapePromLabelValue("line1\nline2")) +} + +func Test_SSE_FanOut(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + // Mock subscriber that sends one message then blocks + received := make(chan string, 1) + mockSub := &mockPubSubSubscriber{ + onSubscribe: func(ctx context.Context, _ string, onMessage func(string)) error { + onMessage("test-payload") + received <- "delivered" + <-ctx.Done() + return ctx.Err() + }, + } + + cancel := hub.FanOut(FanOutConfig{ + Subscriber: mockSub, + Channel: "test-channel", + EventType: "notification", + }) + + // Wait for message delivery + select { + case <-received: + // success + case <-time.After(2 * time.Second): + t.Fatal("FanOut did not deliver message in time") + } + + cancel() + time.Sleep(50 * time.Millisecond) + + stats := hub.Stats() + require.Equal(t, int64(1), stats.EventsPublished) +} + +func Test_SSE_FanOut_Cancel(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + subscribeCalled := make(chan struct{}, 1) + mockSub := &mockPubSubSubscriber{ + onSubscribe: func(ctx context.Context, _ string, _ func(string)) error { + subscribeCalled <- struct{}{} + <-ctx.Done() + return ctx.Err() + }, + } + + cancel := hub.FanOut(FanOutConfig{ + Subscriber: mockSub, + Channel: "ch", + EventType: "evt", + }) + + <-subscribeCalled + cancel() + // Should not hang — goroutine exits cleanly +} + +func Test_SSE_FanOutMulti(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + count := make(chan struct{}, 2) + mockSub := &mockPubSubSubscriber{ + onSubscribe: func(ctx context.Context, channel string, onMessage func(string)) error { + onMessage("msg-from-" + channel) + count <- struct{}{} + <-ctx.Done() + return ctx.Err() + }, + } + + cancel := hub.FanOutMulti( + FanOutConfig{Subscriber: mockSub, Channel: "ch1", EventType: "e1"}, + FanOutConfig{Subscriber: mockSub, Channel: "ch2", EventType: "e2"}, + ) + + // Wait for both + <-count + <-count + cancel() + + time.Sleep(50 * time.Millisecond) + stats := hub.Stats() + require.Equal(t, int64(2), stats.EventsPublished) +} + +func Test_SSE_FanOut_Transform(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + done := make(chan struct{}, 1) + mockSub := &mockPubSubSubscriber{ + onSubscribe: func(ctx context.Context, _ string, onMessage func(string)) error { + onMessage("raw-data") + done <- struct{}{} + <-ctx.Done() + return ctx.Err() + }, + } + + cancel := hub.FanOut(FanOutConfig{ + Subscriber: mockSub, + Channel: "ch", + EventType: "default", + Transform: func(payload string) *Event { + return &Event{ + Type: "transformed", + Data: "transformed:" + payload, + Topics: []string{"custom-topic"}, + } + }, + }) + + <-done + cancel() + time.Sleep(50 * time.Millisecond) + + stats := hub.Stats() + require.Equal(t, int64(1), stats.EventsPublished) +} + +func Test_SSE_FanOut_TransformNil(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + done := make(chan struct{}, 1) + mockSub := &mockPubSubSubscriber{ + onSubscribe: func(ctx context.Context, _ string, onMessage func(string)) error { + onMessage("skip-this") + done <- struct{}{} + <-ctx.Done() + return ctx.Err() + }, + } + + cancel := hub.FanOut(FanOutConfig{ + Subscriber: mockSub, + Channel: "ch", + EventType: "evt", + Transform: func(_ string) *Event { + return nil // skip message + }, + }) + + <-done + cancel() + time.Sleep(50 * time.Millisecond) + + stats := hub.Stats() + require.Equal(t, int64(0), stats.EventsPublished) +} + +func Test_SSE_SetPaused(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + // SetPaused on non-existent connection should not panic + hub.SetPaused("nonexistent", true) + + // Add a connection manually + conn := newConnection("test-conn", []string{"t"}, 10, time.Second) + hub.mu.Lock() + hub.connections["test-conn"] = conn + hub.mu.Unlock() + + hub.SetPaused("test-conn", true) + require.True(t, conn.paused.Load()) + + hub.SetPaused("test-conn", false) + require.False(t, conn.paused.Load()) +} + +func Test_SSE_BatchDomainEvents_Empty(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + // Empty batch should be a no-op + hub.BatchDomainEvents("t_1", nil) + hub.BatchDomainEvents("t_1", []DomainEventSpec{}) + + time.Sleep(50 * time.Millisecond) + stats := hub.Stats() + require.Equal(t, int64(0), stats.EventsPublished) +} + +func Test_SSE_Replayer_EmptyID(t *testing.T) { + t.Parallel() + + replayer := NewMemoryReplayer() + + // Empty lastEventID returns nil + events, err := replayer.Replay("", []string{"t"}) + require.NoError(t, err) + require.Nil(t, events) +} + +func Test_SSE_Replayer_WildcardTopics(t *testing.T) { + t.Parallel() + + replayer := NewMemoryReplayer() + + require.NoError(t, replayer.Store(MarshaledEvent{ID: "e1"}, []string{"orders.created"})) + require.NoError(t, replayer.Store(MarshaledEvent{ID: "e2"}, []string{"orders.updated"})) + require.NoError(t, replayer.Store(MarshaledEvent{ID: "e3"}, []string{"products.created"})) + + // Wildcard replay + events, err := replayer.Replay("e1", []string{"orders.*"}) + require.NoError(t, err) + require.Len(t, events, 1) // e2 matches orders.*, e3 doesn't + require.Equal(t, "e2", events[0].ID) +} + +// --------------------------------------------------------------------------- +// Coverage-boost tests +// --------------------------------------------------------------------------- + +func Test_SSE_New_Wrapper(t *testing.T) { + t.Parallel() + handler := New() + require.NotNil(t, handler) +} + +func Test_SSE_SanitizeSSEField(t *testing.T) { + t.Parallel() + + require.Equal(t, "clean", sanitizeSSEField("clean")) + require.Equal(t, "ab", sanitizeSSEField("a\nb")) + require.Equal(t, "ab", sanitizeSSEField("a\rb")) + require.Equal(t, "ab", sanitizeSSEField("a\r\nb")) + require.Equal(t, "abc", sanitizeSSEField("a\r\nb\nc")) +} + +func Test_SSE_MarshalEvent_SanitizesIDAndType(t *testing.T) { + t.Parallel() + + me := marshalEvent(&Event{ + ID: "id\r\ninjected", + Type: "type\ninjected", + Data: "safe", + }) + require.Equal(t, "idinjected", me.ID) + require.Equal(t, "typeinjected", me.Type) +} + +func Test_SSE_MarshalEvent_JsonMarshalerError(t *testing.T) { + t.Parallel() + + me := marshalEvent(&Event{Data: badMarshaler{}}) + require.Contains(t, me.Data, "error") +} + +func Test_SSE_MarshalEvent_DefaultMarshalError(t *testing.T) { + t.Parallel() + + // A channel cannot be JSON-marshaled + me := marshalEvent(&Event{Data: make(chan int)}) + require.Contains(t, me.Data, "error") +} + +// badMarshaler implements json.Marshaler and always returns an error. +type badMarshaler struct{} + +func (badMarshaler) MarshalJSON() ([]byte, error) { + return nil, errors.New("marshal failed") +} + +func Test_SSE_WriteTo_EmptyFields(t *testing.T) { + t.Parallel() + + me := MarshaledEvent{ + Data: "x", + Retry: -1, + } + var buf bytes.Buffer + _, err := me.WriteTo(&buf) + require.NoError(t, err) + output := buf.String() + // No id: or event: lines + require.NotContains(t, output, "id:") + require.NotContains(t, output, "event:") + require.Contains(t, output, "data: x\n") +} + +func Test_SSE_ConnMatchesGroup(t *testing.T) { + t.Parallel() + + conn := newConnection("c1", []string{"t"}, 10, time.Second) + conn.Metadata["tenant_id"] = "t_1" + conn.Metadata["role"] = "admin" + + require.True(t, connMatchesGroup(conn, map[string]string{"tenant_id": "t_1"})) + require.True(t, connMatchesGroup(conn, map[string]string{"tenant_id": "t_1", "role": "admin"})) + require.False(t, connMatchesGroup(conn, map[string]string{"tenant_id": "t_2"})) + require.False(t, connMatchesGroup(conn, map[string]string{"missing": "key"})) + require.True(t, connMatchesGroup(conn, map[string]string{})) // empty group matches all +} + +func Test_SSE_SendHeartbeat(t *testing.T) { + t.Parallel() + + conn := newConnection("hb", []string{"t"}, 10, time.Second) + + // First heartbeat should succeed + conn.sendHeartbeat() + // Second should be silently dropped (buffer 1) + conn.sendHeartbeat() + + // Drain the heartbeat channel + select { + case <-conn.heartbeat: + default: + t.Fatal("expected heartbeat in channel") + } +} + +func Test_SSE_WriteLoop_Events(t *testing.T) { + t.Parallel() + + conn := newConnection("wl", []string{"t"}, 10, time.Second) + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + // Send an event + heartbeat, then close + conn.trySend(MarshaledEvent{ID: "e1", Type: "test", Data: "hello", Retry: -1}) + conn.sendHeartbeat() + + go func() { + time.Sleep(50 * time.Millisecond) + conn.Close() + }() + + conn.writeLoop(w) + + output := buf.String() + require.Contains(t, output, "id: e1\n") + require.Contains(t, output, "event: test\n") + require.Contains(t, output, "data: hello\n") + require.Contains(t, output, ": heartbeat\n") + require.Equal(t, int64(1), conn.MessagesSent.Load()) +} + +func Test_SSE_WriteLoop_ChannelClose(t *testing.T) { + t.Parallel() + + conn := newConnection("wlc", []string{"t"}, 10, time.Second) + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + // Close the send channel directly to test the !ok path + close(conn.send) + conn.writeLoop(w) + // Should return without panic +} + +func Test_SSE_TopicMatchesAny(t *testing.T) { + t.Parallel() + + require.True(t, topicMatchesAny([]string{"orders", "products"}, "orders")) + require.True(t, topicMatchesAny([]string{"orders.*"}, "orders.created")) + require.False(t, topicMatchesAny([]string{"orders", "products"}, "users")) + require.False(t, topicMatchesAny(nil, "anything")) +} + +func Test_SSE_ConnMatchesTopic(t *testing.T) { + t.Parallel() + + conn := newConnection("ct", []string{"orders.*", "products"}, 10, time.Second) + require.True(t, connMatchesTopic(conn, "orders.created")) + require.True(t, connMatchesTopic(conn, "products")) + require.False(t, connMatchesTopic(conn, "users")) +} + +func Test_SSE_EffectiveInterval_AllBranches(t *testing.T) { + t.Parallel() + + at := newAdaptiveThrottler(2 * time.Second) + + // saturation > 0.8 → maxInterval + require.Equal(t, at.maxInterval, at.effectiveInterval(0.9)) + // saturation > 0.5 → baseInterval * 2 + require.Equal(t, at.baseInterval*2, at.effectiveInterval(0.6)) + // saturation < 0.1 → minInterval + require.Equal(t, at.minInterval, at.effectiveInterval(0.05)) + // default → baseInterval + require.Equal(t, at.baseInterval, at.effectiveInterval(0.3)) +} + +func Test_SSE_Throttler_Cleanup(t *testing.T) { + t.Parallel() + + at := newAdaptiveThrottler(time.Second) + at.shouldFlush("old-conn", 0.0) + at.shouldFlush("new-conn", 0.0) + + // Make "old-conn" stale + at.mu.Lock() + at.lastFlush["old-conn"] = time.Now().Add(-20 * time.Minute) + at.mu.Unlock() + + at.cleanup(time.Now().Add(-10 * time.Minute)) + + at.mu.Lock() + _, oldExists := at.lastFlush["old-conn"] + _, newExists := at.lastFlush["new-conn"] + at.mu.Unlock() + + require.False(t, oldExists, "old conn should be cleaned up") + require.True(t, newExists, "new conn should remain") +} + +func Test_SSE_FormatFloat_AllBranches(t *testing.T) { + t.Parallel() + + require.Equal(t, "42", formatFloat(42.0)) + require.Equal(t, "3.140000", formatFloat(3.14)) + require.Equal(t, "0", formatFloat(math.NaN())) + require.Equal(t, "0", formatFloat(math.Inf(1))) + require.Equal(t, "0", formatFloat(math.Inf(-1))) +} + +func Test_SSE_Metrics_WithConnections(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + // Add a connection manually to test metrics with connections + conn := newConnection("metrics-conn", []string{"orders", "products"}, 10, time.Second) + conn.Metadata["tenant_id"] = "t_1" + + hub.mu.Lock() + hub.connections[conn.ID] = conn + hub.topicIndex["orders"] = map[string]struct{}{conn.ID: {}} + hub.topicIndex["products"] = map[string]struct{}{conn.ID: {}} + hub.mu.Unlock() + + snap := hub.Metrics(true) + require.Equal(t, 1, snap.ActiveConnections) + require.Len(t, snap.Connections, 1) + require.Equal(t, "metrics-conn", snap.Connections[0].ID) + require.Equal(t, 1, snap.ConnectionsByTopic["orders"]) + + // Test with paused connection + conn.paused.Store(true) + snap = hub.Metrics(true) + require.Equal(t, 1, snap.PausedConnections) +} + +func Test_SSE_FanOut_RetryOnError(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + attempts := make(chan struct{}, 5) + mockSub := &mockPubSubSubscriber{ + onSubscribe: func(_ context.Context, _ string, _ func(string)) error { + select { + case attempts <- struct{}{}: + default: + } + return errors.New("connection failed") + }, + } + + cancel := hub.FanOut(FanOutConfig{ + Subscriber: mockSub, + Channel: "retry-ch", + EventType: "evt", + }) + + // Wait for at least one retry attempt + <-attempts + cancel() +} + +func Test_SSE_FanOut_BuildEvent_ConfigDefaults(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + // Non-transform with all config defaults + cfg := &FanOutConfig{ + EventType: "test-event", + Priority: PriorityBatched, + CoalesceKey: "my-key", + TTL: 5 * time.Minute, + } + event := hub.buildFanOutEvent(cfg, "my-topic", "payload") + require.NotNil(t, event) + require.Equal(t, "test-event", event.Type) + require.Equal(t, []string{"my-topic"}, event.Topics) + require.Equal(t, PriorityBatched, event.Priority) + require.Equal(t, "my-key", event.CoalesceKey) + require.Equal(t, 5*time.Minute, event.TTL) + require.Equal(t, "payload", event.Data) + + // Transform that sets its own priority — should be respected + cfgT := &FanOutConfig{ + EventType: "default-type", + Priority: PriorityBatched, + Transform: func(payload string) *Event { + return &Event{ + Type: "custom-type", + Data: "custom:" + payload, + Priority: PriorityCoalesced, + Topics: []string{"custom-topic"}, + } + }, + } + event = hub.buildFanOutEvent(cfgT, "fallback-topic", "raw") + require.NotNil(t, event) + require.Equal(t, "custom-type", event.Type) + require.Equal(t, PriorityCoalesced, event.Priority) // Transform's priority preserved + require.Equal(t, []string{"custom-topic"}, event.Topics) + + // Transform returning event without Topics or Type — should use defaults + cfgT2 := &FanOutConfig{ + EventType: "fallback-type", + Transform: func(_ string) *Event { + return &Event{Data: "minimal"} + }, + } + event = hub.buildFanOutEvent(cfgT2, "default-topic", "x") + require.NotNil(t, event) + require.Equal(t, "fallback-type", event.Type) + require.Equal(t, []string{"default-topic"}, event.Topics) +} + +func Test_SSE_SetPaused_Callbacks(t *testing.T) { + t.Parallel() + + var paused, resumed bool + _, hub := NewWithHub(Config{ + OnPause: func(_ *Connection) { paused = true }, + OnResume: func(_ *Connection) { resumed = true }, + }) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("cb-conn", []string{"t"}, 10, time.Second) + hub.mu.Lock() + hub.connections["cb-conn"] = conn + hub.mu.Unlock() + + hub.SetPaused("cb-conn", true) + require.True(t, paused) + + hub.SetPaused("cb-conn", false) + require.True(t, resumed) +} + +func Test_SSE_RouteEvent_WithGroup(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + // Add two connections with different tenants + conn1 := newConnection("c1", []string{"orders"}, 10, time.Second) + conn1.Metadata["tenant_id"] = "t_1" + conn2 := newConnection("c2", []string{"orders"}, 10, time.Second) + conn2.Metadata["tenant_id"] = "t_2" + + hub.mu.Lock() + hub.connections["c1"] = conn1 + hub.connections["c2"] = conn2 + hub.topicIndex["orders"] = map[string]struct{}{"c1": {}, "c2": {}} + hub.mu.Unlock() + + // Publish with group targeting t_1 only + hub.Publish(Event{ + Type: "test", + Topics: []string{"orders"}, + Data: "for-t1", + Group: map[string]string{"tenant_id": "t_1"}, + Priority: PriorityInstant, + }) + + time.Sleep(100 * time.Millisecond) + + // conn1 should have received the event, conn2 should not + require.Equal(t, int64(0), conn1.MessagesDropped.Load()) + // Check send channel + select { + case me := <-conn1.send: + require.Contains(t, me.Data, "for-t1") + default: + t.Fatal("expected event in conn1 send channel") + } + + select { + case <-conn2.send: + t.Fatal("conn2 should NOT have received the event") + default: + // correct + } +} + +func Test_SSE_RouteEvent_GroupOnly(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + // Connection with metadata but no topic match — group-only delivery + conn := newConnection("g1", []string{"unrelated"}, 10, time.Second) + conn.Metadata["role"] = "admin" + + hub.mu.Lock() + hub.connections["g1"] = conn + hub.topicIndex["unrelated"] = map[string]struct{}{"g1": {}} + hub.mu.Unlock() + + // Publish with group only (no topic overlap) + hub.Publish(Event{ + Type: "admin-alert", + Data: "alert", + Group: map[string]string{"role": "admin"}, + Priority: PriorityInstant, + }) + + time.Sleep(100 * time.Millisecond) + + select { + case me := <-conn.send: + require.Contains(t, me.Data, "alert") + default: + t.Fatal("expected event via group match") + } +} + +func Test_SSE_RouteEvent_WildcardConn(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("wc1", []string{"orders.*"}, 10, time.Second) + + hub.mu.Lock() + hub.connections["wc1"] = conn + hub.wildcardConns["wc1"] = struct{}{} + hub.mu.Unlock() + + hub.Publish(Event{ + Type: "test", + Topics: []string{"orders.created"}, + Data: "wildcard-match", + Priority: PriorityInstant, + }) + + time.Sleep(100 * time.Millisecond) + + select { + case me := <-conn.send: + require.Contains(t, me.Data, "wildcard-match") + default: + t.Fatal("wildcard connection should have received the event") + } +} + +func Test_SSE_RouteEvent_PausedSkipsNonInstant(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("p1", []string{"t"}, 10, time.Second) + conn.paused.Store(true) + + hub.mu.Lock() + hub.connections["p1"] = conn + hub.topicIndex["t"] = map[string]struct{}{"p1": {}} + hub.mu.Unlock() + + // P1 event should be skipped for paused connection + hub.Publish(Event{ + Type: "batch", + Topics: []string{"t"}, + Data: "batched", + Priority: PriorityBatched, + }) + + time.Sleep(100 * time.Millisecond) + + require.Equal(t, 0, conn.coalescer.pending()) + + // P0 (instant) should still deliver + hub.Publish(Event{ + Type: "urgent", + Topics: []string{"t"}, + Data: "instant", + Priority: PriorityInstant, + }) + + time.Sleep(100 * time.Millisecond) + + select { + case me := <-conn.send: + require.Contains(t, me.Data, "instant") + default: + t.Fatal("P0 event should deliver to paused connection") + } +} + +func Test_SSE_RouteEvent_TTLExpired(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + // Publish an expired event + hub.Publish(Event{ + Type: "old", + Topics: []string{"t"}, + Data: "expired", + Priority: PriorityInstant, + TTL: time.Millisecond, + CreatedAt: time.Now().Add(-time.Second), + }) + + time.Sleep(100 * time.Millisecond) + + stats := hub.Stats() + require.Equal(t, int64(1), stats.EventsDropped) +} + +func Test_SSE_DeliverToConn_AllPriorities(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("dc", []string{"t"}, 10, time.Second) + + me := MarshaledEvent{ID: "e1", Data: "test"} + + // Test instant delivery + hub.deliverToConn(conn, &Event{Priority: PriorityInstant}, me) + select { + case <-conn.send: + default: + t.Fatal("instant event should be in send channel") + } + + // Test batched delivery + hub.deliverToConn(conn, &Event{Priority: PriorityBatched}, me) + require.Equal(t, 1, conn.coalescer.pending()) + conn.coalescer.flush() + + // Test coalesced delivery + hub.deliverToConn(conn, &Event{Priority: PriorityCoalesced, Type: "progress", CoalesceKey: "k1"}, me) + require.Equal(t, 1, conn.coalescer.pending()) + conn.coalescer.flush() + + // Test coalesced without explicit key — uses Type + hub.deliverToConn(conn, &Event{Priority: PriorityCoalesced, Type: "counter"}, me) + require.Equal(t, 1, conn.coalescer.pending()) +} + +func Test_SSE_FlushAll(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub(Config{FlushInterval: 50 * time.Millisecond}) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("fl", []string{"t"}, 10, 50*time.Millisecond) + + hub.mu.Lock() + hub.connections["fl"] = conn + hub.topicIndex["t"] = map[string]struct{}{"fl": {}} + hub.mu.Unlock() + + // Add batched events to the coalescer + conn.coalescer.addBatched(MarshaledEvent{ID: "b1", Data: "batch1"}) + conn.coalescer.addBatched(MarshaledEvent{ID: "b2", Data: "batch2"}) + + // Wait for throttler to allow flush, then flush + time.Sleep(100 * time.Millisecond) + hub.flushAll() + + // Events should now be in the send channel + require.Len(t, conn.send, 2) +} + +func Test_SSE_FlushAll_TTLExpiry(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub(Config{FlushInterval: 50 * time.Millisecond}) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("fle", []string{"t"}, 10, 50*time.Millisecond) + + hub.mu.Lock() + hub.connections["fle"] = conn + hub.topicIndex["t"] = map[string]struct{}{"fle": {}} + hub.mu.Unlock() + + // Add an expired event to the coalescer + conn.coalescer.addBatched(MarshaledEvent{ + ID: "exp", + Data: "expired", + TTL: time.Millisecond, + CreatedAt: time.Now().Add(-time.Second), + }) + + time.Sleep(100 * time.Millisecond) + hub.flushAll() + + // Event should be dropped, not delivered + require.Empty(t, conn.send) + require.Equal(t, int64(1), hub.metrics.eventsDropped.Load()) +} + +func Test_SSE_SendHeartbeats(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub(Config{HeartbeatInterval: 50 * time.Millisecond}) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("hb", []string{"t"}, 10, time.Second) + // Set lastWrite to long ago + conn.lastWrite.Store(time.Now().Add(-time.Minute)) + + hub.mu.Lock() + hub.connections["hb"] = conn + hub.mu.Unlock() + + hub.sendHeartbeats() + + // Should have a heartbeat pending + select { + case <-conn.heartbeat: + default: + t.Fatal("expected heartbeat") + } +} + +func Test_SSE_SendHeartbeats_SkipsClosed(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub(Config{HeartbeatInterval: 50 * time.Millisecond}) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("closed-hb", []string{"t"}, 10, time.Second) + conn.lastWrite.Store(time.Now().Add(-time.Minute)) + conn.Close() + + hub.mu.Lock() + hub.connections["closed-hb"] = conn + hub.mu.Unlock() + + // Should not panic + hub.sendHeartbeats() +} + +func Test_SSE_RemoveConnection(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("rm", []string{"orders", "products"}, 10, time.Second) + + hub.addConnection(conn) + + stats := hub.Stats() + require.Equal(t, 1, stats.ActiveConnections) + require.Equal(t, 2, stats.TotalTopics) + + hub.removeConnection(conn) + + stats = hub.Stats() + require.Equal(t, 0, stats.ActiveConnections) + require.Equal(t, 0, stats.TotalTopics) + + // Remove again should be no-op + hub.removeConnection(conn) +} + +func Test_SSE_RemoveConnection_Wildcard(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("rmw", []string{"orders.*"}, 10, time.Second) + + hub.addConnection(conn) + + hub.mu.RLock() + _, hasWildcard := hub.wildcardConns["rmw"] + hub.mu.RUnlock() + require.True(t, hasWildcard) + + hub.removeConnection(conn) + + hub.mu.RLock() + _, hasWildcard = hub.wildcardConns["rmw"] + hub.mu.RUnlock() + require.False(t, hasWildcard) +} + +func Test_SSE_Progress_WithHint(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + hub.Progress("import", "imp_1", "t_1", 50, 100, map[string]any{"filename": "data.csv"}) + hub.Progress("import", "imp_2", "", 0, 0) // zero total + + time.Sleep(50 * time.Millisecond) + stats := hub.Stats() + require.Equal(t, int64(2), stats.EventsPublished) +} + +func Test_SSE_Complete_Failure(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + hub.Complete("import", "imp_1", "t_1", false, map[string]any{"error": "timeout"}) + hub.Complete("import", "imp_2", "", true, nil) // no tenant, no hint + + time.Sleep(50 * time.Millisecond) + stats := hub.Stats() + require.Equal(t, int64(2), stats.EventsPublished) +} + +func Test_SSE_Publish_BufferFull(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + // Fill the event buffer (size 1024) + for range 2000 { + hub.Publish(Event{Type: "flood", Topics: []string{"t"}, Data: "x"}) + } + + time.Sleep(100 * time.Millisecond) + stats := hub.Stats() + // Some events should have been dropped + require.Positive(t, stats.EventsPublished) +} + +func Test_SSE_ReplayEvents(t *testing.T) { + t.Parallel() + + replayer := NewMemoryReplayer() + require.NoError(t, replayer.Store(MarshaledEvent{ID: "r1", Data: "d1", Retry: -1}, []string{"t"})) + require.NoError(t, replayer.Store(MarshaledEvent{ID: "r2", Data: "d2", Retry: -1}, []string{"t"})) + + _, hub := NewWithHub(Config{Replayer: replayer}) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("replay-conn", []string{"t"}, 10, time.Second) + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + err := hub.replayEvents(w, conn, "r1") + require.NoError(t, err) + require.Contains(t, buf.String(), "id: r2") +} + +func Test_SSE_ReplayEvents_NoReplayer(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("no-replay", []string{"t"}, 10, time.Second) + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + err := hub.replayEvents(w, conn, "some-id") + require.NoError(t, err) + require.Empty(t, buf.String()) +} + +func Test_SSE_InitStream(t *testing.T) { + t.Parallel() + + replayer := NewMemoryReplayer() + require.NoError(t, replayer.Store(MarshaledEvent{ID: "i1", Data: "d1", Retry: -1}, []string{"t"})) + require.NoError(t, replayer.Store(MarshaledEvent{ID: "i2", Data: "d2", Retry: -1}, []string{"t"})) + + _, hub := NewWithHub(Config{Replayer: replayer}) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + conn := newConnection("init-conn", []string{"t"}, 10, time.Second) + var buf bytes.Buffer + w := bufio.NewWriter(&buf) + + err := hub.initStream(w, conn, "i1") + require.NoError(t, err) + + output := buf.String() + require.Contains(t, output, "retry: 3000") + require.Contains(t, output, "id: i2") + require.Contains(t, output, `event: connected`) +} + +func Test_SSE_RouteEvent_ReplayerStore(t *testing.T) { + t.Parallel() + + replayer := NewMemoryReplayer() + _, hub := NewWithHub(Config{Replayer: replayer}) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(t, hub.Shutdown(ctx)) + }() + + // Publish a non-group event — should be stored in replayer + hub.Publish(Event{Type: "test", Topics: []string{"t"}, Data: "stored"}) + time.Sleep(100 * time.Millisecond) + + events, err := replayer.Replay("", []string{"t"}) + require.NoError(t, err) + require.Nil(t, events) // empty lastEventID + + // Publish a group event — should NOT be stored in replayer + hub.Publish(Event{ + Type: "test", + Topics: []string{"t"}, + Data: "not-stored", + Group: map[string]string{"tenant_id": "t_1"}, + }) + time.Sleep(100 * time.Millisecond) + + // The replayer should only have 1 event (the non-group one) + replayer.mu.RLock() + count := replayer.count + replayer.mu.RUnlock() + require.Equal(t, 1, count) +} + +func Test_SSE_Shutdown_Timeout(t *testing.T) { + t.Parallel() + + _, hub := NewWithHub() + + // Create a context that's already canceled + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + // Close the hub so it can stop + hub.shutdownOnce.Do(func() { + close(hub.shutdown) + }) + + // With already-canceled context, it might return an error if stopped hasn't been signaled + _ = hub.Shutdown(ctx) //nolint:errcheck // testing shutdown with canceled context +} + +// mockPubSubSubscriber implements PubSubSubscriber for testing. +type mockPubSubSubscriber struct { + onSubscribe func(ctx context.Context, channel string, onMessage func(string)) error +} + +func (m *mockPubSubSubscriber) Subscribe(ctx context.Context, channel string, onMessage func(string)) error { + return m.onSubscribe(ctx, channel, onMessage) +} + +func Benchmark_SSE_Publish(b *testing.B) { + _, hub := NewWithHub() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + require.NoError(b, hub.Shutdown(ctx)) + }() + + event := Event{ + Type: "test", + Topics: []string{"benchmark"}, + Data: "hello", + } + + b.ResetTimer() + for b.Loop() { + hub.Publish(event) + } +} + +func Benchmark_SSE_TopicMatch(b *testing.B) { + b.ResetTimer() + for b.Loop() { + topicMatch("notifications.*", "notifications.orders") + } +} + +func Benchmark_SSE_TopicMatch_Exact(b *testing.B) { + b.ResetTimer() + for b.Loop() { + topicMatch("notifications.orders", "notifications.orders") + } +} + +func Benchmark_SSE_MarshalEvent(b *testing.B) { + event := &Event{ + Type: "test", + Data: map[string]string{"key": "value", "foo": "bar"}, + } + + b.ResetTimer() + for b.Loop() { + marshalEvent(event) + } +} + +func Benchmark_SSE_WriteTo(b *testing.B) { + me := MarshaledEvent{ + ID: "evt_1", + Type: "test", + Data: `{"key":"value"}`, + } + + w := bufio.NewWriter(io.Discard) + + b.ResetTimer() + for b.Loop() { + me.WriteTo(w) //nolint:errcheck // benchmark: error irrelevant for perf measurement + } +} + +func Benchmark_SSE_Coalescer(b *testing.B) { + c := newCoalescer(time.Second) + me := MarshaledEvent{ID: "1", Data: "test"} + + b.ResetTimer() + for b.Loop() { + c.addCoalesced("key", me) + c.flush() + } +} + +func Benchmark_SSE_GenerateID(b *testing.B) { + b.ResetTimer() + for b.Loop() { + generateID() + } +} diff --git a/middleware/sse/stats.go b/middleware/sse/stats.go new file mode 100644 index 00000000000..3b7cabf1904 --- /dev/null +++ b/middleware/sse/stats.go @@ -0,0 +1,74 @@ +package sse + +import ( + "sync" + "sync/atomic" +) + +// HubStats provides a snapshot of the hub's current state. +type HubStats struct { + // ConnectionsByTopic maps each topic to its subscriber count. + ConnectionsByTopic map[string]int `json:"connections_by_topic"` + + // EventsByType maps each SSE event type to its lifetime count. + EventsByType map[string]int64 `json:"events_by_type"` + + // EventsPublished is the lifetime count of events published to the hub. + EventsPublished int64 `json:"events_published"` + + // EventsDropped is the lifetime count of events dropped due to backpressure. + EventsDropped int64 `json:"events_dropped"` + + // ActiveConnections is the total number of open SSE connections. + ActiveConnections int `json:"active_connections"` + + // TotalTopics is the number of unique topics with at least one subscriber. + TotalTopics int `json:"total_topics"` +} + +// hubMetrics tracks lifetime counters for the hub. +type hubMetrics struct { + eventsByType map[string]*atomic.Int64 + eventsByTypeMu sync.RWMutex + eventsPublished atomic.Int64 + eventsDropped atomic.Int64 +} + +// trackEventType increments the counter for a specific event type. +func (m *hubMetrics) trackEventType(eventType string) { + if eventType == "" { + eventType = "message" + } + + m.eventsByTypeMu.RLock() + counter, ok := m.eventsByType[eventType] + m.eventsByTypeMu.RUnlock() + + if ok { + counter.Add(1) + return + } + + m.eventsByTypeMu.Lock() + if counter, ok = m.eventsByType[eventType]; ok { + m.eventsByTypeMu.Unlock() + counter.Add(1) + return + } + counter = &atomic.Int64{} + counter.Add(1) + m.eventsByType[eventType] = counter + m.eventsByTypeMu.Unlock() +} + +// snapshotEventsByType returns a copy of the per-event-type counters. +func (m *hubMetrics) snapshotEventsByType() map[string]int64 { + m.eventsByTypeMu.RLock() + defer m.eventsByTypeMu.RUnlock() + + result := make(map[string]int64, len(m.eventsByType)) + for k, v := range m.eventsByType { + result[k] = v.Load() + } + return result +} diff --git a/middleware/sse/throttle.go b/middleware/sse/throttle.go new file mode 100644 index 00000000000..194a947d782 --- /dev/null +++ b/middleware/sse/throttle.go @@ -0,0 +1,80 @@ +package sse + +import ( + "sync" + "time" +) + +// adaptiveThrottler monitors per-connection buffer saturation and adjusts +// the effective flush interval. Connections with high buffer usage get +// longer flush intervals (fewer sends), reducing backpressure. +type adaptiveThrottler struct { + lastFlush map[string]time.Time + mu sync.Mutex + baseInterval time.Duration + minInterval time.Duration + maxInterval time.Duration +} + +func newAdaptiveThrottler(baseInterval time.Duration) *adaptiveThrottler { + minInt := max(baseInterval/4, 100*time.Millisecond) + maxInt := min(baseInterval*4, 10*time.Second) + return &adaptiveThrottler{ + lastFlush: make(map[string]time.Time), + baseInterval: baseInterval, + minInterval: minInt, + maxInterval: maxInt, + } +} + +// effectiveInterval calculates the flush interval for a connection based +// on its buffer saturation (0.0 = empty, 1.0 = full). +func (at *adaptiveThrottler) effectiveInterval(saturation float64) time.Duration { + switch { + case saturation > 0.8: + return at.maxInterval + case saturation > 0.5: + return at.baseInterval * 2 + case saturation < 0.1: + return at.minInterval + default: + return at.baseInterval + } +} + +// shouldFlush returns true if enough time has passed since the last flush. +func (at *adaptiveThrottler) shouldFlush(connID string, saturation float64) bool { + at.mu.Lock() + defer at.mu.Unlock() + + interval := at.effectiveInterval(saturation) + last, ok := at.lastFlush[connID] + if !ok { + at.lastFlush[connID] = time.Now() + return true + } + + if time.Since(last) >= interval { + at.lastFlush[connID] = time.Now() + return true + } + return false +} + +// remove cleans up tracking for a disconnected connection. +func (at *adaptiveThrottler) remove(connID string) { + at.mu.Lock() + delete(at.lastFlush, connID) + at.mu.Unlock() +} + +// cleanup removes stale entries older than the given cutoff. +func (at *adaptiveThrottler) cleanup(cutoff time.Time) { + at.mu.Lock() + defer at.mu.Unlock() + for k, v := range at.lastFlush { + if v.Before(cutoff) { + delete(at.lastFlush, k) + } + } +} diff --git a/middleware/sse/topic.go b/middleware/sse/topic.go new file mode 100644 index 00000000000..dd1f87f4cd1 --- /dev/null +++ b/middleware/sse/topic.go @@ -0,0 +1,61 @@ +package sse + +import ( + "strings" +) + +// topicMatch checks if a subscription pattern matches a concrete topic. +// Supports NATS-style wildcards: +// +// - * matches exactly one segment (between dots) +// - > matches one or more trailing segments (must be last token) +// - No wildcards = exact match +// +// Examples: +// +// topicMatch("notifications.*", "notifications.orders") → true +// topicMatch("notifications.*", "notifications.orders.new") → false +// topicMatch("analytics.>", "analytics.live") → true +// topicMatch("analytics.>", "analytics.live.visitors") → true +func topicMatch(pattern, topic string) bool { + if !strings.ContainsAny(pattern, "*>") { + return pattern == topic + } + + patParts := strings.Split(pattern, ".") + topParts := strings.Split(topic, ".") + + for i, pp := range patParts { + switch pp { + case ">": + // > must be the last token and matches 1+ remaining segments + return i == len(patParts)-1 && i < len(topParts) + case "*": + if i >= len(topParts) { + return false + } + default: + if i >= len(topParts) || pp != topParts[i] { + return false + } + } + } + + return len(patParts) == len(topParts) +} + +// topicMatchesAny returns true if the concrete topic matches any of the patterns. +func topicMatchesAny(patterns []string, topic string) bool { + for _, p := range patterns { + if topicMatch(p, topic) { + return true + } + } + return false +} + +// connMatchesTopic returns true if a connection's subscription patterns +// match the given concrete topic. +func connMatchesTopic(conn *Connection, topic string) bool { + return topicMatchesAny(conn.Topics, topic) +}