From 3dcad43bf4aa0c48620e78bc3594cd2e8763c601 Mon Sep 17 00:00:00 2001 From: "Kennelly, Martin" Date: Tue, 19 Jan 2021 12:22:05 +0000 Subject: [PATCH] Catch termination signal In order to catch termination signals, changes to how we handle cert/key watcher & HTTP server are needed. Refactor code to use goroutines for TLS cert/key watcher and HTTP server. Add Channel to safely manage signals from goroutines. Add interfaces to aid testing with mockery. Add tests to cover new code changes. Signed-off-by: Kennelly, Martin --- cmd/webhook/main.go | 119 ++++++------------------------ pkg/types/mocks/ClientCAPool.go | 44 +++++++++++ pkg/types/mocks/KeyReloader.go | 72 ++++++++++++++++++ pkg/types/mocks/Server.go | 42 +++++++++++ pkg/types/types.go | 30 ++++++++ pkg/webhook/channel.go | 92 +++++++++++++++++++++++ pkg/webhook/channel_test.go | 75 +++++++++++++++++++ pkg/webhook/server.go | 116 +++++++++++++++++++++++++++++ pkg/webhook/server_test.go | 80 ++++++++++++++++++++ pkg/webhook/tlsutils.go | 15 +++- pkg/webhook/watcher.go | 125 ++++++++++++++++++++++++++++++++ pkg/webhook/watcher_test.go | 107 +++++++++++++++++++++++++++ pkg/webhook/webhook.go | 63 +++++++++++++--- 13 files changed, 870 insertions(+), 110 deletions(-) create mode 100644 pkg/types/mocks/ClientCAPool.go create mode 100644 pkg/types/mocks/KeyReloader.go create mode 100644 pkg/types/mocks/Server.go create mode 100644 pkg/webhook/channel.go create mode 100644 pkg/webhook/channel_test.go create mode 100644 pkg/webhook/server.go create mode 100644 pkg/webhook/server_test.go create mode 100644 pkg/webhook/watcher.go create mode 100644 pkg/webhook/watcher_test.go diff --git a/cmd/webhook/main.go b/cmd/webhook/main.go index 695c7703..e6cf92e9 100644 --- a/cmd/webhook/main.go +++ b/cmd/webhook/main.go @@ -15,18 +15,21 @@ package main import ( - "crypto/tls" "flag" - "fmt" - "net/http" + "os" "time" - "github.com/fsnotify/fsnotify" "github.com/golang/glog" "github.com/intel/network-resources-injector/pkg/webhook" ) -const defaultClientCa = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" +const ( + defaultClientCa = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" + readTo = 5 * time.Second + writeTo = 10 * time.Second + readHeaderTo = 1 * time.Second + serviceTo = 2 * time.Second +) func main() { var clientCAPaths webhook.ClientCAFlags @@ -56,7 +59,7 @@ func main() { glog.Infof("starting mutating admission controller for network resources injection") - keyPair, err := webhook.NewTlsKeypairReloader(*cert, *key) + keyPair, err := webhook.NewTlsKeyPairReloader(*cert, *key) if err != nil { glog.Fatalf("error load certificate: %s", err.Error()) } @@ -78,97 +81,19 @@ func main() { glog.Fatalf("error in setting resource name keys: %s", err.Error()) } - go func() { - /* register handlers */ - var httpServer *http.Server - - http.HandleFunc("/mutate", func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/mutate" { - http.NotFound(w, r) - return - } - if r.Method != http.MethodPost { - http.Error(w, "Invalid HTTP verb requested", 405) - return - } - webhook.MutateHandler(w, r) - }) - - /* start serving */ - httpServer = &http.Server{ - Addr: fmt.Sprintf("%s:%d", *address, *port), - ReadTimeout: 5 * time.Second, - WriteTimeout: 10 * time.Second, - MaxHeaderBytes: 1 << 20, - ReadHeaderTimeout: 1 * time.Second, - TLSConfig: &tls.Config{ - ClientAuth: webhook.GetClientAuth(*insecure), - MinVersion: tls.VersionTLS12, - CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384}, - ClientCAs: clientCaPool.GetCertPool(), - PreferServerCipherSuites: true, - InsecureSkipVerify: false, - CipherSuites: []uint16{ - // tls 1.2 - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - // tls 1.3 configuration not supported - }, - GetCertificate: keyPair.GetCertificateFunc(), - }, - } - - err := httpServer.ListenAndServeTLS("", "") - if err != nil { - glog.Fatalf("error starting web server: %v", err) - } - }() - - /* watch the cert file and restart http sever if the file updated. */ - watcher, err := fsnotify.NewWatcher() - if err != nil { - glog.Fatalf("error starting fsnotify watcher: %v", err) + watcher := webhook.NewKeyPairWatcher(keyPair, serviceTo) + if err = watcher.Run(); err != nil { + glog.Fatalf("starting TLS key & cert file watcher failed: '%s'", err.Error()) } - defer watcher.Close() - - certUpdated := false - keyUpdated := false - - for { - watcher.Add(*cert) - watcher.Add(*key) - - select { - case event, ok := <-watcher.Events: - if !ok { - continue - } - glog.Infof("watcher event: %v", event) - mask := fsnotify.Create | fsnotify.Rename | fsnotify.Remove | - fsnotify.Write | fsnotify.Chmod - if (event.Op & mask) != 0 { - glog.Infof("modified file: %v", event.Name) - if event.Name == *cert { - certUpdated = true - } - if event.Name == *key { - keyUpdated = true - } - if keyUpdated && certUpdated { - if err := keyPair.Reload(); err != nil { - glog.Fatalf("Failed to reload certificate: %v", err) - } - certUpdated = false - keyUpdated = false - } - } - case err, ok := <-watcher.Errors: - if !ok { - continue - } - glog.Infof("watcher error: %v", err) - } + + server := webhook.NewMutateServer(*address, *port, *insecure, readTo, writeTo, readHeaderTo, serviceTo, clientCaPool, keyPair) + if err = server.Run(); err != nil { + watcher.Quit() + glog.Fatalf("starting HTTP server failed: '%s'", err) + } + + /* Blocks until termination or TLS key/cert file watcher or HTTP server signal occurs and stops HTTP server/file watcher */ + if err := webhook.Watch(server, watcher, make(chan os.Signal, 1)); err != nil { + glog.Error(err.Error()) } } diff --git a/pkg/types/mocks/ClientCAPool.go b/pkg/types/mocks/ClientCAPool.go new file mode 100644 index 00000000..3cdc2342 --- /dev/null +++ b/pkg/types/mocks/ClientCAPool.go @@ -0,0 +1,44 @@ +// Code generated by mockery v2.5.1. DO NOT EDIT. + +package mocks + +import ( + mock "github.com/stretchr/testify/mock" + + x509 "crypto/x509" +) + +// ClientCAPool is an autogenerated mock type for the ClientCAPool type +type ClientCAPool struct { + mock.Mock +} + +// GetCertPool provides a mock function with given fields: +func (_m *ClientCAPool) GetCertPool() *x509.CertPool { + ret := _m.Called() + + var r0 *x509.CertPool + if rf, ok := ret.Get(0).(func() *x509.CertPool); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*x509.CertPool) + } + } + + return r0 +} + +// Load provides a mock function with given fields: +func (_m *ClientCAPool) Load() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/pkg/types/mocks/KeyReloader.go b/pkg/types/mocks/KeyReloader.go new file mode 100644 index 00000000..2c35cc70 --- /dev/null +++ b/pkg/types/mocks/KeyReloader.go @@ -0,0 +1,72 @@ +// Code generated by mockery v2.5.1. DO NOT EDIT. + +package mocks + +import ( + tls "crypto/tls" + + mock "github.com/stretchr/testify/mock" +) + +// KeyReloader is an autogenerated mock type for the KeyReloader type +type KeyReloader struct { + mock.Mock +} + +// GetCertPath provides a mock function with given fields: +func (_m *KeyReloader) GetCertPath() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetCertificateFunc provides a mock function with given fields: +func (_m *KeyReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + ret := _m.Called() + + var r0 func(*tls.ClientHelloInfo) (*tls.Certificate, error) + if rf, ok := ret.Get(0).(func() func(*tls.ClientHelloInfo) (*tls.Certificate, error)); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(func(*tls.ClientHelloInfo) (*tls.Certificate, error)) + } + } + + return r0 +} + +// GetKeyPath provides a mock function with given fields: +func (_m *KeyReloader) GetKeyPath() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// Reload provides a mock function with given fields: +func (_m *KeyReloader) Reload() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/pkg/types/mocks/Server.go b/pkg/types/mocks/Server.go new file mode 100644 index 00000000..3fdf81c3 --- /dev/null +++ b/pkg/types/mocks/Server.go @@ -0,0 +1,42 @@ +// Code generated by mockery v2.5.1. DO NOT EDIT. + +package mocks + +import ( + time "time" + + mock "github.com/stretchr/testify/mock" +) + +// Server is an autogenerated mock type for the Server type +type Server struct { + mock.Mock +} + +// Start provides a mock function with given fields: +func (_m *Server) Start() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Stop provides a mock function with given fields: to +func (_m *Server) Stop(to time.Duration) error { + ret := _m.Called(to) + + var r0 error + if rf, ok := ret.Get(0).(func(time.Duration) error); ok { + r0 = rf(to) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/pkg/types/types.go b/pkg/types/types.go index a085c802..c665ec7f 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -14,6 +14,12 @@ package types +import ( + "crypto/tls" + "crypto/x509" + "time" +) + const ( DownwardAPIMountPath = "/etc/podnetinfo" AnnotationsPath = "annotations" @@ -24,3 +30,27 @@ const ( Hugepages1GLimitPath = "hugepages_1G_limit" Hugepages2MLimitPath = "hugepages_2M_limit" ) + +type ClientCAPool interface { + Load() error + GetCertPool() *x509.CertPool +} + +type KeyReloader interface { + Reload() error + GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) + GetKeyPath() string + GetCertPath() string +} + +//start and stop HTTP server - helps unit tests mocking of HTTP server +type Server interface { + Start() error + Stop(timeout time.Duration) error +} + +type Service interface { + Run() error + Quit() error + StatusSignal() chan struct{} +} diff --git a/pkg/webhook/channel.go b/pkg/webhook/channel.go new file mode 100644 index 00000000..0dec2a64 --- /dev/null +++ b/pkg/webhook/channel.go @@ -0,0 +1,92 @@ +package webhook + +import ( + "errors" + "fmt" + "sync" + "time" +) + +const ( + interval = time.Millisecond * 2 + chBuffer = 1 +) + +//Channel contains fields to safely manage a channel +type Channel struct { + ch chan struct{} + isOpen bool + mutex sync.Mutex +} + +//NewChannel returns an instance of type Channel +func NewChannel() *Channel { + return &Channel{} +} + +//Close checks if channel closed before trying to close +func (c *Channel) Close() { + c.mutex.Lock() + defer c.mutex.Unlock() + if c.isOpen { + close(c.ch) + c.isOpen = false + } +} + +//Open opens a channel +func (c *Channel) Open() { + c.mutex.Lock() + defer c.mutex.Unlock() + c.ch = make(chan struct{}, chBuffer) + c.isOpen = true +} + +//GetCh returns channel +func (c *Channel) GetCh() chan struct{} { + return c.ch +} + +//IsOpen checks if channel is open +func (c *Channel) IsOpen() bool { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.isOpen +} + +//IsClosed checks if channel closed +func (c *Channel) IsClosed() bool { + c.mutex.Lock() + defer c.mutex.Unlock() + return !c.isOpen +} + +//WaitUntilClosed will block until time limit is reached or channel is closed +func (c *Channel) WaitUntilClosed(limit time.Duration) error { + if interval > limit { + return errors.New("limit arg value too low") + } + tEnd := time.Now().Add(limit) + for tEnd.After(time.Now()) { + if c.IsClosed() { + return nil + } + time.Sleep(interval) + } + return fmt.Errorf("timed out after waiting '%s'", limit.String()) +} + +//WaitUntilOpened will block until time limit is reached or channel is opened +func (c *Channel) WaitUntilOpened(limit time.Duration) error { + if interval > limit { + return errors.New("limit arg value too low") + } + tEnd := time.Now().Add(limit) + for tEnd.After(time.Now()) { + if c.IsOpen() { + return nil + } + time.Sleep(interval) + } + return fmt.Errorf("timed out after waiting '%s'", limit.String()) +} diff --git a/pkg/webhook/channel_test.go b/pkg/webhook/channel_test.go new file mode 100644 index 00000000..f72aea57 --- /dev/null +++ b/pkg/webhook/channel_test.go @@ -0,0 +1,75 @@ +package webhook + +import ( + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Status", func() { + var ch *Channel + BeforeEach(func() { + ch = NewChannel() + }) + Context("NewChannel()", func() { + It("should be down by default", func() { + Expect(ch.IsClosed()).To(BeTrue()) + }) + }) + Context("Open()", func() { + It("should set channel open", func() { + ch.Open() + Expect(ch.IsOpen()).To(BeTrue()) + }) + }) + Context("Close()", func() { + It("should close channel", func() { + ch.Open() + ch.Close() + Expect(ch.IsOpen()).To(BeFalse()) + }) + }) + Context("IsOpen()", func() { + It("should return closed by default", func() { + Expect(ch.IsOpen()).To(BeFalse()) + }) + }) + Context("IsClosed()", func() { + It("should return true by default", func() { + ch.Open() + Expect(ch.IsClosed()).To(BeFalse()) + }) + }) + Context("WaitUntilClosed()", func() { + It("should return nil if channel close under limit", func() { + ch.Open() + ch.Close() + Expect(ch.WaitUntilClosed(interval)).To(BeNil()) + }) + It("should return error if channel is open after limit", func() { + ch.Open() + go func() { + time.Sleep(interval * 10) + ch.Close() + }() + Expect(ch.WaitUntilClosed(interval * 5).Error()).To(ContainSubstring("timed out")) + }) + }) + Context("WaitUntilOpened()", func() { + It("should return error if limit below min sleep interval time", func() { + Expect(ch.WaitUntilOpened(time.Nanosecond).Error()).To(ContainSubstring("limit arg value too low")) + }) + It("should return nil if channel open under limit", func() { + ch.Open() + Expect(ch.WaitUntilOpened(interval)).To(BeNil()) + }) + It("should return error if channel didnt open before interval limit", func() { + go func() { + time.Sleep(interval * 10) + ch.Open() + }() + Expect(ch.WaitUntilOpened(interval * 5).Error()).To(ContainSubstring("timed out")) + }) + }) +}) diff --git a/pkg/webhook/server.go b/pkg/webhook/server.go new file mode 100644 index 00000000..f04ccfb1 --- /dev/null +++ b/pkg/webhook/server.go @@ -0,0 +1,116 @@ +package webhook + +import ( + "context" + "crypto/tls" + "fmt" + "net/http" + "time" + + "github.com/golang/glog" + nri "github.com/intel/network-resources-injector/pkg/types" +) + +const ( + mServerStartupInterval = time.Millisecond * 50 + mServerEndpoint = "/mutate" +) + +type mutateServer struct { + instance nri.Server + timeout time.Duration + status *Channel +} + +//NewMutateServer generate a new server to serve endpoint /mutate. Server will only serve /mutate endpoint and POST +//HTTP verb. When arg insecure is false, it forces client certificate validation based on CAs in argument pool +//otherwise no client certificate validation is required. Various timeout args exist to prevent DOS. Arg keypair contains +//server TLS key/cert +func NewMutateServer(address string, port int, insecure bool, readT, writeT, readHT, to time.Duration, pool nri.ClientCAPool, + keyPair nri.KeyReloader) nri.Service { + if insecure { + glog.Warning("HTTP server is configured not to require client certificate") + } + srvAddr := fmt.Sprintf("%s:%d", address, port) + glog.Infof("HTTP server address and port: '%s'", srvAddr) + mux := http.NewServeMux() + mux.HandleFunc("/mutate", httpServerHandler) + + httpServer := &http.Server{ + Addr: srvAddr, + Handler: mux, + ReadTimeout: readT, + WriteTimeout: writeT, + MaxHeaderBytes: 1 << 20, + ReadHeaderTimeout: readHT, + TLSConfig: &tls.Config{ + ClientAuth: GetClientAuth(insecure), + MinVersion: tls.VersionTLS12, + CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384}, + ClientCAs: pool.GetCertPool(), + PreferServerCipherSuites: true, + InsecureSkipVerify: false, + CipherSuites: []uint16{ + // tls 1.2 + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + // tls 1.3 configuration not supported + }, + GetCertificate: keyPair.GetCertificateFunc(), + }, + } + return &mutateServer{&server{httpServer}, to, nil} +} + +//Run starts HTTP server in goroutine, waits a period of time and returns any potential errors from server start +func (mSrv *mutateServer) Run() error { + var httpSrvMsg error + glog.Info("starting HTTP server") + mSrv.status = NewChannel() + + go func() { + mSrv.status.Open() + defer mSrv.status.Close() + if httpSrvMsg = mSrv.instance.Start(); httpSrvMsg != nil && + httpSrvMsg != http.ErrServerClosed { + glog.Errorf("HTTP server message: '%s'", httpSrvMsg.Error()) + } + glog.Info("HTTP server finished") + }() + //give server time to start and return error if startup failed + time.Sleep(mServerStartupInterval) + return httpSrvMsg +} + +//Quit attempts to shutdown HTTP server and waits for HTTP server status channel to close +func (mSrv *mutateServer) Quit() error { + glog.Info("terminating HTTP server") + if err := mSrv.instance.Stop(mSrv.timeout); err != nil && err != http.ErrServerClosed { + return err + } + return mSrv.status.WaitUntilClosed(mSrv.timeout) +} + +//StatusSignal returns a channel which indicates whether mutate server has ended when channel closes +func (mSrv *mutateServer) StatusSignal() chan struct{} { + return mSrv.status.GetCh() +} + +type server struct { + httpServer *http.Server +} + +//Start wraps around package http ListenAndServeTLS and returns any error. Helps unit testing +func (srv *server) Start() error { + return srv.httpServer.ListenAndServeTLS("", "") +} + +//Stop wraps around package http Shutdown limited in time by timeout arg to and returns any error. Helps unit testing +func (srv *server) Stop(to time.Duration) error { + srv.httpServer.SetKeepAlivesEnabled(false) + serverCtx, cancel := context.WithTimeout(context.Background(), to) + defer cancel() + return srv.httpServer.Shutdown(serverCtx) +} diff --git a/pkg/webhook/server_test.go b/pkg/webhook/server_test.go new file mode 100644 index 00000000..0ea0d9f4 --- /dev/null +++ b/pkg/webhook/server_test.go @@ -0,0 +1,80 @@ +package webhook + +import ( + "errors" + "time" + + nriMocks "github.com/intel/network-resources-injector/pkg/types/mocks" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("mutate HTTP server", func() { + t := GinkgoT() + Describe("Service interface implementation for HTTP server", func() { + const to = time.Millisecond * 50 + var ( + mutateSrv *mutateServer + srvMock *nriMocks.Server + ) + BeforeEach(func() { + mutateSrv = &mutateServer{&nriMocks.Server{}, to, NewChannel()} + srvMock = &nriMocks.Server{} + mutateSrv.instance = srvMock + }) + Context("Run()", func() { + It("should start server", func() { + srvMock.On("Start").Return(nil) + Expect(mutateSrv.Run()).To(BeNil()) + srvMock.AssertCalled(t, "Start") + }) + It("should return error from server if startup generated an error", func() { + expErr := errors.New("bad start of server") + srvMock.On("Start").Return(expErr) + Expect(mutateSrv.Run()).To(Equal(expErr)) + }) + }) + Context("Quit()", func() { + It("should stop server", func() { + srvMock.On("Stop", to).Return(nil) + mutateSrv.status.Close() // Close in advance of call to ensure we do not get timeout error + Expect(mutateSrv.Quit()).To(BeNil()) + srvMock.AssertCalled(t, "Stop", to) + }) + It("should return error if shutdown generated an error", func() { + expErr := errors.New("bad stop of server") + srvMock.On("Stop", to).Return(expErr) + mutateSrv.status.Close() // Close in advance of call to ensure we do not get timeout error + Expect(mutateSrv.Quit()).To(Equal(expErr)) + }) + }) + }) + Describe("creation of new mutate server", func() { + const ( + address = "127.0.0.1" + port = 12345 + to = time.Millisecond * 2 + insecure = false + ) + var ( + pool *nriMocks.ClientCAPool + keyPair *nriMocks.KeyReloader + ) + BeforeEach(func() { + pool = &nriMocks.ClientCAPool{} + pool.On("GetCertPool").Return(nil) + keyPair = &nriMocks.KeyReloader{} + keyPair.On("GetCertificateFunc").Return(nil) + }) + Context("NewMutateServer()", func() { + It("should retrieve cert pool", func() { + NewMutateServer(address, port, insecure, to, to, to, to, pool, keyPair) + pool.AssertCalled(t, "GetCertPool") + }) + It("should retrieve certificate function", func() { + NewMutateServer(address, port, insecure, to, to, to, to, pool, keyPair) + keyPair.AssertCalled(t, "GetCertificateFunc") + }) + }) + }) +}) diff --git a/pkg/webhook/tlsutils.go b/pkg/webhook/tlsutils.go index 274dd3c0..74f67f83 100644 --- a/pkg/webhook/tlsutils.go +++ b/pkg/webhook/tlsutils.go @@ -22,6 +22,7 @@ import ( "sync" "github.com/golang/glog" + "github.com/intel/network-resources-injector/pkg/types" ) type tlsKeypairReloader struct { @@ -68,8 +69,16 @@ func (keyPair *tlsKeypairReloader) GetCertificateFunc() func(*tls.ClientHelloInf } } -// NewTlsKeypairReloader reload tlsKeypairReloader struct -func NewTlsKeypairReloader(certPath, keyPath string) (*tlsKeypairReloader, error) { +func (keyPair *tlsKeypairReloader) GetKeyPath() string { + return keyPair.keyPath +} + +func (keyPair *tlsKeypairReloader) GetCertPath() string { + return keyPair.certPath +} + +//NewTlsKeyPairReloader loads a cert and key +func NewTlsKeyPairReloader(certPath, keyPath string) (types.KeyReloader, error) { result := &tlsKeypairReloader{ certPath: certPath, keyPath: keyPath, @@ -84,7 +93,7 @@ func NewTlsKeypairReloader(certPath, keyPath string) (*tlsKeypairReloader, error } //NewClientCertPool will load a single client CA -func NewClientCertPool(clientCaPaths *ClientCAFlags, insecure bool) (*clientCertPool, error) { +func NewClientCertPool(clientCaPaths *ClientCAFlags, insecure bool) (types.ClientCAPool, error) { pool := &clientCertPool{ certPaths: clientCaPaths, insecure: insecure, diff --git a/pkg/webhook/watcher.go b/pkg/webhook/watcher.go new file mode 100644 index 00000000..dab6fc4a --- /dev/null +++ b/pkg/webhook/watcher.go @@ -0,0 +1,125 @@ +package webhook + +import ( + "errors" + "fmt" + "os" + "time" + + "github.com/fsnotify/fsnotify" + "github.com/golang/glog" + nri "github.com/intel/network-resources-injector/pkg/types" +) + +type keyPairWatcher struct { + status *Channel + quit *Channel + timeout time.Duration + keyPair nri.KeyReloader +} + +//NewKeyPairWatcher will create a new cert & key file watcher +func NewKeyPairWatcher(keyCert nri.KeyReloader, to time.Duration) nri.Service { + return &keyPairWatcher{nil, nil, to, keyCert} +} + +//Run checks if key & cert exist and start go routine to monitor these files. Quit must be called after Run. +func (kcw *keyPairWatcher) Run() error { + if kcw.status != nil && kcw.status.IsOpen() { + return errors.New("watcher must have exited before attempting to run again") + } + kcw.status = NewChannel() + kcw.quit = NewChannel() + cert := kcw.keyPair.GetCertPath() + key := kcw.keyPair.GetKeyPath() + + if cert == "" || key == "" { + return errors.New("cert and/or key path are not set") + } + if _, errStat := os.Stat(cert); os.IsNotExist(errStat) { + return fmt.Errorf("cert file does not exist at path '%s'", cert) + } + if _, errStat := os.Stat(key); os.IsNotExist(errStat) { + return fmt.Errorf("key file does not exist at path '%s'", key) + } + + go kcw.monitor() + + return kcw.status.WaitUntilOpened(kcw.timeout) +} + +//monitor key & cert files. Finish when quit signal received +func (kcw *keyPairWatcher) monitor() (err error) { + defer func() { + if err != nil { + glog.Error(err) + } + }() + glog.Info("starting TLS key and cert file watcher") + watcher, err := fsnotify.NewWatcher() + if err != nil { + return + } + defer watcher.Close() + + certUpdated := false + keyUpdated := false + watcher.Add(kcw.keyPair.GetCertPath()) + watcher.Add(kcw.keyPair.GetKeyPath()) + kcw.quit.Open() + kcw.status.Open() + defer kcw.status.Close() + + for { + select { + case event, ok := <-watcher.Events: + if !ok { + glog.Error("watcher event received but not OK") + continue + } + glog.Infof("watcher event: '%v'", event) + mask := fsnotify.Create | fsnotify.Rename | fsnotify.Remove | + fsnotify.Write | fsnotify.Chmod + if (event.Op & mask) != 0 { + glog.Infof("modified file: '%v'", event.Name) + if event.Name == kcw.keyPair.GetCertPath() { + certUpdated = true + } + if event.Name == kcw.keyPair.GetKeyPath() { + keyUpdated = true + } + if keyUpdated && certUpdated { + if errReload := kcw.keyPair.Reload(); errReload != nil { + err = fmt.Errorf("failed to reload certificate: '%v'", errReload) + return + } + certUpdated = false + keyUpdated = false + } + } + case watchErr, ok := <-watcher.Errors: + if !ok { + glog.Errorf("watcher error received but got error: '%s'", watchErr.Error()) + continue + } + err = fmt.Errorf("watcher error: '%s'", watchErr) + return + case <-kcw.quit.GetCh(): + glog.Info("TLS cert and key file watcher finished") + return + } + } +} + +//Quit attempts to terminate key/cert watcher go routine and blocks until it ends. Quit call follows Run call. Error +//only when timeout occurs while waiting for watcher to close +func (kcw *keyPairWatcher) Quit() error { + glog.Info("terminating TLS cert & key watcher") + kcw.quit.Close() + return kcw.status.WaitUntilClosed(kcw.timeout) +} + +//StatusSignal returns channel that indicates when key/cert watcher has ended. Channel will be closed if watcher ends +func (kcw *keyPairWatcher) StatusSignal() chan struct{} { + return kcw.status.GetCh() +} diff --git a/pkg/webhook/watcher_test.go b/pkg/webhook/watcher_test.go new file mode 100644 index 00000000..cb8266bf --- /dev/null +++ b/pkg/webhook/watcher_test.go @@ -0,0 +1,107 @@ +package webhook + +import ( + "errors" + "io/ioutil" + "os" + "time" + + nriMocks "github.com/intel/network-resources-injector/pkg/types/mocks" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("cert & key watcher", func() { + t := GinkgoT() + const ( + to = time.Millisecond * 10 + interval = to + keyFName = "nri-watcher-test-key" + certFName = "nri-watcher-test-cert" + TempDir = "/tmp" + ) + var ( + keyPair *nriMocks.KeyReloader + kcw *keyPairWatcher + certF *os.File + keyF *os.File + ) + BeforeEach(func() { + keyPair = &nriMocks.KeyReloader{} + certF, _ = ioutil.TempFile(TempDir, certFName) + keyF, _ = ioutil.TempFile(TempDir, keyFName) + kcw = &keyPairWatcher{nil, nil, to, keyPair} + keyPair.On("GetCertPath").Return(certF.Name()) + keyPair.On("GetKeyPath").Return(keyF.Name()) + }) + + AfterEach(func() { + kcw.Quit() + os.Remove(certF.Name()) + os.Remove(keyF.Name()) + }) + + Context("Run()", func() { + It("should retrieve cert and key path", func() { + keyPair.On("Reload").Return(nil) + kcw.Run() + keyPair.AssertCalled(t, "GetCertPath") + keyPair.AssertCalled(t, "GetKeyPath") + }) + It("should return error if cert doesn't exist", func() { + os.Remove(certF.Name()) + Expect(kcw.Run().Error()).To(ContainSubstring("cert file does not exist")) + }) + It("should return error if key doesn't exist", func() { + os.Remove(keyF.Name()) + Expect(kcw.Run().Error()).To(ContainSubstring("key file does not exist")) + }) + It("should not reload cert/key if only key is altered", func() { + kcw.Run() + os.Chtimes(certF.Name(), time.Now(), time.Now()) // touch file + time.Sleep(interval) // wait for Reload function to be possibly called + keyPair.AssertNotCalled(t, "Reload") + }) + It("should not reload cert/key if only cert is altered", func() { + kcw.Run() + os.Chtimes(keyF.Name(), time.Now(), time.Now()) // touch file + time.Sleep(interval) // wait for Reload function to be possibly called + keyPair.AssertNotCalled(t, "Reload") + }) + It("should reload cert/key if cert and key are altered", func() { + keyPair.On("Reload").Return(nil) + kcw.Run() + os.Chtimes(certF.Name(), time.Now(), time.Now()) // touch file + os.Chtimes(keyF.Name(), time.Now(), time.Now()) + time.Sleep(interval) // wait for Reload function to be called + keyPair.AssertExpectations(t) + }) + It("should terminate watcher when reload fails", func() { + keyPair.On("Reload").Return(errors.New("failed to reload keys")) + kcw.Run() + os.Chtimes(certF.Name(), time.Now(), time.Now()) // touch file + os.Chtimes(keyF.Name(), time.Now(), time.Now()) + time.Sleep(interval) // wait for Reload function to be called + Expect(kcw.status.IsOpen()).To(BeFalse()) + }) + It("should tolerate restart", func() { + kcw.Run() + kcw.Quit() + Expect(kcw.status.IsOpen()).To(BeFalse()) + kcw.Run() // restart + Expect(kcw.status.IsOpen()).To(BeTrue()) + kcw.Quit() + Expect(kcw.status.IsOpen()).To(BeFalse()) + }) + }) + + Context("Quit()", func() { + It("should terminate watcher", func() { + kcw.Run() + time.Sleep(interval) + Expect(kcw.status.IsOpen()).To(BeTrue()) // ensure it is running before test + Expect(kcw.Quit()).To(BeNil()) + Expect(kcw.status.IsClosed()).To(BeTrue()) + }) + }) +}) diff --git a/pkg/webhook/webhook.go b/pkg/webhook/webhook.go index db43d593..bb0b2ded 100644 --- a/pkg/webhook/webhook.go +++ b/pkg/webhook/webhook.go @@ -20,15 +20,17 @@ import ( "fmt" "io/ioutil" "net/http" + "os" + "os/signal" "regexp" "strings" + "syscall" "github.com/golang/glog" + nri "github.com/intel/network-resources-injector/pkg/types" cniv1 "github.com/k8snetworkplumbingwg/network-attachment-definition-client/pkg/apis/k8s.cni.cncf.io/v1" "github.com/pkg/errors" multus "gopkg.in/intel/multus-cni.v3/types" - - "github.com/intel/network-resources-injector/pkg/types" "k8s.io/api/admission/v1beta1" v1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" @@ -404,14 +406,14 @@ func addVolDownwardAPI(patch []jsonPatchOperation, hugepageResourceList []hugepa FieldPath: "metadata.labels", } dAPILabels := corev1.DownwardAPIVolumeFile{ - Path: types.LabelsPath, + Path: nri.LabelsPath, FieldRef: &labels, } annotations := corev1.ObjectFieldSelector{ FieldPath: "metadata.annotations", } dAPIAnnotations := corev1.DownwardAPIVolumeFile{ - Path: types.AnnotationsPath, + Path: nri.AnnotationsPath, FieldRef: &annotations, } dAPIItems := []corev1.DownwardAPIVolumeFile{dAPILabels, dAPIAnnotations} @@ -454,7 +456,7 @@ func addVolumeMount(patch []jsonPatchOperation) []jsonPatchOperation { vm := corev1.VolumeMount{ Name: "podnetinfo", ReadOnly: false, - MountPath: types.DownwardAPIMountPath, + MountPath: nri.DownwardAPIMountPath, } patch = append(patch, jsonPatchOperation{ @@ -724,7 +726,7 @@ func MutateHandler(w http.ResponseWriter, req *http.Request) { hugepageResource := hugepageResourceData{ ResourceName: "requests.hugepages-1Gi", ContainerName: container.Name, - Path: types.Hugepages1GRequestPath + "_" + container.Name, + Path: nri.Hugepages1GRequestPath + "_" + container.Name, } hugepageResourceList = append(hugepageResourceList, hugepageResource) found = true @@ -733,7 +735,7 @@ func MutateHandler(w http.ResponseWriter, req *http.Request) { hugepageResource := hugepageResourceData{ ResourceName: "requests.hugepages-2Mi", ContainerName: container.Name, - Path: types.Hugepages2MRequestPath + "_" + container.Name, + Path: nri.Hugepages2MRequestPath + "_" + container.Name, } hugepageResourceList = append(hugepageResourceList, hugepageResource) found = true @@ -744,7 +746,7 @@ func MutateHandler(w http.ResponseWriter, req *http.Request) { hugepageResource := hugepageResourceData{ ResourceName: "limits.hugepages-1Gi", ContainerName: container.Name, - Path: types.Hugepages1GLimitPath + "_" + container.Name, + Path: nri.Hugepages1GLimitPath + "_" + container.Name, } hugepageResourceList = append(hugepageResourceList, hugepageResource) found = true @@ -753,7 +755,7 @@ func MutateHandler(w http.ResponseWriter, req *http.Request) { hugepageResource := hugepageResourceData{ ResourceName: "limits.hugepages-2Mi", ContainerName: container.Name, - Path: types.Hugepages2MLimitPath + "_" + container.Name, + Path: nri.Hugepages2MLimitPath + "_" + container.Name, } hugepageResourceList = append(hugepageResourceList, hugepageResource) found = true @@ -765,7 +767,7 @@ func MutateHandler(w http.ResponseWriter, req *http.Request) { // so container knows its name and can process hugepages properly. if found { patch = createEnvPatch(patch, &container, containerIndex, - types.EnvNameContainerName, container.Name) + nri.EnvNameContainerName, container.Name) } } } @@ -832,3 +834,44 @@ func SetInjectHugepageDownApi(hugepageFlag bool) { func SetHonorExistingResources(resourcesHonorFlag bool) { honorExistingResources = resourcesHonorFlag } + +//Watch blocks until either TLS cert & key watcher or HTTP server or termination signal generated +func Watch(server nri.Service, watcher nri.Service, term chan os.Signal) (err error) { + signal.Notify(term, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + select { + case <-watcher.StatusSignal(): // when TLS cert & key watcher finishes unexpectedly + glog.Error("TLS key & cert watcher ended") + err = server.Quit() + case <-server.StatusSignal(): // when HTTP server finishes unexpectedly + glog.Error("HTTP server ended") + err = watcher.Quit() + case <-term: // when termination signal received + glog.Info("termination signal received") + err = combineError(server.Quit(), watcher.Quit()) + } + return +} + +//httpServerHandler limits HTTP server endpoint to /mutate and HTTP verb to POST only +func httpServerHandler(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != mServerEndpoint { + http.NotFound(w, r) + return + } + if r.Method != http.MethodPost { + http.Error(w, "Invalid HTTP verb requested", 405) + return + } + MutateHandler(w, r) +} + +//combineError combines two errors into one error message +func combineError(err1, err2 error) error { + if err1 != nil && err2 != nil { + return fmt.Errorf("two errors occured: 1) '%s' - 2) '%s'", err1.Error(), err2.Error()) + } + if err1 != nil { + return err1 + } + return err2 +}