Skip to content

Commit c8bf1a1

Browse files
committed
Migrate nuget handler to OIDCRegistry
Replace manual OIDC credential map and mutex with the shared OIDCRegistry type. Nuget uses Register() for the primary feed URL and RegisterURL() for HTTP-discovered resource URLs. The old code stored OIDC credentials by url-with-host-fallback. OIDCRegistry preserves the full URL with path-prefix matching, fixing potential collisions when multiple nuget feeds share a host. Test log lines updated: RegisterURL uses consistent log format without the leading indent that the old code used.
1 parent 5328230 commit c8bf1a1

2 files changed

Lines changed: 86 additions & 78 deletions

File tree

internal/handlers/nuget_feed.go

Lines changed: 82 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"net/http"
1010
"net/url"
1111
"strings"
12-
"sync"
1312
"time"
1413

1514
"github.com/elazarl/goproxy"
@@ -36,9 +35,8 @@ type nugetV3IndexResponse struct {
3635

3736
// NugetFeedHandler handles requests to nuget feeds, adding auth.
3837
type NugetFeedHandler struct {
39-
credentials []nugetFeedCredentials
40-
oidcCredentials map[string]*oidc.OIDCCredential
41-
mutex sync.RWMutex
38+
credentials []nugetFeedCredentials
39+
oidcRegistry *oidc.OIDCRegistry
4240
}
4341

4442
type nugetFeedCredentials struct {
@@ -52,8 +50,8 @@ type nugetFeedCredentials struct {
5250
// NewNugetFeedHandler returns a new NugetFeedHandler.
5351
func NewNugetFeedHandler(creds config.Credentials) *NugetFeedHandler {
5452
handler := NugetFeedHandler{
55-
credentials: []nugetFeedCredentials{},
56-
oidcCredentials: make(map[string]*oidc.OIDCCredential),
53+
credentials: []nugetFeedCredentials{},
54+
oidcRegistry: oidc.NewOIDCRegistry(),
5755
}
5856

5957
httpClient := &http.Client{
@@ -72,55 +70,61 @@ func NewNugetFeedHandler(creds config.Credentials) *NugetFeedHandler {
7270
username := cred.GetString("username")
7371
password := cred.GetString("password")
7472

75-
oidcCredential, _ := oidc.CreateOIDCCredential(cred)
76-
if oidcCredential != nil {
77-
key := url
78-
if key == "" {
79-
key = host
80-
}
73+
oidcCredential, _, ok := handler.oidcRegistry.Register(cred, []string{"url"}, "nuget feed")
74+
if ok {
75+
// Discover additional resource URLs from the nuget feed index.
76+
// Host-only credentials (from the CLI) are still registered above
77+
// for request-time matching, but discovery requires an absolute URL.
78+
// Wrapped in a closure so defer runs promptly for each credential,
79+
// ensuring the HTTP response body is always closed (pre-existing
80+
// leak fixed here: the body was previously leaked on ReadAll error
81+
// and on early-return status code paths).
82+
if url != "" {
83+
func() {
84+
req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil)
85+
if err != nil {
86+
logging.RequestLogf(nil, "error creating http request (%s): %v", url, err)
87+
return
88+
}
8189

82-
if key != "" {
83-
handler.oidcCredentials[key] = oidcCredential
84-
logging.RequestLogf(nil, "registered %s OIDC credentials for nuget feed: %s", oidcCredential.Provider(), key)
90+
if req.URL.Scheme != "https" {
91+
logging.RequestLogf(nil, "refusing to discover nuget feed over non-https URL %s", url)
92+
return
93+
}
8594

86-
// now query all resources to add to the authentication list
87-
req, err := http.NewRequestWithContext(context.Background(), "GET", key, nil)
88-
if err != nil {
89-
logging.RequestLogf(nil, "error creating http request (%s): %v", key, err)
90-
continue
91-
}
95+
if !handler.oidcRegistry.TryAuth(req, nil) {
96+
return
97+
}
9298

93-
if oidc.TryAuthOIDCRequestWithPrefix(&handler.mutex, handler.oidcCredentials, req, nil) {
9499
rawRsp, err := httpClient.Do(req)
95100
if err != nil {
96-
logging.RequestLogf(nil, "error retrieving http response (%s): %v", key, err)
97-
continue
101+
logging.RequestLogf(nil, "error retrieving http response (%s): %v", url, err)
102+
return
98103
}
104+
defer rawRsp.Body.Close()
99105

100106
body, err := io.ReadAll(rawRsp.Body)
101107
if err != nil {
102-
logging.RequestLogf(nil, "error reading http response body")
103-
continue
108+
logging.RequestLogf(nil, "error reading http response body (%s): %v", url, err)
109+
return
104110
}
105-
rawRsp.Body.Close()
106111

107112
switch rawRsp.StatusCode {
108113
case 401, 403:
109-
logging.RequestLogf(nil, "unauthorized for nuget feed %s", key)
110-
continue
114+
logging.RequestLogf(nil, "unauthorized for nuget feed %s", url)
115+
return
111116
}
112117

113118
if rawRsp.StatusCode >= 400 {
114-
logging.RequestLogf(nil, "unexpected http response %d for nuget feed %s", rawRsp.StatusCode, key)
115-
continue
119+
logging.RequestLogf(nil, "unexpected http response %d for nuget feed %s", rawRsp.StatusCode, url)
120+
return
116121
}
117122

118-
urlsToAuthenticate := extraUrlsFromSourceResponse(body, key)
119-
for _, url := range urlsToAuthenticate {
120-
handler.oidcCredentials[url] = oidcCredential
121-
logging.RequestLogf(nil, " registered %s OIDC credentials for nuget resource: %s", oidcCredential.Provider(), url)
123+
urlsToAuthenticate := extraUrlsFromSourceResponse(body, url)
124+
for _, discoveredURL := range urlsToAuthenticate {
125+
handler.oidcRegistry.RegisterURL(discoveredURL, oidcCredential, "nuget resource")
122126
}
123-
}
127+
}()
124128
}
125129
continue
126130
}
@@ -138,48 +142,52 @@ func NewNugetFeedHandler(creds config.Credentials) *NugetFeedHandler {
138142
// and authenticate them all
139143
if url != "" {
140144
logging.RequestLogf(nil, "fetching service index for nuget feed %s", url)
141-
req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil)
142-
authenticateNugetRequest(req, feedCred, nil)
143-
if err != nil {
144-
logging.RequestLogf(nil, "error creating http request (%s): %v", url, err)
145-
continue
146-
}
145+
// Same closure pattern as the OIDC block above — ensures the
146+
// HTTP response body is always closed via defer.
147+
func() {
148+
req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil)
149+
if err != nil {
150+
logging.RequestLogf(nil, "error creating http request (%s): %v", url, err)
151+
return
152+
}
153+
authenticateNugetRequest(req, feedCred, nil)
147154

148-
rawRsp, err := httpClient.Do(req)
149-
if err != nil {
150-
logging.RequestLogf(nil, "error retrieving http response (%s): %v", url, err)
151-
continue
152-
}
155+
rawRsp, err := httpClient.Do(req)
156+
if err != nil {
157+
logging.RequestLogf(nil, "error retrieving http response (%s): %v", url, err)
158+
return
159+
}
160+
defer rawRsp.Body.Close()
153161

154-
body, err := io.ReadAll(rawRsp.Body)
155-
if err != nil {
156-
logging.RequestLogf(nil, "error reading http response body")
157-
continue
158-
}
159-
rawRsp.Body.Close()
162+
body, err := io.ReadAll(rawRsp.Body)
163+
if err != nil {
164+
logging.RequestLogf(nil, "error reading http response body (%s): %v", url, err)
165+
return
166+
}
160167

161-
switch rawRsp.StatusCode {
162-
case 401, 403:
163-
logging.RequestLogf(nil, "unauthorized for nuget feed %s", url)
164-
continue
165-
}
168+
switch rawRsp.StatusCode {
169+
case 401, 403:
170+
logging.RequestLogf(nil, "unauthorized for nuget feed %s", url)
171+
return
172+
}
166173

167-
if rawRsp.StatusCode >= 400 {
168-
logging.RequestLogf(nil, "unexpected http response %d for nuget feed %s", rawRsp.StatusCode, url)
169-
continue
170-
}
174+
if rawRsp.StatusCode >= 400 {
175+
logging.RequestLogf(nil, "unexpected http response %d for nuget feed %s", rawRsp.StatusCode, url)
176+
return
177+
}
171178

172-
urlsToAuthenticate := extraUrlsFromSourceResponse(body, url)
173-
for _, url := range urlsToAuthenticate {
174-
feedCred := nugetFeedCredentials{
175-
url: url,
176-
token: token,
177-
username: username,
178-
password: password,
179+
urlsToAuthenticate := extraUrlsFromSourceResponse(body, url)
180+
for _, discoveredURL := range urlsToAuthenticate {
181+
feedCred := nugetFeedCredentials{
182+
url: discoveredURL,
183+
token: token,
184+
username: username,
185+
password: password,
186+
}
187+
handler.credentials = append(handler.credentials, feedCred)
188+
logging.RequestLogf(nil, " added url to authentication list: %s", discoveredURL)
179189
}
180-
handler.credentials = append(handler.credentials, feedCred)
181-
logging.RequestLogf(nil, " added url to authentication list: %s", url)
182-
}
190+
}()
183191
}
184192
}
185193

@@ -261,8 +269,8 @@ func (h *NugetFeedHandler) HandleRequest(req *http.Request, ctx *goproxy.ProxyCt
261269
return req, nil
262270
}
263271

264-
// Try OIDC credentials first
265-
if oidc.TryAuthOIDCRequestWithPrefix(&h.mutex, h.oidcCredentials, req, ctx) {
272+
// Try OIDC credentials first (HTTPS only to avoid leaking tokens over plaintext)
273+
if req.URL.Scheme == "https" && h.oidcRegistry.TryAuth(req, ctx) {
266274
return req, nil
267275
}
268276

internal/handlers/oidc_handling_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ func TestOIDCURLsAreAuthenticated(t *testing.T) {
822822
},
823823
expectedLogLines: []string{
824824
"registered aws OIDC credentials for nuget feed: https://nuget.example.com/index.json",
825-
" registered aws OIDC credentials for nuget resource: https://nuget.example.com/v3/packages",
825+
"registered aws OIDC credentials for nuget resource: https://nuget.example.com/v3/packages",
826826
},
827827
urlsToAuthenticate: []string{
828828
"https://nuget.example.com/index.json", // base url
@@ -852,7 +852,7 @@ func TestOIDCURLsAreAuthenticated(t *testing.T) {
852852
},
853853
expectedLogLines: []string{
854854
"registered azure OIDC credentials for nuget feed: https://nuget.example.com/index.json",
855-
" registered azure OIDC credentials for nuget resource: https://nuget.example.com/v3/packages",
855+
"registered azure OIDC credentials for nuget resource: https://nuget.example.com/v3/packages",
856856
},
857857
urlsToAuthenticate: []string{
858858
"https://nuget.example.com/index.json", // base url
@@ -881,7 +881,7 @@ func TestOIDCURLsAreAuthenticated(t *testing.T) {
881881
},
882882
expectedLogLines: []string{
883883
"registered jfrog OIDC credentials for nuget feed: https://jfrog.example.com/index.json",
884-
" registered jfrog OIDC credentials for nuget resource: https://jfrog.example.com/v3/packages",
884+
"registered jfrog OIDC credentials for nuget resource: https://jfrog.example.com/v3/packages",
885885
},
886886
urlsToAuthenticate: []string{
887887
"https://jfrog.example.com/index.json", // base url
@@ -912,7 +912,7 @@ func TestOIDCURLsAreAuthenticated(t *testing.T) {
912912
},
913913
expectedLogLines: []string{
914914
"registered cloudsmith OIDC credentials for nuget feed: https://cloudsmith.example.com/v3/index.json",
915-
" registered cloudsmith OIDC credentials for nuget resource: https://cloudsmith.example.com/v3/packages",
915+
"registered cloudsmith OIDC credentials for nuget resource: https://cloudsmith.example.com/v3/packages",
916916
},
917917
urlsToAuthenticate: []string{
918918
"https://cloudsmith.example.com/v3/index.json", // base url

0 commit comments

Comments
 (0)