diff --git a/lib/service/service.go b/lib/service/service.go index 59ab72a2c6911..96f763f8ee46e 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -157,6 +157,9 @@ const ( // and is ready to start accepting connections. ProxySSHReady = "ProxySSHReady" + // ProxyKubeReady is generated when the kubernetes proxy service has been initialized. + ProxyKubeReady = "ProxyKubeReady" + // NodeSSHReady is generated when the Teleport node has initialized a SSH server // and is ready to start accepting SSH connections. NodeSSHReady = "NodeReady" @@ -185,7 +188,7 @@ const ( // in a graceful way. TeleportReloadEvent = "TeleportReload" - // TeleportPhaseChangeEvent is generated to indidate that teleport + // TeleportPhaseChangeEvent is generated to indicate that teleport // CA rotation phase has been updated, used in tests TeleportPhaseChangeEvent = "TeleportPhaseChange" @@ -747,6 +750,9 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) { process.registerAppDepend() + // Produce global TeleportReadyEvent when all components have started + componentCount := process.registerTeleportReadyEvent(cfg) + process.log = cfg.Log.WithFields(logrus.Fields{ trace.Component: teleport.Component(teleport.ComponentProcess, process.id), }) @@ -754,7 +760,7 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) { serviceStarted := false if !cfg.DiagnosticAddr.IsEmpty() { - if err := process.initDiagnosticService(); err != nil { + if err := process.initDiagnosticService(componentCount); err != nil { return nil, trace.Wrap(err) } } else { @@ -773,9 +779,6 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) { cfg.Keygen = native.New(process.ExitContext(), native.PrecomputeKeys(precomputeCount)) } - // Produce global TeleportReadyEvent when all components have started - process.registerTeleportReadyEvent(cfg) - if cfg.Auth.Enabled { if err := process.initAuthService(); err != nil { return nil, trace.Wrap(err) @@ -2200,7 +2203,7 @@ func (process *TeleportProcess) initMetricsService() error { // initDiagnosticService starts diagnostic service currently serving healthz // and prometheus endpoints -func (process *TeleportProcess) initDiagnosticService() error { +func (process *TeleportProcess) initDiagnosticService(componentCount int) error { mux := http.NewServeMux() // support legacy metrics collection in the diagnostic service. @@ -2231,7 +2234,7 @@ func (process *TeleportProcess) initDiagnosticService() error { // Create a state machine that will process and update the internal state of // Teleport based off Events. Use this state machine to return return the // status from the /readyz endpoint. - ps, err := newProcessState(process) + ps, err := newProcessState(process, componentCount) if err != nil { return trace.Wrap(err) } @@ -3048,6 +3051,9 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { }) log.Infof("Starting Kube proxy on %v.", cfg.Proxy.Kube.ListenAddr.Addr) + // since kubeServer.Serve is a blocking call, we emit this event right before + // the service has started + process.BroadcastEvent(Event{Name: ProxyKubeReady, Payload: nil}) err := kubeServer.Serve(listeners.kube) if err != nil && err != http.ErrServerClosed { log.Warningf("Kube TLS server exited with error: %v.", err) @@ -3439,8 +3445,9 @@ func (process *TeleportProcess) waitForAppDepend() { } // registerTeleportReadyEvent ensures that a TeleportReadyEvent is produced -// when all components have started. -func (process *TeleportProcess) registerTeleportReadyEvent(cfg *Config) { +// when all components enabled (based on the configuration) have started. +// It returns the number of components enabled. +func (process *TeleportProcess) registerTeleportReadyEvent(cfg *Config) int { eventMapping := EventMapping{ Out: TeleportReadyEvent, } @@ -3453,9 +3460,13 @@ func (process *TeleportProcess) registerTeleportReadyEvent(cfg *Config) { eventMapping.In = append(eventMapping.In, NodeSSHReady) } - if cfg.Proxy.Enabled { + proxyConfig := cfg.Proxy + if proxyConfig.Enabled { eventMapping.In = append(eventMapping.In, ProxySSHReady) } + if proxyConfig.Kube.Enabled && !proxyConfig.Kube.ListenAddr.IsEmpty() && !proxyConfig.DisableReverseTunnel { + eventMapping.In = append(eventMapping.In, ProxyKubeReady) + } if cfg.Kube.Enabled { eventMapping.In = append(eventMapping.In, KubernetesReady) @@ -3473,7 +3484,9 @@ func (process *TeleportProcess) registerTeleportReadyEvent(cfg *Config) { eventMapping.In = append(eventMapping.In, WindowsDesktopReady) } + componentCount := len(eventMapping.In) process.RegisterEventMapping(eventMapping) + return componentCount } // appDependEvents is a list of events that the application service depends on. diff --git a/lib/service/service_test.go b/lib/service/service_test.go index 8c233409bb238..2ca3695eab64f 100644 --- a/lib/service/service_test.go +++ b/lib/service/service_test.go @@ -18,6 +18,7 @@ package service import ( "context" "fmt" + "net" "net/http" "os" "strings" @@ -44,6 +45,16 @@ import ( "github.com/stretchr/testify/require" ) +var ports utils.PortList + +func init() { + var err error + ports, err = utils.GetFreeTCPPorts(5, utils.PortStartingNumber) + if err != nil { + panic(fmt.Sprintf("failed to allocate tcp ports for tests: %v", err)) + } +} + func TestMain(m *testing.M) { utils.InitLoggerForTests() os.Exit(m.Run()) @@ -80,21 +91,29 @@ func TestServiceSelfSignedHTTPS(t *testing.T) { require.FileExists(t, cfg.Proxy.KeyPairs[0].PrivateKey) } -func TestMonitor(t *testing.T) { - t.Parallel() - fakeClock := clockwork.NewFakeClock() +type monitorTest struct { + desc string + event *Event + advanceClock time.Duration + wantStatus int +} +func testMonitor(t *testing.T, sshEnabled bool, tests []monitorTest) { + fakeClock := clockwork.NewFakeClock() cfg := MakeDefaultConfig() cfg.Clock = fakeClock var err error cfg.DataDir = t.TempDir() - cfg.DiagnosticAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"} - cfg.AuthServers = []utils.NetAddr{{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}} + cfg.DiagnosticAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("127.0.0.1", ports.Pop())} cfg.Auth.Enabled = true + cfg.Auth.SSHAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("127.0.0.1", ports.Pop())} + cfg.AuthServers = []utils.NetAddr{cfg.Auth.SSHAddr} cfg.Auth.StorageConfig.Params["path"] = t.TempDir() - cfg.Auth.SSHAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"} + if sshEnabled { + cfg.SSH.Enabled = true + cfg.SSH.Addr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("127.0.0.1", ports.Pop())} + } cfg.Proxy.Enabled = false - cfg.SSH.Enabled = false process, err := NewTeleport(cfg) require.NoError(t, err) @@ -111,65 +130,84 @@ func TestMonitor(t *testing.T) { err = waitForStatus(endpoint, http.StatusOK) require.NoError(t, err) - tests := []struct { - desc string - event Event - advanceClock time.Duration - wantStatus []int - }{ + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + fakeClock.Advance(tt.advanceClock) + if tt.event != nil { + process.BroadcastEvent(*tt.event) + } + err := waitForStatus(endpoint, tt.wantStatus) + require.NoError(t, err) + }) + } +} + +func TestMonitorOneComponent(t *testing.T) { + t.Parallel() + sshEnabled := false + tests := []monitorTest{ + { + desc: "it starts with OK state", + event: nil, + wantStatus: http.StatusOK, + }, { desc: "degraded event causes degraded state", - event: Event{Name: TeleportDegradedEvent, Payload: teleport.ComponentAuth}, - wantStatus: []int{http.StatusServiceUnavailable, http.StatusBadRequest}, + event: &Event{Name: TeleportDegradedEvent, Payload: teleport.ComponentAuth}, + wantStatus: http.StatusServiceUnavailable, }, { desc: "ok event causes recovering state", - event: Event{Name: TeleportOKEvent, Payload: teleport.ComponentAuth}, - wantStatus: []int{http.StatusBadRequest}, + event: &Event{Name: TeleportOKEvent, Payload: teleport.ComponentAuth}, + wantStatus: http.StatusBadRequest, }, { desc: "ok event remains in recovering state because not enough time passed", - event: Event{Name: TeleportOKEvent, Payload: teleport.ComponentAuth}, - wantStatus: []int{http.StatusBadRequest}, + event: &Event{Name: TeleportOKEvent, Payload: teleport.ComponentAuth}, + wantStatus: http.StatusBadRequest, }, { desc: "ok event after enough time causes OK state", - event: Event{Name: TeleportOKEvent, Payload: teleport.ComponentAuth}, + event: &Event{Name: TeleportOKEvent, Payload: teleport.ComponentAuth}, advanceClock: defaults.HeartbeatCheckPeriod*2 + 1, - wantStatus: []int{http.StatusOK}, + wantStatus: http.StatusOK, + }, + } + testMonitor(t, sshEnabled, tests) +} + +func TestMonitorTwoComponents(t *testing.T) { + t.Parallel() + sshEnabled := true + tests := []monitorTest{ + { + desc: "it starts with OK state", + event: nil, + wantStatus: http.StatusOK, }, { - desc: "degraded event in a new component causes degraded state", - event: Event{Name: TeleportDegradedEvent, Payload: teleport.ComponentNode}, - wantStatus: []int{http.StatusServiceUnavailable, http.StatusBadRequest}, + desc: "degraded event in one component causes degraded state", + event: &Event{Name: TeleportDegradedEvent, Payload: teleport.ComponentNode}, + wantStatus: http.StatusServiceUnavailable, }, { - desc: "ok event in one component keeps overall status degraded due to other component", - advanceClock: defaults.HeartbeatCheckPeriod*2 + 1, - event: Event{Name: TeleportOKEvent, Payload: teleport.ComponentAuth}, - wantStatus: []int{http.StatusServiceUnavailable, http.StatusBadRequest}, + desc: "ok event in ok component keeps overall status degraded due to degraded component", + event: &Event{Name: TeleportOKEvent, Payload: teleport.ComponentAuth}, + wantStatus: http.StatusServiceUnavailable, }, { - desc: "ok event in new component causes overall recovering state", - advanceClock: defaults.HeartbeatCheckPeriod*2 + 1, - event: Event{Name: TeleportOKEvent, Payload: teleport.ComponentNode}, - wantStatus: []int{http.StatusBadRequest}, + desc: "ok event in degraded component causes overall recovering state", + event: &Event{Name: TeleportOKEvent, Payload: teleport.ComponentNode}, + wantStatus: http.StatusBadRequest, }, { - desc: "ok event in new component causes overall OK state", + desc: "ok event after enough time causes overall OK state", advanceClock: defaults.HeartbeatCheckPeriod*2 + 1, - event: Event{Name: TeleportOKEvent, Payload: teleport.ComponentNode}, - wantStatus: []int{http.StatusOK}, + event: &Event{Name: TeleportOKEvent, Payload: teleport.ComponentNode}, + wantStatus: http.StatusOK, }, } - for _, tt := range tests { - t.Run(tt.desc, func(t *testing.T) { - fakeClock.Advance(tt.advanceClock) - process.BroadcastEvent(tt.event) - err = waitForStatus(endpoint, tt.wantStatus...) - require.NoError(t, err) - }) - } + testMonitor(t, sshEnabled, tests) } // TestServiceCheckPrincipals checks certificates regeneration only requests @@ -452,7 +490,7 @@ func TestDesktopAccessFIPS(t *testing.T) { require.Error(t, err) } -func waitForStatus(diagAddr string, statusCodes ...int) error { +func waitForStatus(diagAddr string, statusCode int) error { tickCh := time.Tick(100 * time.Millisecond) timeoutCh := time.After(10 * time.Second) var lastStatus int @@ -465,13 +503,11 @@ func waitForStatus(diagAddr string, statusCodes ...int) error { } resp.Body.Close() lastStatus = resp.StatusCode - for _, statusCode := range statusCodes { - if resp.StatusCode == statusCode { - return nil - } + if resp.StatusCode == statusCode { + return nil } case <-timeoutCh: - return trace.BadParameter("timeout waiting for status: %v; last status: %v", statusCodes, lastStatus) + return trace.BadParameter("timeout waiting for status: %v; last status: %v", statusCode, lastStatus) } } } diff --git a/lib/service/state.go b/lib/service/state.go index bc6e18685a87b..d42cf3e730dae 100644 --- a/lib/service/state.go +++ b/lib/service/state.go @@ -57,9 +57,10 @@ func init() { // processState tracks the state of the Teleport process. type processState struct { - process *TeleportProcess - mu sync.Mutex - states map[string]*componentState + process *TeleportProcess + mu sync.Mutex + states map[string]*componentState + totalComponentCount int // number of components that will send updates } type componentState struct { @@ -68,15 +69,16 @@ type componentState struct { } // newProcessState returns a new FSM that tracks the state of the Teleport process. -func newProcessState(process *TeleportProcess) (*processState, error) { +func newProcessState(process *TeleportProcess, componentCount int) (*processState, error) { err := utils.RegisterPrometheusCollectors(stateGauge) if err != nil { return nil, trace.Wrap(err) } return &processState{ - process: process, - states: make(map[string]*componentState), + process: process, + states: make(map[string]*componentState), + totalComponentCount: componentCount, }, nil } @@ -127,7 +129,7 @@ func (f *processState) update(event Event) { } // getStateLocked returns the overall process state based on the state of -// individual components. If no components sent updates yet, returns +// individual components. If not all components have sent updates yet, returns // stateStarting. // // Order of importance: @@ -138,8 +140,13 @@ func (f *processState) update(event Event) { // // Note: f.mu must be locked by the caller! func (f *processState) getStateLocked() componentStateEnum { + // Return stateStarting if not all components have sent updates yet. + if len(f.states) < f.totalComponentCount { + return stateStarting + } + state := stateStarting - numNotOK := len(f.states) + numOK := 0 for _, s := range f.states { switch s.state { case stateDegraded: @@ -147,12 +154,14 @@ func (f *processState) getStateLocked() componentStateEnum { case stateRecovering: state = stateRecovering case stateOK: - numNotOK-- + numOK++ } } // Only return stateOK if *all* components are in stateOK. - if numNotOK == 0 && len(f.states) > 0 { + if numOK == f.totalComponentCount { state = stateOK + } else if numOK > f.totalComponentCount { + f.process.log.Errorf("incorrect count of components (found: %d; expected: %d), this is a bug!", numOK, f.totalComponentCount) } return state } diff --git a/lib/service/state_test.go b/lib/service/state_test.go index 078a8354f0b54..32bbf2511b1cd 100644 --- a/lib/service/state_test.go +++ b/lib/service/state_test.go @@ -24,21 +24,24 @@ func TestProcessStateGetState(t *testing.T) { t.Parallel() tests := []struct { - desc string - states map[string]*componentState - want componentStateEnum + desc string + states map[string]*componentState + totalComponentCount int + want componentStateEnum }{ { - desc: "no components", - states: map[string]*componentState{}, - want: stateStarting, + desc: "no components", + states: map[string]*componentState{}, + totalComponentCount: 1, + want: stateStarting, }, { desc: "one component in stateOK", states: map[string]*componentState{ "one": {state: stateOK}, }, - want: stateOK, + totalComponentCount: 1, + want: stateOK, }, { desc: "multiple components in stateOK", @@ -47,7 +50,8 @@ func TestProcessStateGetState(t *testing.T) { "two": {state: stateOK}, "three": {state: stateOK}, }, - want: stateOK, + totalComponentCount: 3, + want: stateOK, }, { desc: "multiple components, one is degraded", @@ -56,7 +60,8 @@ func TestProcessStateGetState(t *testing.T) { "two": {state: stateDegraded}, "three": {state: stateOK}, }, - want: stateDegraded, + totalComponentCount: 3, + want: stateDegraded, }, { desc: "multiple components, one is recovering", @@ -65,7 +70,8 @@ func TestProcessStateGetState(t *testing.T) { "two": {state: stateRecovering}, "three": {state: stateOK}, }, - want: stateRecovering, + totalComponentCount: 3, + want: stateRecovering, }, { desc: "multiple components, one is starting", @@ -74,13 +80,14 @@ func TestProcessStateGetState(t *testing.T) { "two": {state: stateStarting}, "three": {state: stateOK}, }, - want: stateStarting, + totalComponentCount: 3, + want: stateStarting, }, } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - ps := &processState{states: tt.states} + ps := &processState{states: tt.states, totalComponentCount: tt.totalComponentCount} got := ps.getState() require.Equal(t, got, tt.want) }) diff --git a/lib/srv/heartbeat.go b/lib/srv/heartbeat.go index be87e502ba8ff..97936f526a6e2 100644 --- a/lib/srv/heartbeat.go +++ b/lib/srv/heartbeat.go @@ -152,7 +152,7 @@ func NewHeartbeat(cfg HeartbeatConfig) (*Heartbeat, error) { announceC: make(chan struct{}, 1), sendC: make(chan struct{}, 1), } - h.Debugf("Starting %v heartbeat with announce period: %v, keep-alive period %v, poll period: %v", cfg.Mode, cfg.KeepAlivePeriod, cfg.AnnouncePeriod, cfg.CheckPeriod) + h.Debugf("Starting %v heartbeat with announce period: %v, keep-alive period %v, poll period: %v", cfg.Mode, cfg.AnnouncePeriod, cfg.KeepAlivePeriod, cfg.CheckPeriod) return h, nil }