From 07fd46bb1dcebb2da4c10974fbaca102cbc8eab3 Mon Sep 17 00:00:00 2001 From: klizhentas Date: Thu, 3 Mar 2016 12:28:10 -0800 Subject: [PATCH] fix orphane processes, fixes #191 --- lib/auth/api_with_roles.go | 3 ++ lib/auth/tun.go | 60 ++++++++++++++++++++++++++----------- lib/client/client.go | 14 +++++---- lib/hangout/hangout_test.go | 33 ++++++++++---------- lib/reversetunnel/agent.go | 2 +- lib/reversetunnel/srv.go | 10 +++++-- lib/srv/sess.go | 8 +++++ lib/srv/term.go | 4 +-- lib/sshutils/server.go | 4 +++ lib/utils/utils.go | 25 ++++++++++++---- modules.go | 2 ++ 11 files changed, 115 insertions(+), 50 deletions(-) diff --git a/lib/auth/api_with_roles.go b/lib/auth/api_with_roles.go index 91973142daf6d..db443edef80da 100644 --- a/lib/auth/api_with_roles.go +++ b/lib/auth/api_with_roles.go @@ -161,6 +161,9 @@ func (conn *FakeSSHConnection) SetWriteDeadline(t time.Time) error { // and waits for that connection to be closed either by the client or by the server func (socket *fakeSocket) CreateBridge(remoteAddr net.Addr, sshChan ssh.Channel) error { log.Debugf("SocketOverSSH.Handle(from=%v) is called", remoteAddr) + if sshChan == nil { + return trace.Wrap(teleport.BadParameter("sshChan", "supply ssh channel")) + } // wrap sshChan into a 'fake connection' which allows us to // a) preserve the original address of the connected client // b) sit and wait until client closes the ssh channel, so we'll close this fake socket diff --git a/lib/auth/tun.go b/lib/auth/tun.go index b3da168b6a876..a09a671a22a47 100644 --- a/lib/auth/tun.go +++ b/lib/auth/tun.go @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + package auth import ( @@ -22,6 +23,7 @@ import ( "net/http" "os" "sync" + "time" "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/limiter" @@ -55,8 +57,10 @@ type AuthTunnel struct { limiter *limiter.Limiter } +// ServerOption is the functional argument passed to the server type ServerOption func(s *AuthTunnel) error +// SetLimiter sets rate and connection limiter for auth tunnel server func SetLimiter(limiter *limiter.Limiter) ServerOption { return func(s *AuthTunnel) error { s.limiter = limiter @@ -135,10 +139,12 @@ func (s *AuthTunnel) HandleNewChan(_ net.Conn, sconn *ssh.ServerConn, nch ssh.Ne string(nch.ExtraData()), err) nch.Reject(ssh.UnknownChannelType, "failed to parse direct-tcpip request") + return } sshCh, _, err := nch.Accept() if err != nil { log.Infof("[AUTH] could not accept channel (%s)", err) + return } go s.handleDirectTCPIPRequest(sconn, sshCh, req) case ReqWebSessionAgent: @@ -153,6 +159,7 @@ func (s *AuthTunnel) HandleNewChan(_ net.Conn, sconn *ssh.ServerConn, nch ssh.Ne ch, _, err := nch.Accept() if err != nil { log.Infof("[AUTH] could not accept channel (%s)", err) + return } go s.handleWebAgentRequest(sconn, ch) default: @@ -560,40 +567,57 @@ func (t *TunDialer) getClient(reset bool) (*ssh.Client, error) { if t.tun != nil { if !reset { return t.tun, nil - } else { - go t.tun.Close() - t.tun = nil } + go t.tun.Close() + t.tun = nil } - config := &ssh.ClientConfig{ User: t.user, Auth: t.auth, } client, err := ssh.Dial(t.addr.AddrNetwork, t.addr.Addr, config) if err != nil { - return nil, err + log.Infof("TunDialer could not ssh.Dial: %v", err) + return nil, trace.Wrap(err) } t.tun = client return t.tun, nil } +const ( + // DialerRetryAttempts is the amount of attempts for dialer to try and + // connect to the remote destination + DialerRetryAttempts = 3 + // DialerPeriodBetweenAttempts is the period between retry attempts + DialerPeriodBetweenAttempts = time.Second +) + func (t *TunDialer) Dial(network, address string) (net.Conn, error) { - c, err := t.getClient(false) - if err != nil { - return nil, err - } - conn, err := c.Dial(network, address) - if err == nil { - return conn, err - } else { - // reconnecting and trying again - c, err = t.getClient(true) - if err != nil { - return nil, err + var client *ssh.Client + var err error + for i := 0; i < DialerRetryAttempts; i++ { + if i == 0 { + client, err = t.getClient(false) + if err != nil { + log.Infof("TunDialer failed to get client: %v", err) + continue + } + } else { + time.Sleep(DialerPeriodBetweenAttempts) + client, err = t.getClient(true) + if err != nil { + log.Infof("TunDialer failed to get client: %v", err) + continue + } + } + conn, err := client.Dial(network, address) + if err == nil { + return conn, nil } - return c.Dial(network, address) + log.Infof("TunDialer connection issue (%v), reconnect", err) } + return nil, trace.Wrap( + teleport.ConnectionProblem("failed to connect to remote API", err)) } func NewClientFromSSHClient(sshClient *ssh.Client) (*Client, error) { diff --git a/lib/client/client.go b/lib/client/client.go index 0819950873d85..378ea6ee2e8c9 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -40,6 +40,7 @@ import ( "github.com/gravitational/teleport/lib/utils" log "github.com/Sirupsen/logrus" + "github.com/gravitational/teleport" "github.com/gravitational/trace" "golang.org/x/crypto/ssh" ) @@ -247,7 +248,8 @@ func (proxy *ProxyClient) ConnectToHangout(nodeAddress string, authMethods []ssh.AuthMethod) (*NodeClient, error) { if len(authMethods) == 0 { - return nil, trace.Errorf("no authMethods were provided") + return nil, trace.Wrap( + teleport.BadParameter("authMethods", "no authMethods were provided")) } e := trace.Errorf("unknown Error") @@ -296,7 +298,8 @@ func (proxy *ProxyClient) ConnectToHangout(nodeAddress string, buf := make([]byte, 20000) n, err := pipeNetConn.Read(buf) if err != nil { - return nil, trace.Errorf("Readed %v, err: %v", n, err) + return nil, trace.Wrap( + teleport.ConnectionProblem("failed to read target host certificate", err)) } var endpointInfo srv.HangoutEndpointInfo err = json.Unmarshal(buf[:n], &endpointInfo) @@ -307,7 +310,8 @@ func (proxy *ProxyClient) ConnectToHangout(nodeAddress string, hostKeyCallback := func(hostname string, remote net.Addr, key ssh.PublicKey) error { cert, ok := key.(*ssh.Certificate) if !ok { - return trace.Errorf("expected certificate") + return trace.Wrap( + teleport.BadParameter("certificate", "expected certificate")) } checkers, err := endpointInfo.HostKey.Checkers() if err != nil { @@ -318,8 +322,8 @@ func (proxy *ProxyClient) ConnectToHangout(nodeAddress string, return nil } } - - return trace.Errorf("remote host key is not valid") + return trace.Wrap( + teleport.AccessDenied("remote host key is not valid")) } sshConfig := &ssh.ClientConfig{ diff --git a/lib/hangout/hangout_test.go b/lib/hangout/hangout_test.go index 2801a3c9380e5..289a8af8431fa 100644 --- a/lib/hangout/hangout_test.go +++ b/lib/hangout/hangout_test.go @@ -13,11 +13,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + package hangout import ( "bufio" - "net/http" + "fmt" "net/http/httptest" "os/user" "path/filepath" @@ -47,8 +48,6 @@ import ( "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" . "gopkg.in/check.v1" - - log "github.com/Sirupsen/logrus" ) func TestHangouts(t *testing.T) { TestingT(t) } @@ -72,7 +71,7 @@ type HangoutsSuite struct { var _ = Suite(&HangoutsSuite{}) func (s *HangoutsSuite) SetUpSuite(c *C) { - utils.InitLoggerCLI() + utils.InitLoggerDebug() client.KeysDir = c.MkDir() s.dir = c.MkDir() @@ -113,8 +112,11 @@ func (s *HangoutsSuite) SetUpSuite(c *C) { teleport.RoleAdmin, nil) + freePorts, err := utils.GetFreeTCPPorts(10) + c.Assert(err, IsNil) + // Starting proxy - reverseTunnelAddress := utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:34057"} + reverseTunnelAddress := utils.NetAddr{AddrNetwork: "tcp", Addr: fmt.Sprintf("localhost:%v", freePorts.Pop())} s.reverseTunnelAddress = reverseTunnelAddress.Addr reverseTunnelServer, err := reversetunnel.NewServer( reverseTunnelAddress, @@ -123,7 +125,7 @@ func (s *HangoutsSuite) SetUpSuite(c *C) { c.Assert(err, IsNil) c.Assert(reverseTunnelServer.Start(), IsNil) - s.proxyAddress = "localhost:35783" + s.proxyAddress = fmt.Sprintf("localhost:%v", freePorts.Pop()) s.proxy, err = srv.New( utils.NetAddr{AddrNetwork: "tcp", Addr: s.proxyAddress}, @@ -148,7 +150,10 @@ func (s *HangoutsSuite) SetUpSuite(c *C) { go apiSrv.Serve() tsrv, err := auth.NewTunnel( - utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:32498"}, + utils.NetAddr{ + AddrNetwork: "tcp", + Addr: fmt.Sprintf("localhost:%v", freePorts.Pop()), + }, []ssh.Signer{s.signer}, apiSrv, s.a) c.Assert(err, IsNil) @@ -196,13 +201,6 @@ func (s *HangoutsSuite) SetUpSuite(c *C) { s.webAddress = webServer.URL - go func() { - err := http.ListenAndServe(s.webAddress, webHandler) - if err != nil { - log.Errorf(err.Error()) - } - }() - s.teleagent = teleagent.NewTeleAgent(true) err = s.teleagent.Login(s.webAddress, s.user, string(s.pass), s.otp.OTP(), time.Minute) c.Assert(err, IsNil) @@ -235,11 +233,13 @@ func (s *HangoutsSuite) SetUpSuite(c *C) { } func (s *HangoutsSuite) TestHangout(c *C) { + ports, err := utils.GetFreeTCPPorts(2) + c.Assert(err, IsNil) u, err := user.Current() c.Assert(err, IsNil) osUser := u.Username - nodeListeningAddress := "localhost:41003" - authListeningAddress := "localhost:41004" + nodeListeningAddress := fmt.Sprintf("localhost:%v", ports.Pop()) + authListeningAddress := fmt.Sprintf("localhost:%v", ports.Pop()) DefaultSSHShell = "/bin/sh" // Initializing tsh share @@ -272,7 +272,6 @@ func (s *HangoutsSuite) TestHangout(c *C) { c.Assert(out, Equals, " expr 11 + 22\r\n33\r\n$") for i := 0; i < 3; i++ { - // Initializing tsh join proxy, err := client.ConnectToProxy(s.proxyAddress, []ssh.AuthMethod{s.teleagent.AuthMethod()}, client.CheckHostSignerFromCache, "anyuser") diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index 2aa8bf932164f..8b6edb23b9c2b 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -306,7 +306,7 @@ func (a *Agent) proxyAccessPoint(ch ssh.Channel, req <-chan *ssh.Request) { conn, err := a.clt.GetDialer()() if err != nil { - a.log.Errorf("%v error dialing: %v", a, err) + a.log.Errorf("error dialing: %v", err) return } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index b1acfe08f0dda..b7b894ad34a35 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -685,7 +685,6 @@ func (s *tunnelSite) handleHeartbeat(ch ssh.Channel, reqC <-chan *ssh.Request) { s.log.Infof("agent disconnected") return } - //s.log.Debugf("ping") s.lastActive = time.Now() } }() @@ -824,7 +823,14 @@ func (s *tunnelSite) DialServer(addr string) (net.Conn, error) { if err != nil { return nil, trace.Wrap(err) } - knownServers, err := clt.GetServers() + var knownServers []services.Server + for i := 0; i < 10; i++ { + knownServers, err = clt.GetServers() + if err != nil { + log.Infof("failed to get servers: %v", err) + time.Sleep(time.Second) + } + } if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/sess.go b/lib/srv/sess.go index 96e35df79e185..ea6f53e888ec9 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -336,6 +336,14 @@ func (s *session) start(sconn *ssh.ServerConn, ch ssh.Channel, ctx *ctx) error { } }() + go func() { + <-s.closeC + if cmd.Process != nil { + err := cmd.Process.Kill() + p.ctx.Infof("killed process: %v", err) + } + }() + return nil } diff --git a/lib/srv/term.go b/lib/srv/term.go index 02a6d542982df..7f74e4d8ccc84 100644 --- a/lib/srv/term.go +++ b/lib/srv/term.go @@ -105,7 +105,7 @@ func (t *terminal) run(c *exec.Cmd) error { c.Stderr = t.tty c.SysProcAttr.Setctty = true c.SysProcAttr.Setsid = true - return c.Start() + return trace.Wrap(c.Start()) } func (t *terminal) Close() error { @@ -118,7 +118,7 @@ func (t *terminal) Close() error { err = e } } - return err + return trace.Wrap(err) } func parseWinChange(req *ssh.Request) (*rsession.TerminalParams, error) { diff --git a/lib/sshutils/server.go b/lib/sshutils/server.go index 1187230261920..6ddaf8b251e55 100644 --- a/lib/sshutils/server.go +++ b/lib/sshutils/server.go @@ -216,6 +216,10 @@ func (s *Server) handleRequests(reqs <-chan *ssh.Request) { func (s *Server) handleChannels(conn net.Conn, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel) { for nch := range chans { + if nch == nil { + log.Warningf("nil channel: %v", nch) + continue + } s.newChanHandler.HandleNewChan(conn, sconn, nch) } } diff --git a/lib/utils/utils.go b/lib/utils/utils.go index 4a0c473ab54a2..0dbbc887b69d6 100644 --- a/lib/utils/utils.go +++ b/lib/utils/utils.go @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + package utils import ( @@ -114,26 +115,40 @@ func IsHandshakeFailedError(err error) bool { return strings.Contains(err.Error(), "handshake failed") } -func GetFreeTCPPorts(n int) (ports []string, e error) { +// PortList is a list of TCP port +type PortList []string + +// Pop returns a value from the list, it panics if the value is not there +func (p *PortList) Pop() string { + if len(*p) == 0 { + panic("list is empty") + } + val := (*p)[len(*p)-1] + *p = (*p)[:len(*p)-1] + return val +} + +// GetFreeTCPPorts returns a lit of available ports on localhost +// used for testing +func GetFreeTCPPorts(n int) (PortList, error) { + list := make(PortList, 0, n) for i := 0; i < n; i++ { addr, err := net.ResolveTCPAddr("tcp", "localhost:0") if err != nil { return nil, trace.Wrap(err) } - listener, err := net.ListenTCP("tcp", addr) if err != nil { return nil, trace.Wrap(err) } defer listener.Close() - tcpAddr, ok := listener.Addr().(*net.TCPAddr) if !ok { return nil, trace.Errorf("Can't get tcp address") } - ports = append(ports, strconv.Itoa(tcpAddr.Port)) + list = append(list, strconv.Itoa(tcpAddr.Port)) } - return ports, nil + return list, nil } // ReadOrMakeHostUUID looks for a hostid file in the data dir. If present, diff --git a/modules.go b/modules.go index c9becb15bc0e8..84dd5cdf8bc42 100644 --- a/modules.go +++ b/modules.go @@ -17,6 +17,8 @@ const ( ComponentNode = "node" // ComponentProxy is SSH proxy (SSH server forwarding connections) ComponentProxy = "proxy" + // ComponentTunClient is a tunnel client + ComponentTunClient = "tunclient" ) const (