diff --git a/middleware/session/data.go b/middleware/session/data.go index 2a26b785bbc..d28ad0eb878 100644 --- a/middleware/session/data.go +++ b/middleware/session/data.go @@ -31,12 +31,12 @@ var dataPool = sync.Pool{ // // d := acquireData() func acquireData() *data { - obj := dataPool.Get() - if d, ok := obj.(*data); ok { - return d + d, ok := dataPool.Get().(*data) + if !ok { + d = new(data) + d.Data = make(map[any]any) } - // Handle unexpected type in the pool - panic("unexpected type in data pool") + return d } // Reset clears the data map and resets the data object. diff --git a/middleware/session/middleware.go b/middleware/session/middleware.go index c4a17f8966f..57bba6d8228 100644 --- a/middleware/session/middleware.go +++ b/middleware/session/middleware.go @@ -106,7 +106,7 @@ func NewWithStore(config ...Config) (fiber.Handler, *Store) { m.finalizeSession() - c.Locals(middlewareContextKey, nil) + fiber.StoreInContext(c, middlewareContextKey, nil) releaseMiddleware(m) return stackErr } @@ -160,7 +160,7 @@ func (m *Middleware) finalizeSession() { func acquireMiddleware() *Middleware { m, ok := middlewarePool.Get().(*Middleware) if !ok { - panic(ErrTypeAssertionFailed.Error()) + return &Middleware{} } return m } diff --git a/middleware/session/middleware_test.go b/middleware/session/middleware_test.go index 91c4d864d4e..25b085a8880 100644 --- a/middleware/session/middleware_test.go +++ b/middleware/session/middleware_test.go @@ -620,6 +620,33 @@ func Test_Session_Middleware_ClearsContextLocalsOnRelease(t *testing.T) { require.Equal(t, fiber.StatusOK, resp.StatusCode) } +func Test_Session_Middleware_ClearsContextOnRelease_PassLocalsToContext(t *testing.T) { + t.Parallel() + + app := fiber.New(fiber.Config{PassLocalsToContext: true}) + + app.Use(func(c fiber.Ctx) error { + err := c.Next() + // Verify cleared via all context types + require.Nil(t, FromContext(c)) + require.Nil(t, FromContext(c.Context())) + return err + }) + app.Use(New()) + + app.Get("/", func(c fiber.Ctx) error { + // Session should be available from all context types + require.NotNil(t, FromContext(c)) + require.NotNil(t, FromContext(c.Context())) + return c.SendStatus(fiber.StatusOK) + }) + + req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) + resp, err := app.Test(req) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) +} + type failingStorage struct { getErr error getCalls int diff --git a/middleware/session/session.go b/middleware/session/session.go index 37566e5b154..58e08d69b70 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -55,7 +55,10 @@ var sessionPool = sync.Pool{ // // s := acquireSession() func acquireSession() *Session { - s := sessionPool.Get().(*Session) //nolint:forcetypeassert,errcheck // We store nothing else in the pool + s, ok := sessionPool.Get().(*Session) + if !ok { + s = &Session{} + } if s.data == nil { s.data = acquireData() } diff --git a/middleware/session/store.go b/middleware/session/store.go index d7d670c52e6..8e67963a017 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -19,6 +19,10 @@ var ( ErrSessionIDNotFoundInStore = errors.New("session ID not found in session store") ) +// maxSessionIDLen is the upper bound on extracted session IDs. +// IDs exceeding this length are treated as invalid and discarded. +const maxSessionIDLen = 4096 + // sessionIDKey is the local key type used to store and retrieve the session ID in context. type sessionIDKey int @@ -204,9 +208,28 @@ func (s *Store) getSessionID(c fiber.Ctx) string { // If extraction fails, return empty string to generate a new session return "" } + if !isValidSessionID(sessionID) { + return "" + } return sessionID } +// isValidSessionID reports whether id is safe for use as a storage key. +// It rejects empty values, values longer than maxSessionIDLen, and any +// byte outside the visible-ASCII range (0x21–0x7E). +func isValidSessionID(id string) bool { + if id == "" || len(id) > maxSessionIDLen { + return false + } + for i := 0; i < len(id); i++ { + c := id[i] + if c <= 0x20 || c > 0x7e { + return false + } + } + return true +} + // Reset deletes all sessions from the storage. // // Returns: @@ -309,7 +332,7 @@ func (s *Store) GetByID(ctx context.Context, id string) (*Session, error) { if s.AbsoluteTimeout > 0 { if sess.isAbsExpired() { - if err := sess.Destroy(); err != nil { //nolint:contextcheck // it is not right + if err := sess.Destroy(); err != nil { //nolint:contextcheck // sess.gctx is set to ctx above; Destroy honors it. log.Errorf("failed to destroy session: %v", err) } sess.Release() diff --git a/middleware/session/store_test.go b/middleware/session/store_test.go index 3780bc01df5..421eea6f3ac 100644 --- a/middleware/session/store_test.go +++ b/middleware/session/store_test.go @@ -375,3 +375,66 @@ func Test_Store_GetByID_DestroyUsesContext(t *testing.T) { require.ErrorIs(t, err, ErrSessionIDNotFoundInStore) require.ErrorIs(t, storage.lastCtxErr, context.Canceled) } + +func Test_isValidSessionID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + id string + valid bool + }{ + {name: "empty", id: "", valid: false}, + {name: "normal alphanumeric", id: "abc123", valid: true}, + {name: "base64url token", id: "dGVzdC10b2tlbi12YWx1ZQ==", valid: true}, + {name: "uuid", id: "550e8400-e29b-41d4-a716-446655440000", valid: true}, + {name: "contains space", id: "abc 123", valid: false}, + {name: "contains tab", id: "abc\t123", valid: false}, + {name: "contains newline", id: "abc\n123", valid: false}, + {name: "contains null byte", id: "abc\x00123", valid: false}, + {name: "non-ascii", id: "abc\x80xyz", valid: false}, + {name: "del character", id: "abc\x7fxyz", valid: false}, + {name: "too long", id: string(make([]byte, maxSessionIDLen+1)), valid: false}, + {name: "max length", id: string(makeVisibleASCII(maxSessionIDLen)), valid: true}, + {name: "visible ascii symbols", id: "!@#$%^&*()_+-=[]{}|;':\",./<>?", valid: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.valid, isValidSessionID(tt.id)) + }) + } +} + +// makeVisibleASCII returns a byte slice of length n filled with visible ASCII characters. +func makeVisibleASCII(n int) []byte { + b := make([]byte, n) + for i := range b { + b[i] = 'a' + } + return b +} + +func Test_Store_getSessionID_RejectsInvalidIDs(t *testing.T) { + t.Parallel() + + app := fiber.New() + store := NewStore() + + t.Run("control characters rejected", func(t *testing.T) { + t.Parallel() + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + ctx.Request().Header.SetCookie("session_id", "abc\x00def") + require.Empty(t, store.getSessionID(ctx)) + }) + + t.Run("valid id accepted", func(t *testing.T) { + t.Parallel() + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + ctx.Request().Header.SetCookie("session_id", "valid-session-id") + require.Equal(t, "valid-session-id", store.getSessionID(ctx)) + }) +}