From 52a9425218c03076ab775c2ff28846409047d5ed Mon Sep 17 00:00:00 2001 From: Ross Nelson Date: Tue, 24 Feb 2026 09:29:17 -0500 Subject: [PATCH] Check certFile modification time instead of keyFile The certLoader was using keyFile's modification time to detect certificate changes. When a certificate is renewed using the same key, the keyFile remains unchanged and the new cert is never loaded. This switches to checking certFile's modification time, which correctly detects renewals regardless of whether the key was rotated. Based on #2805 by @ndtretyak. --- server/server/rpc/tls.go | 6 +- server/server/rpc/tls_cert_loader_test.go | 106 +++++++++++++--------- 2 files changed, 67 insertions(+), 45 deletions(-) diff --git a/server/server/rpc/tls.go b/server/server/rpc/tls.go index 382bb3b6ab..3e53e4f2ba 100644 --- a/server/server/rpc/tls.go +++ b/server/server/rpc/tls.go @@ -61,17 +61,17 @@ type certLoader struct { } func (l *certLoader) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { - stat, err := os.Stat(l.KeyFile) + stat, err := os.Stat(l.CertFile) if err != nil { l.lock.RLock() existingCert := l.cachedCert l.lock.RUnlock() if existingCert == nil { - return nil, fmt.Errorf("statting tls key file: %w", err) + return nil, fmt.Errorf("statting tls cert file: %w", err) } - log.Printf("unable to stat tls key file, returning cached cert which may expire: %s", err) + log.Printf("unable to stat tls cert file, returning cached cert which may expire: %s", err) return existingCert, nil } diff --git a/server/server/rpc/tls_cert_loader_test.go b/server/server/rpc/tls_cert_loader_test.go index 0712a86357..4ab6543100 100644 --- a/server/server/rpc/tls_cert_loader_test.go +++ b/server/server/rpc/tls_cert_loader_test.go @@ -17,58 +17,84 @@ import ( ) func TestCertLoader_ReloadsNewKeyPair(t *testing.T) { - // Use Go testing temporary directory for test files - dir := t.TempDir() - certPath := filepath.Join(dir, "cert.pem") - keyPath := filepath.Join(dir, "key.pem") + tests := []struct { + name string + reuseKey bool + }{ + { + name: "regenerate both cert and key", + reuseKey: false, + }, + { + name: "regenerate only cert with same key", + reuseKey: true, + }, + } - // Write initial certificate and key - certPEM1, keyPEM1 := generateCertKeyPair(t, "initial") - assert.NoError(t, os.WriteFile(certPath, certPEM1, 0644)) - assert.NoError(t, os.WriteFile(keyPath, keyPEM1, 0644)) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + certPath := filepath.Join(dir, "cert.pem") + keyPath := filepath.Join(dir, "key.pem") - // Initialize the loader - loader := &certLoader{CertFile: certPath, KeyFile: keyPath} + certPEM1, keyPEM1 := generateCertKeyPair(t, "initial") + assert.NoError(t, os.WriteFile(certPath, certPEM1, 0644)) + assert.NoError(t, os.WriteFile(keyPath, keyPEM1, 0644)) - // First load should read the initial pair - loaded1, err := loader.GetClientCertificate(nil) - assert.NoError(t, err) - assert.NotNil(t, loaded1) + loader := &certLoader{CertFile: certPath, KeyFile: keyPath} - // Compare against a direct X509KeyPair parse - expect1, err := tls.X509KeyPair(certPEM1, keyPEM1) - assert.NoError(t, err) - assert.Equal(t, expect1.Certificate, loaded1.Certificate) + loaded1, err := loader.GetClientCertificate(nil) + assert.NoError(t, err) + assert.NotNil(t, loaded1) - // Wait to ensure file modification time will differ - time.Sleep(500 * time.Millisecond) + expect1, err := tls.X509KeyPair(certPEM1, keyPEM1) + assert.NoError(t, err) + assert.Equal(t, expect1.Certificate, loaded1.Certificate) - // Overwrite with a new certificate and key - certPEM2, keyPEM2 := generateCertKeyPair(t, "updated") - assert.NoError(t, os.WriteFile(certPath, certPEM2, 0644)) - assert.NoError(t, os.WriteFile(keyPath, keyPEM2, 0644)) + time.Sleep(500 * time.Millisecond) - // Second load should pick up the updated pair - loaded2, err := loader.GetClientCertificate(nil) - assert.NoError(t, err) - assert.NotNil(t, loaded2) + var certPEM2, keyPEM2 []byte + if tt.reuseKey { + certPEM2 = generateCertForKey(t, "updated", keyPEM1) + keyPEM2 = keyPEM1 + } else { + certPEM2, keyPEM2 = generateCertKeyPair(t, "updated") + } - expect2, err := tls.X509KeyPair(certPEM2, keyPEM2) - assert.NoError(t, err) + assert.NoError(t, os.WriteFile(certPath, certPEM2, 0644)) + + if !tt.reuseKey { + assert.NoError(t, os.WriteFile(keyPath, keyPEM2, 0644)) + } - // Compare the loaded certificate with the expected one - assert.Equal(t, expect2.Certificate, loaded2.Certificate) + loaded2, err := loader.GetClientCertificate(nil) + assert.NoError(t, err) + assert.NotNil(t, loaded2) - // Ensure the loader did not return the old certificate - assert.NotEqual(t, expect1.Certificate, loaded2.Certificate) + expect2, err := tls.X509KeyPair(certPEM2, keyPEM2) + assert.NoError(t, err) + + assert.Equal(t, expect2.Certificate, loaded2.Certificate) + assert.Equal(t, expect2.PrivateKey, loaded2.PrivateKey) + assert.NotEqual(t, expect1.Certificate, loaded2.Certificate) + }) + } } -// generateCertKeyPair creates a self-signed certificate and private key for testing. func generateCertKeyPair(t *testing.T, commonName string) (certPEM, keyPEM []byte) { - // Generate a private key key, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) - // Create a certificate template + keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + certPEM = generateCertForKey(t, commonName, keyPEM) + return certPEM, keyPEM +} + +func generateCertForKey(t *testing.T, commonName string, existingKeyPEM []byte) []byte { + block, _ := pem.Decode(existingKeyPEM) + assert.NotNil(t, block) + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + assert.NoError(t, err) + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) assert.NoError(t, err) @@ -82,11 +108,7 @@ func generateCertKeyPair(t *testing.T, commonName string) (certPEM, keyPEM []byt IsCA: true, BasicConstraintsValid: true, } - // Self-sign the certificate derBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &key.PublicKey, key) assert.NoError(t, err) - // PEM encode the certificate and key - certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) - keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) - return certPEM, keyPEM + return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) }