Skip to content

Commit ba19fa5

Browse files
kbukum1microblag
andauthored
Add generalized OIDCRegistry for collision-free credential storage (#78)
* Add OIDCRegistry type for collision-free OIDC credential storage Co-authored-by: James Garratt <572389+microblag@users.noreply.github.com>
1 parent 475db02 commit ba19fa5

File tree

2 files changed

+713
-0
lines changed

2 files changed

+713
-0
lines changed

internal/oidc/oidc_registry.go

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
package oidc
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
"strings"
7+
"sync"
8+
9+
"github.com/elazarl/goproxy"
10+
11+
"github.com/dependabot/proxy/internal/config"
12+
"github.com/dependabot/proxy/internal/helpers"
13+
"github.com/dependabot/proxy/internal/logging"
14+
)
15+
16+
// OIDCRegistry stores OIDC credentials indexed by host, with path-based
17+
// matching within each host bucket. This structure provides O(1) host lookup
18+
// and avoids key collisions when multiple registries share a host with
19+
// different paths.
20+
type OIDCRegistry struct {
21+
byHost map[string][]oidcEntry
22+
mutex sync.RWMutex
23+
}
24+
25+
type oidcEntry struct {
26+
path string // URL path prefix, e.g. "/org/_packaging/feed-A/npm/registry"
27+
port string // port, defaults to "443"
28+
credential *OIDCCredential
29+
}
30+
31+
// NewOIDCRegistry creates an empty registry.
32+
func NewOIDCRegistry() *OIDCRegistry {
33+
return &OIDCRegistry{
34+
byHost: make(map[string][]oidcEntry),
35+
}
36+
}
37+
38+
// Register attempts to create an OIDC credential from the config and store it.
39+
// urlFields are checked in order for a URL (preserving host + path);
40+
// falls back to cred.Host() (hostname only) as last resort.
41+
//
42+
// Returns:
43+
// - (credential, key, true) if an OIDC credential was created and registered
44+
// - (credential, "", false) if OIDC-configured but no URL or host could be resolved
45+
// - (nil, "", false) if the credential is not OIDC-configured
46+
func (r *OIDCRegistry) Register(
47+
cred config.Credential,
48+
urlFields []string,
49+
registryType string,
50+
) (*OIDCCredential, string, bool) {
51+
oidcCredential, _ := CreateOIDCCredential(cred)
52+
if oidcCredential == nil {
53+
return nil, "", false
54+
}
55+
56+
// Resolve the key: prefer URL fields (preserves path), fall back to host
57+
var key string
58+
for _, field := range urlFields {
59+
if v := cred.GetString(field); v != "" {
60+
key = v
61+
break
62+
}
63+
}
64+
if key == "" {
65+
key = cred.Host()
66+
}
67+
if key == "" {
68+
return oidcCredential, "", false
69+
}
70+
71+
if !r.addEntry(key, oidcCredential, registryType) {
72+
return oidcCredential, "", false
73+
}
74+
75+
return oidcCredential, key, true
76+
}
77+
78+
// RegisterURL adds an already-created credential under a URL.
79+
// Used by nuget to register HTTP-discovered resource URLs that
80+
// should share the same OIDC credential as the primary feed URL.
81+
func (r *OIDCRegistry) RegisterURL(url string, cred *OIDCCredential, registryType string) {
82+
if url == "" || cred == nil {
83+
return
84+
}
85+
r.addEntry(url, cred, registryType)
86+
}
87+
88+
// TryAuth finds the matching OIDC credential for the request and
89+
// sets the appropriate auth header.
90+
//
91+
// Callers are responsible for scheme and method checks (e.g. HTTPS-only,
92+
// GET/HEAD only) before calling TryAuth.
93+
//
94+
// Lookup:
95+
// 1. Find the host bucket via map lookup (exact hostname match)
96+
// 2. Within that bucket, find the entry whose stored path is a
97+
// prefix of the request path
98+
//
99+
// Returns true if the request was authenticated, false otherwise.
100+
func (r *OIDCRegistry) TryAuth(req *http.Request, ctx *goproxy.ProxyCtx) bool {
101+
host := strings.ToLower(helpers.GetHost(req))
102+
reqPort := req.URL.Port()
103+
if reqPort == "" {
104+
reqPort = "443"
105+
}
106+
107+
r.mutex.RLock()
108+
entries := r.byHost[host]
109+
r.mutex.RUnlock()
110+
111+
if len(entries) == 0 {
112+
return false
113+
}
114+
115+
// Find the most specific matching entry: host is already matched,
116+
// select the longest path prefix among entries with the same port.
117+
var matched *OIDCCredential
118+
bestPathLen := -1
119+
for i := range entries {
120+
e := &entries[i]
121+
if e.port != reqPort {
122+
continue
123+
}
124+
if !strings.HasPrefix(req.URL.Path, e.path) {
125+
continue
126+
}
127+
if len(e.path) > bestPathLen {
128+
matched = e.credential
129+
bestPathLen = len(e.path)
130+
}
131+
}
132+
133+
if matched == nil {
134+
return false
135+
}
136+
137+
token, err := GetOrRefreshOIDCToken(matched, req.Context())
138+
if err != nil {
139+
logging.RequestLogf(ctx, "* failed to get %s token via OIDC for %s: %v", matched.Provider(), host, err)
140+
return false
141+
}
142+
143+
switch matched.parameters.(type) {
144+
case *CloudsmithOIDCParameters:
145+
logging.RequestLogf(ctx, "* authenticating request with OIDC API key (host: %s)", host)
146+
req.Header.Set("X-Api-Key", token)
147+
default:
148+
logging.RequestLogf(ctx, "* authenticating request with OIDC token (host: %s)", host)
149+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
150+
}
151+
152+
return true
153+
}
154+
155+
// addEntry parses a URL or hostname string and adds a credential entry
156+
// to the appropriate host bucket. Returns true if the credential is stored
157+
// (either newly added or already present from a prior registration).
158+
// Returns false only when the URL cannot be parsed.
159+
// Duplicates with the same path and port are skipped (first-wins).
160+
func (r *OIDCRegistry) addEntry(urlOrHost string, cred *OIDCCredential, registryType string) bool {
161+
host, path, port := parseRegistryURL(urlOrHost)
162+
if host == "" {
163+
logging.RequestLogf(nil, "failed to parse OIDC credential URL: %s", urlOrHost)
164+
return false
165+
}
166+
167+
r.mutex.Lock()
168+
defer r.mutex.Unlock()
169+
170+
for _, e := range r.byHost[host] {
171+
if e.path == path && e.port == port {
172+
// First-wins: the credential already stored for this path+port
173+
// will be used; the new one is discarded.
174+
logging.RequestLogf(nil, "skipping duplicate OIDC credential for %s (path: %s)", host, path)
175+
return true
176+
}
177+
}
178+
179+
r.byHost[host] = append(r.byHost[host], oidcEntry{
180+
path: path,
181+
port: port,
182+
credential: cred,
183+
})
184+
185+
logging.RequestLogf(nil, "registered %s OIDC credentials for %s: %s", cred.Provider(), registryType, urlOrHost)
186+
return true
187+
}
188+
189+
// parseRegistryURL extracts host, path, and port from a URL or hostname string.
190+
// For hostname-only input, path is empty and port defaults to "443".
191+
func parseRegistryURL(urlOrHost string) (host, path, port string) {
192+
parsed, err := helpers.ParseURLLax(urlOrHost)
193+
if err != nil {
194+
return "", "", ""
195+
}
196+
197+
host = strings.ToLower(parsed.Hostname())
198+
path = strings.TrimRight(parsed.Path, "/")
199+
port = parsed.Port()
200+
if port == "" {
201+
port = "443"
202+
}
203+
204+
return host, path, port
205+
}

0 commit comments

Comments
 (0)