Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions server/server/rpc/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,17 @@ type certLoader struct {
}

func (l *certLoader) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
stat, err := os.Stat(l.KeyFile)
stat, err := os.Stat(l.CertFile)
if err != nil {
l.lock.RLock()
existingCert := l.cachedCert
l.lock.RUnlock()

if existingCert == nil {
return nil, fmt.Errorf("statting tls key file: %w", err)
return nil, fmt.Errorf("statting tls cert file: %w", err)
}

log.Printf("unable to stat tls key file, returning cached cert which may expire: %s", err)
log.Printf("unable to stat tls cert file, returning cached cert which may expire: %s", err)
return existingCert, nil
}

Expand Down
106 changes: 64 additions & 42 deletions server/server/rpc/tls_cert_loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,58 +17,84 @@ import (
)

func TestCertLoader_ReloadsNewKeyPair(t *testing.T) {
// Use Go testing temporary directory for test files
dir := t.TempDir()
certPath := filepath.Join(dir, "cert.pem")
keyPath := filepath.Join(dir, "key.pem")
tests := []struct {
name string
reuseKey bool
}{
{
name: "regenerate both cert and key",
reuseKey: false,
},
{
name: "regenerate only cert with same key",
reuseKey: true,
},
}

// Write initial certificate and key
certPEM1, keyPEM1 := generateCertKeyPair(t, "initial")
assert.NoError(t, os.WriteFile(certPath, certPEM1, 0644))
assert.NoError(t, os.WriteFile(keyPath, keyPEM1, 0644))
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
certPath := filepath.Join(dir, "cert.pem")
keyPath := filepath.Join(dir, "key.pem")

// Initialize the loader
loader := &certLoader{CertFile: certPath, KeyFile: keyPath}
certPEM1, keyPEM1 := generateCertKeyPair(t, "initial")
assert.NoError(t, os.WriteFile(certPath, certPEM1, 0644))
assert.NoError(t, os.WriteFile(keyPath, keyPEM1, 0644))

// First load should read the initial pair
loaded1, err := loader.GetClientCertificate(nil)
assert.NoError(t, err)
assert.NotNil(t, loaded1)
loader := &certLoader{CertFile: certPath, KeyFile: keyPath}

// Compare against a direct X509KeyPair parse
expect1, err := tls.X509KeyPair(certPEM1, keyPEM1)
assert.NoError(t, err)
assert.Equal(t, expect1.Certificate, loaded1.Certificate)
loaded1, err := loader.GetClientCertificate(nil)
assert.NoError(t, err)
assert.NotNil(t, loaded1)

// Wait to ensure file modification time will differ
time.Sleep(500 * time.Millisecond)
expect1, err := tls.X509KeyPair(certPEM1, keyPEM1)
assert.NoError(t, err)
assert.Equal(t, expect1.Certificate, loaded1.Certificate)

// Overwrite with a new certificate and key
certPEM2, keyPEM2 := generateCertKeyPair(t, "updated")
assert.NoError(t, os.WriteFile(certPath, certPEM2, 0644))
assert.NoError(t, os.WriteFile(keyPath, keyPEM2, 0644))
time.Sleep(500 * time.Millisecond)

// Second load should pick up the updated pair
loaded2, err := loader.GetClientCertificate(nil)
assert.NoError(t, err)
assert.NotNil(t, loaded2)
var certPEM2, keyPEM2 []byte
if tt.reuseKey {
certPEM2 = generateCertForKey(t, "updated", keyPEM1)
keyPEM2 = keyPEM1
} else {
certPEM2, keyPEM2 = generateCertKeyPair(t, "updated")
}

expect2, err := tls.X509KeyPair(certPEM2, keyPEM2)
assert.NoError(t, err)
assert.NoError(t, os.WriteFile(certPath, certPEM2, 0644))

if !tt.reuseKey {
assert.NoError(t, os.WriteFile(keyPath, keyPEM2, 0644))
}

// Compare the loaded certificate with the expected one
assert.Equal(t, expect2.Certificate, loaded2.Certificate)
loaded2, err := loader.GetClientCertificate(nil)
assert.NoError(t, err)
assert.NotNil(t, loaded2)

// Ensure the loader did not return the old certificate
assert.NotEqual(t, expect1.Certificate, loaded2.Certificate)
expect2, err := tls.X509KeyPair(certPEM2, keyPEM2)
assert.NoError(t, err)

assert.Equal(t, expect2.Certificate, loaded2.Certificate)
assert.Equal(t, expect2.PrivateKey, loaded2.PrivateKey)
assert.NotEqual(t, expect1.Certificate, loaded2.Certificate)
})
}
}

// generateCertKeyPair creates a self-signed certificate and private key for testing.
func generateCertKeyPair(t *testing.T, commonName string) (certPEM, keyPEM []byte) {
// Generate a private key
key, err := rsa.GenerateKey(rand.Reader, 2048)
assert.NoError(t, err)
// Create a certificate template
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
certPEM = generateCertForKey(t, commonName, keyPEM)
return certPEM, keyPEM
}

func generateCertForKey(t *testing.T, commonName string, existingKeyPEM []byte) []byte {
block, _ := pem.Decode(existingKeyPEM)
assert.NotNil(t, block)
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
assert.NoError(t, err)

serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
assert.NoError(t, err)
Expand All @@ -82,11 +108,7 @@ func generateCertKeyPair(t *testing.T, commonName string) (certPEM, keyPEM []byt
IsCA: true,
BasicConstraintsValid: true,
}
// Self-sign the certificate
derBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &key.PublicKey, key)
assert.NoError(t, err)
// PEM encode the certificate and key
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
return certPEM, keyPEM
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
}
Loading