diff --git a/app/api_topology.go b/app/api_topology.go index 65aba207b4..e4da6c3bf3 100644 --- a/app/api_topology.go +++ b/app/api_topology.go @@ -77,15 +77,16 @@ func handleWebsocket( renderer render.Renderer, loop time.Duration, ) { - conn, err := upgrader.Upgrade(w, r, nil) + wsConn, err := upgrader.Upgrade(w, r, nil) if err != nil { // log.Info("Upgrade:", err) return } + conn := xfer.Ping(wsConn) defer conn.Close() quit := make(chan struct{}) - go func(c *websocket.Conn) { + go func(c xfer.Websocket) { for { // just discard everything the browser sends if _, _, err := c.NextReader(); err != nil { close(quit) diff --git a/app/controls.go b/app/controls.go index 0f5c38f6b1..855d92a545 100644 --- a/app/controls.go +++ b/app/controls.go @@ -1,6 +1,7 @@ package app import ( + "io" "math/rand" "net/http" "net/rpc" @@ -115,7 +116,9 @@ func (cr *controlRouter) handleProbeWS(w http.ResponseWriter, r *http.Request) { cr.set(probeID, handler) - codec.WaitForReadError() + if err := codec.WaitForReadError(); err != nil && err != io.EOF { + log.Errorf("Error on websocket: %v", err) + } cr.rm(probeID, handler) client.Close() diff --git a/app/pipes.go b/app/pipes.go index bed35309b7..f54842bc28 100644 --- a/app/pipes.go +++ b/app/pipes.go @@ -176,11 +176,12 @@ func (pr *PipeRouter) handleWs(endSelector func(*pipe) (*end, io.ReadWriter)) fu } defer pr.release(pipeID, pipe, endRef) - conn, err := upgrader.Upgrade(w, r, nil) + wsConn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Errorf("Error upgrading to websocket: %v", err) return } + conn := xfer.Ping(wsConn) defer conn.Close() pipe.CopyToWebsocket(endIO, conn) diff --git a/common/xfer/controls.go b/common/xfer/controls.go index aba3c71d4e..6217f6f969 100644 --- a/common/xfer/controls.go +++ b/common/xfer/controls.go @@ -4,8 +4,6 @@ import ( "fmt" "net/rpc" "sync" - - "github.com/gorilla/websocket" ) // ErrInvalidMessage is the error returned when the on-wire message is unexpected. @@ -70,22 +68,22 @@ func ResponseError(err error) Response { // that transmits and receives RPC messages over a websocker, as JSON. type JSONWebsocketCodec struct { sync.Mutex - conn *websocket.Conn - err chan struct{} + conn Websocket + err chan error } // NewJSONWebsocketCodec makes a new JSONWebsocketCodec -func NewJSONWebsocketCodec(conn *websocket.Conn) *JSONWebsocketCodec { +func NewJSONWebsocketCodec(conn Websocket) *JSONWebsocketCodec { return &JSONWebsocketCodec{ conn: conn, - err: make(chan struct{}), + err: make(chan error, 1), } } // WaitForReadError blocks until any read on this codec returns an error. // This is useful to know when the server has disconnected from the client. -func (j *JSONWebsocketCodec) WaitForReadError() { - <-j.err +func (j *JSONWebsocketCodec) WaitForReadError() error { + return <-j.err } // WriteRequest implements rpc.ClientCodec @@ -113,6 +111,7 @@ func (j *JSONWebsocketCodec) WriteResponse(r *rpc.Response, v interface{}) error func (j *JSONWebsocketCodec) readMessage(v interface{}) (*Message, error) { m := Message{Value: v} if err := ReadJSONfromWS(j.conn, &m); err != nil { + j.err <- err close(j.err) return nil, err } diff --git a/common/xfer/pipes.go b/common/xfer/pipes.go index 8e85515ff0..8172644b64 100644 --- a/common/xfer/pipes.go +++ b/common/xfer/pipes.go @@ -11,7 +11,7 @@ import ( // to the UI. type Pipe interface { Ends() (io.ReadWriter, io.ReadWriter) - CopyToWebsocket(io.ReadWriter, *websocket.Conn) error + CopyToWebsocket(io.ReadWriter, Websocket) error Close() error Closed() bool @@ -83,7 +83,7 @@ func (p *pipe) OnClose(f func()) { } // CopyToWebsocket copies pipe data to/from a websocket. It blocks. -func (p *pipe) CopyToWebsocket(end io.ReadWriter, conn *websocket.Conn) error { +func (p *pipe) CopyToWebsocket(end io.ReadWriter, conn Websocket) error { p.mtx.Lock() if p.closed { p.mtx.Unlock() diff --git a/common/xfer/websocket.go b/common/xfer/websocket.go index 54d83883c4..099d861532 100644 --- a/common/xfer/websocket.go +++ b/common/xfer/websocket.go @@ -2,13 +2,56 @@ package xfer import ( "io" + "time" "github.com/gorilla/websocket" "github.com/ugorji/go/codec" + + "github.com/weaveworks/scope/common/mtime" +) + +const ( + pingInterval = 5 * time.Second ) +// Websocket exposes the bits of *websocket.Conn we actually use. +type Websocket interface { + NextReader() (messageType int, r io.Reader, err error) + NextWriter(messageType int) (io.WriteCloser, error) + ReadMessage() (messageType int, p []byte, err error) + WriteMessage(messageType int, data []byte) error + SetReadDeadline(time.Time) error + SetWriteDeadline(time.Time) error + Close() error +} + +type pingingWebsocket struct { + pinger *time.Timer + *websocket.Conn +} + +// Ping adds a periodic ping to a websocket connection. +func Ping(c *websocket.Conn) Websocket { + w := &pingingWebsocket{Conn: c} + w.pinger = time.AfterFunc(pingInterval, w.ping) + return w +} + +func (p *pingingWebsocket) Close() error { + p.pinger.Stop() + return p.Conn.Close() +} + +func (p *pingingWebsocket) ping() { + if err := p.Conn.WriteControl(websocket.PingMessage, nil, mtime.Now().Add(pingInterval)); err != nil { + p.Close() + return + } + p.pinger.Reset(pingInterval) +} + // WriteJSONtoWS writes the JSON encoding of v to the connection. -func WriteJSONtoWS(c *websocket.Conn, v interface{}) error { +func WriteJSONtoWS(c Websocket, v interface{}) error { w, err := c.NextWriter(websocket.TextMessage) if err != nil { return err @@ -23,7 +66,7 @@ func WriteJSONtoWS(c *websocket.Conn, v interface{}) error { // ReadJSONfromWS reads the next JSON-encoded message from the connection and stores // it in the value pointed to by v. -func ReadJSONfromWS(c *websocket.Conn, v interface{}) error { +func ReadJSONfromWS(c Websocket, v interface{}) error { _, r, err := c.NextReader() if err != nil { return err diff --git a/probe/appclient/app_client.go b/probe/appclient/app_client.go index 1424d9962c..b23ffa5a34 100644 --- a/probe/appclient/app_client.go +++ b/probe/appclient/app_client.go @@ -46,7 +46,7 @@ type appClient struct { backgroundWait sync.WaitGroup // Track ongoing websocket connections - conns map[string]*websocket.Conn + conns map[string]xfer.Websocket // For publish publishLoop sync.Once @@ -73,7 +73,7 @@ func NewAppClient(pc ProbeConfig, hostname, target string, control xfer.ControlH wsDialer: websocket.Dialer{ TLSClientConfig: httpTransport.TLSClientConfig, }, - conns: map[string]*websocket.Conn{}, + conns: map[string]xfer.Websocket{}, readers: make(chan io.Reader), control: control, }, nil @@ -88,7 +88,7 @@ func (c *appClient) hasQuit() bool { } } -func (c *appClient) registerConn(id string, conn *websocket.Conn) bool { +func (c *appClient) registerConn(id string, conn xfer.Websocket) bool { c.mtx.Lock() defer c.mtx.Unlock() if c.hasQuit() { @@ -130,7 +130,7 @@ func (c *appClient) Stop() { for _, conn := range c.conns { conn.Close() } - c.conns = map[string]*websocket.Conn{} + c.conns = map[string]xfer.Websocket{} c.mtx.Unlock() c.backgroundWait.Wait() @@ -188,10 +188,11 @@ func (c *appClient) controlConnection() (bool, error) { headers := http.Header{} c.ProbeConfig.authorizeHeaders(headers) url := sanitize.URL("ws://", 0, "/api/control/ws")(c.target) - conn, _, err := c.wsDialer.Dial(url, headers) + wsConn, _, err := c.wsDialer.Dial(url, headers) if err != nil { return false, err } + conn := xfer.Ping(wsConn) defer func() { conn.Close() }() @@ -271,7 +272,7 @@ func (c *appClient) pipeConnection(id string, pipe xfer.Pipe) (bool, error) { headers := http.Header{} c.ProbeConfig.authorizeHeaders(headers) url := sanitize.URL("ws://", 0, fmt.Sprintf("/api/pipe/%s/probe", id))(c.target) - conn, resp, err := c.wsDialer.Dial(url, headers) + wsConn, resp, err := c.wsDialer.Dial(url, headers) if resp != nil && resp.StatusCode == http.StatusNotFound { // Special handling - 404 means the app/user has closed the pipe pipe.Close() @@ -280,6 +281,7 @@ func (c *appClient) pipeConnection(id string, pipe xfer.Pipe) (bool, error) { if err != nil { return false, err } + conn := xfer.Ping(wsConn) // Will return false if we are exiting if !c.registerConn(id, conn) { diff --git a/probe/docker/controls_test.go b/probe/docker/controls_test.go index bf5093fd16..e7f17aef79 100644 --- a/probe/docker/controls_test.go +++ b/probe/docker/controls_test.go @@ -6,8 +6,6 @@ import ( "testing" "time" - "github.com/gorilla/websocket" - "github.com/weaveworks/scope/common/xfer" "github.com/weaveworks/scope/probe/controls" "github.com/weaveworks/scope/probe/docker" @@ -43,11 +41,11 @@ func TestControls(t *testing.T) { type mockPipe struct{} -func (mockPipe) Ends() (io.ReadWriter, io.ReadWriter) { return nil, nil } -func (mockPipe) CopyToWebsocket(io.ReadWriter, *websocket.Conn) error { return nil } -func (mockPipe) Close() error { return nil } -func (mockPipe) Closed() bool { return false } -func (mockPipe) OnClose(func()) {} +func (mockPipe) Ends() (io.ReadWriter, io.ReadWriter) { return nil, nil } +func (mockPipe) CopyToWebsocket(io.ReadWriter, xfer.Websocket) error { return nil } +func (mockPipe) Close() error { return nil } +func (mockPipe) Closed() bool { return false } +func (mockPipe) OnClose(func()) {} func TestPipes(t *testing.T) { oldNewPipe := controls.NewPipe