Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle nomad exec termination events in order #10657

Merged
merged 4 commits into from
May 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 12 additions & 191 deletions api/allocations.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,10 @@ package api

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"sort"
"strconv"
"sync"
"time"

"github.com/gorilla/websocket"
)

var (
Expand Down Expand Up @@ -87,195 +81,22 @@ func (a *Allocations) Exec(ctx context.Context,
stdin io.Reader, stdout, stderr io.Writer,
terminalSizeCh <-chan TerminalSize, q *QueryOptions) (exitCode int, err error) {

ctx, cancelFn := context.WithCancel(ctx)
defer cancelFn()

errCh := make(chan error, 4)

sender, output := a.execFrames(ctx, alloc, task, tty, command, errCh, q)

select {
case err := <-errCh:
return -2, err
default:
}

// Errors resulting from sending input (in goroutines) are silently dropped.
// To mitigate this, extra care is needed to distinguish between actual send errors
// and from send errors due to command terminating and our race to detect failures.
// If we have an actual network failure or send a bad input, we'd get an
// error in the reading side of websocket.

go func() {

bytes := make([]byte, 2048)
for {
if ctx.Err() != nil {
return
}

input := ExecStreamingInput{Stdin: &ExecStreamingIOOperation{}}

n, err := stdin.Read(bytes)

// always send data if we read some
if n != 0 {
input.Stdin.Data = bytes[:n]
sender(&input)
}

// then handle error
if err == io.EOF {
// if n != 0, send data and we'll get n = 0 on next read
if n == 0 {
input.Stdin.Close = true
sender(&input)
return
}
} else if err != nil {
errCh <- err
return
}
}
}()

// forwarding terminal size
go func() {
for {
resizeInput := ExecStreamingInput{}

select {
case <-ctx.Done():
return
case size, ok := <-terminalSizeCh:
if !ok {
return
}
resizeInput.TTYSize = &size
sender(&resizeInput)
}

}
}()

// send a heartbeat every 10 seconds
go func() {
for {
select {
case <-ctx.Done():
return
// heartbeat message
case <-time.After(10 * time.Second):
sender(&execStreamingInputHeartbeat)
}

}
}()

for {
select {
case err := <-errCh:
// drop websocket code, not relevant to user
if wsErr, ok := err.(*websocket.CloseError); ok && wsErr.Text != "" {
return -2, errors.New(wsErr.Text)
}
return -2, err
case <-ctx.Done():
return -2, ctx.Err()
case frame, ok := <-output:
if !ok {
return -2, errors.New("disconnected without receiving the exit code")
}

switch {
case frame.Stdout != nil:
if len(frame.Stdout.Data) != 0 {
stdout.Write(frame.Stdout.Data)
}
// don't really do anything if stdout is closing
case frame.Stderr != nil:
if len(frame.Stderr.Data) != 0 {
stderr.Write(frame.Stderr.Data)
}
// don't really do anything if stderr is closing
case frame.Exited && frame.Result != nil:
return frame.Result.ExitCode, nil
default:
// noop - heartbeat
}
}
}
}
s := &execSession{
client: a.client,
alloc: alloc,
task: task,
tty: tty,
command: command,

func (a *Allocations) execFrames(ctx context.Context, alloc *Allocation, task string, tty bool, command []string,
errCh chan<- error, q *QueryOptions) (sendFn func(*ExecStreamingInput) error, output <-chan *ExecStreamingOutput) {
nodeClient, _ := a.client.GetNodeClientWithTimeout(alloc.NodeID, ClientConnTimeout, q)
stdin: stdin,
stdout: stdout,
stderr: stderr,

if q == nil {
q = &QueryOptions{}
}
if q.Params == nil {
q.Params = make(map[string]string)
terminalSizeCh: terminalSizeCh,
q: q,
}

commandBytes, err := json.Marshal(command)
if err != nil {
errCh <- fmt.Errorf("failed to marshal command: %s", err)
return nil, nil
}

q.Params["tty"] = strconv.FormatBool(tty)
q.Params["task"] = task
q.Params["command"] = string(commandBytes)

reqPath := fmt.Sprintf("/v1/client/allocation/%s/exec", alloc.ID)

var conn *websocket.Conn

if nodeClient != nil {
conn, _, _ = nodeClient.websocket(reqPath, q)
}

if conn == nil {
conn, _, err = a.client.websocket(reqPath, q)
if err != nil {
errCh <- err
return nil, nil
}
}

// Create the output channel
frames := make(chan *ExecStreamingOutput, 10)

go func() {
defer conn.Close()
for ctx.Err() == nil {

// Decode the next frame
var frame ExecStreamingOutput
err := conn.ReadJSON(&frame)
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
close(frames)
return
} else if err != nil {
errCh <- err
return
}

frames <- &frame
}
}()

var sendLock sync.Mutex
send := func(v *ExecStreamingInput) error {
sendLock.Lock()
defer sendLock.Unlock()

return conn.WriteJSON(v)
}

return send, frames

return s.run(ctx)
}

func (a *Allocations) Stats(alloc *Allocation, q *QueryOptions) (*AllocResourceUsage, error) {
Expand Down
Loading