diff --git a/main.go b/main.go index 46830542f..143b8c850 100644 --- a/main.go +++ b/main.go @@ -17,7 +17,7 @@ func main() { if err != nil { panic(err) } - s, err := NewServer(cfg.Port, cfg.KeyPath, tui.SessionHandler) + s, err := NewServer(cfg.Port, cfg.KeyPath, LoggingMiddleware(), BubbleTeaMiddleware(tui.SessionHandler)) if err != nil { panic(err) } diff --git a/server.go b/server.go index efbaed003..ea735c41c 100644 --- a/server.go +++ b/server.go @@ -12,22 +12,44 @@ import ( gossh "golang.org/x/crypto/ssh" ) -type SessionHandler func(ssh.Session) (tea.Model, error) +type Middleware func(ssh.Handler) ssh.Handler -type Server struct { - server *ssh.Server - key gossh.PublicKey - handler SessionHandler +func LoggingMiddleware() Middleware { + return func(sh ssh.Handler) ssh.Handler { + return func(s ssh.Session) { + hpk := s.PublicKey() != nil + log.Printf("%s connect %v %v\n", s.RemoteAddr().String(), hpk, s.Command()) + sh(s) + log.Printf("%s disconnect %v %v\n", s.RemoteAddr().String(), hpk, s.Command()) + } + } } -func NewServer(port int, keyPath string, handler SessionHandler) (*Server, error) { - s := &Server{ - server: &ssh.Server{}, - handler: handler, +func BubbleTeaMiddleware(bth func(ssh.Session) tea.Model) Middleware { + return func(sh ssh.Handler) ssh.Handler { + return func(s ssh.Session) { + m := bth(s) + if m != nil { + p := tea.NewProgram(m, tea.WithAltScreen(), tea.WithInput(s), tea.WithOutput(s)) + err := p.Start() + if err != nil { + log.Printf("%s error %v: %s\n", s.RemoteAddr().String(), s.Command(), err) + } + } + sh(s) + } } +} + +type Server struct { + server *ssh.Server + key gossh.PublicKey +} + +func NewServer(port int, keyPath string, mw ...Middleware) (*Server, error) { + s := &Server{server: &ssh.Server{}} s.server.Version = "OpenSSH_7.6p1" s.server.Addr = fmt.Sprintf(":%d", port) - s.server.Handler = s.sessionHandler s.server.PasswordHandler = s.passHandler s.server.PublicKeyHandler = s.authHandler kps := strings.Split(keyPath, string(filepath.Separator)) @@ -42,28 +64,15 @@ func NewServer(port int, keyPath string, handler SessionHandler) (*Server, error if err != nil { return nil, err } + h := func(s ssh.Session) {} + for _, m := range mw { + h = m(h) + } + s.server.Handler = h return s, nil } func (srv *Server) sessionHandler(s ssh.Session) { - hpk := s.PublicKey() != nil - log.Printf("%s connect %v %v\n", s.RemoteAddr().String(), hpk, s.Command()) - m, err := srv.handler(s) - if err != nil { - log.Printf("%s error %v %s\n", s.RemoteAddr().String(), hpk, err) - s.Exit(1) - return - } - if m != nil { - p := tea.NewProgram(m, tea.WithAltScreen(), tea.WithInput(s), tea.WithOutput(s)) - err = p.Start() - if err != nil { - log.Printf("%s error %v %s\n", s.RemoteAddr().String(), hpk, err) - s.Exit(1) - return - } - } - log.Printf("%s disconnect %v %v\n", s.RemoteAddr().String(), hpk, s.Command()) } func (srv *Server) authHandler(ctx ssh.Context, key ssh.PublicKey) bool { diff --git a/tui/model.go b/tui/model.go index 4a0e1d236..e29f72f55 100644 --- a/tui/model.go +++ b/tui/model.go @@ -24,12 +24,12 @@ func (e errMsg) Error() string { return e.err.Error() } -func SessionHandler(s ssh.Session) (tea.Model, error) { +func SessionHandler(s ssh.Session) tea.Model { pty, changes, active := s.Pty() if !active { - return nil, fmt.Errorf("you need to do this from a terminal with PTY support") + return nil } - return NewModel(pty.Window.Width, pty.Window.Height, changes), nil + return NewModel(pty.Window.Width, pty.Window.Height, changes) } type Model struct {