Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions middleware/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,14 @@ func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if req.Header.Get(echo.HeaderXForwardedProto) == "" {
req.Header.Set(echo.HeaderXForwardedProto, c.Scheme())
}
if c.IsWebSocket() && req.Header.Get(echo.HeaderXForwardedFor) == "" { // For HTTP, it is automatically set by Go HTTP reverse proxy.
req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
if c.IsWebSocket() { // For HTTP, this is set by Go HTTP reverse proxy.
// Append, not set, to preserve the incoming chain from upstream proxies.
prior := req.Header[echo.HeaderXForwardedFor]
if len(prior) > 0 {
req.Header.Set(echo.HeaderXForwardedFor, strings.Join(prior, ", ")+", "+c.RealIP())
} else {
req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
}
}

retries := config.RetryCount
Expand Down
123 changes: 123 additions & 0 deletions middleware/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"net/http/httptest"
"net/url"
"regexp"
"strings"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -1164,3 +1165,125 @@ func TestProxyWithConfigWebSocketTLS2NonTLS(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, sendMsg, recvMsg)
}

// TestProxyWebSocketXForwardedFor verifies that for WebSocket Upgrade requests,
// the proxy middleware appends c.RealIP() to any existing X-Forwarded-For chain,
// mirroring net/http/httputil.(*ProxyRequest).SetXForwarded used by the HTTP path.
//
// Regression guard for the previous "set only if empty" behavior, which dropped
// the proxy's own peer IP from the chain whenever upstream proxies had already
// added entries.
func TestProxyWebSocketXForwardedFor(t *testing.T) {
var (
mu sync.Mutex
capturedHeader http.Header
)

// Upstream that captures the request header at WS Upgrade time, then echoes.
upstreamHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wsHandler := func(conn *websocket.Conn) {
mu.Lock()
capturedHeader = conn.Request().Header.Clone()
mu.Unlock()

defer conn.Close()

var msg string
if err := websocket.Message.Receive(conn, &msg); err == nil {
_ = websocket.Message.Send(conn, msg)
}
}
websocket.Server{Handler: wsHandler}.ServeHTTP(w, r)
})
upstream := httptest.NewServer(upstreamHandler)
defer upstream.Close()

tgtURL, _ := url.Parse(upstream.URL)
e := echo.New()
balancer := NewRandomBalancer([]*ProxyTarget{{URL: tgtURL}})
e.Use(ProxyWithConfig(ProxyConfig{Balancer: balancer}))
proxySrv := httptest.NewServer(e)
defer proxySrv.Close()

proxyWSURL, _ := url.Parse(proxySrv.URL)
proxyWSURL.Scheme = "ws"

tests := []struct {
name string
incomingXFF []string // nil = no incoming X-Forwarded-For header at all
wantPrefix string // expected join of entries preceding the appended proxy RealIP
}{
{
name: "no incoming XFF, only proxy RealIP is set",
incomingXFF: nil,
wantPrefix: "",
},
{
name: "single-line single-entry XFF is preserved with proxy RealIP appended",
incomingXFF: []string{"203.0.113.1"},
wantPrefix: "203.0.113.1",
},
{
name: "single-line comma-separated XFF is preserved with proxy RealIP appended",
incomingXFF: []string{"203.0.113.1, 10.0.0.5"},
wantPrefix: "203.0.113.1, 10.0.0.5",
},
{
name: "multi-line XFF (multiple header occurrences) is joined with proxy RealIP appended",
incomingXFF: []string{"203.0.113.1", "10.0.0.5"},
wantPrefix: "203.0.113.1, 10.0.0.5",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mu.Lock()
capturedHeader = nil
mu.Unlock()

origin, _ := url.Parse(proxySrv.URL)
cfg := &websocket.Config{
Location: proxyWSURL,
Origin: origin,
Version: websocket.ProtocolVersionHybi13,
Header: http.Header{},
}

for _, v := range tt.incomingXFF {
cfg.Header.Add(echo.HeaderXForwardedFor, v)
}

wsConn, err := websocket.DialConfig(cfg)
assert.NoError(t, err)

defer wsConn.Close()

// Exchange one message to ensure the upstream handler has captured the header.
assert.NoError(t, websocket.Message.Send(wsConn, "ping"))

var got string
assert.NoError(t, websocket.Message.Receive(wsConn, &got))

mu.Lock()
xff := capturedHeader.Get(echo.HeaderXForwardedFor)
mu.Unlock()

// The middleware uses Header.Set, so the upstream sees exactly one
// X-Forwarded-For header line. Split it back into entries.
entries := strings.Split(xff, ", ")
assert.NotEmpty(t, entries, "X-Forwarded-For must be set by the proxy middleware")

// The tail entry is the proxy's c.RealIP(). When the test client dials
// via httptest.NewServer the proxy sees 127.0.0.1.
tail := entries[len(entries)-1]
assert.Equal(t, "127.0.0.1", tail,
"proxy RealIP must be appended at the tail of X-Forwarded-For")

// The remaining entries must equal the prior chain, preserving order
// and joining multi-line headers with ", ".
gotPrefix := strings.Join(entries[:len(entries)-1], ", ")
assert.Equal(t, tt.wantPrefix, gotPrefix,
"prior X-Forwarded-For entries must be preserved before the appended RealIP")
})
}
}
Loading