Skip to content

Commit

Permalink
alloc exec: fix panics after stream close
Browse files Browse the repository at this point in the history
In #19172 we added a check on websocket errors to see if they were one of
several benign "close" messages. This change inadvertently assumed that other
messages used for close would not implement `HTTPCodedError`. When errors like
the following are received:

> msgpack decode error [pos 0]: io: read/write on closed pipe"

they are sent from the inner loop as though they were a "real" error, but the
channel is already being closed with a "close" message.

This allowed many more attempts to pass thru a previously-undiscovered race
condition in the two goroutines that stream RPC responses to the websocket. When
the input stream returns an error for any reason (for example, the command we're
executing has exited), it will unblock the "outer" goroutine and cause a write
to the websocket. If we're concurrently writing the "close error" discussed
above, this results in a panic from the websocket library.

This changeset includes two fixes:
* Catch "closed pipe" error correctly so that we're not sending unnecessary
  error messages.
* Move all writes to the websocket into the same response streaming
  goroutine. The main handler goroutine will block on a results channel, and the
  response streaming goroutine will send on that channel with the final error when
  it's done so it can be reported to the user.
  • Loading branch information
tgross committed Feb 9, 2024
1 parent 81f8686 commit 498c18f
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 40 deletions.
3 changes: 3 additions & 0 deletions .changelog/19932.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
exec: Fixed a bug in `alloc exec` where closing websocket streams could cause a panic
```
104 changes: 65 additions & 39 deletions command/agent/alloc_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package agent
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -516,7 +517,7 @@ func (s *HTTPServer) allocExec(allocID string, resp http.ResponseWriter, req *ht
return nil, err
}

return s.execStreamImpl(conn, &args)
return s.execStream(conn, &args)
}

// readWsHandshake reads the websocket handshake message and sets
Expand Down Expand Up @@ -552,7 +553,9 @@ type wsHandshakeMessage struct {
AuthToken string `json:"auth_token"`
}

func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExecRequest) (interface{}, error) {
// execStream finds the appropriate RPC handler and then runs the bidirectional
// websocket-to-RPC stream
func (s *HTTPServer) execStream(ws *websocket.Conn, args *cstructs.AllocExecRequest) (any, error) {
allocID := args.AllocID
method := "Allocations.Exec"

Expand All @@ -572,6 +575,13 @@ func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExec
return nil, CodedError(500, handlerErr.Error())
}

return s.execStreamImpl(ws, args, handler)
}

// execStreamImpl is called by execStream with the appropriate RPC handler and
// then runs the bidirectional websocket-to-RPC stream.
func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExecRequest, handler structs.StreamingRpcHandler) (any, error) {

// Create a pipe connecting the (possibly remote) handler to the http response
httpPipe, handlerPipe := net.Pipe()
decoder := codec.NewDecoder(httpPipe, structs.MsgpackHandle)
Expand All @@ -586,33 +596,37 @@ func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExec
// don't close ws - wait to drain messages
}()

// Create a channel that decodes the results
errCh := make(chan HTTPCodedError, 2)
// Create a channel for the final result
resultCh := make(chan HTTPCodedError)

// stream response
// stream response back to the websocket: this should be the only goroutine
// that writes to this websocket connection
go func() {
defer cancel()
errCh := make(chan HTTPCodedError, 2)

// Send the request
if err := encoder.Encode(args); err != nil {
errCh <- CodedError(500, err.Error())
resultCh <- s.execStreamHandleError(ws, CodedError(500, err.Error()))
return
}

go forwardExecInput(encoder, ws, errCh)
// only start this after we've tried to send the initial args
go forwardExecInput(ctx, encoder, ws, errCh)

for {
var res cstructs.StreamErrWrapper
err := decoder.Decode(&res)
if isClosedError(err) {
ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
errCh <- nil
select {
case codedErr := <-errCh:
resultCh <- s.execStreamHandleError(ws, codedErr)
return
default:
}

var res cstructs.StreamErrWrapper
err := decoder.Decode(&res)
if err != nil {
errCh <- CodedError(500, err.Error())
return
continue
}
decoder.Reset(httpPipe)

Expand All @@ -622,39 +636,47 @@ func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExec
code = int(*err.Code)
}
errCh <- CodedError(code, err.Error())
return
continue
}

if err := ws.WriteMessage(websocket.TextMessage, res.Payload); err != nil {
errCh <- CodedError(500, err.Error())
return
continue
}
}
}()

// start streaming request to streaming RPC - returns when streaming completes or errors
// start streaming request to streaming RPC - returns when streaming
// completes or errors
handler(handlerPipe)
// stop streaming background goroutines for streaming - but not websocket activity

// stop streaming background goroutines for streaming - but not websocket
// activity
cancel()
// retrieve any error and/or wait until goroutine stop and close errCh connection before
// closing websocket connection
codedErr := <-errCh

// retrieve any error and/or wait until goroutine stop and close errCh
// connection before closing websocket connection
result := <-resultCh
ws.Close()
return nil, result
}

// execStreamHandleError writes a CloseMessage to the websocket if we get an
// error that isn't a ""close error" caused by the RPC pipe finishing up. Note
// that this should *only* ever be called in the same goroutine as we're
// streaming the responses
func (s *HTTPServer) execStreamHandleError(ws *websocket.Conn, codedErr HTTPCodedError) HTTPCodedError {
// we won't return an error on ws close, but at least make it available in
// the logs so we can trace spurious disconnects
if codedErr != nil {
s.logger.Debug("alloc exec channel closed with error", "error", codedErr)
}
s.logger.Trace("alloc exec channel closed with error", "error", codedErr)

if isClosedError(codedErr) {
codedErr = nil
return nil // we're intentionally throwing this error away
} else if codedErr != nil {
ws.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(toWsCode(codedErr.Code()), codedErr.Error()))
return codedErr
}
ws.Close()

return nil, codedErr
return nil
}

func toWsCode(httpCode int) int {
Expand All @@ -667,30 +689,34 @@ func toWsCode(httpCode int) int {
}
}

// isClosedError checks if the websocket "error" is one of the benign "close" status codes
func isClosedError(err error) bool {
if err == nil {
return false
}

// check if the websocket "error" is one of the benign "close" status codes
if codedErr, ok := err.(HTTPCodedError); ok {
return slices.ContainsFunc([]string{
return errors.Is(err, io.EOF) ||
errors.Is(err, io.ErrClosedPipe) ||
err == io.ErrClosedPipe ||
slices.ContainsFunc([]string{
"closed", // msgpack decode error [pos 0]: io: read/write on closed pipe"
"EOF",
"close 1000", // CLOSE_NORMAL
"close 1001", // CLOSE_GOING_AWAY
"close 1005", // CLOSED_NO_STATUS
}, func(s string) bool { return strings.Contains(codedErr.Error(), s) })
}

return err == io.EOF ||
err == io.ErrClosedPipe ||
strings.Contains(err.Error(), "closed") ||
strings.Contains(err.Error(), "EOF")
}, func(s string) bool { return strings.Contains(err.Error(), s) })
}

// forwardExecInput forwards exec input (e.g. stdin) from websocket connection
// to the streaming RPC connection to client
func forwardExecInput(encoder *codec.Encoder, ws *websocket.Conn, errCh chan<- HTTPCodedError) {
func forwardExecInput(ctx context.Context, encoder *codec.Encoder, ws *websocket.Conn, errCh chan<- HTTPCodedError) {
for {
select {
case <-ctx.Done():
return
default:
}

sf := &drivers.ExecTaskStreamingRequestMsg{}
err := ws.ReadJSON(sf)
if err == io.EOF {
Expand Down
112 changes: 112 additions & 0 deletions command/agent/alloc_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package agent

import (
"archive/tar"
"context"
"fmt"
"io"
"net/http"
Expand All @@ -14,16 +15,22 @@ import (
"strconv"
"strings"
"testing"
"time"

"github.com/golang/snappy"
"github.com/gorilla/websocket"
"github.com/hashicorp/go-msgpack/codec"
"github.com/hashicorp/nomad/acl"
"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/client/allocdir"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/helper/pointer"
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/testutil"
"github.com/shoenig/test"
"github.com/shoenig/test/must"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -1123,3 +1130,108 @@ func TestHTTP_ReadWsHandshake(t *testing.T) {
})
}
}

// TestHTTP_AllocsExecStream_SafeClose verifies that we are safely closing the
// AllocExec stream when we're done without making concurrent writes to the
// websocket that can cause a panic
func TestHTTP_AllocsExecStream_SafeClose(t *testing.T) {
httpTest(t,
func(c *Config) { c.Server.NumSchedulers = pointer.Of(0) },
func(s *TestAgent) {

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)

rpcHandler := mockStreamingRpcHandler(t, [][]byte{
[]byte("one"), []byte("two"), []byte("done!")})

// This replaces the top-level HTTP handler, which is not under test
// here. It will call execStreamImpl using the mock streaming RPC
// handler defined above.
wsHandler := func(w http.ResponseWriter, r *http.Request) {
var upgrader = websocket.Upgrader{}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
must.NoError(t, err, must.Sprint("during ws upgrade"))
return
}
defer conn.Close()

args := cstructs.AllocExecRequest{
AllocID: uuid.Generate(),
Task: "foo",
Cmd: []string{"bar"},
}

_, err = s.Server.execStreamImpl(conn, &args, rpcHandler)
must.NoError(t, err)
}

// Spin up a HTTP server that only handles our websocket
srv := httptest.NewServer(http.HandlerFunc(wsHandler))
t.Cleanup(srv.Close)
u := strings.Replace(srv.URL, "http://", "ws://", 1)
conn, _, err := websocket.DefaultDialer.Dial(u, nil)
must.NoError(t, err, must.Sprint("failed to dial"))
defer conn.Close()

drainResp := func() []string {
resp := []string{}
for {
select {
case <-ctx.Done():
return resp
default:
_, message, err := conn.ReadMessage()
if err != nil {
if !isClosedError(err) {
resp = append(resp, err.Error())
return resp
}
return resp
}
resp = append(resp, string(message))
}
}
}

must.Eq(t, []string{"one", "two", "done!"}, drainResp())
})
}

// mockStreamingRpcHandler returns a function that can stand in for any
// structs.StreamingRpcHandler and streams the slice of payloads before
// closing. It marks a test failure if we get a non-close error.
func mockStreamingRpcHandler(t *testing.T, payloads [][]byte) func(io.ReadWriteCloser) {

return func(conn io.ReadWriteCloser) {

decoder := codec.NewDecoder(conn, structs.MsgpackHandle)
encoder := codec.NewEncoder(conn, structs.MsgpackHandle)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// drain any incoming requests
go func() {
for {
select {
case <-ctx.Done():
return
default:
}
var res cstructs.StreamErrWrapper
err := decoder.Decode(&res)
if !isClosedError(err) {
test.NoError(t, err, test.Sprint("unexpected non-close error"))
}
}
}()

for _, payload := range payloads {
err := encoder.Encode(cstructs.StreamErrWrapper{Payload: payload})
test.NoError(t, err, test.Sprint("could not send RPC payload"))
}
test.NoError(t, conn.Close())
}
}
2 changes: 1 addition & 1 deletion command/agent/job_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ func (s *HTTPServer) jobRunAction(resp http.ResponseWriter, req *http.Request, j
return nil, err
}

return s.execStreamImpl(conn, &args)
return s.execStream(conn, &args)
}

func (s *HTTPServer) jobSubmissionCRUD(resp http.ResponseWriter, req *http.Request, jobID string) (*structs.JobSubmission, error) {
Expand Down

0 comments on commit 498c18f

Please sign in to comment.