Skip to content

Commit

Permalink
Ensure stateOK is reported only when all components have sent updates (
Browse files Browse the repository at this point in the history
…#11249)

Fixes #11065.

This commit:
- ensures  that `TeleportReadyEvent` is only produced when all components that send heartbeats (i.e. call [`process.onHeartbeat`](https://github.com/gravitational/teleport/blob/16bf416556f337b045b66dc9c3f5a3e16f8cc988/lib/service/service.go#L358-L366)) are ready
- changes `TeleportProcess.registerTeleportReadyEvent` so that it returns a count of these components (let's call it `componentCount`)
- uses `componentCount` to also ensure that `stateOK` is only reported when all the components have sent their heartbeat, thus fixing #11065

Since it seems difficult to know when `TeleportProcess.registerTeleportReadyEvent` should be updated, with the goal of quickly detecting a bug when it's introduced we have that:
1. if `componentCount` is lower than it should, then the service fails to start (due to #11725)
2. if `componentCount` is higher than it should, then an error is logged in function `processState.getStateLocked`.
  • Loading branch information
Vitor Enes authored Apr 7, 2022
1 parent 32cb76e commit b749302
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 82 deletions.
33 changes: 23 additions & 10 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ const (
// and is ready to start accepting connections.
ProxySSHReady = "ProxySSHReady"

// ProxyKubeReady is generated when the kubernetes proxy service has been initialized.
ProxyKubeReady = "ProxyKubeReady"

// NodeSSHReady is generated when the Teleport node has initialized a SSH server
// and is ready to start accepting SSH connections.
NodeSSHReady = "NodeReady"
Expand Down Expand Up @@ -185,7 +188,7 @@ const (
// in a graceful way.
TeleportReloadEvent = "TeleportReload"

// TeleportPhaseChangeEvent is generated to indidate that teleport
// TeleportPhaseChangeEvent is generated to indicate that teleport
// CA rotation phase has been updated, used in tests
TeleportPhaseChangeEvent = "TeleportPhaseChange"

Expand Down Expand Up @@ -747,14 +750,17 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) {

process.registerAppDepend()

// Produce global TeleportReadyEvent when all components have started
componentCount := process.registerTeleportReadyEvent(cfg)

process.log = cfg.Log.WithFields(logrus.Fields{
trace.Component: teleport.Component(teleport.ComponentProcess, process.id),
})

serviceStarted := false

if !cfg.DiagnosticAddr.IsEmpty() {
if err := process.initDiagnosticService(); err != nil {
if err := process.initDiagnosticService(componentCount); err != nil {
return nil, trace.Wrap(err)
}
} else {
Expand All @@ -773,9 +779,6 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) {
cfg.Keygen = native.New(process.ExitContext(), native.PrecomputeKeys(precomputeCount))
}

// Produce global TeleportReadyEvent when all components have started
process.registerTeleportReadyEvent(cfg)

if cfg.Auth.Enabled {
if err := process.initAuthService(); err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -2200,7 +2203,7 @@ func (process *TeleportProcess) initMetricsService() error {

// initDiagnosticService starts diagnostic service currently serving healthz
// and prometheus endpoints
func (process *TeleportProcess) initDiagnosticService() error {
func (process *TeleportProcess) initDiagnosticService(componentCount int) error {
mux := http.NewServeMux()

// support legacy metrics collection in the diagnostic service.
Expand Down Expand Up @@ -2231,7 +2234,7 @@ func (process *TeleportProcess) initDiagnosticService() error {
// Create a state machine that will process and update the internal state of
// Teleport based off Events. Use this state machine to return return the
// status from the /readyz endpoint.
ps, err := newProcessState(process)
ps, err := newProcessState(process, componentCount)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -3048,6 +3051,9 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
})

log.Infof("Starting Kube proxy on %v.", cfg.Proxy.Kube.ListenAddr.Addr)
// since kubeServer.Serve is a blocking call, we emit this event right before
// the service has started
process.BroadcastEvent(Event{Name: ProxyKubeReady, Payload: nil})
err := kubeServer.Serve(listeners.kube)
if err != nil && err != http.ErrServerClosed {
log.Warningf("Kube TLS server exited with error: %v.", err)
Expand Down Expand Up @@ -3439,8 +3445,9 @@ func (process *TeleportProcess) waitForAppDepend() {
}

// registerTeleportReadyEvent ensures that a TeleportReadyEvent is produced
// when all components have started.
func (process *TeleportProcess) registerTeleportReadyEvent(cfg *Config) {
// when all components enabled (based on the configuration) have started.
// It returns the number of components enabled.
func (process *TeleportProcess) registerTeleportReadyEvent(cfg *Config) int {
eventMapping := EventMapping{
Out: TeleportReadyEvent,
}
Expand All @@ -3453,9 +3460,13 @@ func (process *TeleportProcess) registerTeleportReadyEvent(cfg *Config) {
eventMapping.In = append(eventMapping.In, NodeSSHReady)
}

if cfg.Proxy.Enabled {
proxyConfig := cfg.Proxy
if proxyConfig.Enabled {
eventMapping.In = append(eventMapping.In, ProxySSHReady)
}
if proxyConfig.Kube.Enabled && !proxyConfig.Kube.ListenAddr.IsEmpty() && !proxyConfig.DisableReverseTunnel {
eventMapping.In = append(eventMapping.In, ProxyKubeReady)
}

if cfg.Kube.Enabled {
eventMapping.In = append(eventMapping.In, KubernetesReady)
Expand All @@ -3473,7 +3484,9 @@ func (process *TeleportProcess) registerTeleportReadyEvent(cfg *Config) {
eventMapping.In = append(eventMapping.In, WindowsDesktopReady)
}

componentCount := len(eventMapping.In)
process.RegisterEventMapping(eventMapping)
return componentCount
}

// appDependEvents is a list of events that the application service depends on.
Expand Down
134 changes: 85 additions & 49 deletions lib/service/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package service
import (
"context"
"fmt"
"net"
"net/http"
"os"
"strings"
Expand All @@ -44,6 +45,16 @@ import (
"github.com/stretchr/testify/require"
)

var ports utils.PortList

func init() {
var err error
ports, err = utils.GetFreeTCPPorts(5, utils.PortStartingNumber)
if err != nil {
panic(fmt.Sprintf("failed to allocate tcp ports for tests: %v", err))
}
}

func TestMain(m *testing.M) {
utils.InitLoggerForTests()
os.Exit(m.Run())
Expand Down Expand Up @@ -80,21 +91,29 @@ func TestServiceSelfSignedHTTPS(t *testing.T) {
require.FileExists(t, cfg.Proxy.KeyPairs[0].PrivateKey)
}

func TestMonitor(t *testing.T) {
t.Parallel()
fakeClock := clockwork.NewFakeClock()
type monitorTest struct {
desc string
event *Event
advanceClock time.Duration
wantStatus int
}

func testMonitor(t *testing.T, sshEnabled bool, tests []monitorTest) {
fakeClock := clockwork.NewFakeClock()
cfg := MakeDefaultConfig()
cfg.Clock = fakeClock
var err error
cfg.DataDir = t.TempDir()
cfg.DiagnosticAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}
cfg.AuthServers = []utils.NetAddr{{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}}
cfg.DiagnosticAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("127.0.0.1", ports.Pop())}
cfg.Auth.Enabled = true
cfg.Auth.SSHAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("127.0.0.1", ports.Pop())}
cfg.AuthServers = []utils.NetAddr{cfg.Auth.SSHAddr}
cfg.Auth.StorageConfig.Params["path"] = t.TempDir()
cfg.Auth.SSHAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}
if sshEnabled {
cfg.SSH.Enabled = true
cfg.SSH.Addr = utils.NetAddr{AddrNetwork: "tcp", Addr: net.JoinHostPort("127.0.0.1", ports.Pop())}
}
cfg.Proxy.Enabled = false
cfg.SSH.Enabled = false

process, err := NewTeleport(cfg)
require.NoError(t, err)
Expand All @@ -111,65 +130,84 @@ func TestMonitor(t *testing.T) {
err = waitForStatus(endpoint, http.StatusOK)
require.NoError(t, err)

tests := []struct {
desc string
event Event
advanceClock time.Duration
wantStatus []int
}{
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
fakeClock.Advance(tt.advanceClock)
if tt.event != nil {
process.BroadcastEvent(*tt.event)
}
err := waitForStatus(endpoint, tt.wantStatus)
require.NoError(t, err)
})
}
}

func TestMonitorOneComponent(t *testing.T) {
t.Parallel()
sshEnabled := false
tests := []monitorTest{
{
desc: "it starts with OK state",
event: nil,
wantStatus: http.StatusOK,
},
{
desc: "degraded event causes degraded state",
event: Event{Name: TeleportDegradedEvent, Payload: teleport.ComponentAuth},
wantStatus: []int{http.StatusServiceUnavailable, http.StatusBadRequest},
event: &Event{Name: TeleportDegradedEvent, Payload: teleport.ComponentAuth},
wantStatus: http.StatusServiceUnavailable,
},
{
desc: "ok event causes recovering state",
event: Event{Name: TeleportOKEvent, Payload: teleport.ComponentAuth},
wantStatus: []int{http.StatusBadRequest},
event: &Event{Name: TeleportOKEvent, Payload: teleport.ComponentAuth},
wantStatus: http.StatusBadRequest,
},
{
desc: "ok event remains in recovering state because not enough time passed",
event: Event{Name: TeleportOKEvent, Payload: teleport.ComponentAuth},
wantStatus: []int{http.StatusBadRequest},
event: &Event{Name: TeleportOKEvent, Payload: teleport.ComponentAuth},
wantStatus: http.StatusBadRequest,
},
{
desc: "ok event after enough time causes OK state",
event: Event{Name: TeleportOKEvent, Payload: teleport.ComponentAuth},
event: &Event{Name: TeleportOKEvent, Payload: teleport.ComponentAuth},
advanceClock: defaults.HeartbeatCheckPeriod*2 + 1,
wantStatus: []int{http.StatusOK},
wantStatus: http.StatusOK,
},
}
testMonitor(t, sshEnabled, tests)
}

func TestMonitorTwoComponents(t *testing.T) {
t.Parallel()
sshEnabled := true
tests := []monitorTest{
{
desc: "it starts with OK state",
event: nil,
wantStatus: http.StatusOK,
},
{
desc: "degraded event in a new component causes degraded state",
event: Event{Name: TeleportDegradedEvent, Payload: teleport.ComponentNode},
wantStatus: []int{http.StatusServiceUnavailable, http.StatusBadRequest},
desc: "degraded event in one component causes degraded state",
event: &Event{Name: TeleportDegradedEvent, Payload: teleport.ComponentNode},
wantStatus: http.StatusServiceUnavailable,
},
{
desc: "ok event in one component keeps overall status degraded due to other component",
advanceClock: defaults.HeartbeatCheckPeriod*2 + 1,
event: Event{Name: TeleportOKEvent, Payload: teleport.ComponentAuth},
wantStatus: []int{http.StatusServiceUnavailable, http.StatusBadRequest},
desc: "ok event in ok component keeps overall status degraded due to degraded component",
event: &Event{Name: TeleportOKEvent, Payload: teleport.ComponentAuth},
wantStatus: http.StatusServiceUnavailable,
},
{
desc: "ok event in new component causes overall recovering state",
advanceClock: defaults.HeartbeatCheckPeriod*2 + 1,
event: Event{Name: TeleportOKEvent, Payload: teleport.ComponentNode},
wantStatus: []int{http.StatusBadRequest},
desc: "ok event in degraded component causes overall recovering state",
event: &Event{Name: TeleportOKEvent, Payload: teleport.ComponentNode},
wantStatus: http.StatusBadRequest,
},
{
desc: "ok event in new component causes overall OK state",
desc: "ok event after enough time causes overall OK state",
advanceClock: defaults.HeartbeatCheckPeriod*2 + 1,
event: Event{Name: TeleportOKEvent, Payload: teleport.ComponentNode},
wantStatus: []int{http.StatusOK},
event: &Event{Name: TeleportOKEvent, Payload: teleport.ComponentNode},
wantStatus: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
fakeClock.Advance(tt.advanceClock)
process.BroadcastEvent(tt.event)
err = waitForStatus(endpoint, tt.wantStatus...)
require.NoError(t, err)
})
}
testMonitor(t, sshEnabled, tests)
}

// TestServiceCheckPrincipals checks certificates regeneration only requests
Expand Down Expand Up @@ -452,7 +490,7 @@ func TestDesktopAccessFIPS(t *testing.T) {
require.Error(t, err)
}

func waitForStatus(diagAddr string, statusCodes ...int) error {
func waitForStatus(diagAddr string, statusCode int) error {
tickCh := time.Tick(100 * time.Millisecond)
timeoutCh := time.After(10 * time.Second)
var lastStatus int
Expand All @@ -465,13 +503,11 @@ func waitForStatus(diagAddr string, statusCodes ...int) error {
}
resp.Body.Close()
lastStatus = resp.StatusCode
for _, statusCode := range statusCodes {
if resp.StatusCode == statusCode {
return nil
}
if resp.StatusCode == statusCode {
return nil
}
case <-timeoutCh:
return trace.BadParameter("timeout waiting for status: %v; last status: %v", statusCodes, lastStatus)
return trace.BadParameter("timeout waiting for status: %v; last status: %v", statusCode, lastStatus)
}
}
}
Expand Down
29 changes: 19 additions & 10 deletions lib/service/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@ func init() {

// processState tracks the state of the Teleport process.
type processState struct {
process *TeleportProcess
mu sync.Mutex
states map[string]*componentState
process *TeleportProcess
mu sync.Mutex
states map[string]*componentState
totalComponentCount int // number of components that will send updates
}

type componentState struct {
Expand All @@ -68,15 +69,16 @@ type componentState struct {
}

// newProcessState returns a new FSM that tracks the state of the Teleport process.
func newProcessState(process *TeleportProcess) (*processState, error) {
func newProcessState(process *TeleportProcess, componentCount int) (*processState, error) {
err := utils.RegisterPrometheusCollectors(stateGauge)
if err != nil {
return nil, trace.Wrap(err)
}

return &processState{
process: process,
states: make(map[string]*componentState),
process: process,
states: make(map[string]*componentState),
totalComponentCount: componentCount,
}, nil
}

Expand Down Expand Up @@ -127,7 +129,7 @@ func (f *processState) update(event Event) {
}

// getStateLocked returns the overall process state based on the state of
// individual components. If no components sent updates yet, returns
// individual components. If not all components have sent updates yet, returns
// stateStarting.
//
// Order of importance:
Expand All @@ -138,21 +140,28 @@ func (f *processState) update(event Event) {
//
// Note: f.mu must be locked by the caller!
func (f *processState) getStateLocked() componentStateEnum {
// Return stateStarting if not all components have sent updates yet.
if len(f.states) < f.totalComponentCount {
return stateStarting
}

state := stateStarting
numNotOK := len(f.states)
numOK := 0
for _, s := range f.states {
switch s.state {
case stateDegraded:
return stateDegraded
case stateRecovering:
state = stateRecovering
case stateOK:
numNotOK--
numOK++
}
}
// Only return stateOK if *all* components are in stateOK.
if numNotOK == 0 && len(f.states) > 0 {
if numOK == f.totalComponentCount {
state = stateOK
} else if numOK > f.totalComponentCount {
f.process.log.Errorf("incorrect count of components (found: %d; expected: %d), this is a bug!", numOK, f.totalComponentCount)
}
return state
}
Expand Down
Loading

0 comments on commit b749302

Please sign in to comment.