Skip to content

Commit

Permalink
Prevent panic caused by nil session recorder (#10792)
Browse files Browse the repository at this point in the history
* Prevent panic caused by nil session recorder

In startInteractive the session recorder was being assigned the
return value of events.NewAuditWriter, even if it returned an
error. This causes problems because the nil *events.AuditWriter
that is returned in this case ends up being stored in recorder as
a non-nil events.StreamWriter. So when the session tries to close
the check on recorder != nil is mistakenly true and recorder.Close
is called on a nil *events.AuditWriter - which results in a panic.
  • Loading branch information
rosstimothy committed Mar 4, 2022
1 parent dc7619a commit d1fdecc
Show file tree
Hide file tree
Showing 2 changed files with 286 additions and 28 deletions.
68 changes: 40 additions & 28 deletions lib/srv/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -659,32 +659,11 @@ func (s *session) startInteractive(ch ssh.Channel, ctx *ServerContext) error {
// create a new "party" (connected client)
p := newParty(s, ch, ctx)

// Nodes discard events in cases when proxies are already recording them.
if s.registry.srv.Component() == teleport.ComponentNode &&
services.IsRecordAtProxy(ctx.SessionRecordingConfig.GetMode()) {
s.recorder = &events.DiscardStream{}
} else {
streamer, err := s.newStreamer(ctx)
if err != nil {
return trace.Wrap(err)
}
s.recorder, err = events.NewAuditWriter(events.AuditWriterConfig{
// Audit stream is using server context, not session context,
// to make sure that session is uploaded even after it is closed
Context: ctx.srv.Context(),
Streamer: streamer,
Clock: ctx.srv.GetClock(),
SessionID: s.id,
Namespace: ctx.srv.GetNamespace(),
ServerID: ctx.srv.HostUUID(),
RecordOutput: ctx.SessionRecordingConfig.GetMode() != types.RecordOff,
Component: teleport.Component(teleport.ComponentSession, ctx.srv.Component()),
ClusterName: ctx.ClusterName,
})
if err != nil {
return trace.Wrap(err)
}
rec, err := newRecorder(s, ctx)
if err != nil {
return trace.Wrap(err)
}
s.recorder = rec
s.writer.addWriter("session-recorder", utils.WriteCloserWithContext(ctx.srv.Context(), s.recorder), true)

// allocate a terminal or take the one previously allocated via a
Expand Down Expand Up @@ -849,12 +828,45 @@ func (s *session) startInteractive(ch ssh.Channel, ctx *ServerContext) error {
return nil
}

func (s *session) startExec(channel ssh.Channel, ctx *ServerContext) error {
var err error

// newRecorder creates a new events.StreamWriter to be used as the recorder
// of the passed in session.
func newRecorder(s *session, ctx *ServerContext) (events.StreamWriter, error) {
// Nodes discard events in cases when proxies are already recording them.
if s.registry.srv.Component() == teleport.ComponentNode &&
services.IsRecordAtProxy(ctx.SessionRecordingConfig.GetMode()) {
return &events.DiscardStream{}, nil
}

streamer, err := s.newStreamer(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
rec, err := events.NewAuditWriter(events.AuditWriterConfig{
// Audit stream is using server context, not session context,
// to make sure that session is uploaded even after it is closed
Context: ctx.srv.Context(),
Streamer: streamer,
SessionID: s.id,
Clock: ctx.srv.GetClock(),
Namespace: ctx.srv.GetNamespace(),
ServerID: ctx.srv.HostUUID(),
RecordOutput: ctx.SessionRecordingConfig.GetMode() != types.RecordOff,
Component: teleport.Component(teleport.ComponentSession, ctx.srv.Component()),
ClusterName: ctx.ClusterName,
})
if err != nil {
return nil, trace.Wrap(err)
}

return rec, nil
}

func (s *session) startExec(channel ssh.Channel, ctx *ServerContext) error {
rec, err := newRecorder(s, ctx)
if err != nil {
return trace.Wrap(err)
}
s.recorder = rec
s.recorder = &events.DiscardStream{}
} else {
streamer, err := s.newStreamer(ctx)
Expand Down
246 changes: 246 additions & 0 deletions lib/srv/sess_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,20 @@ limitations under the License.
package srv

import (
"context"
"testing"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/bpf"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/pam"
restricted "github.com/gravitational/teleport/lib/restrictedsession"
"github.com/gravitational/teleport/lib/services"
rsession "github.com/gravitational/teleport/lib/session"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -63,3 +75,237 @@ func TestParseAccessRequestIDs(t *testing.T) {
}

}

type mockServer struct {
events.StreamEmitter
}

// ID is the unique ID of the server.
func (m *mockServer) ID() string {
return "test"
}

// HostUUID is the UUID of the underlying host. For the forwarding
// server this is the proxy the forwarding server is running in.
func (m *mockServer) HostUUID() string {
return "test"
}

// GetNamespace returns the namespace the server was created in.
func (m *mockServer) GetNamespace() string {
return "test"
}

// AdvertiseAddr is the publicly addressable address of this server.
func (m *mockServer) AdvertiseAddr() string {
return "test"
}

// Component is the type of server, forwarding or regular.
func (m *mockServer) Component() string {
return teleport.ComponentNode
}

// PermitUserEnvironment returns if reading environment variables upon
// startup is allowed.
func (m *mockServer) PermitUserEnvironment() bool {
return false
}

// GetAccessPoint returns an AccessPoint for this cluster.
func (m *mockServer) GetAccessPoint() AccessPoint {
return nil
}

// GetSessionServer returns a session server.
func (m *mockServer) GetSessionServer() rsession.Service {
return nil
}

// GetDataDir returns data directory of the server
func (m *mockServer) GetDataDir() string {
return "test"
}

// GetPAM returns PAM configuration for this server.
func (m *mockServer) GetPAM() (*pam.Config, error) {
return nil, nil
}

// GetClock returns a clock setup for the server
func (m *mockServer) GetClock() clockwork.Clock {
return clockwork.NewRealClock()
}

// GetInfo returns a services.Server that represents this server.
func (m *mockServer) GetInfo() types.Server {
return nil
}

// UseTunnel used to determine if this node has connected to this cluster
// using reverse tunnel.
func (m *mockServer) UseTunnel() bool {
return false
}

// GetBPF returns the BPF service used for enhanced session recording.
func (m *mockServer) GetBPF() bpf.BPF {
return nil
}

// GetRestrictedSessionManager returns the manager for restricting user activity
func (m *mockServer) GetRestrictedSessionManager() restricted.Manager {
return nil
}

// Context returns server shutdown context
func (m *mockServer) Context() context.Context {
return context.Background()
}

// GetUtmpPath returns the path of the user accounting database and log. Returns empty for system defaults.
func (m *mockServer) GetUtmpPath() (utmp, wtmp string) {
return "test", "test"
}

// GetLockWatcher gets the server's lock watcher.
func (m *mockServer) GetLockWatcher() *services.LockWatcher {
return nil
}

func TestSession_newRecorder(t *testing.T) {
proxyRecording, err := types.NewSessionRecordingConfigFromConfigFile(types.SessionRecordingConfigSpecV2{
Mode: types.RecordAtProxy,
})
require.NoError(t, err)

proxyRecordingSync, err := types.NewSessionRecordingConfigFromConfigFile(types.SessionRecordingConfigSpecV2{
Mode: types.RecordAtProxySync,
})
require.NoError(t, err)

nodeRecording, err := types.NewSessionRecordingConfigFromConfigFile(types.SessionRecordingConfigSpecV2{
Mode: types.RecordAtNode,
})
require.NoError(t, err)

nodeRecordingSync, err := types.NewSessionRecordingConfigFromConfigFile(types.SessionRecordingConfigSpecV2{
Mode: types.RecordAtNodeSync,
})
require.NoError(t, err)

logger := logrus.WithFields(logrus.Fields{
trace.Component: teleport.ComponentAuth,
})

cases := []struct {
desc string
sess *session
sctx *ServerContext
errAssertion require.ErrorAssertionFunc
recAssertion require.ValueAssertionFunc
}{
{
desc: "discard-stream-when-proxy-recording",
sess: &session{
id: "test",
log: logger,
registry: &SessionRegistry{
srv: &mockServer{},
},
},
sctx: &ServerContext{
SessionRecordingConfig: proxyRecording,
},
errAssertion: require.NoError,
recAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) {
require.NotNil(t, i)
_, ok := i.(*events.DiscardStream)
require.True(t, ok)
},
},
{
desc: "discard-stream--when-proxy-sync-recording",
sess: &session{
id: "test",
log: logger,
registry: &SessionRegistry{
srv: &mockServer{},
},
},
sctx: &ServerContext{
SessionRecordingConfig: proxyRecordingSync,
},
errAssertion: require.NoError,
recAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) {
require.NotNil(t, i)
_, ok := i.(*events.DiscardStream)
require.True(t, ok)
},
},
{
desc: "err-new-streamer-fails",
sess: &session{
id: "test",
log: logger,
registry: &SessionRegistry{
srv: &mockServer{},
},
},
sctx: &ServerContext{
SessionRecordingConfig: nodeRecording,
srv: &mockServer{},
},
errAssertion: require.Error,
recAssertion: require.Nil,
},
{
desc: "err-new-audit-writer-fails",
sess: &session{
id: "test",
log: logger,
registry: &SessionRegistry{
srv: &mockServer{},
},
},
sctx: &ServerContext{
SessionRecordingConfig: nodeRecordingSync,
srv: &mockServer{},
},
errAssertion: require.Error,
recAssertion: require.Nil,
},
{
desc: "audit-writer",
sess: &session{
id: "test",
log: logger,
registry: &SessionRegistry{
srv: &mockServer{},
},
},
sctx: &ServerContext{
ClusterName: "test",
SessionRecordingConfig: nodeRecordingSync,
srv: &mockServer{
StreamEmitter: &events.DiscardEmitter{},
},
},
errAssertion: require.NoError,
recAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) {
require.NotNil(t, i)
aw, ok := i.(*events.AuditWriter)
require.True(t, ok)
require.NoError(t, aw.Close(context.Background()))
},
},
}

for _, tt := range cases {
t.Run(tt.desc, func(t *testing.T) {
rec, err := newRecorder(tt.sess, tt.sctx)
tt.errAssertion(t, err)
tt.recAssertion(t, rec)
})
}
}

0 comments on commit d1fdecc

Please sign in to comment.