Skip to content

Commit

Permalink
Fix race condition in pkg/apiserver/certificate unit tests (antrea-io…
Browse files Browse the repository at this point in the history
…#6004)

There was some "interference" between TestSelfSignedCertProviderRotate
and TestSelfSignedCertProviderRun. The root cause is that the
certutil.GenerateSelfSignedCertKey does not support a custom clock
implementation and always calls time.Now() to determine the current
time. It then adds a year to the current time to set the expiration time
of the certificate. This means that when rotateSelfSignedCertificate()
is called as part of TestSelfSignedCertProviderRotate, the new
certificate is already expired, and rotateSelfSignedCertificate() will
be called immediately a second time. By this time however,
TestSelfSignedCertProviderRotate has already exited, and we are already
running the next test, TestSelfSignedCertProviderRun. This creates a
race condition because the next test will overwrite
generateSelfSignedCertKey with a mock version, right as it is called by
the second call to rotateSelfSignedCertificate() from the previous
test's provider.

To avoid this race condition, we make generateSelfSignedCertKey a member
of selfSignedCertProvider.

Fixes antrea-io#5977

Signed-off-by: Antonin Bas <[email protected]>
Co-authored-by: Quan Tian <[email protected]>
antoninbas and tnqn authored Feb 23, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 5789889 commit 56f2fb5
Showing 2 changed files with 44 additions and 29 deletions.
47 changes: 35 additions & 12 deletions pkg/apiserver/certificate/selfsignedcert_provider.go
Original file line number Diff line number Diff line change
@@ -44,11 +44,11 @@ import (
"antrea.io/antrea/pkg/util/env"
)

var (
loopbackAddresses = []net.IP{net.ParseIP("127.0.0.1"), net.IPv6loopback}
// Declared for unit testing.
generateSelfSignedCertKey = certutil.GenerateSelfSignedCertKey
)
var loopbackAddresses = []net.IP{net.ParseIP("127.0.0.1"), net.IPv6loopback}

// generateSelfSignedCertKeyFn represents a function which can create a self-signed certificate and
// key for the given host.
type generateSelfSignedCertKeyFn func(host string, alternateIPs []net.IP, alternateDNS []string) ([]byte, []byte, error)

type selfSignedCertProvider struct {
client kubernetes.Interface
@@ -69,23 +69,46 @@ type selfSignedCertProvider struct {
cert []byte
key []byte
verifyOptions *x509.VerifyOptions

// generateSelfSignedCertKey is the function used to generate self-signed certificates and keys.
// We use a struct member for unit testing.
generateSelfSignedCertKey generateSelfSignedCertKeyFn
}

var _ dynamiccertificates.CAContentProvider = &selfSignedCertProvider{}
var _ dynamiccertificates.ControllerRunner = &selfSignedCertProvider{}

func newSelfSignedCertProvider(client kubernetes.Interface, secureServing *options.SecureServingOptionsWithLoopback, caConfig *CAConfig) (*selfSignedCertProvider, error) {
type providerOption func(p *selfSignedCertProvider)

func withGenerateSelfSignedCertKeyFn(fn generateSelfSignedCertKeyFn) providerOption {
return func(p *selfSignedCertProvider) {
p.generateSelfSignedCertKey = fn
}
}

func withClock(clock clockutils.Clock) providerOption {
return func(p *selfSignedCertProvider) {
p.clock = clock
}
}

func newSelfSignedCertProvider(client kubernetes.Interface, secureServing *options.SecureServingOptionsWithLoopback, caConfig *CAConfig, options ...providerOption) (*selfSignedCertProvider, error) {
// Set the CertKey and CertDirectory to generate the certificate files.
secureServing.ServerCert.CertDirectory = caConfig.SelfSignedCertDir
secureServing.ServerCert.CertKey.CertFile = filepath.Join(caConfig.SelfSignedCertDir, caConfig.PairName+".crt")
secureServing.ServerCert.CertKey.KeyFile = filepath.Join(caConfig.SelfSignedCertDir, caConfig.PairName+".key")

provider := &selfSignedCertProvider{
client: client,
secureServing: secureServing,
caConfig: caConfig,
queue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "selfSignedCertProvider"),
clock: clockutils.RealClock{},
client: client,
secureServing: secureServing,
caConfig: caConfig,
queue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "selfSignedCertProvider"),
clock: clockutils.RealClock{},
generateSelfSignedCertKey: certutil.GenerateSelfSignedCertKey,
}

for _, option := range options {
option(provider)
}

if caConfig.TLSSecretName != "" {
@@ -233,7 +256,7 @@ func (p *selfSignedCertProvider) rotateSelfSignedCertificate() error {
}
if p.shouldRotateCertificate(cert) {
klog.InfoS("Generating self-signed cert")
if cert, key, err = generateSelfSignedCertKey(p.caConfig.ServiceName, loopbackAddresses, GetAntreaServerNames(p.caConfig.ServiceName)); err != nil {
if cert, key, err = p.generateSelfSignedCertKey(p.caConfig.ServiceName, loopbackAddresses, GetAntreaServerNames(p.caConfig.ServiceName)); err != nil {
return fmt.Errorf("unable to generate self-signed cert: %v", err)
}
// If Secret is specified, we should save the new certificate and key to it.
26 changes: 9 additions & 17 deletions pkg/apiserver/certificate/selfsignedcert_provider_test.go
Original file line number Diff line number Diff line change
@@ -48,7 +48,7 @@ var (
testOneYearCert3, testOneYearKey3, _ = certutil.GenerateSelfSignedCertKeyWithFixtures("localhost", loopbackAddresses, nil, "")
)

func newTestSelfSignedCertProvider(t *testing.T, client *fakeclientset.Clientset, tlsSecretName string, minValidDuration time.Duration) *selfSignedCertProvider {
func newTestSelfSignedCertProvider(t *testing.T, client *fakeclientset.Clientset, tlsSecretName string, minValidDuration time.Duration, options ...providerOption) *selfSignedCertProvider {
secureServing := genericoptions.NewSecureServingOptions().WithLoopback()
caConfig := &CAConfig{
TLSSecretName: tlsSecretName,
@@ -57,7 +57,7 @@ func newTestSelfSignedCertProvider(t *testing.T, client *fakeclientset.Clientset
ServiceName: testServiceName,
PairName: testPairName,
}
p, err := newSelfSignedCertProvider(client, secureServing, caConfig)
p, err := newSelfSignedCertProvider(client, secureServing, caConfig, options...)
require.NoError(t, err)
return p
}
@@ -107,8 +107,7 @@ func TestSelfSignedCertProviderRotate(t *testing.T) {
defer cancel()
client := fakeclientset.NewSimpleClientset()
fakeClock := clocktesting.NewFakeClock(time.Now())
p := newTestSelfSignedCertProvider(t, client, testSecretName, time.Hour*24*90)
p.clock = fakeClock
p := newTestSelfSignedCertProvider(t, client, testSecretName, time.Hour*24*90, withClock(fakeClock))
certInFile, err := os.ReadFile(p.secureServing.ServerCert.CertKey.CertFile)
require.NoError(t, err)
keyInFile, _ := os.ReadFile(p.secureServing.ServerCert.CertKey.KeyFile)
@@ -161,7 +160,7 @@ func TestSelfSignedCertProviderRotate(t *testing.T) {
assert.NotEqual(c, map[string][]byte{
corev1.TLSCertKey: testOneYearCert,
corev1.TLSPrivateKeyKey: testOneYearKey,
}, gotSecret.Data, "Secret doesn't match")
}, gotSecret.Data, "Secret should not match")
}, 2*time.Second, 50*time.Millisecond)
}

@@ -264,15 +263,18 @@ func TestSelfSignedCertProviderRun(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer mockGenerateSelfSignedCertKey(testOneYearCert2, testOneYearKey2)()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var objs []runtime.Object
if tt.existingSecret != nil {
objs = append(objs, tt.existingSecret)
}
client := fakeclientset.NewSimpleClientset(objs...)
p := newTestSelfSignedCertProvider(t, client, tt.tlsSecretName, tt.minValidDuration)
// mock the generateSelfSignedCertKey fuction
generateSelfSignedCertKey := func(_ string, _ []net.IP, _ []string) ([]byte, []byte, error) {
return testOneYearCert2, testOneYearKey2, nil
}
p := newTestSelfSignedCertProvider(t, client, tt.tlsSecretName, tt.minValidDuration, withGenerateSelfSignedCertKeyFn(generateSelfSignedCertKey))
go p.Run(ctx, 1)
if tt.updatedSecret != nil {
client.CoreV1().Secrets(tt.updatedSecret.Namespace).Update(ctx, tt.updatedSecret, metav1.UpdateOptions{})
@@ -291,13 +293,3 @@ func TestSelfSignedCertProviderRun(t *testing.T) {
})
}
}

func mockGenerateSelfSignedCertKey(cert, key []byte) func() {
originalFn := generateSelfSignedCertKey
generateSelfSignedCertKey = func(_ string, _ []net.IP, _ []string) ([]byte, []byte, error) {
return cert, key, nil
}
return func() {
generateSelfSignedCertKey = originalFn
}
}

0 comments on commit 56f2fb5

Please sign in to comment.