Skip to content

Commit b6ffe5a

Browse files
committed
Migrate nuget handler to OIDCRegistry
1 parent d088b52 commit b6ffe5a

2 files changed

Lines changed: 90 additions & 78 deletions

File tree

internal/handlers/nuget_feed.go

Lines changed: 86 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"net/http"
1010
"net/url"
1111
"strings"
12-
"sync"
1312
"time"
1413

1514
"github.com/elazarl/goproxy"
@@ -36,9 +35,8 @@ type nugetV3IndexResponse struct {
3635

3736
// NugetFeedHandler handles requests to nuget feeds, adding auth.
3837
type NugetFeedHandler struct {
39-
credentials []nugetFeedCredentials
40-
oidcCredentials map[string]*oidc.OIDCCredential
41-
mutex sync.RWMutex
38+
credentials []nugetFeedCredentials
39+
oidcRegistry *oidc.OIDCRegistry
4240
}
4341

4442
type nugetFeedCredentials struct {
@@ -52,8 +50,8 @@ type nugetFeedCredentials struct {
5250
// NewNugetFeedHandler returns a new NugetFeedHandler.
5351
func NewNugetFeedHandler(creds config.Credentials) *NugetFeedHandler {
5452
handler := NugetFeedHandler{
55-
credentials: []nugetFeedCredentials{},
56-
oidcCredentials: make(map[string]*oidc.OIDCCredential),
53+
credentials: []nugetFeedCredentials{},
54+
oidcRegistry: oidc.NewOIDCRegistry(),
5755
}
5856

5957
httpClient := &http.Client{
@@ -72,58 +70,68 @@ func NewNugetFeedHandler(creds config.Credentials) *NugetFeedHandler {
7270
username := cred.GetString("username")
7371
password := cred.GetString("password")
7472

75-
oidcCredential, _ := oidc.CreateOIDCCredential(cred)
76-
if oidcCredential != nil {
77-
key := url
78-
if key == "" {
79-
key = host
80-
}
73+
oidcCredential, _, ok := handler.oidcRegistry.Register(cred, []string{"url"}, "nuget feed")
74+
if ok {
75+
// Discover additional resource URLs from the nuget feed index.
76+
// Host-only credentials (from the CLI) are still registered above
77+
// for request-time matching, but discovery requires an absolute URL.
78+
// Wrapped in a closure so defer runs promptly for each credential,
79+
// ensuring the HTTP response body is always closed (pre-existing
80+
// leak fixed here: the body was previously leaked on ReadAll error
81+
// and on early-return status code paths).
82+
if url != "" {
83+
func() {
84+
req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil)
85+
if err != nil {
86+
logging.RequestLogf(nil, "error creating http request (%s): %v", url, err)
87+
return
88+
}
8189

82-
if key != "" {
83-
handler.oidcCredentials[key] = oidcCredential
84-
logging.RequestLogf(nil, "registered %s OIDC credentials for nuget feed: %s", oidcCredential.Provider(), key)
90+
if req.URL.Scheme != "https" {
91+
logging.RequestLogf(nil, "refusing to discover nuget feed over non-https URL %s", url)
92+
return
93+
}
8594

86-
// now query all resources to add to the authentication list
87-
req, err := http.NewRequestWithContext(context.Background(), "GET", key, nil)
88-
if err != nil {
89-
logging.RequestLogf(nil, "error creating http request (%s): %v", key, err)
90-
continue
91-
}
95+
if !handler.oidcRegistry.TryAuth(req, nil) {
96+
return
97+
}
9298

93-
if oidc.TryAuthOIDCRequestWithPrefix(&handler.mutex, handler.oidcCredentials, req, nil) {
9499
rawRsp, err := httpClient.Do(req)
95100
if err != nil {
96-
logging.RequestLogf(nil, "error retrieving http response (%s): %v", key, err)
97-
continue
101+
logging.RequestLogf(nil, "error retrieving http response (%s): %v", url, err)
102+
return
98103
}
104+
defer rawRsp.Body.Close()
99105

100106
body, err := io.ReadAll(rawRsp.Body)
101107
if err != nil {
102-
logging.RequestLogf(nil, "error reading http response body")
103-
continue
108+
logging.RequestLogf(nil, "error reading http response body (%s): %v", url, err)
109+
return
104110
}
105-
rawRsp.Body.Close()
106111

107112
switch rawRsp.StatusCode {
108113
case 401, 403:
109-
logging.RequestLogf(nil, "unauthorized for nuget feed %s", key)
110-
continue
114+
logging.RequestLogf(nil, "unauthorized for nuget feed %s", url)
115+
return
111116
}
112117

113118
if rawRsp.StatusCode >= 400 {
114-
logging.RequestLogf(nil, "unexpected http response %d for nuget feed %s", rawRsp.StatusCode, key)
115-
continue
119+
logging.RequestLogf(nil, "unexpected http response %d for nuget feed %s", rawRsp.StatusCode, url)
120+
return
116121
}
117122

118-
urlsToAuthenticate := extraUrlsFromSourceResponse(body, key)
119-
for _, url := range urlsToAuthenticate {
120-
handler.oidcCredentials[url] = oidcCredential
121-
logging.RequestLogf(nil, " registered %s OIDC credentials for nuget resource: %s", oidcCredential.Provider(), url)
123+
urlsToAuthenticate := extraUrlsFromSourceResponse(body, url)
124+
for _, discoveredURL := range urlsToAuthenticate {
125+
handler.oidcRegistry.RegisterURL(discoveredURL, oidcCredential, "nuget resource")
122126
}
123-
}
127+
}()
124128
}
125129
continue
126130
}
131+
// OIDC credentials are not used as static credentials.
132+
if oidcCredential != nil {
133+
continue
134+
}
127135

128136
feedCred := nugetFeedCredentials{
129137
url: url,
@@ -138,48 +146,52 @@ func NewNugetFeedHandler(creds config.Credentials) *NugetFeedHandler {
138146
// and authenticate them all
139147
if url != "" {
140148
logging.RequestLogf(nil, "fetching service index for nuget feed %s", url)
141-
req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil)
142-
authenticateNugetRequest(req, feedCred, nil)
143-
if err != nil {
144-
logging.RequestLogf(nil, "error creating http request (%s): %v", url, err)
145-
continue
146-
}
149+
// Same closure pattern as the OIDC block above — ensures the
150+
// HTTP response body is always closed via defer.
151+
func() {
152+
req, err := http.NewRequestWithContext(context.Background(), "GET", url, nil)
153+
if err != nil {
154+
logging.RequestLogf(nil, "error creating http request (%s): %v", url, err)
155+
return
156+
}
157+
authenticateNugetRequest(req, feedCred, nil)
147158

148-
rawRsp, err := httpClient.Do(req)
149-
if err != nil {
150-
logging.RequestLogf(nil, "error retrieving http response (%s): %v", url, err)
151-
continue
152-
}
159+
rawRsp, err := httpClient.Do(req)
160+
if err != nil {
161+
logging.RequestLogf(nil, "error retrieving http response (%s): %v", url, err)
162+
return
163+
}
164+
defer rawRsp.Body.Close()
153165

154-
body, err := io.ReadAll(rawRsp.Body)
155-
if err != nil {
156-
logging.RequestLogf(nil, "error reading http response body")
157-
continue
158-
}
159-
rawRsp.Body.Close()
166+
body, err := io.ReadAll(rawRsp.Body)
167+
if err != nil {
168+
logging.RequestLogf(nil, "error reading http response body (%s): %v", url, err)
169+
return
170+
}
160171

161-
switch rawRsp.StatusCode {
162-
case 401, 403:
163-
logging.RequestLogf(nil, "unauthorized for nuget feed %s", url)
164-
continue
165-
}
172+
switch rawRsp.StatusCode {
173+
case 401, 403:
174+
logging.RequestLogf(nil, "unauthorized for nuget feed %s", url)
175+
return
176+
}
166177

167-
if rawRsp.StatusCode >= 400 {
168-
logging.RequestLogf(nil, "unexpected http response %d for nuget feed %s", rawRsp.StatusCode, url)
169-
continue
170-
}
178+
if rawRsp.StatusCode >= 400 {
179+
logging.RequestLogf(nil, "unexpected http response %d for nuget feed %s", rawRsp.StatusCode, url)
180+
return
181+
}
171182

172-
urlsToAuthenticate := extraUrlsFromSourceResponse(body, url)
173-
for _, url := range urlsToAuthenticate {
174-
feedCred := nugetFeedCredentials{
175-
url: url,
176-
token: token,
177-
username: username,
178-
password: password,
183+
urlsToAuthenticate := extraUrlsFromSourceResponse(body, url)
184+
for _, discoveredURL := range urlsToAuthenticate {
185+
feedCred := nugetFeedCredentials{
186+
url: discoveredURL,
187+
token: token,
188+
username: username,
189+
password: password,
190+
}
191+
handler.credentials = append(handler.credentials, feedCred)
192+
logging.RequestLogf(nil, " added url to authentication list: %s", discoveredURL)
179193
}
180-
handler.credentials = append(handler.credentials, feedCred)
181-
logging.RequestLogf(nil, " added url to authentication list: %s", url)
182-
}
194+
}()
183195
}
184196
}
185197

@@ -261,8 +273,8 @@ func (h *NugetFeedHandler) HandleRequest(req *http.Request, ctx *goproxy.ProxyCt
261273
return req, nil
262274
}
263275

264-
// Try OIDC credentials first
265-
if oidc.TryAuthOIDCRequestWithPrefix(&h.mutex, h.oidcCredentials, req, ctx) {
276+
// Try OIDC credentials first (HTTPS only to avoid leaking tokens over plaintext)
277+
if req.URL.Scheme == "https" && h.oidcRegistry.TryAuth(req, ctx) {
266278
return req, nil
267279
}
268280

internal/handlers/oidc_handling_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ func TestOIDCURLsAreAuthenticated(t *testing.T) {
822822
},
823823
expectedLogLines: []string{
824824
"registered aws OIDC credentials for nuget feed: https://nuget.example.com/index.json",
825-
" registered aws OIDC credentials for nuget resource: https://nuget.example.com/v3/packages",
825+
"registered aws OIDC credentials for nuget resource: https://nuget.example.com/v3/packages",
826826
},
827827
urlsToAuthenticate: []string{
828828
"https://nuget.example.com/index.json", // base url
@@ -852,7 +852,7 @@ func TestOIDCURLsAreAuthenticated(t *testing.T) {
852852
},
853853
expectedLogLines: []string{
854854
"registered azure OIDC credentials for nuget feed: https://nuget.example.com/index.json",
855-
" registered azure OIDC credentials for nuget resource: https://nuget.example.com/v3/packages",
855+
"registered azure OIDC credentials for nuget resource: https://nuget.example.com/v3/packages",
856856
},
857857
urlsToAuthenticate: []string{
858858
"https://nuget.example.com/index.json", // base url
@@ -881,7 +881,7 @@ func TestOIDCURLsAreAuthenticated(t *testing.T) {
881881
},
882882
expectedLogLines: []string{
883883
"registered jfrog OIDC credentials for nuget feed: https://jfrog.example.com/index.json",
884-
" registered jfrog OIDC credentials for nuget resource: https://jfrog.example.com/v3/packages",
884+
"registered jfrog OIDC credentials for nuget resource: https://jfrog.example.com/v3/packages",
885885
},
886886
urlsToAuthenticate: []string{
887887
"https://jfrog.example.com/index.json", // base url
@@ -912,7 +912,7 @@ func TestOIDCURLsAreAuthenticated(t *testing.T) {
912912
},
913913
expectedLogLines: []string{
914914
"registered cloudsmith OIDC credentials for nuget feed: https://cloudsmith.example.com/v3/index.json",
915-
" registered cloudsmith OIDC credentials for nuget resource: https://cloudsmith.example.com/v3/packages",
915+
"registered cloudsmith OIDC credentials for nuget resource: https://cloudsmith.example.com/v3/packages",
916916
},
917917
urlsToAuthenticate: []string{
918918
"https://cloudsmith.example.com/v3/index.json", // base url

0 commit comments

Comments
 (0)