Skip to content

Commit

Permalink
refactor: separate cert operations and webhook cert manager (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
zoumo authored Dec 24, 2024
1 parent 131ac3d commit 8618d81
Show file tree
Hide file tree
Showing 12 changed files with 259 additions and 113 deletions.
14 changes: 12 additions & 2 deletions webhook/cert/cert.go → cert/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@ type (
AltNames = cert.AltNames
)

// ServingCerts is a set of serving certificates.
type ServingCerts struct {
Key []byte
Cert []byte
CAKey []byte
CACert []byte
}

// Validate checks if the serving certificates are valid for given host.
func (c *ServingCerts) Validate(host string) error {
if len(c.Key) == 0 {
return fmt.Errorf("private key is empty")
Expand Down Expand Up @@ -75,6 +77,7 @@ func (c *ServingCerts) Validate(host string) error {
return err
}

// GenerateSelfSignedCerts generates a self-signed certificate and key for the given host.
func GenerateSelfSignedCerts(cfg Config) (*ServingCerts, error) {
caKey, caCert, key, cert, err := generateSelfSignedCertKey(cfg)
if err != nil {
Expand All @@ -94,15 +97,22 @@ func GenerateSelfSignedCerts(cfg Config) (*ServingCerts, error) {
}, nil
}

// GenerateSelfSignedCertKeyIfNotExist generates a self-signed certificate and
// write them to the given path if not exist.
func GenerateSelfSignedCertKeyIfNotExist(path string, cfg cert.Config) error {
fscerts, err := NewFSProvider(path, FSOptions{})
fscerts, err := NewFSCertProvider(path, FSOptions{})
if err != nil {
return err
}
return fscerts.Ensure(context.Background(), cfg)
_, err = fscerts.Ensure(context.Background(), cfg)
return err
}

func generateSelfSignedCertKey(cfg Config) (*rsa.PrivateKey, *x509.Certificate, *rsa.PrivateKey, *x509.Certificate, error) {
if len(cfg.CommonName) == 0 {
return nil, nil, nil, nil, fmt.Errorf("common name is empty")
}

caKey, err := certutil.NewRSAPrivateKey()
if err != nil {
return nil, nil, nil, nil, err
Expand Down
File renamed without changes.
2 changes: 2 additions & 0 deletions webhook/cert/error.go → cert/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ func newNotFound(name string, err error) error {
return fmt.Errorf("%s %w: %v", name, errNotFound, err)
}

// IsNotFound returns true if certificate not found.
func IsNotFound(err error) bool {
return apierrors.IsNotFound(err) || errors.Is(err, errNotFound)
}

// IsConflict returns true if certificate is already exist.
func IsConflict(err error) bool {
return apierrors.IsAlreadyExists(err) || apierrors.IsConflict(err)
}
File renamed without changes.
43 changes: 23 additions & 20 deletions webhook/cert/fs.go → cert/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,7 @@ import (
"k8s.io/klog/v2"
)

type FSProvider struct {
FSOptions
path string
}

// FSOptions is the options for FSCertProvider.
type FSOptions struct {
FS afero.Fs
CertName string
Expand All @@ -58,49 +54,56 @@ func (o *FSOptions) setDefaults() {
}
}

func NewFSProvider(path string, opts FSOptions) (*FSProvider, error) {
// FSCertProvider is a CertProvider that stores certificates on the local filesystem.
type FSCertProvider struct {
FSOptions
path string
}

// NewFSCertProvider creates a new FSCertProvider.
func NewFSCertProvider(path string, opts FSOptions) (*FSCertProvider, error) {
opts.setDefaults()

if len(path) == 0 {
return nil, fmt.Errorf("cert path is required")
}

return &FSProvider{
return &FSCertProvider{
path: path,
FSOptions: opts,
}, nil
}

func (p *FSProvider) Ensure(_ context.Context, cfg Config) error {
certs, err := p.Load()
func (p *FSCertProvider) Ensure(ctx context.Context, cfg Config) (*ServingCerts, error) {
certs, err := p.Load(ctx)
if err != nil && !IsNotFound(err) {
return err
return nil, err
}

if IsNotFound(err) {
certs, err = GenerateSelfSignedCerts(cfg)
if err != nil {
return err
return nil, err
}
_, err = p.Overwrite(certs)
return err
return certs, err
}

err = certs.Validate(cfg.CommonName)
if err != nil {
// re-generate if expired or invalid
klog.Info("certificates are invalid, regenerating...")
certs, err := GenerateSelfSignedCerts(cfg)
certs, err = GenerateSelfSignedCerts(cfg)
if err != nil {
return err
return nil, err
}
_, err = p.Overwrite(certs)
return err
return certs, err
}
return nil
return certs, nil
}

func (p *FSProvider) checkIfExist() error {
func (p *FSCertProvider) checkIfExist() error {
files := []string{
path.Join(p.path, p.KeyName),
path.Join(p.path, p.CertName),
Expand All @@ -121,7 +124,7 @@ func (p *FSProvider) checkIfExist() error {
return nil
}

func (p *FSProvider) Load() (*ServingCerts, error) {
func (p *FSCertProvider) Load(_ context.Context) (*ServingCerts, error) {
err := p.checkIfExist()
if err != nil {
return nil, err
Expand Down Expand Up @@ -154,7 +157,7 @@ func (p *FSProvider) Load() (*ServingCerts, error) {
return certs, nil
}

func (p *FSProvider) Overwrite(certs *ServingCerts) (bool, error) {
func (p *FSCertProvider) Overwrite(certs *ServingCerts) (bool, error) {
if certs == nil {
return false, fmt.Errorf("certs are required")
}
Expand Down Expand Up @@ -206,7 +209,7 @@ func (p *FSProvider) Overwrite(certs *ServingCerts) (bool, error) {
return updated, nil
}

func (p *FSProvider) writeFile(path string, data []byte) (bool, error) {
func (p *FSCertProvider) writeFile(path string, data []byte) (bool, error) {
_, err := p.FS.Stat(path)
if err != nil && !os.IsNotExist(err) {
return false, err
Expand Down
45 changes: 45 additions & 0 deletions cert/fs_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/**
* Copyright 2024 The KusionStack Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cert

import (
"context"
"testing"

"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
)

func TestFSProvider_Ensure(t *testing.T) {
dir := "/serving/cert"
fs := afero.NewMemMapFs()
provider, _ := NewFSCertProvider(dir, FSOptions{
FS: fs,
})

domains := []string{"one.kusionstack.io", "two.kusionstack.io"}
for _, domain := range domains {
certs, err := provider.Ensure(context.Background(), Config{CommonName: domain})
assert.NoError(t, err)
certs.Validate(domain)
assert.NotNil(t, certs)
certs, err = provider.Load(context.Background())
assert.NoError(t, err)
certs.Validate(domain)
assert.NotNil(t, certs)
}
}
73 changes: 60 additions & 13 deletions webhook/cert/secret.go → cert/secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ import (
"context"
"fmt"

"github.com/go-logr/logr"
corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/klog/v2"
"sigs.k8s.io/controller-runtime/pkg/client"
)
Expand All @@ -34,30 +36,76 @@ const (
TLSCAPrivateKeyKey = "ca.key"
)

type SecretProvider struct {
client SecretClient
namespace string
name string
}

// SecretClient is a client wrapper for secret operations.
type SecretClient interface {
Get(ctx context.Context, namespace string, name string) (*corev1.Secret, error)
Create(ctx context.Context, secret *corev1.Secret) error
Update(ctx context.Context, secret *corev1.Secret) error
}

func NewSecretProvider(client SecretClient, namespace, name string) (*SecretProvider, error) {
var _ SecretClient = &secretClient{}

type secretClient struct {
reader client.Reader
writer client.Writer
}

func NewSecretClient(reader client.Reader, writer client.Writer) SecretClient {
return &secretClient{
reader: reader,
writer: writer,
}
}

// Create implements SecretClient.
func (s *secretClient) Create(ctx context.Context, secret *corev1.Secret) error {
err := s.writer.Create(ctx, secret)
if err == nil {
logger := logr.FromContextOrDiscard(ctx)
logger.Info("create secret successfully", "namespace", secret.Namespace, "name", secret.Name)
}
return err
}

// Get implements SecretClient.
func (s *secretClient) Get(ctx context.Context, namespace string, name string) (*corev1.Secret, error) {
var secret corev1.Secret
err := s.reader.Get(ctx, types.NamespacedName{Namespace: namespace, Name: name}, &secret)
if err != nil {
return nil, err
}
return &secret, nil
}

// Update implements SecretClient.
func (s *secretClient) Update(ctx context.Context, secret *corev1.Secret) error {
err := s.writer.Update(ctx, secret)
if err == nil {
logger := logr.FromContextOrDiscard(ctx)
logger.Info("update secret successfully", "namespace", secret.Namespace, "name", secret.Name)
}
return err
}

// SecretCertProvider is a provider for operating certs in k8s secret.
type SecretCertProvider struct {
client SecretClient
namespace string
name string
}

func NewSecretCertProvider(client SecretClient, namespace, name string) (*SecretCertProvider, error) {
if client == nil {
return nil, fmt.Errorf("secret client must not be nil")
}
return &SecretProvider{
return &SecretCertProvider{
client: client,
namespace: namespace,
name: name,
}, nil
}

func (p *SecretProvider) Ensure(ctx context.Context, cfg Config) (*ServingCerts, error) {
func (p *SecretCertProvider) Ensure(ctx context.Context, cfg Config) (*ServingCerts, error) {
certs, err := p.Load(ctx)
if err != nil && !IsNotFound(err) {
return nil, err
Expand Down Expand Up @@ -91,16 +139,15 @@ func (p *SecretProvider) Ensure(ctx context.Context, cfg Config) (*ServingCerts,
return certs, nil
}

func (p *SecretProvider) Load(ctx context.Context) (*ServingCerts, error) {
func (p *SecretCertProvider) Load(ctx context.Context) (*ServingCerts, error) {
secret, err := p.client.Get(ctx, p.namespace, p.name)
if err != nil {
return nil, err
}

return convertSecretToCerts(secret), nil
}

func (p *SecretProvider) create(ctx context.Context, certs *ServingCerts) error {
func (p *SecretCertProvider) create(ctx context.Context, certs *ServingCerts) error {
if certs == nil {
return fmt.Errorf("certs are required")
}
Expand All @@ -118,7 +165,7 @@ func (p *SecretProvider) create(ctx context.Context, certs *ServingCerts) error
return p.client.Create(ctx, secret)
}

func (p *SecretProvider) overwrite(ctx context.Context, certs *ServingCerts) error {
func (p *SecretCertProvider) overwrite(ctx context.Context, certs *ServingCerts) error {
if certs == nil {
return fmt.Errorf("certs are required")
}
Expand Down
Loading

0 comments on commit 8618d81

Please sign in to comment.