Skip to content

Commit 6dfaf50

Browse files
author
Viktor Pettersson
authored
fix(edge): add server certificate revalidation for mTLS connections [BE-11609] (#456)
1 parent 61732c8 commit 6dfaf50

File tree

2 files changed

+65
-24
lines changed

2 files changed

+65
-24
lines changed

edge/client/build.go

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,16 @@ import (
1818
)
1919

2020
type edgeHTTPClient struct {
21-
httpClient *http.Client
22-
options *agent.Options
23-
revokeService *revoke.Service
24-
certMTime time.Time
25-
keyMTime time.Time
26-
caMTime time.Time
27-
mu sync.RWMutex
28-
localAddr *net.TCPAddr
21+
httpClient *http.Client
22+
options *agent.Options
23+
revokeService *revoke.Service
24+
certMTime time.Time
25+
keyMTime time.Time
26+
caMTime time.Time
27+
mu sync.RWMutex
28+
localAddr *net.TCPAddr
29+
verifiedChains [][]*x509.Certificate // verifiedChains is used to store the verified certificate chains from the mTLS handshake
30+
verifiedChainsMu sync.RWMutex
2931
}
3032

3133
func BuildHTTPClient(timeout float64, options *agent.Options) *edgeHTTPClient {
@@ -55,6 +57,15 @@ func (c *edgeHTTPClient) Do(req *http.Request) (*http.Response, error) {
5557
c.mu.Unlock()
5658
}
5759

60+
// If mTLS is enforced and there is an established connection, check validity of server certificates
61+
if !c.options.EdgeInsecurePoll {
62+
if err := c.verifyPeerCertificate(nil, c.getVerifiedChains()); err != nil {
63+
c.mu.Lock()
64+
c.httpClient.Transport = c.buildTransport()
65+
c.mu.Unlock()
66+
}
67+
}
68+
5869
c.mu.RLock()
5970
defer c.mu.RUnlock()
6071

@@ -139,22 +150,54 @@ func (c *edgeHTTPClient) buildTransport() *http.Transport {
139150
return &cert, err
140151
}
141152

142-
transport.TLSClientConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
143-
for _, chain := range verifiedChains {
144-
for _, cert := range chain {
145-
revoked, err := c.revokeService.VerifyCertificate(cert)
146-
if err != nil {
147-
return err
148-
}
149-
150-
if revoked {
151-
return errors.New("certificate has been revoked")
152-
}
153-
}
154-
}
153+
transport.TLSClientConfig.VerifyPeerCertificate = c.verifyPeerCertificate
154+
155+
transport.TLSClientConfig.VerifyConnection = func(state tls.ConnectionState) error {
156+
// Note, this callback is called during the TLS handshake, but after the VerifyPeerCertificate callback.
157+
// The state argument contains the verified chains, which are used to check if certificates have been revoked, or expired.
158+
c.setVerifiedChains(state.VerifiedChains)
155159

156160
return nil
157161
}
158162

159163
return transport
160164
}
165+
166+
// verifyPeerCertificate is a callback that adheres to the tls.Config.VerifyPeerCertificate signature. It is used to
167+
// verify the peer certificate chain and check if the certificate has been revoked.
168+
func (c *edgeHTTPClient) verifyPeerCertificate(_ [][]byte, verifiedChains [][]*x509.Certificate) error {
169+
for _, chain := range verifiedChains {
170+
for _, cert := range chain {
171+
revoked, err := c.revokeService.VerifyCertificate(cert)
172+
if err != nil {
173+
return err
174+
}
175+
176+
if revoked {
177+
return errors.New("certificate has been revoked")
178+
}
179+
}
180+
}
181+
182+
return nil
183+
}
184+
185+
func (c *edgeHTTPClient) setVerifiedChains(chains [][]*x509.Certificate) {
186+
c.verifiedChainsMu.Lock()
187+
defer c.verifiedChainsMu.Unlock()
188+
189+
c.verifiedChains = chains
190+
}
191+
192+
func (c *edgeHTTPClient) getVerifiedChains() [][]*x509.Certificate {
193+
c.verifiedChainsMu.RLock()
194+
defer c.verifiedChainsMu.RUnlock()
195+
196+
verifiedChains := make([][]*x509.Certificate, len(c.verifiedChains))
197+
for i, chain := range c.verifiedChains {
198+
verifiedChains[i] = make([]*x509.Certificate, len(chain))
199+
copy(verifiedChains[i], chain)
200+
}
201+
202+
return verifiedChains
203+
}

edge/revoke/revoke.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,6 @@ func (service *Service) certIsRevokedCRL(cert *x509.Certificate, url string) (re
123123
}
124124
}
125125

126-
issuer := service.getIssuer(cert)
127-
128126
if shouldFetchCRL {
129127
var err error
130128
crl, err = service.fetchCRL(url)
@@ -135,7 +133,7 @@ func (service *Service) certIsRevokedCRL(cert *x509.Certificate, url string) (re
135133
}
136134

137135
// check CRL signature
138-
if issuer != nil {
136+
if issuer := service.getIssuer(cert); issuer != nil {
139137
err = issuer.CheckCRLSignature(crl)
140138
if err != nil {
141139
log.Warn().Str("url", url).Err(err).Msg("failed verifying CRL")

0 commit comments

Comments
 (0)