Skip to content

Commit 65234e2

Browse files
committed
Migrate nuget handler to OIDCRegistry
Replace manual OIDC credential map and mutex with the shared OIDCRegistry type. Nuget uses Register() for the primary feed URL and RegisterURL() for HTTP-discovered resource URLs. The old code stored OIDC credentials by url-with-host-fallback. OIDCRegistry preserves the full URL with path-prefix matching, fixing potential collisions when multiple nuget feeds share a host. Test log lines updated: RegisterURL uses consistent log format without the leading indent that the old code used.
1 parent 5328230 commit 65234e2

File tree

2 files changed

+90
-78
lines changed

2 files changed

+90
-78
lines changed

internal/handlers/nuget_feed.go

Lines changed: 86 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"net/http"
1010
"net/url"
1111
"strings"
12-
"sync"
1312
"time"
1413

1514
"github.com/elazarl/goproxy"
@@ -36,9 +35,8 @@ type nugetV3IndexResponse struct {
3635

3736
// NugetFeedHandler handles requests to nuget feeds, adding auth.
3837
type NugetFeedHandler struct {
39-
credentials []nugetFeedCredentials
40-
oidcCredentials map[string]*oidc.OIDCCredential
41-
mutex sync.RWMutex
38+
credentials []nugetFeedCredentials
39+
oidcRegistry *oidc.OIDCRegistry
4240
}
4341

4442
type nugetFeedCredentials struct {
@@ -52,8 +50,8 @@ type nugetFeedCredentials struct {
5250
// NewNugetFeedHandler returns a new NugetFeedHandler.
5351
func NewNugetFeedHandler(creds config.Credentials) *NugetFeedHandler {
5452
handler := NugetFeedHandler{
55-
credentials: []nugetFeedCredentials{},
56-
oidcCredentials: make(map[string]*oidc.OIDCCredential),
53+
credentials: []nugetFeedCredentials{},
54+
oidcRegistry: oidc.NewOIDCRegistry(),
5755
}
5856

5957
httpClient := &http.Client{
@@ -72,58 +70,68 @@ func NewNugetFeedHandler(creds config.Credentials) *NugetFeedHandler {
7270
username := cred.GetString("username")
7371
password := cred.GetString("password")
7472

75-
oidcCredential, _ := oidc.CreateOIDCCredential(cred)
76-
if oidcCredential != nil {
77-
key := url
78-
if key == "" {
79-
key = host
80-
}
73+
oidcCredential, _, ok := handler.oidcRegistry.Register(cred, []string{"url"}, "nuget feed")
74+
if ok {
75+
// Discover additional resource URLs from the nuget feed index.
76+
// Host-only credentials (from the CLI) are still registered above
77+
// for request-time matching, but discovery requires an absolute URL.
78+
// Wrapped in a closure so defer runs promptly for each credential,
79+
// ensuring the HTTP response body is always closed (pre-existing
80+
// leak fixed here: the body was previously leaked on ReadAll error
81+
// and on early-return status code paths).
82+
if url != "" {
83+
func() {
84+
req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil)
85+
if err != nil {
86+
logging.RequestLogf(nil, "error creating http request (%s): %v", url, err)
87+
return
88+
}
8189

82-
if key != "" {
83-
handler.oidcCredentials[key] = oidcCredential
84-
logging.RequestLogf(nil, "registered %s OIDC credentials for nuget feed: %s", oidcCredential.Provider(), key)
90+
if req.URL.Scheme != "https" {
91+
logging.RequestLogf(nil, "refusing to discover nuget feed over non-https URL %s", url)
92+
return
93+
}
8594

86-
// now query all resources to add to the authentication list
87-
req, err := http.NewRequestWithContext(context.Background(), "GET", key, nil)
88-
if err != nil {
89-
logging.RequestLogf(nil, "error creating http request (%s): %v", key, err)
90-
continue
91-
}
95+
if !handler.oidcRegistry.TryAuth(req, nil) {
96+
return
97+
}
9298

93-
if oidc.TryAuthOIDCRequestWithPrefix(&handler.mutex, handler.oidcCredentials, req, nil) {
9499
rawRsp, err := httpClient.Do(req)
95100
if err != nil {
96-
logging.RequestLogf(nil, "error retrieving http response (%s): %v", key, err)
97-
continue
101+
logging.RequestLogf(nil, "error retrieving http response (%s): %v", url, err)
102+
return
98103
}
104+
defer rawRsp.Body.Close()
99105

100106
body, err := io.ReadAll(rawRsp.Body)
101107
if err != nil {
102-
logging.RequestLogf(nil, "error reading http response body")
103-
continue
108+
logging.RequestLogf(nil, "error reading http response body (%s): %v", url, err)
109+
return
104110
}
105-
rawRsp.Body.Close()
106111

107112
switch rawRsp.StatusCode {
108113
case 401, 403:
109-
logging.RequestLogf(nil, "unauthorized for nuget feed %s", key)
110-
continue
114+
logging.RequestLogf(nil, "unauthorized for nuget feed %s", url)
115+
return
111116
}
112117

113118
if rawRsp.StatusCode >= 400 {
114-
logging.RequestLogf(nil, "unexpected http response %d for nuget feed %s", rawRsp.StatusCode, key)
115-
continue
119+
logging.RequestLogf(nil, "unexpected http response %d for nuget feed %s", rawRsp.StatusCode, url)
120+
return
116121
}
117122

118-
urlsToAuthenticate := extraUrlsFromSourceResponse(body, key)
119-
for _, url := range urlsToAuthenticate {
120-
handler.oidcCredentials[url] = oidcCredential
121-
logging.RequestLogf(nil, " registered %s OIDC credentials for nuget resource: %s", oidcCredential.Provider(), url)
123+
urlsToAuthenticate := extraUrlsFromSourceResponse(body, url)
124+
for _, discoveredURL := range urlsToAuthenticate {
125+
handler.oidcRegistry.RegisterURL(discoveredURL, oidcCredential, "nuget resource")
122126
}
123-
}
127+
}()
124128
}
125129
continue
126130
}
131+
// OIDC credentials are not used as static credentials.
132+
if oidcCredential != nil {
133+
continue
134+
}
127135

128136
feedCred := nugetFeedCredentials{
129137
url: url,
@@ -138,48 +146,52 @@ func NewNugetFeedHandler(creds config.Credentials) *NugetFeedHandler {
138146
// and authenticate them all
139147
if url != "" {
140148
logging.RequestLogf(nil, "fetching service index for nuget feed %s", url)
141-
req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil)
142-
authenticateNugetRequest(req, feedCred, nil)
143-
if err != nil {
144-
logging.RequestLogf(nil, "error creating http request (%s): %v", url, err)
145-
continue
146-
}
149+
// Same closure pattern as the OIDC block above — ensures the
150+
// HTTP response body is always closed via defer.
151+
func() {
152+
req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil)
153+
if err != nil {
154+
logging.RequestLogf(nil, "error creating http request (%s): %v", url, err)
155+
return
156+
}
157+
authenticateNugetRequest(req, feedCred, nil)
147158

148-
rawRsp, err := httpClient.Do(req)
149-
if err != nil {
150-
logging.RequestLogf(nil, "error retrieving http response (%s): %v", url, err)
151-
continue
152-
}
159+
rawRsp, err := httpClient.Do(req)
160+
if err != nil {
161+
logging.RequestLogf(nil, "error retrieving http response (%s): %v", url, err)
162+
return
163+
}
164+
defer rawRsp.Body.Close()
153165

154-
body, err := io.ReadAll(rawRsp.Body)
155-
if err != nil {
156-
logging.RequestLogf(nil, "error reading http response body")
157-
continue
158-
}
159-
rawRsp.Body.Close()
166+
body, err := io.ReadAll(rawRsp.Body)
167+
if err != nil {
168+
logging.RequestLogf(nil, "error reading http response body (%s): %v", url, err)
169+
return
170+
}
160171

161-
switch rawRsp.StatusCode {
162-
case 401, 403:
163-
logging.RequestLogf(nil, "unauthorized for nuget feed %s", url)
164-
continue
165-
}
172+
switch rawRsp.StatusCode {
173+
case 401, 403:
174+
logging.RequestLogf(nil, "unauthorized for nuget feed %s", url)
175+
return
176+
}
166177

167-
if rawRsp.StatusCode >= 400 {
168-
logging.RequestLogf(nil, "unexpected http response %d for nuget feed %s", rawRsp.StatusCode, url)
169-
continue
170-
}
178+
if rawRsp.StatusCode >= 400 {
179+
logging.RequestLogf(nil, "unexpected http response %d for nuget feed %s", rawRsp.StatusCode, url)
180+
return
181+
}
171182

172-
urlsToAuthenticate := extraUrlsFromSourceResponse(body, url)
173-
for _, url := range urlsToAuthenticate {
174-
feedCred := nugetFeedCredentials{
175-
url: url,
176-
token: token,
177-
username: username,
178-
password: password,
183+
urlsToAuthenticate := extraUrlsFromSourceResponse(body, url)
184+
for _, discoveredURL := range urlsToAuthenticate {
185+
feedCred := nugetFeedCredentials{
186+
url: discoveredURL,
187+
token: token,
188+
username: username,
189+
password: password,
190+
}
191+
handler.credentials = append(handler.credentials, feedCred)
192+
logging.RequestLogf(nil, " added url to authentication list: %s", discoveredURL)
179193
}
180-
handler.credentials = append(handler.credentials, feedCred)
181-
logging.RequestLogf(nil, " added url to authentication list: %s", url)
182-
}
194+
}()
183195
}
184196
}
185197

@@ -261,8 +273,8 @@ func (h *NugetFeedHandler) HandleRequest(req *http.Request, ctx *goproxy.ProxyCt
261273
return req, nil
262274
}
263275

264-
// Try OIDC credentials first
265-
if oidc.TryAuthOIDCRequestWithPrefix(&h.mutex, h.oidcCredentials, req, ctx) {
276+
// Try OIDC credentials first (HTTPS only to avoid leaking tokens over plaintext)
277+
if req.URL.Scheme == "https" && h.oidcRegistry.TryAuth(req, ctx) {
266278
return req, nil
267279
}
268280

internal/handlers/oidc_handling_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ func TestOIDCURLsAreAuthenticated(t *testing.T) {
822822
},
823823
expectedLogLines: []string{
824824
"registered aws OIDC credentials for nuget feed: https://nuget.example.com/index.json",
825-
" registered aws OIDC credentials for nuget resource: https://nuget.example.com/v3/packages",
825+
"registered aws OIDC credentials for nuget resource: https://nuget.example.com/v3/packages",
826826
},
827827
urlsToAuthenticate: []string{
828828
"https://nuget.example.com/index.json", // base url
@@ -852,7 +852,7 @@ func TestOIDCURLsAreAuthenticated(t *testing.T) {
852852
},
853853
expectedLogLines: []string{
854854
"registered azure OIDC credentials for nuget feed: https://nuget.example.com/index.json",
855-
" registered azure OIDC credentials for nuget resource: https://nuget.example.com/v3/packages",
855+
"registered azure OIDC credentials for nuget resource: https://nuget.example.com/v3/packages",
856856
},
857857
urlsToAuthenticate: []string{
858858
"https://nuget.example.com/index.json", // base url
@@ -881,7 +881,7 @@ func TestOIDCURLsAreAuthenticated(t *testing.T) {
881881
},
882882
expectedLogLines: []string{
883883
"registered jfrog OIDC credentials for nuget feed: https://jfrog.example.com/index.json",
884-
" registered jfrog OIDC credentials for nuget resource: https://jfrog.example.com/v3/packages",
884+
"registered jfrog OIDC credentials for nuget resource: https://jfrog.example.com/v3/packages",
885885
},
886886
urlsToAuthenticate: []string{
887887
"https://jfrog.example.com/index.json", // base url
@@ -912,7 +912,7 @@ func TestOIDCURLsAreAuthenticated(t *testing.T) {
912912
},
913913
expectedLogLines: []string{
914914
"registered cloudsmith OIDC credentials for nuget feed: https://cloudsmith.example.com/v3/index.json",
915-
" registered cloudsmith OIDC credentials for nuget resource: https://cloudsmith.example.com/v3/packages",
915+
"registered cloudsmith OIDC credentials for nuget resource: https://cloudsmith.example.com/v3/packages",
916916
},
917917
urlsToAuthenticate: []string{
918918
"https://cloudsmith.example.com/v3/index.json", // base url

0 commit comments

Comments
 (0)