Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions middleware/session/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ var dataPool = sync.Pool{
//
// d := acquireData()
func acquireData() *data {
obj := dataPool.Get()
if d, ok := obj.(*data); ok {
return d
d, ok := dataPool.Get().(*data)
if !ok {
d = new(data)
d.Data = make(map[any]any)
}
// Handle unexpected type in the pool
panic("unexpected type in data pool")
return d
}

// Reset clears the data map and resets the data object.
Expand Down
4 changes: 2 additions & 2 deletions middleware/session/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func NewWithStore(config ...Config) (fiber.Handler, *Store) {

m.finalizeSession()

c.Locals(middlewareContextKey, nil)
fiber.StoreInContext(c, middlewareContextKey, nil)
releaseMiddleware(m)
return stackErr
}
Expand Down Expand Up @@ -160,7 +160,7 @@ func (m *Middleware) finalizeSession() {
func acquireMiddleware() *Middleware {
m, ok := middlewarePool.Get().(*Middleware)
if !ok {
panic(ErrTypeAssertionFailed.Error())
return &Middleware{}
}
return m
}
Expand Down
27 changes: 27 additions & 0 deletions middleware/session/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,33 @@ func Test_Session_Middleware_ClearsContextLocalsOnRelease(t *testing.T) {
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}

func Test_Session_Middleware_ClearsContextOnRelease_PassLocalsToContext(t *testing.T) {
t.Parallel()

app := fiber.New(fiber.Config{PassLocalsToContext: true})

app.Use(func(c fiber.Ctx) error {
err := c.Next()
// Verify cleared via all context types
require.Nil(t, FromContext(c))
require.Nil(t, FromContext(c.Context()))
return err
})
app.Use(New())

app.Get("/", func(c fiber.Ctx) error {
// Session should be available from all context types
require.NotNil(t, FromContext(c))
require.NotNil(t, FromContext(c.Context()))
return c.SendStatus(fiber.StatusOK)
})

req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)
}

type failingStorage struct {
getErr error
getCalls int
Expand Down
5 changes: 4 additions & 1 deletion middleware/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ var sessionPool = sync.Pool{
//
// s := acquireSession()
func acquireSession() *Session {
s := sessionPool.Get().(*Session) //nolint:forcetypeassert,errcheck // We store nothing else in the pool
s, ok := sessionPool.Get().(*Session)
if !ok {
s = &Session{}
}
if s.data == nil {
s.data = acquireData()
}
Expand Down
25 changes: 24 additions & 1 deletion middleware/session/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ var (
ErrSessionIDNotFoundInStore = errors.New("session ID not found in session store")
)

// maxSessionIDLen is the upper bound on extracted session IDs.
// IDs exceeding this length are treated as invalid and discarded.
const maxSessionIDLen = 4096

// sessionIDKey is the local key type used to store and retrieve the session ID in context.
type sessionIDKey int

Expand Down Expand Up @@ -204,9 +208,28 @@ func (s *Store) getSessionID(c fiber.Ctx) string {
// If extraction fails, return empty string to generate a new session
return ""
}
if !isValidSessionID(sessionID) {
return ""
}
return sessionID
}

// isValidSessionID reports whether id is safe for use as a storage key.
// It rejects empty values, values longer than maxSessionIDLen, and any
// byte outside the visible-ASCII range (0x21–0x7E).
func isValidSessionID(id string) bool {
if id == "" || len(id) > maxSessionIDLen {
return false
}
for i := 0; i < len(id); i++ {
c := id[i]
if c <= 0x20 || c > 0x7e {
return false
}
}
return true
}

// Reset deletes all sessions from the storage.
//
// Returns:
Expand Down Expand Up @@ -309,7 +332,7 @@ func (s *Store) GetByID(ctx context.Context, id string) (*Session, error) {

if s.AbsoluteTimeout > 0 {
if sess.isAbsExpired() {
if err := sess.Destroy(); err != nil { //nolint:contextcheck // it is not right
if err := sess.Destroy(); err != nil { //nolint:contextcheck // sess.gctx is set to ctx above; Destroy honors it.
log.Errorf("failed to destroy session: %v", err)
}
sess.Release()
Expand Down
63 changes: 63 additions & 0 deletions middleware/session/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,66 @@ func Test_Store_GetByID_DestroyUsesContext(t *testing.T) {
require.ErrorIs(t, err, ErrSessionIDNotFoundInStore)
require.ErrorIs(t, storage.lastCtxErr, context.Canceled)
}

func Test_isValidSessionID(t *testing.T) {
t.Parallel()

tests := []struct {
name string
id string
valid bool
}{
{name: "empty", id: "", valid: false},
{name: "normal alphanumeric", id: "abc123", valid: true},
{name: "base64url token", id: "dGVzdC10b2tlbi12YWx1ZQ==", valid: true},
{name: "uuid", id: "550e8400-e29b-41d4-a716-446655440000", valid: true},
{name: "contains space", id: "abc 123", valid: false},
{name: "contains tab", id: "abc\t123", valid: false},
{name: "contains newline", id: "abc\n123", valid: false},
{name: "contains null byte", id: "abc\x00123", valid: false},
{name: "non-ascii", id: "abc\x80xyz", valid: false},
{name: "del character", id: "abc\x7fxyz", valid: false},
{name: "too long", id: string(make([]byte, maxSessionIDLen+1)), valid: false},
{name: "max length", id: string(makeVisibleASCII(maxSessionIDLen)), valid: true},
{name: "visible ascii symbols", id: "!@#$%^&*()_+-=[]{}|;':\",./<>?", valid: true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.valid, isValidSessionID(tt.id))
})
}
}

// makeVisibleASCII returns a byte slice of length n filled with visible ASCII characters.
func makeVisibleASCII(n int) []byte {
b := make([]byte, n)
for i := range b {
b[i] = 'a'
}
return b
}

func Test_Store_getSessionID_RejectsInvalidIDs(t *testing.T) {
t.Parallel()

app := fiber.New()
store := NewStore()

t.Run("control characters rejected", func(t *testing.T) {
t.Parallel()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
ctx.Request().Header.SetCookie("session_id", "abc\x00def")
require.Empty(t, store.getSessionID(ctx))
})

t.Run("valid id accepted", func(t *testing.T) {
t.Parallel()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(ctx)
ctx.Request().Header.SetCookie("session_id", "valid-session-id")
require.Equal(t, "valid-session-id", store.getSessionID(ctx))
})
}
Loading