Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport of alloc exec: fix panics after stream close into release/1.7.x #19951

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .changelog/19932.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
exec: Fixed a bug in `alloc exec` where closing websocket streams could cause a panic
```
104 changes: 65 additions & 39 deletions command/agent/alloc_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package agent
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -516,7 +517,7 @@ func (s *HTTPServer) allocExec(allocID string, resp http.ResponseWriter, req *ht
return nil, err
}

return s.execStreamImpl(conn, &args)
return s.execStream(conn, &args)
}

// readWsHandshake reads the websocket handshake message and sets
Expand Down Expand Up @@ -552,7 +553,9 @@ type wsHandshakeMessage struct {
AuthToken string `json:"auth_token"`
}

func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExecRequest) (interface{}, error) {
// execStream finds the appropriate RPC handler and then runs the bidirectional
// websocket-to-RPC stream
func (s *HTTPServer) execStream(ws *websocket.Conn, args *cstructs.AllocExecRequest) (any, error) {
allocID := args.AllocID
method := "Allocations.Exec"

Expand All @@ -572,6 +575,13 @@ func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExec
return nil, CodedError(500, handlerErr.Error())
}

return s.execStreamImpl(ws, args, handler)
}

// execStreamImpl is called by execStream with the appropriate RPC handler and
// then runs the bidirectional websocket-to-RPC stream.
func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExecRequest, handler structs.StreamingRpcHandler) (any, error) {

// Create a pipe connecting the (possibly remote) handler to the http response
httpPipe, handlerPipe := net.Pipe()
decoder := codec.NewDecoder(httpPipe, structs.MsgpackHandle)
Expand All @@ -586,33 +596,37 @@ func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExec
// don't close ws - wait to drain messages
}()

// Create a channel that decodes the results
errCh := make(chan HTTPCodedError, 2)
// Create a channel for the final result
resultCh := make(chan HTTPCodedError, 1)

// stream response
// stream response back to the websocket: this should be the only goroutine
// that writes to this websocket connection
go func() {
defer cancel()
errCh := make(chan HTTPCodedError, 2)

// Send the request
if err := encoder.Encode(args); err != nil {
errCh <- CodedError(500, err.Error())
resultCh <- s.execStreamHandleError(ws, CodedError(500, err.Error()))
return
}

go forwardExecInput(encoder, ws, errCh)
// only start this after we've tried to send the initial args
go forwardExecInput(ctx, encoder, ws, errCh)

for {
var res cstructs.StreamErrWrapper
err := decoder.Decode(&res)
if isClosedError(err) {
ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
errCh <- nil
select {
case codedErr := <-errCh:
resultCh <- s.execStreamHandleError(ws, codedErr)
return
default:
}

var res cstructs.StreamErrWrapper
err := decoder.Decode(&res)
if err != nil {
errCh <- CodedError(500, err.Error())
return
continue
}
decoder.Reset(httpPipe)

Expand All @@ -622,39 +636,47 @@ func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExec
code = int(*err.Code)
}
errCh <- CodedError(code, err.Error())
return
continue
}

if err := ws.WriteMessage(websocket.TextMessage, res.Payload); err != nil {
errCh <- CodedError(500, err.Error())
return
continue
}
}
}()

// start streaming request to streaming RPC - returns when streaming completes or errors
// start streaming request to streaming RPC - returns when streaming
// completes or errors
handler(handlerPipe)
// stop streaming background goroutines for streaming - but not websocket activity

// stop streaming background goroutines for streaming - but not websocket
// activity
cancel()
// retrieve any error and/or wait until goroutine stop and close errCh connection before
// closing websocket connection
codedErr := <-errCh

// retrieve any error and/or wait until goroutine stop and close errCh
// connection before closing websocket connection
result := <-resultCh
ws.Close()
return nil, result
}

// execStreamHandleError writes a CloseMessage to the websocket if we get an
// error that isn't a "close error" caused by the RPC pipe finishing up. Note
// that this should *only* ever be called in the same goroutine as we're
// streaming the responses
func (s *HTTPServer) execStreamHandleError(ws *websocket.Conn, codedErr HTTPCodedError) HTTPCodedError {
// we won't return an error on ws close, but at least make it available in
// the logs so we can trace spurious disconnects
if codedErr != nil {
s.logger.Debug("alloc exec channel closed with error", "error", codedErr)
}
s.logger.Trace("alloc exec channel closed with error", "error", codedErr)

if isClosedError(codedErr) {
codedErr = nil
return nil // we're intentionally throwing this error away
} else if codedErr != nil {
ws.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(toWsCode(codedErr.Code()), codedErr.Error()))
return codedErr
}
ws.Close()

return nil, codedErr
return nil
}

func toWsCode(httpCode int) int {
Expand All @@ -667,30 +689,34 @@ func toWsCode(httpCode int) int {
}
}

// isClosedError checks if the websocket "error" is one of the benign "close" status codes
func isClosedError(err error) bool {
if err == nil {
return false
}

// check if the websocket "error" is one of the benign "close" status codes
if codedErr, ok := err.(HTTPCodedError); ok {
return slices.ContainsFunc([]string{
return errors.Is(err, io.EOF) ||
errors.Is(err, io.ErrClosedPipe) ||
err == io.ErrClosedPipe ||
slices.ContainsFunc([]string{
"closed", // msgpack decode error [pos 0]: io: read/write on closed pipe"
"EOF",
"close 1000", // CLOSE_NORMAL
"close 1001", // CLOSE_GOING_AWAY
"close 1005", // CLOSED_NO_STATUS
}, func(s string) bool { return strings.Contains(codedErr.Error(), s) })
}

return err == io.EOF ||
err == io.ErrClosedPipe ||
strings.Contains(err.Error(), "closed") ||
strings.Contains(err.Error(), "EOF")
}, func(s string) bool { return strings.Contains(err.Error(), s) })
}

// forwardExecInput forwards exec input (e.g. stdin) from websocket connection
// to the streaming RPC connection to client
func forwardExecInput(encoder *codec.Encoder, ws *websocket.Conn, errCh chan<- HTTPCodedError) {
func forwardExecInput(ctx context.Context, encoder *codec.Encoder, ws *websocket.Conn, errCh chan<- HTTPCodedError) {
for {
select {
case <-ctx.Done():
return
default:
}

sf := &drivers.ExecTaskStreamingRequestMsg{}
err := ws.ReadJSON(sf)
if err == io.EOF {
Expand Down
112 changes: 112 additions & 0 deletions command/agent/alloc_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package agent

import (
"archive/tar"
"context"
"fmt"
"io"
"net/http"
Expand All @@ -14,16 +15,22 @@ import (
"strconv"
"strings"
"testing"
"time"

"github.com/golang/snappy"
"github.com/gorilla/websocket"
"github.com/hashicorp/go-msgpack/codec"
"github.com/hashicorp/nomad/acl"
"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/client/allocdir"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/helper/pointer"
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/testutil"
"github.com/shoenig/test"
"github.com/shoenig/test/must"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -1123,3 +1130,108 @@ func TestHTTP_ReadWsHandshake(t *testing.T) {
})
}
}

// TestHTTP_AllocsExecStream_SafeClose verifies that we are safely closing the
// AllocExec stream when we're done without making concurrent writes to the
// websocket that can cause a panic
func TestHTTP_AllocsExecStream_SafeClose(t *testing.T) {
httpTest(t,
func(c *Config) { c.Server.NumSchedulers = pointer.Of(0) },
func(s *TestAgent) {

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)

rpcHandler := mockStreamingRpcHandler(t, [][]byte{
[]byte("one"), []byte("two"), []byte("done!")})

// This replaces the top-level HTTP handler, which is not under test
// here. It will call execStreamImpl using the mock streaming RPC
// handler defined above.
wsHandler := func(w http.ResponseWriter, r *http.Request) {
var upgrader = websocket.Upgrader{}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
must.NoError(t, err, must.Sprint("during ws upgrade"))
return
}
defer conn.Close()

args := cstructs.AllocExecRequest{
AllocID: uuid.Generate(),
Task: "foo",
Cmd: []string{"bar"},
}

_, err = s.Server.execStreamImpl(conn, &args, rpcHandler)
must.NoError(t, err)
}

// Spin up a HTTP server that only handles our websocket
srv := httptest.NewServer(http.HandlerFunc(wsHandler))
t.Cleanup(srv.Close)
u := strings.Replace(srv.URL, "http://", "ws://", 1)
conn, _, err := websocket.DefaultDialer.Dial(u, nil)
must.NoError(t, err, must.Sprint("failed to dial"))
defer conn.Close()

drainResp := func() []string {
resp := []string{}
for {
select {
case <-ctx.Done():
return resp
default:
_, message, err := conn.ReadMessage()
if err != nil {
if !isClosedError(err) {
resp = append(resp, err.Error())
return resp
}
return resp
}
resp = append(resp, string(message))
}
}
}

must.Eq(t, []string{"one", "two", "done!"}, drainResp())
})
}

// mockStreamingRpcHandler returns a function that can stand in for any
// structs.StreamingRpcHandler and streams the slice of payloads before
// closing. It marks a test failure if we get a non-close error.
func mockStreamingRpcHandler(t *testing.T, payloads [][]byte) func(io.ReadWriteCloser) {

return func(conn io.ReadWriteCloser) {

decoder := codec.NewDecoder(conn, structs.MsgpackHandle)
encoder := codec.NewEncoder(conn, structs.MsgpackHandle)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// drain any incoming requests
go func() {
for {
select {
case <-ctx.Done():
return
default:
}
var res cstructs.StreamErrWrapper
err := decoder.Decode(&res)
if !isClosedError(err) {
test.NoError(t, err, test.Sprint("unexpected non-close error"))
}
}
}()

for _, payload := range payloads {
err := encoder.Encode(cstructs.StreamErrWrapper{Payload: payload})
test.NoError(t, err, test.Sprint("could not send RPC payload"))
}
test.NoError(t, conn.Close())
}
}
2 changes: 1 addition & 1 deletion command/agent/job_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ func (s *HTTPServer) jobRunAction(resp http.ResponseWriter, req *http.Request, j
return nil, err
}

return s.execStreamImpl(conn, &args)
return s.execStream(conn, &args)
}

func (s *HTTPServer) jobSubmissionCRUD(resp http.ResponseWriter, req *http.Request, jobID string) (*structs.JobSubmission, error) {
Expand Down
Loading