diff --git a/server/server/rpc/tls.go b/server/server/rpc/tls.go index 382bb3b6ab..3e53e4f2ba 100644 --- a/server/server/rpc/tls.go +++ b/server/server/rpc/tls.go @@ -61,17 +61,17 @@ type certLoader struct { } func (l *certLoader) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { - stat, err := os.Stat(l.KeyFile) + stat, err := os.Stat(l.CertFile) if err != nil { l.lock.RLock() existingCert := l.cachedCert l.lock.RUnlock() if existingCert == nil { - return nil, fmt.Errorf("statting tls key file: %w", err) + return nil, fmt.Errorf("statting tls cert file: %w", err) } - log.Printf("unable to stat tls key file, returning cached cert which may expire: %s", err) + log.Printf("unable to stat tls cert file, returning cached cert which may expire: %s", err) return existingCert, nil } diff --git a/server/server/rpc/tls_cert_loader_test.go b/server/server/rpc/tls_cert_loader_test.go index 0712a86357..4ab6543100 100644 --- a/server/server/rpc/tls_cert_loader_test.go +++ b/server/server/rpc/tls_cert_loader_test.go @@ -17,58 +17,84 @@ import ( ) func TestCertLoader_ReloadsNewKeyPair(t *testing.T) { - // Use Go testing temporary directory for test files - dir := t.TempDir() - certPath := filepath.Join(dir, "cert.pem") - keyPath := filepath.Join(dir, "key.pem") + tests := []struct { + name string + reuseKey bool + }{ + { + name: "regenerate both cert and key", + reuseKey: false, + }, + { + name: "regenerate only cert with same key", + reuseKey: true, + }, + } - // Write initial certificate and key - certPEM1, keyPEM1 := generateCertKeyPair(t, "initial") - assert.NoError(t, os.WriteFile(certPath, certPEM1, 0644)) - assert.NoError(t, os.WriteFile(keyPath, keyPEM1, 0644)) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + certPath := filepath.Join(dir, "cert.pem") + keyPath := filepath.Join(dir, "key.pem") - // Initialize the loader - loader := &certLoader{CertFile: certPath, KeyFile: keyPath} + certPEM1, keyPEM1 := generateCertKeyPair(t, "initial") + assert.NoError(t, os.WriteFile(certPath, certPEM1, 0644)) + assert.NoError(t, os.WriteFile(keyPath, keyPEM1, 0644)) - // First load should read the initial pair - loaded1, err := loader.GetClientCertificate(nil) - assert.NoError(t, err) - assert.NotNil(t, loaded1) + loader := &certLoader{CertFile: certPath, KeyFile: keyPath} - // Compare against a direct X509KeyPair parse - expect1, err := tls.X509KeyPair(certPEM1, keyPEM1) - assert.NoError(t, err) - assert.Equal(t, expect1.Certificate, loaded1.Certificate) + loaded1, err := loader.GetClientCertificate(nil) + assert.NoError(t, err) + assert.NotNil(t, loaded1) - // Wait to ensure file modification time will differ - time.Sleep(500 * time.Millisecond) + expect1, err := tls.X509KeyPair(certPEM1, keyPEM1) + assert.NoError(t, err) + assert.Equal(t, expect1.Certificate, loaded1.Certificate) - // Overwrite with a new certificate and key - certPEM2, keyPEM2 := generateCertKeyPair(t, "updated") - assert.NoError(t, os.WriteFile(certPath, certPEM2, 0644)) - assert.NoError(t, os.WriteFile(keyPath, keyPEM2, 0644)) + time.Sleep(500 * time.Millisecond) - // Second load should pick up the updated pair - loaded2, err := loader.GetClientCertificate(nil) - assert.NoError(t, err) - assert.NotNil(t, loaded2) + var certPEM2, keyPEM2 []byte + if tt.reuseKey { + certPEM2 = generateCertForKey(t, "updated", keyPEM1) + keyPEM2 = keyPEM1 + } else { + certPEM2, keyPEM2 = generateCertKeyPair(t, "updated") + } - expect2, err := tls.X509KeyPair(certPEM2, keyPEM2) - assert.NoError(t, err) + assert.NoError(t, os.WriteFile(certPath, certPEM2, 0644)) + + if !tt.reuseKey { + assert.NoError(t, os.WriteFile(keyPath, keyPEM2, 0644)) + } - // Compare the loaded certificate with the expected one - assert.Equal(t, expect2.Certificate, loaded2.Certificate) + loaded2, err := loader.GetClientCertificate(nil) + assert.NoError(t, err) + assert.NotNil(t, loaded2) - // Ensure the loader did not return the old certificate - assert.NotEqual(t, expect1.Certificate, loaded2.Certificate) + expect2, err := tls.X509KeyPair(certPEM2, keyPEM2) + assert.NoError(t, err) + + assert.Equal(t, expect2.Certificate, loaded2.Certificate) + assert.Equal(t, expect2.PrivateKey, loaded2.PrivateKey) + assert.NotEqual(t, expect1.Certificate, loaded2.Certificate) + }) + } } -// generateCertKeyPair creates a self-signed certificate and private key for testing. func generateCertKeyPair(t *testing.T, commonName string) (certPEM, keyPEM []byte) { - // Generate a private key key, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) - // Create a certificate template + keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + certPEM = generateCertForKey(t, commonName, keyPEM) + return certPEM, keyPEM +} + +func generateCertForKey(t *testing.T, commonName string, existingKeyPEM []byte) []byte { + block, _ := pem.Decode(existingKeyPEM) + assert.NotNil(t, block) + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + assert.NoError(t, err) + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) assert.NoError(t, err) @@ -82,11 +108,7 @@ func generateCertKeyPair(t *testing.T, commonName string) (certPEM, keyPEM []byt IsCA: true, BasicConstraintsValid: true, } - // Self-sign the certificate derBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &key.PublicKey, key) assert.NoError(t, err) - // PEM encode the certificate and key - certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) - keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) - return certPEM, keyPEM + return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) }