diff --git a/middleware/session/data.go b/middleware/session/data.go index 2a26b785bbc..d28ad0eb878 100644 --- a/middleware/session/data.go +++ b/middleware/session/data.go @@ -31,12 +31,12 @@ var dataPool = sync.Pool{ // // d := acquireData() func acquireData() *data { - obj := dataPool.Get() - if d, ok := obj.(*data); ok { - return d + d, ok := dataPool.Get().(*data) + if !ok { + d = new(data) + d.Data = make(map[any]any) } - // Handle unexpected type in the pool - panic("unexpected type in data pool") + return d } // Reset clears the data map and resets the data object. diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index 78c2ef05384..57bba6d8228 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -88,18 +88,25 @@ func NewWithStore(config ...Config) (fiber.Handler, *Store) { // Acquire session middleware m := acquireMiddleware() - m.initialize(c, &cfg) + if err := m.initialize(c, &cfg); err != nil { + if cfg.ErrorHandler != nil { + cfg.ErrorHandler(c, err) + } else { + DefaultErrorHandler(c, err) + } + + releaseMiddleware(m) + if c.Response().StatusCode() == fiber.StatusOK && len(c.Response().Body()) == 0 { + return err + } + return nil + } stackErr := c.Next() - m.mu.RLock() - destroyed := m.destroyed - m.mu.RUnlock() - - if !destroyed { - m.saveSession() - } + m.finalizeSession() + fiber.StoreInContext(c, middlewareContextKey, nil) releaseMiddleware(m) return stackErr } @@ -108,13 +115,13 @@ func NewWithStore(config ...Config) (fiber.Handler, *Store) { } // initialize sets up middleware for the request. -func (m *Middleware) initialize(c fiber.Ctx, cfg *Config) { +func (m *Middleware) initialize(c fiber.Ctx, cfg *Config) error { m.mu.Lock() defer m.mu.Unlock() session, err := cfg.Store.getSession(c) if err != nil { - panic(err) // handle or log this error appropriately in production + return err } m.config = *cfg @@ -122,6 +129,7 @@ func (m *Middleware) initialize(c fiber.Ctx, cfg *Config) { m.ctx = c fiber.StoreInContext(c, middlewareContextKey, m) + return nil } // saveSession handles session saving and error management after the response. @@ -133,6 +141,17 @@ func (m *Middleware) saveSession() { DefaultErrorHandler(m.ctx, err) } } +} + +// finalizeSession handles session persistence and always releases the session object. +func (m *Middleware) finalizeSession() { + m.mu.RLock() + destroyed := m.destroyed + m.mu.RUnlock() + + if !destroyed { + m.saveSession() + } releaseSession(m.Session) } @@ -141,7 +160,7 @@ func (m *Middleware) saveSession() { func acquireMiddleware() *Middleware { m, ok := middlewarePool.Get().(*Middleware) if !ok { - panic(ErrTypeAssertionFailed.Error()) + return &Middleware{} } return m } diff --git a/middleware/session/middleware_test.go b/middleware/session/middleware_test.go index a48b6b1b49c..25b085a8880 100644 --- a/middleware/session/middleware_test.go +++ b/middleware/session/middleware_test.go @@ -1,12 +1,16 @@ package session import ( + "context" + "errors" "fmt" + "io" "net/http" "net/http/httptest" "sort" "strings" "sync" + "sync/atomic" "testing" "time" @@ -590,3 +594,271 @@ func Test_Session_Middleware_Store(t *testing.T) { h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) } + +func Test_Session_Middleware_ClearsContextLocalsOnRelease(t *testing.T) { + t.Parallel() + + app := fiber.New() + + app.Use(func(c fiber.Ctx) error { + err := c.Next() + require.Nil(t, FromContext(c)) + return err + }) + app.Use(New()) + + app.Get("/", func(c fiber.Ctx) error { + sess := FromContext(c) + require.NotNil(t, sess) + sess.Set("key", "value") + return c.SendStatus(fiber.StatusOK) + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) +} + +func Test_Session_Middleware_ClearsContextOnRelease_PassLocalsToContext(t *testing.T) { + t.Parallel() + + app := fiber.New(fiber.Config{PassLocalsToContext: true}) + + app.Use(func(c fiber.Ctx) error { + err := c.Next() + // Verify cleared via all context types + require.Nil(t, FromContext(c)) + require.Nil(t, FromContext(c.Context())) + return err + }) + app.Use(New()) + + app.Get("/", func(c fiber.Ctx) error { + // Session should be available from all context types + require.NotNil(t, FromContext(c)) + require.NotNil(t, FromContext(c.Context())) + return c.SendStatus(fiber.StatusOK) + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) +} + +type failingStorage struct { + getErr error + getCalls int + setCalls int +} + +func (s *failingStorage) GetWithContext(_ context.Context, _ string) ([]byte, error) { + s.getCalls++ + return nil, s.getErr +} + +func (s *failingStorage) Get(_ string) ([]byte, error) { + return nil, s.getErr +} + +func (s *failingStorage) SetWithContext(_ context.Context, _ string, _ []byte, _ time.Duration) error { + s.setCalls++ + return nil +} + +func (*failingStorage) Set(_ string, _ []byte, _ time.Duration) error { + return nil +} + +func (*failingStorage) DeleteWithContext(context.Context, string) error { return nil } +func (*failingStorage) Delete(string) error { return nil } +func (*failingStorage) ResetWithContext(context.Context) error { return nil } +func (*failingStorage) Reset() error { return nil } +func (*failingStorage) Close() error { return nil } + +func Test_Session_Middleware_InitializeError_WithCustomErrorHandler(t *testing.T) { + t.Parallel() + errStorage := &failingStorage{getErr: errors.New("storage down")} + app := fiber.New() + + var ( + nextCalled bool + localsClearedOnReturn atomic.Bool + ) + + app.Use(func(c fiber.Ctx) error { + err := c.Next() + localsClearedOnReturn.Store(FromContext(c) == nil) + return err + }) + + app.Use(New(Config{ + Storage: errStorage, + ErrorHandler: func(c fiber.Ctx, err error) { + require.ErrorContains(t, err, "storage down") + require.NoError(t, c.Status(fiber.StatusServiceUnavailable).SendString("session unavailable")) + }, + })) + app.Get("/", func(c fiber.Ctx) error { + nextCalled = true + return c.SendStatus(fiber.StatusOK) + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.AddCookie(&http.Cookie{Name: "session_id", Value: "existing-id"}) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusServiceUnavailable, resp.StatusCode) + + body, readErr := io.ReadAll(resp.Body) + require.NoError(t, readErr) + require.Equal(t, "session unavailable", string(body)) + require.False(t, nextCalled) + require.Equal(t, 1, errStorage.getCalls) + require.Equal(t, 0, errStorage.setCalls) + require.True(t, localsClearedOnReturn.Load()) +} + +func Test_Session_Middleware_InitializeError_DefaultErrorHandler(t *testing.T) { + t.Parallel() + errStorage := &failingStorage{getErr: errors.New("storage down")} + app := fiber.New() + + app.Use(New(Config{Storage: errStorage})) + app.Get("/", func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.AddCookie(&http.Cookie{Name: "session_id", Value: "existing-id"}) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) + + body, readErr := io.ReadAll(resp.Body) + require.NoError(t, readErr) + require.Equal(t, "Internal Server Error", string(body)) + require.Equal(t, 1, errStorage.getCalls) + require.Equal(t, 0, errStorage.setCalls) +} + +func Test_Session_Middleware_InitializeError_ReturnsHandlerErrorWhenUnwritten(t *testing.T) { + t.Parallel() + errStorage := &failingStorage{getErr: errors.New("storage down")} + app := fiber.New() + + app.Use(New(Config{ + Storage: errStorage, + ErrorHandler: func(fiber.Ctx, error) { + // Intentionally do not write a response to assert return semantics. + }, + })) + app.Get("/", func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + req.AddCookie(&http.Cookie{Name: "session_id", Value: "existing-id"}) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) + require.Equal(t, 1, errStorage.getCalls) + require.Equal(t, 0, errStorage.setCalls) +} + +type lifecycleStorage struct { + setErr error + setCalls int32 + deleteCalls int32 +} + +func (*lifecycleStorage) GetWithContext(context.Context, string) ([]byte, error) { return nil, nil } +func (*lifecycleStorage) Get(string) ([]byte, error) { return nil, nil } + +func (s *lifecycleStorage) SetWithContext(context.Context, string, []byte, time.Duration) error { + s.setCalls++ + return s.setErr +} + +func (s *lifecycleStorage) Set(string, []byte, time.Duration) error { return s.setErr } + +func (s *lifecycleStorage) DeleteWithContext(context.Context, string) error { + s.deleteCalls++ + return nil +} + +func (*lifecycleStorage) Delete(string) error { return nil } +func (*lifecycleStorage) ResetWithContext(context.Context) error { return nil } +func (*lifecycleStorage) Reset() error { return nil } +func (*lifecycleStorage) Close() error { return nil } + +func newMiddlewareSessionForFinalize(t *testing.T, storage *lifecycleStorage) (*Middleware, *Session) { + t.Helper() + + store := NewStore(Config{Storage: storage}) + sess := acquireSession() + sess.mu.Lock() + sess.id = "session-id" + sess.config = store + sess.mu.Unlock() + sess.Set("k", "v") + + m := &Middleware{ + Session: sess, + config: Config{Store: store, ErrorHandler: func(fiber.Ctx, error) {}}, + } + + return m, sess +} + +func Test_Middleware_FinalizeSession_NormalSavePath(t *testing.T) { + t.Parallel() + + storage := &lifecycleStorage{} + m, sess := newMiddlewareSessionForFinalize(t, storage) + + m.finalizeSession() + + require.EqualValues(t, 1, storage.setCalls) + require.EqualValues(t, 0, storage.deleteCalls) + require.Nil(t, sess.ctx) + require.Nil(t, sess.config) + require.Empty(t, sess.id) +} + +func Test_Middleware_FinalizeSession_DestroyedPathSkipsSave(t *testing.T) { + t.Parallel() + + storage := &lifecycleStorage{} + m, sess := newMiddlewareSessionForFinalize(t, storage) + + m.destroyed = true + m.finalizeSession() + + require.EqualValues(t, 0, storage.setCalls) + require.EqualValues(t, 0, storage.deleteCalls) + require.Nil(t, sess.ctx) + require.Nil(t, sess.config) + require.Empty(t, sess.id) +} + +func Test_Middleware_FinalizeSession_SaveErrorStillReleasesSession(t *testing.T) { + t.Parallel() + + storage := &lifecycleStorage{setErr: errors.New("set failed")} + m, sess := newMiddlewareSessionForFinalize(t, storage) + + errCalls := 0 + m.config.ErrorHandler = func(fiber.Ctx, error) { + errCalls++ + } + + m.finalizeSession() + + require.EqualValues(t, 1, storage.setCalls) + require.Equal(t, 1, errCalls) + require.Nil(t, sess.ctx) + require.Nil(t, sess.config) + require.Empty(t, sess.id) +} diff --git a/middleware/session/session.go b/middleware/session/session.go index 03b1f86b4aa..58e08d69b70 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -16,13 +16,14 @@ import ( // Session represents a user session. type Session struct { - ctx fiber.Ctx // fiber context - config *Store // store configuration - data *data // key value data - id string // session id - idleTimeout time.Duration // idleTimeout of this session - mu sync.RWMutex // Mutex to protect non-data fields - fresh bool // if new session + ctx fiber.Ctx // fiber context + gctx context.Context //nolint:containedctx // Stored to honor GetByID context during Destroy. + config *Store // store configuration + data *data // key value data + id string // session id + idleTimeout time.Duration // idleTimeout of this session + mu sync.RWMutex // Mutex to protect non-data fields + fresh bool // if new session } type absExpirationKeyType int @@ -54,7 +55,10 @@ var sessionPool = sync.Pool{ // // s := acquireSession() func acquireSession() *Session { - s := sessionPool.Get().(*Session) //nolint:forcetypeassert,errcheck // We store nothing else in the pool + s, ok := sessionPool.Get().(*Session) + if !ok { + s = &Session{} + } if s.data == nil { s.data = acquireData() } @@ -91,6 +95,7 @@ func releaseSession(s *Session) { s.id = "" s.idleTimeout = 0 s.ctx = nil + s.gctx = nil s.config = nil if s.data != nil { s.data.Reset() @@ -197,7 +202,9 @@ func (s *Session) Destroy() error { // Use external Storage if exist var ctx context.Context = s.ctx - if ctx == nil { + if s.gctx != nil { + ctx = s.gctx + } else if ctx == nil { ctx = context.Background() } if err := s.config.Storage.DeleteWithContext(ctx, s.id); err != nil { diff --git a/middleware/session/store.go b/middleware/session/store.go index c32c58f79ff..8e67963a017 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -19,6 +19,10 @@ var ( ErrSessionIDNotFoundInStore = errors.New("session ID not found in session store") ) +// maxSessionIDLen is the upper bound on extracted session IDs. +// IDs exceeding this length are treated as invalid and discarded. +const maxSessionIDLen = 4096 + // sessionIDKey is the local key type used to store and retrieve the session ID in context. type sessionIDKey int @@ -177,6 +181,7 @@ func (s *Store) getSession(c fiber.Ctx) (*Session, error) { sess.setAbsExpiration(time.Now().Add(s.AbsoluteTimeout)) } else if sess.isAbsExpired() { if err := sess.Reset(); err != nil { + sess.Release() return nil, fmt.Errorf("failed to reset session: %w", err) } sess.setAbsExpiration(time.Now().Add(s.AbsoluteTimeout)) @@ -203,9 +208,28 @@ func (s *Store) getSessionID(c fiber.Ctx) string { // If extraction fails, return empty string to generate a new session return "" } + if !isValidSessionID(sessionID) { + return "" + } return sessionID } +// isValidSessionID reports whether id is safe for use as a storage key. +// It rejects empty values, values longer than maxSessionIDLen, and any +// byte outside the visible-ASCII range (0x21–0x7E). +func isValidSessionID(id string) bool { + if id == "" || len(id) > maxSessionIDLen { + return false + } + for i := 0; i < len(id); i++ { + c := id[i] + if c <= 0x20 || c > 0x7e { + return false + } + } + return true +} + // Reset deletes all sessions from the storage. // // Returns: @@ -292,6 +316,7 @@ func (s *Store) GetByID(ctx context.Context, id string) (*Session, error) { sess.mu.Lock() + sess.gctx = ctx sess.config = s sess.id = id sess.fresh = false @@ -307,10 +332,10 @@ func (s *Store) GetByID(ctx context.Context, id string) (*Session, error) { if s.AbsoluteTimeout > 0 { if sess.isAbsExpired() { - if err := sess.Destroy(); err != nil { //nolint:contextcheck // it is not right - sess.Release() + if err := sess.Destroy(); err != nil { //nolint:contextcheck // sess.gctx is set to ctx above; Destroy honors it. log.Errorf("failed to destroy session: %v", err) } + sess.Release() return nil, ErrSessionIDNotFoundInStore } } diff --git a/middleware/session/store_test.go b/middleware/session/store_test.go index 24428d56510..421eea6f3ac 100644 --- a/middleware/session/store_test.go +++ b/middleware/session/store_test.go @@ -2,8 +2,10 @@ package session import ( "context" + "errors" "fmt" "testing" + "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" @@ -225,3 +227,214 @@ func Test_Store_GetByID(t *testing.T) { }) }) } + +type trackingStorage struct { + data map[string][]byte + lastCtxErr error + deleteErr error + deleteCalls int +} + +func newTrackingStorage() *trackingStorage { + return &trackingStorage{data: make(map[string][]byte)} +} + +func (s *trackingStorage) GetWithContext(_ context.Context, key string) ([]byte, error) { + if v, ok := s.data[key]; ok { + copied := make([]byte, len(v)) + copy(copied, v) + return copied, nil + } + return nil, nil +} + +func (s *trackingStorage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} + +func (s *trackingStorage) SetWithContext(_ context.Context, key string, val []byte, _ time.Duration) error { + copied := make([]byte, len(val)) + copy(copied, val) + s.data[key] = copied + return nil +} + +func (s *trackingStorage) Set(key string, val []byte, exp time.Duration) error { + return s.SetWithContext(context.Background(), key, val, exp) +} + +func (s *trackingStorage) DeleteWithContext(ctx context.Context, key string) error { + s.deleteCalls++ + if ctx != nil { + s.lastCtxErr = ctx.Err() + } + if s.deleteErr != nil { + return s.deleteErr + } + delete(s.data, key) + return nil +} + +func (s *trackingStorage) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} + +func (*trackingStorage) ResetWithContext(context.Context) error { return nil } +func (*trackingStorage) Reset() error { return nil } +func (*trackingStorage) Close() error { return nil } + +func seedExpiredSessionInStore(t *testing.T, store *Store, sessionID string) { + t.Helper() + + sess := acquireSession() + sess.mu.Lock() + sess.config = store + sess.id = sessionID + sess.fresh = false + sess.mu.Unlock() + sess.Set("name", "john") + sess.Set(absExpirationKey, time.Now().Add(-time.Minute)) + require.NoError(t, sess.Save()) + sess.Release() +} + +func Test_Store_getSession_ExpiredResetFailureReleasesSession(t *testing.T) { + t.Parallel() + + storage := newTrackingStorage() + store := NewStore(Config{ + Storage: storage, + IdleTimeout: time.Minute, + AbsoluteTimeout: time.Minute, + }) + + const sessionID = "existing-session-id" + seedExpiredSessionInStore(t, store, sessionID) + storage.deleteErr = errors.New("delete failed") + + app := fiber.New() + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + ctx.Request().Header.SetCookie("session_id", sessionID) + + sess, err := store.Get(ctx) + require.Nil(t, sess) + require.ErrorContains(t, err, "failed to reset session") + require.Equal(t, 1, storage.deleteCalls) + + reused := acquireSession() + require.Nil(t, reused.ctx) + require.Nil(t, reused.config) + require.Empty(t, reused.id) + reused.Release() +} + +func Test_Store_GetByID_ExpiredDestroySuccessReleasesSession(t *testing.T) { + t.Parallel() + + storage := newTrackingStorage() + store := NewStore(Config{ + Storage: storage, + IdleTimeout: time.Minute, + AbsoluteTimeout: time.Minute, + }) + + const sessionID = "expired-session-id" + seedExpiredSessionInStore(t, store, sessionID) + + sess, err := store.GetByID(context.Background(), sessionID) + require.Nil(t, sess) + require.ErrorIs(t, err, ErrSessionIDNotFoundInStore) + require.Equal(t, 1, storage.deleteCalls) + + reused := acquireSession() + require.Nil(t, reused.ctx) + require.Nil(t, reused.config) + require.Empty(t, reused.id) + reused.Release() +} + +func Test_Store_GetByID_DestroyUsesContext(t *testing.T) { + t.Parallel() + + storage := newTrackingStorage() + store := NewStore(Config{ + Storage: storage, + IdleTimeout: time.Minute, + AbsoluteTimeout: time.Minute, + }) + + const sessionID = "expired-session-id" + seedExpiredSessionInStore(t, store, sessionID) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + sess, err := store.GetByID(ctx, sessionID) + require.Nil(t, sess) + require.ErrorIs(t, err, ErrSessionIDNotFoundInStore) + require.ErrorIs(t, storage.lastCtxErr, context.Canceled) +} + +func Test_isValidSessionID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + id string + valid bool + }{ + {name: "empty", id: "", valid: false}, + {name: "normal alphanumeric", id: "abc123", valid: true}, + {name: "base64url token", id: "dGVzdC10b2tlbi12YWx1ZQ==", valid: true}, + {name: "uuid", id: "550e8400-e29b-41d4-a716-446655440000", valid: true}, + {name: "contains space", id: "abc 123", valid: false}, + {name: "contains tab", id: "abc\t123", valid: false}, + {name: "contains newline", id: "abc\n123", valid: false}, + {name: "contains null byte", id: "abc\x00123", valid: false}, + {name: "non-ascii", id: "abc\x80xyz", valid: false}, + {name: "del character", id: "abc\x7fxyz", valid: false}, + {name: "too long", id: string(make([]byte, maxSessionIDLen+1)), valid: false}, + {name: "max length", id: string(makeVisibleASCII(maxSessionIDLen)), valid: true}, + {name: "visible ascii symbols", id: "!@#$%^&*()_+-=[]{}|;':\",./<>?", valid: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.valid, isValidSessionID(tt.id)) + }) + } +} + +// makeVisibleASCII returns a byte slice of length n filled with visible ASCII characters. +func makeVisibleASCII(n int) []byte { + b := make([]byte, n) + for i := range b { + b[i] = 'a' + } + return b +} + +func Test_Store_getSessionID_RejectsInvalidIDs(t *testing.T) { + t.Parallel() + + app := fiber.New() + store := NewStore() + + t.Run("control characters rejected", func(t *testing.T) { + t.Parallel() + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + ctx.Request().Header.SetCookie("session_id", "abc\x00def") + require.Empty(t, store.getSessionID(ctx)) + }) + + t.Run("valid id accepted", func(t *testing.T) { + t.Parallel() + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + ctx.Request().Header.SetCookie("session_id", "valid-session-id") + require.Equal(t, "valid-session-id", store.getSessionID(ctx)) + }) +}