Skip to content

Commit

Permalink
feat: toolbox exec session management (#1732)
Browse files Browse the repository at this point in the history
Signed-off-by: Toma Puljak <[email protected]>
  • Loading branch information
Tpuljak authored Jan 20, 2025
1 parent 009c100 commit 2237af5
Show file tree
Hide file tree
Showing 32 changed files with 4,102 additions and 110 deletions.
35 changes: 2 additions & 33 deletions pkg/agent/ssh/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ import (
"io"
"os"
"os/exec"
"strings"
"syscall"
"unsafe"

"github.com/creack/pty"
"github.com/daytonaio/daytona/pkg/agent/ssh/config"
"github.com/daytonaio/daytona/pkg/common"
"github.com/gliderlabs/ssh"
"github.com/pkg/sftp"
"golang.org/x/sys/unix"
Expand Down Expand Up @@ -81,7 +81,7 @@ func (s *Server) Start() error {
}

func (s *Server) handlePty(session ssh.Session, ptyReq ssh.Pty, winCh <-chan ssh.Window) {
shell := s.getShell()
shell := common.GetShell()
cmd := exec.Command(shell)

cmd.Dir = s.ProjectDir
Expand Down Expand Up @@ -233,37 +233,6 @@ func (s *Server) osSignalFrom(sig ssh.Signal) os.Signal {
}
}

func (s *Server) getShell() string {
out, err := exec.Command("sh", "-c", "grep '^[^#]' /etc/shells").Output()
if err != nil {
return "sh"
}

if strings.Contains(string(out), "/usr/bin/zsh") {
return "/usr/bin/zsh"
}

if strings.Contains(string(out), "/bin/zsh") {
return "/bin/zsh"
}

if strings.Contains(string(out), "/usr/bin/bash") {
return "/usr/bin/bash"
}

if strings.Contains(string(out), "/bin/bash") {
return "/bin/bash"
}

shellEnv, shellSet := os.LookupEnv("SHELL")

if shellSet {
return shellEnv
}

return "sh"
}

func (s *Server) sftpHandler(session ssh.Session) {
debugStream := io.Discard
serverOptions := []sftp.ServerOption{
Expand Down
87 changes: 45 additions & 42 deletions pkg/agent/toolbox/process/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,63 +14,66 @@ import (
"github.com/gin-gonic/gin"
)

func ExecuteCommand(c *gin.Context) {
var request ExecuteRequest
if err := c.ShouldBindJSON(&request); err != nil {
c.AbortWithError(400, errors.New("command is required"))
return
}
func ExecuteCommand(projectDir string) func(c *gin.Context) {
return func(c *gin.Context) {
var request ExecuteRequest
if err := c.ShouldBindJSON(&request); err != nil {
c.AbortWithError(400, errors.New("command is required"))
return
}

cmdParts := parseCommand(request.Command)
if len(cmdParts) == 0 {
c.AbortWithError(400, errors.New("empty command"))
return
}
cmdParts := parseCommand(request.Command)
if len(cmdParts) == 0 {
c.AbortWithError(400, errors.New("empty command"))
return
}

cmd := exec.Command(cmdParts[0], cmdParts[1:]...)
cmd := exec.Command(cmdParts[0], cmdParts[1:]...)
cmd.Dir = projectDir

// set maximum execution time
timeout := 10 * time.Second
if request.Timeout != nil && *request.Timeout > 0 {
timeout = time.Duration(*request.Timeout) * time.Second
}
// set maximum execution time
timeout := 10 * time.Second
if request.Timeout != nil && *request.Timeout > 0 {
timeout = time.Duration(*request.Timeout) * time.Second
}

timeoutReached := false
timer := time.AfterFunc(timeout, func() {
timeoutReached = true
if cmd.Process != nil {
// kill the process group
err := cmd.Process.Kill()
if err != nil {
log.Error(err)
return
}
}
})
defer timer.Stop()

timeoutReached := false
timer := time.AfterFunc(timeout, func() {
timeoutReached = true
if cmd.Process != nil {
// kill the process group
err := cmd.Process.Kill()
if err != nil {
log.Error(err)
output, err := cmd.CombinedOutput()
if err != nil {
if timeoutReached {
c.AbortWithError(408, errors.New("command execution timeout"))
return
}
c.AbortWithError(400, err)
return
}
})
defer timer.Stop()

output, err := cmd.CombinedOutput()
if err != nil {
if timeoutReached {
c.AbortWithError(408, errors.New("command execution timeout"))
if cmd.ProcessState == nil {
c.JSON(200, ExecuteResponse{
Code: -1,
Result: string(output),
})
return
}
c.AbortWithError(400, err)
return
}

if cmd.ProcessState == nil {
c.JSON(200, ExecuteResponse{
Code: -1,
Code: cmd.ProcessState.ExitCode(),
Result: string(output),
})
return
}

c.JSON(200, ExecuteResponse{
Code: cmd.ProcessState.ExitCode(),
Result: string(output),
})
}

// parseCommand splits a command string properly handling quotes
Expand Down
153 changes: 153 additions & 0 deletions pkg/agent/toolbox/process/session/execute.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
// Copyright 2024 Daytona Platforms Inc.
// SPDX-License-Identifier: Apache-2.0

package session

import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"

"github.com/daytonaio/daytona/internal/util"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)

func SessionExecuteCommand(configDir string) func(c *gin.Context) {
return func(c *gin.Context) {
sessionId := c.Param("sessionId")

var request SessionExecuteRequest
if err := c.ShouldBindJSON(&request); err != nil {
c.AbortWithError(http.StatusBadRequest, err)
return
}

session, ok := sessions[sessionId]
if !ok {
c.AbortWithError(http.StatusNotFound, errors.New("session not found"))
return
}

var cmdId *string
var logFile *os.File

cmdId = util.Pointer(uuid.NewString())

command := &Command{
Id: *cmdId,
Command: request.Command,
}
session.commands[*cmdId] = command

logFilePath := command.LogFilePath(session.Dir(configDir))

err := os.MkdirAll(filepath.Dir(logFilePath), 0755)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}

logFile, err = os.Create(logFilePath)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}

cmdToExec := fmt.Sprintf("%s > %s 2>&1 ; echo \"DTN_EXIT: $?\" >> %s\n", request.Command, logFile.Name(), logFile.Name())

type execResult struct {
out string
err error
exitCode *int
}
resultChan := make(chan execResult)

go func() {
out := ""
defer close(resultChan)

logChan := make(chan []byte)
errChan := make(chan error)

go util.ReadLog(context.Background(), logFile, true, logChan, errChan)

defer logFile.Close()

for {
select {
case logEntry := <-logChan:
logEntry = bytes.Trim(logEntry, "\x00")
if len(logEntry) == 0 {
continue
}
exitCode, line := extractExitCode(string(logEntry))
out += line

if exitCode != nil {
sessions[sessionId].commands[*cmdId].ExitCode = exitCode
resultChan <- execResult{out: out, exitCode: exitCode, err: nil}
return
}
case err := <-errChan:
if err != nil {
resultChan <- execResult{out: out, exitCode: nil, err: err}
return
}
}
}
}()

_, err = session.stdinWriter.Write([]byte(cmdToExec))
if err != nil {
c.AbortWithError(http.StatusBadRequest, err)
return
}

if request.Async {
c.JSON(http.StatusAccepted, SessionExecuteResponse{
CommandId: cmdId,
})
return
}

result := <-resultChan
if result.err != nil {
c.AbortWithError(http.StatusBadRequest, result.err)
return
}

c.JSON(http.StatusOK, SessionExecuteResponse{
CommandId: cmdId,
Output: &result.out,
ExitCode: result.exitCode,
})
}
}

func extractExitCode(output string) (*int, string) {
var exitCode *int

regex := regexp.MustCompile(`DTN_EXIT: (\d+)\n`)
matches := regex.FindStringSubmatch(output)
if len(matches) > 1 {
code, err := strconv.Atoi(matches[1])
if err != nil {
return nil, output
}
exitCode = &code
}

if exitCode != nil {
output = strings.Replace(output, fmt.Sprintf("DTN_EXIT: %d\n", *exitCode), "", 1)
}

return exitCode, output
}
74 changes: 74 additions & 0 deletions pkg/agent/toolbox/process/session/log.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright 2024 Daytona Platforms Inc.
// SPDX-License-Identifier: Apache-2.0

package session

import (
"errors"
"net/http"
"os"

"github.com/daytonaio/daytona/internal/util"
"github.com/daytonaio/daytona/pkg/api/controllers/log"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)

func GetSessionCommandLogs(configDir string) func(c *gin.Context) {
return func(c *gin.Context) {
sessionId := c.Param("sessionId")
cmdId := c.Param("commandId")

session, ok := sessions[sessionId]
if !ok {
c.AbortWithError(http.StatusNotFound, errors.New("session not found"))
return
}

command, ok := sessions[sessionId].commands[cmdId]
if !ok {
c.AbortWithError(http.StatusNotFound, errors.New("command not found"))
return
}

path := command.LogFilePath(session.Dir(configDir))

if c.Request.Header.Get("Upgrade") == "websocket" {
logFile, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
c.AbortWithError(http.StatusNotFound, errors.New("log file not found"))
return
}
c.AbortWithError(http.StatusInternalServerError, err)
return
}
defer logFile.Close()
log.ReadLog(c, logFile, util.ReadLog, func(conn *websocket.Conn, messages chan []byte, errors chan error) {
for {
msg := <-messages
_, output := extractExitCode(string(msg))
err := conn.WriteMessage(websocket.TextMessage, []byte(output))
if err != nil {
errors <- err
break
}
}
})
return
}

content, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
c.AbortWithError(http.StatusNotFound, errors.New("log file not found"))
return
}
c.AbortWithError(http.StatusInternalServerError, err)
return
}

_, output := extractExitCode(string(content))
c.String(http.StatusOK, output)
}
}
Loading

0 comments on commit 2237af5

Please sign in to comment.