Skip to content

Commit

Permalink
Merge pull request #10657 from hashicorp/b-alloc-exec-closing
Browse files Browse the repository at this point in the history
Handle `nomad exec` termination events in order
  • Loading branch information
Mahmood Ali authored May 25, 2021
2 parents 99b8e31 + b294f63 commit e694000
Show file tree
Hide file tree
Showing 6 changed files with 497 additions and 397 deletions.
203 changes: 12 additions & 191 deletions api/allocations.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,10 @@ package api

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"sort"
"strconv"
"sync"
"time"

"github.com/gorilla/websocket"
)

var (
Expand Down Expand Up @@ -87,195 +81,22 @@ func (a *Allocations) Exec(ctx context.Context,
stdin io.Reader, stdout, stderr io.Writer,
terminalSizeCh <-chan TerminalSize, q *QueryOptions) (exitCode int, err error) {

ctx, cancelFn := context.WithCancel(ctx)
defer cancelFn()

errCh := make(chan error, 4)

sender, output := a.execFrames(ctx, alloc, task, tty, command, errCh, q)

select {
case err := <-errCh:
return -2, err
default:
}

// Errors resulting from sending input (in goroutines) are silently dropped.
// To mitigate this, extra care is needed to distinguish between actual send errors
// and from send errors due to command terminating and our race to detect failures.
// If we have an actual network failure or send a bad input, we'd get an
// error in the reading side of websocket.

go func() {

bytes := make([]byte, 2048)
for {
if ctx.Err() != nil {
return
}

input := ExecStreamingInput{Stdin: &ExecStreamingIOOperation{}}

n, err := stdin.Read(bytes)

// always send data if we read some
if n != 0 {
input.Stdin.Data = bytes[:n]
sender(&input)
}

// then handle error
if err == io.EOF {
// if n != 0, send data and we'll get n = 0 on next read
if n == 0 {
input.Stdin.Close = true
sender(&input)
return
}
} else if err != nil {
errCh <- err
return
}
}
}()

// forwarding terminal size
go func() {
for {
resizeInput := ExecStreamingInput{}

select {
case <-ctx.Done():
return
case size, ok := <-terminalSizeCh:
if !ok {
return
}
resizeInput.TTYSize = &size
sender(&resizeInput)
}

}
}()

// send a heartbeat every 10 seconds
go func() {
for {
select {
case <-ctx.Done():
return
// heartbeat message
case <-time.After(10 * time.Second):
sender(&execStreamingInputHeartbeat)
}

}
}()

for {
select {
case err := <-errCh:
// drop websocket code, not relevant to user
if wsErr, ok := err.(*websocket.CloseError); ok && wsErr.Text != "" {
return -2, errors.New(wsErr.Text)
}
return -2, err
case <-ctx.Done():
return -2, ctx.Err()
case frame, ok := <-output:
if !ok {
return -2, errors.New("disconnected without receiving the exit code")
}

switch {
case frame.Stdout != nil:
if len(frame.Stdout.Data) != 0 {
stdout.Write(frame.Stdout.Data)
}
// don't really do anything if stdout is closing
case frame.Stderr != nil:
if len(frame.Stderr.Data) != 0 {
stderr.Write(frame.Stderr.Data)
}
// don't really do anything if stderr is closing
case frame.Exited && frame.Result != nil:
return frame.Result.ExitCode, nil
default:
// noop - heartbeat
}
}
}
}
s := &execSession{
client: a.client,
alloc: alloc,
task: task,
tty: tty,
command: command,

func (a *Allocations) execFrames(ctx context.Context, alloc *Allocation, task string, tty bool, command []string,
errCh chan<- error, q *QueryOptions) (sendFn func(*ExecStreamingInput) error, output <-chan *ExecStreamingOutput) {
nodeClient, _ := a.client.GetNodeClientWithTimeout(alloc.NodeID, ClientConnTimeout, q)
stdin: stdin,
stdout: stdout,
stderr: stderr,

if q == nil {
q = &QueryOptions{}
}
if q.Params == nil {
q.Params = make(map[string]string)
terminalSizeCh: terminalSizeCh,
q: q,
}

commandBytes, err := json.Marshal(command)
if err != nil {
errCh <- fmt.Errorf("failed to marshal command: %s", err)
return nil, nil
}

q.Params["tty"] = strconv.FormatBool(tty)
q.Params["task"] = task
q.Params["command"] = string(commandBytes)

reqPath := fmt.Sprintf("/v1/client/allocation/%s/exec", alloc.ID)

var conn *websocket.Conn

if nodeClient != nil {
conn, _, _ = nodeClient.websocket(reqPath, q)
}

if conn == nil {
conn, _, err = a.client.websocket(reqPath, q)
if err != nil {
errCh <- err
return nil, nil
}
}

// Create the output channel
frames := make(chan *ExecStreamingOutput, 10)

go func() {
defer conn.Close()
for ctx.Err() == nil {

// Decode the next frame
var frame ExecStreamingOutput
err := conn.ReadJSON(&frame)
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
close(frames)
return
} else if err != nil {
errCh <- err
return
}

frames <- &frame
}
}()

var sendLock sync.Mutex
send := func(v *ExecStreamingInput) error {
sendLock.Lock()
defer sendLock.Unlock()

return conn.WriteJSON(v)
}

return send, frames

return s.run(ctx)
}

func (a *Allocations) Stats(alloc *Allocation, q *QueryOptions) (*AllocResourceUsage, error) {
Expand Down
Loading

0 comments on commit e694000

Please sign in to comment.