From c8ca543886531d1414eaf89d3afba46b18166820 Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Tue, 8 Mar 2022 13:10:45 -0500 Subject: [PATCH] Prevent panic caused by nil session recorder (#10792) * Prevent panic caused by nil session recorder In startInteractive the session recorder was being assigned the return value of events.NewAuditWriter, even if it returned an error. This causes problems because the nil *events.AuditWriter that is returned in this case ends up being stored in recorder as a non-nil events.StreamWriter. So when the session tries to close the check on recorder != nil is mistakenly true and recorder.Close is called on a nil *events.AuditWriter - which results in a panic. --- lib/srv/sess.go | 89 +++++++--------- lib/srv/sess_test.go | 247 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 286 insertions(+), 50 deletions(-) diff --git a/lib/srv/sess.go b/lib/srv/sess.go index bcb541ef2f354..89134b42b9ebd 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -659,32 +659,11 @@ func (s *session) startInteractive(ch ssh.Channel, ctx *ServerContext) error { // create a new "party" (connected client) p := newParty(s, ch, ctx) - // Nodes discard events in cases when proxies are already recording them. - if s.registry.srv.Component() == teleport.ComponentNode && - services.IsRecordAtProxy(ctx.SessionRecordingConfig.GetMode()) { - s.recorder = &events.DiscardStream{} - } else { - streamer, err := s.newStreamer(ctx) - if err != nil { - return trace.Wrap(err) - } - s.recorder, err = events.NewAuditWriter(events.AuditWriterConfig{ - // Audit stream is using server context, not session context, - // to make sure that session is uploaded even after it is closed - Context: ctx.srv.Context(), - Streamer: streamer, - Clock: ctx.srv.GetClock(), - SessionID: s.id, - Namespace: ctx.srv.GetNamespace(), - ServerID: ctx.srv.HostUUID(), - RecordOutput: ctx.SessionRecordingConfig.GetMode() != types.RecordOff, - Component: teleport.Component(teleport.ComponentSession, ctx.srv.Component()), - ClusterName: ctx.ClusterName, - }) - if err != nil { - return trace.Wrap(err) - } + rec, err := newRecorder(s, ctx) + if err != nil { + return trace.Wrap(err) } + s.recorder = rec s.writer.addWriter("session-recorder", utils.WriteCloserWithContext(ctx.srv.Context(), s.recorder), true) // allocate a terminal or take the one previously allocated via a @@ -849,35 +828,45 @@ func (s *session) startInteractive(ch ssh.Channel, ctx *ServerContext) error { return nil } -func (s *session) startExec(channel ssh.Channel, ctx *ServerContext) error { - var err error - +// newRecorder creates a new events.StreamWriter to be used as the recorder +// of the passed in session. +func newRecorder(s *session, ctx *ServerContext) (events.StreamWriter, error) { // Nodes discard events in cases when proxies are already recording them. if s.registry.srv.Component() == teleport.ComponentNode && services.IsRecordAtProxy(ctx.SessionRecordingConfig.GetMode()) { - s.recorder = &events.DiscardStream{} - } else { - streamer, err := s.newStreamer(ctx) - if err != nil { - return trace.Wrap(err) - } - s.recorder, err = events.NewAuditWriter(events.AuditWriterConfig{ - // Audit stream is using server context, not session context, - // to make sure that session is uploaded even after it is closed - Context: ctx.srv.Context(), - Streamer: streamer, - SessionID: s.id, - Clock: ctx.srv.GetClock(), - Namespace: ctx.srv.GetNamespace(), - ServerID: ctx.srv.HostUUID(), - RecordOutput: ctx.SessionRecordingConfig.GetMode() != types.RecordOff, - Component: teleport.Component(teleport.ComponentSession, ctx.srv.Component()), - ClusterName: ctx.ClusterName, - }) - if err != nil { - return trace.Wrap(err) - } + return &events.DiscardStream{}, nil + } + + streamer, err := s.newStreamer(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + rec, err := events.NewAuditWriter(events.AuditWriterConfig{ + // Audit stream is using server context, not session context, + // to make sure that session is uploaded even after it is closed + Context: ctx.srv.Context(), + Streamer: streamer, + SessionID: s.id, + Clock: ctx.srv.GetClock(), + Namespace: ctx.srv.GetNamespace(), + ServerID: ctx.srv.HostUUID(), + RecordOutput: ctx.SessionRecordingConfig.GetMode() != types.RecordOff, + Component: teleport.Component(teleport.ComponentSession, ctx.srv.Component()), + ClusterName: ctx.ClusterName, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + return rec, nil +} + +func (s *session) startExec(channel ssh.Channel, ctx *ServerContext) error { + rec, err := newRecorder(s, ctx) + if err != nil { + return trace.Wrap(err) } + s.recorder = rec // Emit a session.start event for the exec session. sessionStartEvent := &apievents.SessionStart{ diff --git a/lib/srv/sess_test.go b/lib/srv/sess_test.go index 88b433559cda8..18c359ac9cd99 100644 --- a/lib/srv/sess_test.go +++ b/lib/srv/sess_test.go @@ -17,8 +17,21 @@ limitations under the License. package srv import ( + "context" "testing" + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/bpf" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/pam" + restricted "github.com/gravitational/teleport/lib/restrictedsession" + "github.com/gravitational/teleport/lib/services" + rsession "github.com/gravitational/teleport/lib/session" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" ) @@ -63,3 +76,237 @@ func TestParseAccessRequestIDs(t *testing.T) { } } + +type mockServer struct { + events.StreamEmitter +} + +// ID is the unique ID of the server. +func (m *mockServer) ID() string { + return "test" +} + +// HostUUID is the UUID of the underlying host. For the forwarding +// server this is the proxy the forwarding server is running in. +func (m *mockServer) HostUUID() string { + return "test" +} + +// GetNamespace returns the namespace the server was created in. +func (m *mockServer) GetNamespace() string { + return "test" +} + +// AdvertiseAddr is the publicly addressable address of this server. +func (m *mockServer) AdvertiseAddr() string { + return "test" +} + +// Component is the type of server, forwarding or regular. +func (m *mockServer) Component() string { + return teleport.ComponentNode +} + +// PermitUserEnvironment returns if reading environment variables upon +// startup is allowed. +func (m *mockServer) PermitUserEnvironment() bool { + return false +} + +// GetAccessPoint returns an AccessPoint for this cluster. +func (m *mockServer) GetAccessPoint() auth.AccessPoint { + return nil +} + +// GetSessionServer returns a session server. +func (m *mockServer) GetSessionServer() rsession.Service { + return nil +} + +// GetDataDir returns data directory of the server +func (m *mockServer) GetDataDir() string { + return "test" +} + +// GetPAM returns PAM configuration for this server. +func (m *mockServer) GetPAM() (*pam.Config, error) { + return nil, nil +} + +// GetClock returns a clock setup for the server +func (m *mockServer) GetClock() clockwork.Clock { + return clockwork.NewRealClock() +} + +// GetInfo returns a services.Server that represents this server. +func (m *mockServer) GetInfo() types.Server { + return nil +} + +// UseTunnel used to determine if this node has connected to this cluster +// using reverse tunnel. +func (m *mockServer) UseTunnel() bool { + return false +} + +// GetBPF returns the BPF service used for enhanced session recording. +func (m *mockServer) GetBPF() bpf.BPF { + return nil +} + +// GetRestrictedSessionManager returns the manager for restricting user activity +func (m *mockServer) GetRestrictedSessionManager() restricted.Manager { + return nil +} + +// Context returns server shutdown context +func (m *mockServer) Context() context.Context { + return context.Background() +} + +// GetUtmpPath returns the path of the user accounting database and log. Returns empty for system defaults. +func (m *mockServer) GetUtmpPath() (utmp, wtmp string) { + return "test", "test" +} + +// GetLockWatcher gets the server's lock watcher. +func (m *mockServer) GetLockWatcher() *services.LockWatcher { + return nil +} + +func TestSession_newRecorder(t *testing.T) { + proxyRecording, err := types.NewSessionRecordingConfigFromConfigFile(types.SessionRecordingConfigSpecV2{ + Mode: types.RecordAtProxy, + }) + require.NoError(t, err) + + proxyRecordingSync, err := types.NewSessionRecordingConfigFromConfigFile(types.SessionRecordingConfigSpecV2{ + Mode: types.RecordAtProxySync, + }) + require.NoError(t, err) + + nodeRecording, err := types.NewSessionRecordingConfigFromConfigFile(types.SessionRecordingConfigSpecV2{ + Mode: types.RecordAtNode, + }) + require.NoError(t, err) + + nodeRecordingSync, err := types.NewSessionRecordingConfigFromConfigFile(types.SessionRecordingConfigSpecV2{ + Mode: types.RecordAtNodeSync, + }) + require.NoError(t, err) + + logger := logrus.WithFields(logrus.Fields{ + trace.Component: teleport.ComponentAuth, + }) + + cases := []struct { + desc string + sess *session + sctx *ServerContext + errAssertion require.ErrorAssertionFunc + recAssertion require.ValueAssertionFunc + }{ + { + desc: "discard-stream-when-proxy-recording", + sess: &session{ + id: "test", + log: logger, + registry: &SessionRegistry{ + srv: &mockServer{}, + }, + }, + sctx: &ServerContext{ + SessionRecordingConfig: proxyRecording, + }, + errAssertion: require.NoError, + recAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { + require.NotNil(t, i) + _, ok := i.(*events.DiscardStream) + require.True(t, ok) + }, + }, + { + desc: "discard-stream--when-proxy-sync-recording", + sess: &session{ + id: "test", + log: logger, + registry: &SessionRegistry{ + srv: &mockServer{}, + }, + }, + sctx: &ServerContext{ + SessionRecordingConfig: proxyRecordingSync, + }, + errAssertion: require.NoError, + recAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { + require.NotNil(t, i) + _, ok := i.(*events.DiscardStream) + require.True(t, ok) + }, + }, + { + desc: "err-new-streamer-fails", + sess: &session{ + id: "test", + log: logger, + registry: &SessionRegistry{ + srv: &mockServer{}, + }, + }, + sctx: &ServerContext{ + SessionRecordingConfig: nodeRecording, + srv: &mockServer{}, + }, + errAssertion: require.Error, + recAssertion: require.Nil, + }, + { + desc: "err-new-audit-writer-fails", + sess: &session{ + id: "test", + log: logger, + registry: &SessionRegistry{ + srv: &mockServer{}, + }, + }, + sctx: &ServerContext{ + SessionRecordingConfig: nodeRecordingSync, + srv: &mockServer{}, + }, + errAssertion: require.Error, + recAssertion: require.Nil, + }, + { + desc: "audit-writer", + sess: &session{ + id: "test", + log: logger, + registry: &SessionRegistry{ + srv: &mockServer{}, + }, + }, + sctx: &ServerContext{ + ClusterName: "test", + SessionRecordingConfig: nodeRecordingSync, + srv: &mockServer{ + StreamEmitter: &events.DiscardEmitter{}, + }, + }, + errAssertion: require.NoError, + recAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) { + require.NotNil(t, i) + aw, ok := i.(*events.AuditWriter) + require.True(t, ok) + require.NoError(t, aw.Close(context.Background())) + }, + }, + } + + for _, tt := range cases { + t.Run(tt.desc, func(t *testing.T) { + rec, err := newRecorder(tt.sess, tt.sctx) + tt.errAssertion(t, err) + tt.recAssertion(t, rec) + }) + } +}