Skip to content

Commit

Permalink
backport of commit e986c29
Browse files Browse the repository at this point in the history
  • Loading branch information
tgross authored Feb 12, 2024
1 parent ddb389e commit 455b6b1
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 40 deletions.
3 changes: 3 additions & 0 deletions .changelog/19932.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
exec: Fixed a bug in `alloc exec` where closing websocket streams could cause a panic
```
104 changes: 65 additions & 39 deletions command/agent/alloc_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package agent
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -516,7 +517,7 @@ func (s *HTTPServer) allocExec(allocID string, resp http.ResponseWriter, req *ht
return nil, err
}

return s.execStreamImpl(conn, &args)
return s.execStream(conn, &args)
}

// readWsHandshake reads the websocket handshake message and sets
Expand Down Expand Up @@ -552,7 +553,9 @@ type wsHandshakeMessage struct {
AuthToken string `json:"auth_token"`
}

func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExecRequest) (interface{}, error) {
// execStream finds the appropriate RPC handler and then runs the bidirectional
// websocket-to-RPC stream
func (s *HTTPServer) execStream(ws *websocket.Conn, args *cstructs.AllocExecRequest) (any, error) {
allocID := args.AllocID
method := "Allocations.Exec"

Expand All @@ -572,6 +575,13 @@ func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExec
return nil, CodedError(500, handlerErr.Error())
}

return s.execStreamImpl(ws, args, handler)
}

// execStreamImpl is called by execStream with the appropriate RPC handler and
// then runs the bidirectional websocket-to-RPC stream.
func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExecRequest, handler structs.StreamingRpcHandler) (any, error) {

// Create a pipe connecting the (possibly remote) handler to the http response
httpPipe, handlerPipe := net.Pipe()
decoder := codec.NewDecoder(httpPipe, structs.MsgpackHandle)
Expand All @@ -586,33 +596,37 @@ func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExec
// don't close ws - wait to drain messages
}()

// Create a channel that decodes the results
errCh := make(chan HTTPCodedError, 2)
// Create a channel for the final result
resultCh := make(chan HTTPCodedError, 1)

// stream response
// stream response back to the websocket: this should be the only goroutine
// that writes to this websocket connection
go func() {
defer cancel()
errCh := make(chan HTTPCodedError, 2)

// Send the request
if err := encoder.Encode(args); err != nil {
errCh <- CodedError(500, err.Error())
resultCh <- s.execStreamHandleError(ws, CodedError(500, err.Error()))
return
}

go forwardExecInput(encoder, ws, errCh)
// only start this after we've tried to send the initial args
go forwardExecInput(ctx, encoder, ws, errCh)

for {
var res cstructs.StreamErrWrapper
err := decoder.Decode(&res)
if isClosedError(err) {
ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
errCh <- nil
select {
case codedErr := <-errCh:
resultCh <- s.execStreamHandleError(ws, codedErr)
return
default:
}

var res cstructs.StreamErrWrapper
err := decoder.Decode(&res)
if err != nil {
errCh <- CodedError(500, err.Error())
return
continue
}
decoder.Reset(httpPipe)

Expand All @@ -622,39 +636,47 @@ func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExec
code = int(*err.Code)
}
errCh <- CodedError(code, err.Error())
return
continue
}

if err := ws.WriteMessage(websocket.TextMessage, res.Payload); err != nil {
errCh <- CodedError(500, err.Error())
return
continue
}
}
}()

// start streaming request to streaming RPC - returns when streaming completes or errors
// start streaming request to streaming RPC - returns when streaming
// completes or errors
handler(handlerPipe)
// stop streaming background goroutines for streaming - but not websocket activity

// stop streaming background goroutines for streaming - but not websocket
// activity
cancel()
// retrieve any error and/or wait until goroutine stop and close errCh connection before
// closing websocket connection
codedErr := <-errCh

// retrieve any error and/or wait until goroutine stop and close errCh
// connection before closing websocket connection
result := <-resultCh
ws.Close()
return nil, result
}

// execStreamHandleError writes a CloseMessage to the websocket if we get an
// error that isn't a "close error" caused by the RPC pipe finishing up. Note
// that this should *only* ever be called in the same goroutine as we're
// streaming the responses
func (s *HTTPServer) execStreamHandleError(ws *websocket.Conn, codedErr HTTPCodedError) HTTPCodedError {
// we won't return an error on ws close, but at least make it available in
// the logs so we can trace spurious disconnects
if codedErr != nil {
s.logger.Debug("alloc exec channel closed with error", "error", codedErr)
}
s.logger.Trace("alloc exec channel closed with error", "error", codedErr)

if isClosedError(codedErr) {
codedErr = nil
return nil // we're intentionally throwing this error away
} else if codedErr != nil {
ws.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(toWsCode(codedErr.Code()), codedErr.Error()))
return codedErr
}
ws.Close()

return nil, codedErr
return nil
}

func toWsCode(httpCode int) int {
Expand All @@ -667,30 +689,34 @@ func toWsCode(httpCode int) int {
}
}

// isClosedError checks if the websocket "error" is one of the benign "close" status codes
func isClosedError(err error) bool {
if err == nil {
return false
}

// check if the websocket "error" is one of the benign "close" status codes
if codedErr, ok := err.(HTTPCodedError); ok {
return slices.ContainsFunc([]string{
return errors.Is(err, io.EOF) ||
errors.Is(err, io.ErrClosedPipe) ||
err == io.ErrClosedPipe ||
slices.ContainsFunc([]string{
"closed", // msgpack decode error [pos 0]: io: read/write on closed pipe"
"EOF",
"close 1000", // CLOSE_NORMAL
"close 1001", // CLOSE_GOING_AWAY
"close 1005", // CLOSED_NO_STATUS
}, func(s string) bool { return strings.Contains(codedErr.Error(), s) })
}

return err == io.EOF ||
err == io.ErrClosedPipe ||
strings.Contains(err.Error(), "closed") ||
strings.Contains(err.Error(), "EOF")
}, func(s string) bool { return strings.Contains(err.Error(), s) })
}

// forwardExecInput forwards exec input (e.g. stdin) from websocket connection
// to the streaming RPC connection to client
func forwardExecInput(encoder *codec.Encoder, ws *websocket.Conn, errCh chan<- HTTPCodedError) {
func forwardExecInput(ctx context.Context, encoder *codec.Encoder, ws *websocket.Conn, errCh chan<- HTTPCodedError) {
for {
select {
case <-ctx.Done():
return
default:
}

sf := &drivers.ExecTaskStreamingRequestMsg{}
err := ws.ReadJSON(sf)
if err == io.EOF {
Expand Down
112 changes: 112 additions & 0 deletions command/agent/alloc_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package agent

import (
"archive/tar"
"context"
"fmt"
"io"
"net/http"
Expand All @@ -14,16 +15,22 @@ import (
"strconv"
"strings"
"testing"
"time"

"github.com/golang/snappy"
"github.com/gorilla/websocket"
"github.com/hashicorp/go-msgpack/codec"
"github.com/hashicorp/nomad/acl"
"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/client/allocdir"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/helper/pointer"
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/testutil"
"github.com/shoenig/test"
"github.com/shoenig/test/must"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -1123,3 +1130,108 @@ func TestHTTP_ReadWsHandshake(t *testing.T) {
})
}
}

// TestHTTP_AllocsExecStream_SafeClose verifies that we are safely closing the
// AllocExec stream when we're done without making concurrent writes to the
// websocket that can cause a panic
func TestHTTP_AllocsExecStream_SafeClose(t *testing.T) {
httpTest(t,
func(c *Config) { c.Server.NumSchedulers = pointer.Of(0) },
func(s *TestAgent) {

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)

rpcHandler := mockStreamingRpcHandler(t, [][]byte{
[]byte("one"), []byte("two"), []byte("done!")})

// This replaces the top-level HTTP handler, which is not under test
// here. It will call execStreamImpl using the mock streaming RPC
// handler defined above.
wsHandler := func(w http.ResponseWriter, r *http.Request) {
var upgrader = websocket.Upgrader{}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
must.NoError(t, err, must.Sprint("during ws upgrade"))
return
}
defer conn.Close()

args := cstructs.AllocExecRequest{
AllocID: uuid.Generate(),
Task: "foo",
Cmd: []string{"bar"},
}

_, err = s.Server.execStreamImpl(conn, &args, rpcHandler)
must.NoError(t, err)
}

// Spin up a HTTP server that only handles our websocket
srv := httptest.NewServer(http.HandlerFunc(wsHandler))
t.Cleanup(srv.Close)
u := strings.Replace(srv.URL, "http://", "ws://", 1)
conn, _, err := websocket.DefaultDialer.Dial(u, nil)
must.NoError(t, err, must.Sprint("failed to dial"))
defer conn.Close()

drainResp := func() []string {
resp := []string{}
for {
select {
case <-ctx.Done():
return resp
default:
_, message, err := conn.ReadMessage()
if err != nil {
if !isClosedError(err) {
resp = append(resp, err.Error())
return resp
}
return resp
}
resp = append(resp, string(message))
}
}
}

must.Eq(t, []string{"one", "two", "done!"}, drainResp())
})
}

// mockStreamingRpcHandler returns a function that can stand in for any
// structs.StreamingRpcHandler and streams the slice of payloads before
// closing. It marks a test failure if we get a non-close error.
func mockStreamingRpcHandler(t *testing.T, payloads [][]byte) func(io.ReadWriteCloser) {

return func(conn io.ReadWriteCloser) {

decoder := codec.NewDecoder(conn, structs.MsgpackHandle)
encoder := codec.NewEncoder(conn, structs.MsgpackHandle)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// drain any incoming requests
go func() {
for {
select {
case <-ctx.Done():
return
default:
}
var res cstructs.StreamErrWrapper
err := decoder.Decode(&res)
if !isClosedError(err) {
test.NoError(t, err, test.Sprint("unexpected non-close error"))
}
}
}()

for _, payload := range payloads {
err := encoder.Encode(cstructs.StreamErrWrapper{Payload: payload})
test.NoError(t, err, test.Sprint("could not send RPC payload"))
}
test.NoError(t, conn.Close())
}
}
2 changes: 1 addition & 1 deletion command/agent/job_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ func (s *HTTPServer) jobRunAction(resp http.ResponseWriter, req *http.Request, j
return nil, err
}

return s.execStreamImpl(conn, &args)
return s.execStream(conn, &args)
}

func (s *HTTPServer) jobSubmissionCRUD(resp http.ResponseWriter, req *http.Request, jobID string) (*structs.JobSubmission, error) {
Expand Down

0 comments on commit 455b6b1

Please sign in to comment.