diff --git a/cmd/soft/main.go b/cmd/soft/main.go index f7ea656fe..d15a23bd9 100644 --- a/cmd/soft/main.go +++ b/cmd/soft/main.go @@ -1,10 +1,14 @@ package main import ( + "context" "flag" "fmt" "log" "os" + "os/signal" + "syscall" + "time" "github.com/charmbracelet/soft-serve/config" "github.com/charmbracelet/soft-serve/server" @@ -49,9 +53,23 @@ func main() { cfg := config.DefaultConfig() s := server.NewServer(cfg) - log.Printf("Starting SSH server on %s:%d\n", cfg.Host, cfg.Port) - err := s.Start() - if err != nil { + + done := make(chan os.Signal, 1) + signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + log.Printf("Starting SSH server on %s:%d", cfg.Host, cfg.Port) + go func() { + if err := s.Start(); err != nil { + log.Fatalln(err) + } + }() + + <-done + + log.Printf("Stopping SSH server on %s:%d", cfg.Host, cfg.Port) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer func() { cancel() }() + if err := s.Shutdown(ctx); err != nil { log.Fatalln(err) } } diff --git a/server/server.go b/server/server.go index e7ab0cf23..17a0bd7e0 100644 --- a/server/server.go +++ b/server/server.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "log" @@ -63,3 +64,8 @@ func (srv *Server) Reload() error { func (srv *Server) Start() error { return srv.SSHServer.ListenAndServe() } + +// Shutdown lets the server gracefully shutdown. +func (srv *Server) Shutdown(ctx context.Context) error { + return srv.SSHServer.Shutdown(ctx) +}