From 7ec5e14ab1f37ddcc34d1b08515d3d0300c14684 Mon Sep 17 00:00:00 2001 From: Adam Babik Date: Mon, 2 Dec 2024 19:03:15 +0100 Subject: [PATCH 1/5] Research performance of remote exec in runnerv2 --- internal/runnerv2service/convert.go | 30 + internal/runnerv2service/execution.go | 130 +-- internal/runnerv2service/service.go | 1 + internal/runnerv2service/service_execute.go | 16 +- .../runnerv2service/service_execute_test.go | 760 +++++++++--------- internal/runnerv2service/service_sessions.go | 32 +- 6 files changed, 473 insertions(+), 496 deletions(-) create mode 100644 internal/runnerv2service/convert.go diff --git a/internal/runnerv2service/convert.go b/internal/runnerv2service/convert.go new file mode 100644 index 000000000..61cf39142 --- /dev/null +++ b/internal/runnerv2service/convert.go @@ -0,0 +1,30 @@ +package runnerv2service + +import ( + "github.com/stateful/runme/v3/internal/session" + runnerv2 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/runner/v2" + "github.com/stateful/runme/v3/pkg/project" +) + +func convertSessionToProtoSession(sess *session.Session) *runnerv2.Session { + return &runnerv2.Session{ + Id: sess.ID, + Env: sess.GetAllEnv(), + // Metadata: sess.Metadata, + } +} + +// TODO(adamb): this function should not return nil project and nil error at the same time. +func convertProtoProjectToProject(runnerProj *runnerv2.Project) (*project.Project, error) { + if runnerProj == nil { + return nil, nil + } + + opts := project.DefaultProjectOptions[:] + + if runnerProj.EnvLoadOrder != nil { + opts = append(opts, project.WithEnvFilesReadOrder(runnerProj.EnvLoadOrder)) + } + + return project.NewDirProject(runnerProj.Root, opts...) +} diff --git a/internal/runnerv2service/execution.go b/internal/runnerv2service/execution.go index 38123b10e..663845ace 100644 --- a/internal/runnerv2service/execution.go +++ b/internal/runnerv2service/execution.go @@ -34,66 +34,8 @@ const ( var opininatedEnvVarNamingRegexp = regexp.MustCompile(`^[A-Z_][A-Z0-9_]{1}[A-Z0-9_]*[A-Z][A-Z0-9_]*$`) -type buffer struct { - mu *sync.Mutex - // +checklocks:mu - b *bytes.Buffer - closed *atomic.Bool - close chan struct{} - more chan struct{} -} - -var _ io.WriteCloser = (*buffer)(nil) - -func newBuffer() *buffer { - return &buffer{ - mu: &sync.Mutex{}, - b: bytes.NewBuffer(make([]byte, 0, msgBufferSize)), - closed: &atomic.Bool{}, - close: make(chan struct{}), - more: make(chan struct{}), - } -} - -func (b *buffer) Write(p []byte) (int, error) { - if b.closed.Load() { - return 0, errors.New("closed") - } - - b.mu.Lock() - n, err := b.b.Write(p) - b.mu.Unlock() - - select { - case b.more <- struct{}{}: - default: - } - - return n, err -} - -func (b *buffer) Close() error { - if b.closed.CompareAndSwap(false, true) { - close(b.close) - } - return nil -} - -func (b *buffer) Read(p []byte) (int, error) { - b.mu.Lock() - n, err := b.b.Read(p) - b.mu.Unlock() - - if err != nil && errors.Is(err, io.EOF) && !b.closed.Load() { - select { - case <-b.more: - case <-b.close: - return n, io.EOF - } - return n, nil - } - - return n, err +func matchesOpinionatedEnvVarNaming(knownName string) bool { + return opininatedEnvVarNamingRegexp.MatchString(knownName) } type execution struct { @@ -153,7 +95,7 @@ func newExecution( return exec, nil } -func (e *execution) Wait(ctx context.Context, sender sender) (int, error) { +func (e *execution) Wait(ctx context.Context, sender runnerv2.RunnerService_ExecuteServer) (int, error) { lastStdout := io.Discard if e.storeStdoutInEnv { b := rbuffer.NewRingBuffer(session.MaxEnvSizeInBytes - len(command.StoreStdoutEnvName) - 1) @@ -337,10 +279,6 @@ func (e *execution) storeOutputInEnv(ctx context.Context, r io.Reader) { } } -func matchesOpinionatedEnvVarNaming(knownName string) bool { - return opininatedEnvVarNamingRegexp.MatchString(knownName) -} - type sender interface { Send(*runnerv2.ExecuteResponse) error } @@ -400,3 +338,65 @@ func exitCodeFromErr(err error) int { } return -1 } + +type buffer struct { + mu *sync.Mutex + // +checklocks:mu + b *bytes.Buffer + closed *atomic.Bool + close chan struct{} + more chan struct{} +} + +var _ io.WriteCloser = (*buffer)(nil) + +func newBuffer() *buffer { + return &buffer{ + mu: &sync.Mutex{}, + b: bytes.NewBuffer(make([]byte, 0, msgBufferSize)), + closed: &atomic.Bool{}, + close: make(chan struct{}), + more: make(chan struct{}), + } +} + +func (b *buffer) Write(p []byte) (int, error) { + if b.closed.Load() { + return 0, errors.New("closed") + } + + b.mu.Lock() + n, err := b.b.Write(p) + b.mu.Unlock() + + select { + case b.more <- struct{}{}: + default: + } + + return n, err +} + +func (b *buffer) Close() error { + if b.closed.CompareAndSwap(false, true) { + close(b.close) + } + return nil +} + +func (b *buffer) Read(p []byte) (int, error) { + b.mu.Lock() + n, err := b.b.Read(p) + b.mu.Unlock() + + if err != nil && errors.Is(err, io.EOF) && !b.closed.Load() { + select { + case <-b.more: + case <-b.close: + return n, io.EOF + } + return n, nil + } + + return n, err +} diff --git a/internal/runnerv2service/service.go b/internal/runnerv2service/service.go index e36b8a4eb..8e7824b92 100644 --- a/internal/runnerv2service/service.go +++ b/internal/runnerv2service/service.go @@ -68,6 +68,7 @@ func (r *runnerService) getOrCreateSessionFromRequest(req requestWithSession, pr found bool ) + // TODO(adamb): this should come from the runme.yaml in the future. seedEnv := os.Environ() switch req.GetSessionStrategy() { diff --git a/internal/runnerv2service/service_execute.go b/internal/runnerv2service/service_execute.go index 8f4a17d30..24d44e83c 100644 --- a/internal/runnerv2service/service_execute.go +++ b/internal/runnerv2service/service_execute.go @@ -19,13 +19,11 @@ func (r *runnerService) Execute(srv runnerv2.RunnerService_ExecuteServer) error _id := ulid.GenerateID() logger := r.logger.With(zap.String("id", _id)) - logger.Info("running Execute in runnerService") - // Get the initial request. req, err := srv.Recv() if err != nil { if errors.Is(err, io.EOF) { - logger.Info("client closed the connection while getting initial request") + logger.Info("client closed the connection while getting initial request; exiting") return nil } logger.Info("failed to receive a request", zap.Error(err)) @@ -142,18 +140,10 @@ func (r *runnerService) Execute(srv runnerv2.RunnerService_ExecuteServer) error } func getExecutionInfoFromExecutionRequest(runID string, req *runnerv2.ExecuteRequest) *rcontext.ExecutionInfo { - knownName, knownID := "", "" - - reqConfig := req.GetConfig() - if reqConfig != nil { - knownName = reqConfig.GetKnownName() - knownID = reqConfig.GetKnownId() - } - return &rcontext.ExecutionInfo{ ExecContext: "Execute", - KnownID: knownID, - KnownName: knownName, + KnownID: req.GetConfig().GetKnownId(), + KnownName: req.GetConfig().GetKnownName(), RunID: runID, } } diff --git a/internal/runnerv2service/service_execute_test.go b/internal/runnerv2service/service_execute_test.go index 28739b611..e403f7cd1 100644 --- a/internal/runnerv2service/service_execute_test.go +++ b/internal/runnerv2service/service_execute_test.go @@ -70,30 +70,30 @@ func TestRunnerServiceServerExecute_Response(t *testing.T) { stream, err := client.Execute(context.Background()) require.NoError(t, err) - req := &runnerv2.ExecuteRequest{ - Config: &runnerv2.ProgramConfig{ - ProgramName: "bash", - Source: &runnerv2.ProgramConfig_Commands{ - Commands: &runnerv2.ProgramConfig_CommandList{ - Items: []string{ - "echo test | tee >(cat >&2)", + err = stream.Send( + &runnerv2.ExecuteRequest{ + Config: &runnerv2.ProgramConfig{ + ProgramName: "bash", + Source: &runnerv2.ProgramConfig_Commands{ + Commands: &runnerv2.ProgramConfig_CommandList{ + Items: []string{ + "echo test | tee >(cat >&2)", + }, }, }, + Mode: runnerv2.CommandMode_COMMAND_MODE_INLINE, }, - Mode: runnerv2.CommandMode_COMMAND_MODE_INLINE, }, - } - - err = stream.Send(req) + ) require.NoError(t, err) - // Assert first response. + // Assert first response which contains PID. resp, err := stream.Recv() assert.NoError(t, err) assert.Greater(t, resp.Pid.Value, uint32(1)) assert.Nil(t, resp.ExitCode) - // Assert second and third responses. + // Collect second and third responses. var ( out bytes.Buffer mimeType string @@ -122,7 +122,7 @@ func TestRunnerServiceServerExecute_Response(t *testing.T) { if resp.MimeType != "" { mimeType = resp.MimeType } - + // Assert the second and third responses. assert.Contains(t, mimeType, "text/plain") assert.Equal(t, "test\ntest\n", out.String()) @@ -133,7 +133,7 @@ func TestRunnerServiceServerExecute_Response(t *testing.T) { assert.Nil(t, resp.Pid) } -func TestRunnerServiceServerExecute_MimeType(t *testing.T) { +func TestRunnerServiceServerExecute_StoreLastStdout(t *testing.T) { t.Parallel() lis, stop := startRunnerServiceServer(t) @@ -141,113 +141,112 @@ func TestRunnerServiceServerExecute_MimeType(t *testing.T) { _, client := testutils.NewGRPCClientWithT(t, lis, runnerv2.NewRunnerServiceClient) - stream, err := client.Execute(context.Background()) + sessionResp, err := client.CreateSession(context.Background(), &runnerv2.CreateSessionRequest{}) + require.NoError(t, err) + require.NotNil(t, sessionResp.Session) + + stream1, err := client.Execute(context.Background()) require.NoError(t, err) - execResult := make(chan executeResult) - go getExecuteResult(stream, execResult) - - req := &runnerv2.ExecuteRequest{ - Config: &runnerv2.ProgramConfig{ - ProgramName: "bash", - Source: &runnerv2.ProgramConfig_Commands{ - Commands: &runnerv2.ProgramConfig_CommandList{ - Items: []string{ - // Echo JSON to stderr and plain text to stdout. - // Only the plain text should be detected. - ">&2 echo '{\"field1\": \"value\", \"field2\": 2}'", - "echo 'some plain text'", + result1C := make(chan executeResult) + go getExecuteResult(stream1, result1C) + + err = stream1.Send( + &runnerv2.ExecuteRequest{ + Config: &runnerv2.ProgramConfig{ + ProgramName: "bash", + Source: &runnerv2.ProgramConfig_Commands{ + Commands: &runnerv2.ProgramConfig_CommandList{ + Items: []string{ + "echo test | tee >(cat >&2)", + }, }, }, }, + SessionId: sessionResp.GetSession().GetId(), + StoreStdoutInEnv: true, }, - } - - err = stream.Send(req) + ) assert.NoError(t, err) - result := <-execResult - - assert.NoError(t, result.Err) - assert.EqualValues(t, 0, result.ExitCode) - assert.Equal(t, "{\"field1\": \"value\", \"field2\": 2}\n", string(result.Stderr)) - assert.Equal(t, "some plain text\n", string(result.Stdout)) - assert.Contains(t, result.MimeType, "text/plain") -} - -func TestRunnerServiceServerExecute_StoreLastStdout(t *testing.T) { - t.Parallel() - - lis, stop := startRunnerServiceServer(t) - t.Cleanup(stop) - - _, client := testutils.NewGRPCClientWithT(t, lis, runnerv2.NewRunnerServiceClient) - - sessionResp, err := client.CreateSession(context.Background(), &runnerv2.CreateSessionRequest{}) - require.NoError(t, err) - require.NotNil(t, sessionResp.Session) + result1 := <-result1C + assert.NoError(t, result1.Err) + assert.EqualValues(t, 0, result1.ExitCode) + assert.Equal(t, "test\n", string(result1.Stdout)) + assert.Contains(t, result1.MimeType, "text/plain") - stream1, err := client.Execute(context.Background()) + // subsequent req to check last stored value + stream2, err := client.Execute(context.Background()) require.NoError(t, err) - execResult1 := make(chan executeResult) - go getExecuteResult(stream1, execResult1) + result2C := make(chan executeResult) + go getExecuteResult(stream2, result2C) - req1 := &runnerv2.ExecuteRequest{ - Config: &runnerv2.ProgramConfig{ - ProgramName: "bash", - Source: &runnerv2.ProgramConfig_Commands{ - Commands: &runnerv2.ProgramConfig_CommandList{ - Items: []string{ - "echo test | tee >(cat >&2)", + err = stream2.Send( + &runnerv2.ExecuteRequest{ + Config: &runnerv2.ProgramConfig{ + ProgramName: "bash", + Source: &runnerv2.ProgramConfig_Commands{ + Commands: &runnerv2.ProgramConfig_CommandList{ + Items: []string{ + "echo $__", + }, }, }, }, + SessionId: sessionResp.GetSession().GetId(), }, - SessionId: sessionResp.GetSession().GetId(), - StoreStdoutInEnv: true, - } - - err = stream1.Send(req1) + ) assert.NoError(t, err) - result := <-execResult1 + result2 := <-result2C + assert.NoError(t, result2.Err) + assert.EqualValues(t, 0, result2.ExitCode) + assert.Equal(t, "test\n", string(result2.Stdout)) + assert.Contains(t, result2.MimeType, "text/plain") +} - assert.NoError(t, result.Err) - assert.EqualValues(t, 0, result.ExitCode) - assert.Equal(t, "test\n", string(result.Stdout)) - assert.Contains(t, result.MimeType, "text/plain") +func TestRunnerServiceServerExecute_LargeOutput(t *testing.T) { + t.Parallel() - // subsequent req to check last stored value - stream2, err := client.Execute(context.Background()) + temp := t.TempDir() + fileName := filepath.Join(temp, "large_output.json") + _, err := testdata.UngzipToFile(testdata.Users1MGzip, fileName) + require.NoError(t, err) + + lis, stop := startRunnerServiceServer(t) + t.Cleanup(stop) + + _, client := testutils.NewGRPCClientWithT(t, lis, runnerv2.NewRunnerServiceClient) + + stream, err := client.Execute(context.Background()) require.NoError(t, err) - execResult2 := make(chan executeResult) - go getExecuteResult(stream2, execResult2) + resultC := make(chan executeResult) + go getExecuteResult(stream, resultC) - req2 := &runnerv2.ExecuteRequest{ - Config: &runnerv2.ProgramConfig{ - ProgramName: "bash", - Source: &runnerv2.ProgramConfig_Commands{ - Commands: &runnerv2.ProgramConfig_CommandList{ - Items: []string{ - "echo $__", + err = stream.Send( + &runnerv2.ExecuteRequest{ + Config: &runnerv2.ProgramConfig{ + ProgramName: "bash", + Source: &runnerv2.ProgramConfig_Commands{ + Commands: &runnerv2.ProgramConfig_CommandList{ + Items: []string{ + "cat " + fileName, + }, }, }, }, }, - SessionId: sessionResp.GetSession().GetId(), - } - - err = stream2.Send(req2) + ) assert.NoError(t, err) - result = <-execResult2 - + result := <-resultC assert.NoError(t, result.Err) assert.EqualValues(t, 0, result.ExitCode) - assert.Equal(t, "test\n", string(result.Stdout)) - assert.Contains(t, result.MimeType, "text/plain") + fileSize, err := os.Stat(fileName) + assert.NoError(t, err) + assert.EqualValues(t, fileSize.Size(), len(result.Stdout)) } func TestRunnerServiceServerExecute_LastStdoutExceedsEnvLimit(t *testing.T) { @@ -270,28 +269,30 @@ func TestRunnerServiceServerExecute_LastStdoutExceedsEnvLimit(t *testing.T) { stream1, err := client.Execute(context.Background()) require.NoError(t, err) - execResult1 := make(chan executeResult) - go getExecuteResult(stream1, execResult1) + result1C := make(chan executeResult) + go getExecuteResult(stream1, result1C) - req1 := &runnerv2.ExecuteRequest{ - Config: &runnerv2.ProgramConfig{ - ProgramName: "bash", - Source: &runnerv2.ProgramConfig_Commands{ - Commands: &runnerv2.ProgramConfig_CommandList{ - Items: []string{ - "cat " + fileName, + err = stream1.Send( + &runnerv2.ExecuteRequest{ + Config: &runnerv2.ProgramConfig{ + ProgramName: "bash", + Source: &runnerv2.ProgramConfig_Commands{ + Commands: &runnerv2.ProgramConfig_CommandList{ + Items: []string{ + "cat " + fileName, + }, }, }, }, + SessionId: sessionResp.GetSession().GetId(), + StoreStdoutInEnv: true, }, - SessionId: sessionResp.GetSession().GetId(), - StoreStdoutInEnv: true, - } - - err = stream1.Send(req1) + ) + assert.NoError(t, err) + err = stream1.CloseSend() assert.NoError(t, err) - result1 := <-execResult1 + result1 := <-result1C assert.NoError(t, result1.Err) assert.EqualValues(t, 0, result1.ExitCode) @@ -299,30 +300,29 @@ func TestRunnerServiceServerExecute_LastStdoutExceedsEnvLimit(t *testing.T) { stream2, err := client.Execute(context.Background()) require.NoError(t, err) - execResult2 := make(chan executeResult) - go getExecuteResult(stream2, execResult2) + result2C := make(chan executeResult) + go getExecuteResult(stream2, result2C) - req2 := &runnerv2.ExecuteRequest{ - Config: &runnerv2.ProgramConfig{ - ProgramName: "bash", - Source: &runnerv2.ProgramConfig_Commands{ - Commands: &runnerv2.ProgramConfig_CommandList{ - Items: []string{ - "echo -n $" + command.StoreStdoutEnvName, + err = stream2.Send( + &runnerv2.ExecuteRequest{ + Config: &runnerv2.ProgramConfig{ + ProgramName: "bash", + Source: &runnerv2.ProgramConfig_Commands{ + Commands: &runnerv2.ProgramConfig_CommandList{ + Items: []string{ + "echo -n $" + command.StoreStdoutEnvName, + }, }, }, }, + SessionId: sessionResp.GetSession().GetId(), }, - SessionId: sessionResp.GetSession().GetId(), - } - - err = stream2.Send(req2) + ) assert.NoError(t, err) - result2 := <-execResult2 + result2 := <-result2C assert.NoError(t, result2.Err) assert.EqualValues(t, 0, result2.ExitCode) - expected, err := os.ReadFile(fileName) require.NoError(t, err) got := result2.Stdout // stdout is trimmed and should be the suffix of the complete output @@ -330,48 +330,6 @@ func TestRunnerServiceServerExecute_LastStdoutExceedsEnvLimit(t *testing.T) { assert.True(t, bytes.HasSuffix(expected, got)) } -func TestRunnerServiceServerExecute_LargeOutput(t *testing.T) { - t.Parallel() - - temp := t.TempDir() - fileName := filepath.Join(temp, "large_output.json") - _, err := testdata.UngzipToFile(testdata.Users1MGzip, fileName) - require.NoError(t, err) - - lis, stop := startRunnerServiceServer(t) - t.Cleanup(stop) - - _, client := testutils.NewGRPCClientWithT(t, lis, runnerv2.NewRunnerServiceClient) - - stream, err := client.Execute(context.Background()) - require.NoError(t, err) - - execResult := make(chan executeResult) - go getExecuteResult(stream, execResult) - - req := &runnerv2.ExecuteRequest{ - Config: &runnerv2.ProgramConfig{ - ProgramName: "bash", - Source: &runnerv2.ProgramConfig_Commands{ - Commands: &runnerv2.ProgramConfig_CommandList{ - Items: []string{ - "cat " + fileName, - }, - }, - }, - }, - } - err = stream.Send(req) - assert.NoError(t, err) - - result := <-execResult - assert.NoError(t, result.Err) - assert.EqualValues(t, 0, result.ExitCode) - fileSize, err := os.Stat(fileName) - assert.NoError(t, err) - assert.EqualValues(t, fileSize.Size(), len(result.Stdout)) -} - func TestRunnerServiceServerExecute_StoreKnownName(t *testing.T) { t.Parallel() @@ -387,30 +345,29 @@ func TestRunnerServiceServerExecute_StoreKnownName(t *testing.T) { stream1, err := client.Execute(context.Background()) require.NoError(t, err) - execResult1 := make(chan executeResult) - go getExecuteResult(stream1, execResult1) + result1C := make(chan executeResult) + go getExecuteResult(stream1, result1C) - req1 := &runnerv2.ExecuteRequest{ - Config: &runnerv2.ProgramConfig{ - ProgramName: "bash", - Source: &runnerv2.ProgramConfig_Commands{ - Commands: &runnerv2.ProgramConfig_CommandList{ - Items: []string{ - "echo test | tee >(cat >&2)", + err = stream1.Send( + &runnerv2.ExecuteRequest{ + Config: &runnerv2.ProgramConfig{ + ProgramName: "bash", + Source: &runnerv2.ProgramConfig_Commands{ + Commands: &runnerv2.ProgramConfig_CommandList{ + Items: []string{ + "echo test | tee >(cat >&2)", + }, }, }, + KnownName: "TEST_VAR", }, - KnownName: "TEST_VAR", + SessionId: sessionResp.GetSession().GetId(), + StoreStdoutInEnv: true, }, - SessionId: sessionResp.GetSession().GetId(), - StoreStdoutInEnv: true, - } - - err = stream1.Send(req1) + ) assert.NoError(t, err) - result := <-execResult1 - + result := <-result1C assert.NoError(t, result.Err) assert.EqualValues(t, 0, result.ExitCode) assert.Equal(t, "test\n", string(result.Stdout)) @@ -420,28 +377,27 @@ func TestRunnerServiceServerExecute_StoreKnownName(t *testing.T) { stream2, err := client.Execute(context.Background()) require.NoError(t, err) - execResult2 := make(chan executeResult) - go getExecuteResult(stream2, execResult2) + result2C := make(chan executeResult) + go getExecuteResult(stream2, result2C) - req2 := &runnerv2.ExecuteRequest{ - Config: &runnerv2.ProgramConfig{ - ProgramName: "bash", - Source: &runnerv2.ProgramConfig_Commands{ - Commands: &runnerv2.ProgramConfig_CommandList{ - Items: []string{ - "echo $TEST_VAR", + err = stream2.Send( + &runnerv2.ExecuteRequest{ + Config: &runnerv2.ProgramConfig{ + ProgramName: "bash", + Source: &runnerv2.ProgramConfig_Commands{ + Commands: &runnerv2.ProgramConfig_CommandList{ + Items: []string{ + "echo $TEST_VAR", + }, }, }, }, + SessionId: sessionResp.GetSession().GetId(), }, - SessionId: sessionResp.GetSession().GetId(), - } - - err = stream2.Send(req2) + ) assert.NoError(t, err) - result = <-execResult2 - + result = <-result2C assert.NoError(t, result.Err) assert.EqualValues(t, 0, result.ExitCode) assert.Equal(t, "test\n", string(result.Stdout)) @@ -585,13 +541,12 @@ func TestRunnerServiceServerExecute_Configs(t *testing.T) { stream, err := client.Execute(context.Background()) require.NoError(t, err) - execResult := make(chan executeResult) - go getExecuteResult(stream, execResult) + resultC := make(chan executeResult) + go getExecuteResult(stream, resultC) req := &runnerv2.ExecuteRequest{ Config: tc.programConfig, } - if tc.inputData != nil { req.InputData = tc.inputData } @@ -599,8 +554,7 @@ func TestRunnerServiceServerExecute_Configs(t *testing.T) { err = stream.Send(req) assert.NoError(t, err) - result := <-execResult - + result := <-resultC assert.NoError(t, result.Err) assert.Equal(t, tc.expectedOutput, string(result.Stdout)) assert.EqualValues(t, 0, result.ExitCode) @@ -622,69 +576,76 @@ func TestRunnerServiceServerExecute_CommandMode_Terminal(t *testing.T) { // Step 1: execute the first command in the terminal mode with bash, // then write a line that exports an environment variable. { - stream, err := client.Execute(context.Background()) + execStream, err := client.Execute(context.Background()) require.NoError(t, err) - execResult := make(chan executeResult) - go getExecuteResult(stream, execResult) + resultC := make(chan executeResult) + go getExecuteResult(execStream, resultC) - err = stream.Send(&runnerv2.ExecuteRequest{ - Config: &runnerv2.ProgramConfig{ - ProgramName: "bash", - Source: &runnerv2.ProgramConfig_Commands{ - Commands: &runnerv2.ProgramConfig_CommandList{ - Items: []string{ - "bash", + err = execStream.Send( + &runnerv2.ExecuteRequest{ + Config: &runnerv2.ProgramConfig{ + ProgramName: "bash", + Source: &runnerv2.ProgramConfig_Commands{ + Commands: &runnerv2.ProgramConfig_CommandList{ + Items: []string{ + "bash", + }, }, }, + Mode: runnerv2.CommandMode_COMMAND_MODE_TERMINAL, }, - Mode: runnerv2.CommandMode_COMMAND_MODE_TERMINAL, + SessionId: sessResp.GetSession().GetId(), }, - SessionId: sessResp.GetSession().GetId(), - }) + ) require.NoError(t, err) + // Wait for the bash to start. time.Sleep(time.Second) // Export some variables so that it can be tested if they are collected. - req := &runnerv2.ExecuteRequest{InputData: []byte("export TEST_ENV=TEST_VALUE\n")} - err = stream.Send(req) + err = execStream.Send( + &runnerv2.ExecuteRequest{InputData: []byte("export TEST_ENV=TEST_VALUE\n")}, + ) require.NoError(t, err) // Signal the end of input. - req = &runnerv2.ExecuteRequest{InputData: []byte{0x04}} - err = stream.Send(req) + err = execStream.Send( + &runnerv2.ExecuteRequest{InputData: []byte{0x04}}, + ) require.NoError(t, err) - result := <-execResult + result := <-resultC require.NoError(t, result.Err) } // Step 2: execute the second command which will try to get the value of // the exported environment variable from the step 1. { - stream, err := client.Execute(context.Background()) + execStream, err := client.Execute(context.Background()) require.NoError(t, err) - execResult := make(chan executeResult) - go getExecuteResult(stream, execResult) + resultC := make(chan executeResult) + go getExecuteResult(execStream, resultC) - err = stream.Send(&runnerv2.ExecuteRequest{ - Config: &runnerv2.ProgramConfig{ - ProgramName: "bash", - Source: &runnerv2.ProgramConfig_Commands{ - Commands: &runnerv2.ProgramConfig_CommandList{ - Items: []string{ - "echo -n $TEST_ENV", + err = execStream.Send( + &runnerv2.ExecuteRequest{ + Config: &runnerv2.ProgramConfig{ + ProgramName: "bash", + Source: &runnerv2.ProgramConfig_Commands{ + Commands: &runnerv2.ProgramConfig_CommandList{ + Items: []string{ + "echo -n $TEST_ENV", + }, }, }, + Mode: runnerv2.CommandMode_COMMAND_MODE_INLINE, }, - Mode: runnerv2.CommandMode_COMMAND_MODE_INLINE, + SessionId: sessResp.GetSession().GetId(), }, - SessionId: sessResp.GetSession().GetId(), - }) + ) require.NoError(t, err) - result := <-execResult + result := <-resultC require.NoError(t, result.Err) require.Equal(t, "TEST_VALUE", string(result.Stdout)) } @@ -703,24 +664,24 @@ func TestRunnerServiceServerExecute_PathEnvInSession(t *testing.T) { // Run the first request with the default PATH. { - stream, err := client.Execute(context.Background()) + execStream, err := client.Execute(context.Background()) require.NoError(t, err) - result := make(chan executeResult) - go getExecuteResult(stream, result) + resultC := make(chan executeResult) + go getExecuteResult(execStream, resultC) - req := &runnerv2.ExecuteRequest{ - Config: &runnerv2.ProgramConfig{ - ProgramName: "echo", - Arguments: []string{"-n", "test"}, - Mode: runnerv2.CommandMode_COMMAND_MODE_INLINE, + err = execStream.Send( + &runnerv2.ExecuteRequest{ + Config: &runnerv2.ProgramConfig{ + ProgramName: "echo", + Arguments: []string{"-n", "test"}, + Mode: runnerv2.CommandMode_COMMAND_MODE_INLINE, + }, + SessionId: sessionResp.GetSession().GetId(), }, - SessionId: sessionResp.GetSession().GetId(), - } - - err = stream.Send(req) + ) require.NoError(t, err) - require.Equal(t, "test", string((<-result).Stdout)) + require.Equal(t, "test", string((<-resultC).Stdout)) } // Provide a PATH in the session. It will be an empty dir so @@ -736,21 +697,21 @@ func TestRunnerServiceServerExecute_PathEnvInSession(t *testing.T) { // This time the request will fail because the echo command is not found. { - stream, err := client.Execute(context.Background()) + execStream, err := client.Execute(context.Background()) require.NoError(t, err) result := make(chan executeResult) - go getExecuteResult(stream, result) + go getExecuteResult(execStream, result) - req := &runnerv2.ExecuteRequest{ - Config: &runnerv2.ProgramConfig{ - ProgramName: "echo", - Arguments: []string{"-n", "test"}, + err = execStream.Send( + &runnerv2.ExecuteRequest{ + Config: &runnerv2.ProgramConfig{ + ProgramName: "echo", + Arguments: []string{"-n", "test"}, + }, + SessionId: sessionResp.GetSession().GetId(), }, - SessionId: sessionResp.GetSession().GetId(), - } - - err = stream.Send(req) + ) require.NoError(t, err) require.ErrorContains(t, (<-result).Err, "failed program lookup \"echo\"") } @@ -767,36 +728,37 @@ func TestRunnerServiceServerExecute_WithInput(t *testing.T) { t.Run("ContinuousInput", func(t *testing.T) { t.Parallel() - stream, err := client.Execute(context.Background()) + execStream, err := client.Execute(context.Background()) require.NoError(t, err) - execResult := make(chan executeResult) - go getExecuteResult(stream, execResult) + resultC := make(chan executeResult) + go getExecuteResult(execStream, resultC) - err = stream.Send(&runnerv2.ExecuteRequest{ - Config: &runnerv2.ProgramConfig{ - ProgramName: "bash", - Source: &runnerv2.ProgramConfig_Commands{ - Commands: &runnerv2.ProgramConfig_CommandList{ - Items: []string{ - "cat - | tr a-z A-Z", + err = execStream.Send( + &runnerv2.ExecuteRequest{ + Config: &runnerv2.ProgramConfig{ + ProgramName: "bash", + Source: &runnerv2.ProgramConfig_Commands{ + Commands: &runnerv2.ProgramConfig_CommandList{ + Items: []string{ + "cat - | tr a-z A-Z", + }, }, }, + Interactive: true, }, - Interactive: true, + InputData: []byte("a\n"), }, - InputData: []byte("a\n"), - }) + ) require.NoError(t, err) for _, data := range [][]byte{[]byte("b\n"), []byte("c\n"), []byte("d\n"), {0x04}} { req := &runnerv2.ExecuteRequest{InputData: data} - err = stream.Send(req) + err = execStream.Send(req) assert.NoError(t, err) } - result := <-execResult - + result := <-resultC assert.NoError(t, result.Err) assert.EqualValues(t, 0, result.ExitCode) // Validate the output by asserting that lowercase letters precede uppercase letters. @@ -812,42 +774,43 @@ func TestRunnerServiceServerExecute_WithInput(t *testing.T) { t.Run("SimulateCtrlC", func(t *testing.T) { t.Parallel() - stream, err := client.Execute(context.Background()) + execStream, err := client.Execute(context.Background()) require.NoError(t, err) - execResult := make(chan executeResult) - go getExecuteResult(stream, execResult) + resultC := make(chan executeResult) + go getExecuteResult(execStream, resultC) - err = stream.Send(&runnerv2.ExecuteRequest{ - Config: &runnerv2.ProgramConfig{ - ProgramName: "bash", - Source: &runnerv2.ProgramConfig_Commands{ - Commands: &runnerv2.ProgramConfig_CommandList{ - Items: []string{ - "bash", + err = execStream.Send( + &runnerv2.ExecuteRequest{ + Config: &runnerv2.ProgramConfig{ + ProgramName: "bash", + Source: &runnerv2.ProgramConfig_Commands{ + Commands: &runnerv2.ProgramConfig_CommandList{ + Items: []string{ + "bash", + }, }, }, + Interactive: true, }, - Interactive: true, }, - }) + ) require.NoError(t, err) time.Sleep(time.Millisecond * 500) - err = stream.Send(&runnerv2.ExecuteRequest{InputData: []byte("sleep 30")}) + err = execStream.Send(&runnerv2.ExecuteRequest{InputData: []byte("sleep 30")}) assert.NoError(t, err) // cancel sleep time.Sleep(time.Millisecond * 500) - err = stream.Send(&runnerv2.ExecuteRequest{InputData: []byte{0x03}}) + err = execStream.Send(&runnerv2.ExecuteRequest{InputData: []byte{0x03}}) assert.NoError(t, err) time.Sleep(time.Millisecond * 500) - err = stream.Send(&runnerv2.ExecuteRequest{InputData: []byte{0x04}}) + err = execStream.Send(&runnerv2.ExecuteRequest{InputData: []byte{0x04}}) assert.NoError(t, err) - result := <-execResult - + result := <-resultC // TODO(adamb): This should be a specific gRPC error rather than Unknown. assert.Contains(t, result.Err.Error(), "exit status 130") assert.Equal(t, 130, result.ExitCode) @@ -859,8 +822,8 @@ func TestRunnerServiceServerExecute_WithInput(t *testing.T) { stream, err := client.Execute(context.Background()) require.NoError(t, err) - execResult := make(chan executeResult) - go getExecuteResult(stream, execResult) + resultC := make(chan executeResult) + go getExecuteResult(stream, resultC) err = stream.Send(&runnerv2.ExecuteRequest{ Config: &runnerv2.ProgramConfig{ @@ -872,9 +835,10 @@ func TestRunnerServiceServerExecute_WithInput(t *testing.T) { require.NoError(t, err) // Close the send direction. - assert.NoError(t, stream.CloseSend()) + err = stream.CloseSend() + assert.NoError(t, err) - result := <-execResult + result := <-resultC // TODO(adamb): This should be a specific gRPC error rather than Unknown. require.NotNil(t, result.Err) assert.Contains(t, result.Err.Error(), "signal: interrupt") @@ -897,27 +861,28 @@ func TestRunnerServiceServerExecute_WithSession(t *testing.T) { stream, err := client.Execute(context.Background()) require.NoError(t, err) - execResult := make(chan executeResult) - go getExecuteResult(stream, execResult) - - err = stream.Send(&runnerv2.ExecuteRequest{ - Config: &runnerv2.ProgramConfig{ - ProgramName: "bash", - Source: &runnerv2.ProgramConfig_Commands{ - Commands: &runnerv2.ProgramConfig_CommandList{ - Items: []string{ - "echo -n \"$TEST_ENV\"", - "export TEST_ENV=hello-2", + resultC := make(chan executeResult) + go getExecuteResult(stream, resultC) + + err = stream.Send( + &runnerv2.ExecuteRequest{ + Config: &runnerv2.ProgramConfig{ + ProgramName: "bash", + Source: &runnerv2.ProgramConfig_Commands{ + Commands: &runnerv2.ProgramConfig_CommandList{ + Items: []string{ + "echo -n \"$TEST_ENV\"", + "export TEST_ENV=hello-2", + }, }, }, + Env: []string{"TEST_ENV=hello"}, }, - Env: []string{"TEST_ENV=hello"}, }, - }) + ) require.NoError(t, err) - result := <-execResult - + result := <-resultC assert.NoError(t, result.Err) assert.Equal(t, "hello", string(result.Stdout)) } @@ -927,26 +892,27 @@ func TestRunnerServiceServerExecute_WithSession(t *testing.T) { stream, err := client.Execute(context.Background()) require.NoError(t, err) - execResult := make(chan executeResult) - go getExecuteResult(stream, execResult) - - err = stream.Send(&runnerv2.ExecuteRequest{ - Config: &runnerv2.ProgramConfig{ - ProgramName: "bash", - Source: &runnerv2.ProgramConfig_Commands{ - Commands: &runnerv2.ProgramConfig_CommandList{ - Items: []string{ - "echo -n \"$TEST_ENV\"", + resultC := make(chan executeResult) + go getExecuteResult(stream, resultC) + + err = stream.Send( + &runnerv2.ExecuteRequest{ + Config: &runnerv2.ProgramConfig{ + ProgramName: "bash", + Source: &runnerv2.ProgramConfig_Commands{ + Commands: &runnerv2.ProgramConfig_CommandList{ + Items: []string{ + "echo -n \"$TEST_ENV\"", + }, }, }, }, + SessionStrategy: runnerv2.SessionStrategy_SESSION_STRATEGY_MOST_RECENT, }, - SessionStrategy: runnerv2.SessionStrategy_SESSION_STRATEGY_MOST_RECENT, - }) + ) require.NoError(t, err) - result := <-execResult - + result := <-resultC assert.NoError(t, result.Err) assert.Equal(t, "hello-2", string(result.Stdout)) } @@ -964,38 +930,40 @@ func TestRunnerServiceServerExecute_WithStop(t *testing.T) { stream, err := client.Execute(context.Background()) require.NoError(t, err) - execResult := make(chan executeResult) - go getExecuteResult(stream, execResult) - - err = stream.Send(&runnerv2.ExecuteRequest{ - Config: &runnerv2.ProgramConfig{ - ProgramName: "bash", - Source: &runnerv2.ProgramConfig_Commands{ - Commands: &runnerv2.ProgramConfig_CommandList{ - Items: []string{ - "echo 1", - "sleep 30", + resultC := make(chan executeResult) + go getExecuteResult(stream, resultC) + + err = stream.Send( + &runnerv2.ExecuteRequest{ + Config: &runnerv2.ProgramConfig{ + ProgramName: "bash", + Source: &runnerv2.ProgramConfig_Commands{ + Commands: &runnerv2.ProgramConfig_CommandList{ + Items: []string{ + "echo 1", + "sleep 30", + }, }, }, + Interactive: true, }, - Interactive: true, }, - }) + ) require.NoError(t, err) - errc := make(chan error) + errC := make(chan error) go func() { - defer close(errc) + defer close(errC) time.Sleep(time.Second) err := stream.Send(&runnerv2.ExecuteRequest{ Stop: runnerv2.ExecuteStop_EXECUTE_STOP_INTERRUPT, }) - errc <- err + errC <- err }() - assert.NoError(t, <-errc) + assert.NoError(t, <-errC) select { - case result := <-execResult: + case result := <-resultC: // TODO(adamb): There should be no error. assert.Contains(t, result.Err.Error(), "signal: interrupt") assert.Equal(t, 130, result.ExitCode) @@ -1005,18 +973,21 @@ func TestRunnerServiceServerExecute_WithStop(t *testing.T) { stream, err = client.Execute(context.Background()) require.NoError(t, err) - execResult = make(chan executeResult) - go getExecuteResult(stream, execResult) + resultC := make(chan executeResult) + go getExecuteResult(stream, resultC) - err = stream.Send(&runnerv2.ExecuteRequest{ - Config: &runnerv2.ProgramConfig{ - ProgramName: "echo", - Arguments: []string{"-n", "1"}, - Mode: runnerv2.CommandMode_COMMAND_MODE_INLINE, + err = stream.Send( + &runnerv2.ExecuteRequest{ + Config: &runnerv2.ProgramConfig{ + ProgramName: "echo", + Arguments: []string{"-n", "1"}, + Mode: runnerv2.CommandMode_COMMAND_MODE_INLINE, + }, }, - }) + ) require.NoError(t, err) - result = <-execResult + + result = <-resultC assert.Equal(t, "1", string(result.Stdout)) case <-time.After(5 * time.Second): t.Fatal("expected the response early as the command got interrupted") @@ -1037,28 +1008,29 @@ func TestRunnerServiceServerExecute_Winsize(t *testing.T) { stream, err := client.Execute(context.Background()) require.NoError(t, err) - execResult := make(chan executeResult) - go getExecuteResult(stream, execResult) + resultC := make(chan executeResult) + go getExecuteResult(stream, resultC) - err = stream.Send(&runnerv2.ExecuteRequest{ - Config: &runnerv2.ProgramConfig{ - ProgramName: "bash", - Source: &runnerv2.ProgramConfig_Commands{ - Commands: &runnerv2.ProgramConfig_CommandList{ - Items: []string{ - "tput lines", - "tput cols", + err = stream.Send( + &runnerv2.ExecuteRequest{ + Config: &runnerv2.ProgramConfig{ + ProgramName: "bash", + Source: &runnerv2.ProgramConfig_Commands{ + Commands: &runnerv2.ProgramConfig_CommandList{ + Items: []string{ + "tput lines", + "tput cols", + }, }, }, + Env: []string{"TERM=linux"}, + Interactive: true, }, - Env: []string{"TERM=linux"}, - Interactive: true, }, - }) + ) require.NoError(t, err) - result := <-execResult - + result := <-resultC assert.NoError(t, result.Err) assert.Equal(t, "24\r\n80\r\n", string(result.Stdout)) assert.EqualValues(t, 0, result.ExitCode) @@ -1070,33 +1042,34 @@ func TestRunnerServiceServerExecute_Winsize(t *testing.T) { stream, err := client.Execute(context.Background()) require.NoError(t, err) - execResult := make(chan executeResult) - go getExecuteResult(stream, execResult) + resultC := make(chan executeResult) + go getExecuteResult(stream, resultC) - err = stream.Send(&runnerv2.ExecuteRequest{ - Config: &runnerv2.ProgramConfig{ - ProgramName: "bash", - Source: &runnerv2.ProgramConfig_Commands{ - Commands: &runnerv2.ProgramConfig_CommandList{ - Items: []string{ - "sleep 3", // wait for the winsize to be set - "tput lines", - "tput cols", + err = stream.Send( + &runnerv2.ExecuteRequest{ + Config: &runnerv2.ProgramConfig{ + ProgramName: "bash", + Source: &runnerv2.ProgramConfig_Commands{ + Commands: &runnerv2.ProgramConfig_CommandList{ + Items: []string{ + "sleep 3", // wait for the winsize to be set + "tput lines", + "tput cols", + }, }, }, + Interactive: true, + Env: []string{"TERM=linux"}, + }, + Winsize: &runnerv2.Winsize{ + Cols: 200, + Rows: 64, }, - Interactive: true, - Env: []string{"TERM=linux"}, - }, - Winsize: &runnerv2.Winsize{ - Cols: 200, - Rows: 64, }, - }) + ) require.NoError(t, err) - result := <-execResult - + result := <-resultC assert.NoError(t, result.Err) assert.Equal(t, "64\r\n200\r\n", string(result.Stdout)) assert.EqualValues(t, 0, result.ExitCode) @@ -1123,18 +1096,22 @@ func startRunnerServiceServer(t *testing.T) (_ *bufconn.Listener, stop func()) { } type executeResult struct { - Stdout []byte - Stderr []byte - MimeType string - ExitCode int Err error + ExitCode int + MimeType string + Stderr []byte + Stdout []byte } func getExecuteResult( stream runnerv2.RunnerService_ExecuteClient, resultc chan<- executeResult, ) { - result := executeResult{ExitCode: -1} + result := executeResult{ + ExitCode: -1, + } + bufStdout := new(bytes.Buffer) + bufStderr := new(bytes.Buffer) for { r, rerr := stream.Recv() @@ -1145,8 +1122,8 @@ func getExecuteResult( result.Err = rerr break } - result.Stdout = append(result.Stdout, r.StdoutData...) - result.Stderr = append(result.Stderr, r.StderrData...) + _, _ = bufStdout.Write(r.StdoutData) + _, _ = bufStderr.Write(r.StderrData) if r.MimeType != "" { result.MimeType = r.MimeType } @@ -1155,5 +1132,8 @@ func getExecuteResult( } } + result.Stdout = bufStdout.Bytes() + result.Stderr = bufStderr.Bytes() + resultc <- result } diff --git a/internal/runnerv2service/service_sessions.go b/internal/runnerv2service/service_sessions.go index 95ce59bb4..1d3df7847 100644 --- a/internal/runnerv2service/service_sessions.go +++ b/internal/runnerv2service/service_sessions.go @@ -11,32 +11,8 @@ import ( rcontext "github.com/stateful/runme/v3/internal/runner/context" "github.com/stateful/runme/v3/internal/session" runnerv2 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/runner/v2" - "github.com/stateful/runme/v3/pkg/project" ) -func convertSessionToRunnerv2alpha1Session(sess *session.Session) *runnerv2.Session { - return &runnerv2.Session{ - Id: sess.ID, - Env: sess.GetAllEnv(), - // Metadata: sess.Metadata, - } -} - -// TODO(adamb): this function should not return nil project and nil error at the same time. -func convertProtoProjectToProject(runnerProj *runnerv2.Project) (*project.Project, error) { - if runnerProj == nil { - return nil, nil - } - - opts := project.DefaultProjectOptions[:] - - if runnerProj.EnvLoadOrder != nil { - opts = append(opts, project.WithEnvFilesReadOrder(runnerProj.EnvLoadOrder)) - } - - return project.NewDirProject(runnerProj.Root, opts...) -} - func (r *runnerService) CreateSession(ctx context.Context, req *runnerv2.CreateSessionRequest) (*runnerv2.CreateSessionResponse, error) { r.logger.Info("running CreateSession in runnerService") @@ -75,7 +51,7 @@ func (r *runnerService) CreateSession(ctx context.Context, req *runnerv2.CreateS r.logger.Debug("created session", zap.String("id", sess.ID), zap.Bool("owl", owl), zap.Int("seed_env_len", len(seedEnv))) return &runnerv2.CreateSessionResponse{ - Session: convertSessionToRunnerv2alpha1Session(sess), + Session: convertSessionToProtoSession(sess), }, nil } @@ -88,7 +64,7 @@ func (r *runnerService) GetSession(_ context.Context, req *runnerv2.GetSessionRe } return &runnerv2.GetSessionResponse{ - Session: convertSessionToRunnerv2alpha1Session(sess), + Session: convertSessionToProtoSession(sess), }, nil } @@ -99,7 +75,7 @@ func (r *runnerService) ListSessions(_ context.Context, req *runnerv2.ListSessio runnerSessions := make([]*runnerv2.Session, 0, len(sessions)) for _, s := range sessions { - runnerSessions = append(runnerSessions, convertSessionToRunnerv2alpha1Session(s)) + runnerSessions = append(runnerSessions, convertSessionToProtoSession(s)) } return &runnerv2.ListSessionsResponse{Sessions: runnerSessions}, nil @@ -117,7 +93,7 @@ func (r *runnerService) UpdateSession(ctx context.Context, req *runnerv2.UpdateS return nil, err } - return &runnerv2.UpdateSessionResponse{Session: convertSessionToRunnerv2alpha1Session(sess)}, nil + return &runnerv2.UpdateSessionResponse{Session: convertSessionToProtoSession(sess)}, nil } func (r *runnerService) DeleteSession(_ context.Context, req *runnerv2.DeleteSessionRequest) (*runnerv2.DeleteSessionResponse, error) { From 458c2d6ef9c7779b51490ea988c3a6bd3abc46e3 Mon Sep 17 00:00:00 2001 From: Adam Babik Date: Sun, 8 Dec 2024 20:32:49 +0100 Subject: [PATCH 2/5] Implement execution2, but unclear it is faster --- internal/cmd/beta/run_cmd.go | 2 +- internal/command/config.go | 4 +- internal/config/autoconfig/autoconfig.go | 2 +- internal/owl/store.go | 4 +- internal/runner/context/exec_info.go | 13 +- internal/runner/service.go | 4 +- internal/runnerv2client/client.go | 10 + internal/runnerv2service/execution.go | 28 +- internal/runnerv2service/execution2.go | 304 ++++++++++++++++++ internal/runnerv2service/service_execute.go | 32 +- .../runnerv2service/service_execute_test.go | 9 +- internal/runnerv2service/service_sessions.go | 2 +- internal/server/server.go | 4 +- internal/testutils/grpc.go | 1 + 14 files changed, 371 insertions(+), 48 deletions(-) create mode 100644 internal/runnerv2service/execution2.go diff --git a/internal/cmd/beta/run_cmd.go b/internal/cmd/beta/run_cmd.go index 4789faafe..ede5c94b8 100644 --- a/internal/cmd/beta/run_cmd.go +++ b/internal/cmd/beta/run_cmd.go @@ -181,7 +181,7 @@ func runCodeBlock( KnownName: block.Name(), KnownID: block.ID(), } - ctx = rcontext.ContextWithExecutionInfo(ctx, execInfo) + ctx = rcontext.WithExecutionInfo(ctx, execInfo) cmd, err := factory.Build(cfg, options) if err != nil { diff --git a/internal/command/config.go b/internal/command/config.go index 96ff2ecb7..a2efc8db0 100644 --- a/internal/command/config.go +++ b/internal/command/config.go @@ -25,10 +25,10 @@ func redactConfig(cfg *ProgramConfig) *ProgramConfig { } func isShell(cfg *ProgramConfig) bool { - return IsShellProgram(filepath.Base(cfg.ProgramName)) || IsShellLanguage(cfg.LanguageId) + return isShellProgram(filepath.Base(cfg.ProgramName)) || IsShellLanguage(cfg.LanguageId) } -func IsShellProgram(programName string) bool { +func isShellProgram(programName string) bool { switch strings.ToLower(programName) { case "sh", "bash", "zsh", "ksh", "shell": return true diff --git a/internal/config/autoconfig/autoconfig.go b/internal/config/autoconfig/autoconfig.go index 7d9353e1e..87333966f 100644 --- a/internal/config/autoconfig/autoconfig.go +++ b/internal/config/autoconfig/autoconfig.go @@ -166,7 +166,7 @@ func getLogger(c *config.Config) (*zap.Logger, error) { } if c.Log.Verbose { - zapConfig.Level = zap.NewAtomicLevelAt(zap.DebugLevel) + zapConfig.Level = zap.NewAtomicLevelAt(zap.InfoLevel) zapConfig.Development = true zapConfig.Encoding = "console" zapConfig.EncoderConfig = zap.NewDevelopmentEncoderConfig() diff --git a/internal/owl/store.go b/internal/owl/store.go index 45742fdbd..eac119be2 100644 --- a/internal/owl/store.go +++ b/internal/owl/store.go @@ -698,12 +698,12 @@ func (s *Store) LoadEnvs(source string, envs ...string) error { return nil } -func (s *Store) Update(context context.Context, newOrUpdated, deleted []string) error { +func (s *Store) Update(ctx context.Context, newOrUpdated, deleted []string) error { s.mu.Lock() defer s.mu.Unlock() execRef := "[execution]" - if execInfo, ok := context.Value(rcontext.ExecutionInfoKey).(*rcontext.ExecutionInfo); ok { + if execInfo, ok := rcontext.ExecutionInfoFromContext(ctx); ok { execRef = fmt.Sprintf("#%s", execInfo.KnownID) if execInfo.KnownName != "" { execRef = fmt.Sprintf("#%s", execInfo.KnownName) diff --git a/internal/runner/context/exec_info.go b/internal/runner/context/exec_info.go index 81069257f..de912f14e 100644 --- a/internal/runner/context/exec_info.go +++ b/internal/runner/context/exec_info.go @@ -2,9 +2,9 @@ package runner import "context" -type runnerContextKey struct{} +type contextKey struct{ string } -var ExecutionInfoKey = &runnerContextKey{} +var executionInfoKey = &contextKey{"ExecutionInfo"} type ExecutionInfo struct { ExecContext string @@ -13,6 +13,11 @@ type ExecutionInfo struct { RunID string } -func ContextWithExecutionInfo(ctx context.Context, execInfo *ExecutionInfo) context.Context { - return context.WithValue(ctx, ExecutionInfoKey, execInfo) +func WithExecutionInfo(ctx context.Context, execInfo *ExecutionInfo) context.Context { + return context.WithValue(ctx, executionInfoKey, execInfo) +} + +func ExecutionInfoFromContext(ctx context.Context) (*ExecutionInfo, bool) { + execInfo, ok := ctx.Value(executionInfoKey).(*ExecutionInfo) + return execInfo, ok } diff --git a/internal/runner/service.go b/internal/runner/service.go index a0811b336..fc5f19745 100644 --- a/internal/runner/service.go +++ b/internal/runner/service.go @@ -228,7 +228,7 @@ func (r *runnerService) Execute(srv runnerv1.RunnerService_ExecuteServer) error KnownName: req.GetKnownName(), KnownID: req.GetKnownId(), } - ctx := rcontext.ContextWithExecutionInfo(srv.Context(), execInfo) + ctx := rcontext.WithExecutionInfo(srv.Context(), execInfo) if req.KnownId != "" { logger = logger.With(zap.String("knownID", req.KnownId)) @@ -353,7 +353,7 @@ func (r *runnerService) Execute(srv runnerv1.RunnerService_ExecuteServer) error cmdCtx := ctx if req.Background { - cmdCtx = rcontext.ContextWithExecutionInfo(context.Background(), execInfo) + cmdCtx = rcontext.WithExecutionInfo(context.Background(), execInfo) } if err := cmd.StartWithOpts(cmdCtx, &startOpts{}); err != nil { diff --git a/internal/runnerv2client/client.go b/internal/runnerv2client/client.go index d0e6305e3..e218ef70a 100644 --- a/internal/runnerv2client/client.go +++ b/internal/runnerv2client/client.go @@ -12,6 +12,8 @@ import ( runnerv2 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/runner/v2" ) +const maxMsgSize = 32 * 1024 * 1024 // 32 MiB + type Client struct { runnerv2.RunnerServiceClient conn *grpc.ClientConn @@ -19,6 +21,14 @@ type Client struct { } func New(target string, logger *zap.Logger, opts ...grpc.DialOption) (*Client, error) { + opts = append( + // default options + []grpc.DialOption{ + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxMsgSize)), + }, + opts..., + ) + client, err := grpc.NewClient(target, opts...) if err != nil { return nil, errors.WithStack(err) diff --git a/internal/runnerv2service/execution.go b/internal/runnerv2service/execution.go index 663845ace..3b654d7bb 100644 --- a/internal/runnerv2service/execution.go +++ b/internal/runnerv2service/execution.go @@ -22,16 +22,6 @@ import ( "github.com/stateful/runme/v3/pkg/project" ) -const ( - // msgBufferSize limits the size of data chunks - // sent by the handler to clients. It's smaller - // intentionally as typically the messages are - // small. - // In the future, it might be worth to implement - // variable-sized buffers. - msgBufferSize = 2 * 1024 * 1024 // 2 MiB -) - var opininatedEnvVarNamingRegexp = regexp.MustCompile(`^[A-Z_][A-Z0-9_]{1}[A-Z0-9_]*[A-Z][A-Z0-9_]*$`) func matchesOpinionatedEnvVarNaming(knownName string) bool { @@ -63,8 +53,8 @@ func newExecution( ) stdin, stdinWriter := io.Pipe() - stdout := newBuffer() - stderr := newBuffer() + stdout := newBuffer(msgBufferSize) + stderr := newBuffer(msgBufferSize) cmdOptions := command.CommandOptions{ EnableEcho: true, @@ -301,7 +291,7 @@ func readSendLoop( eof = true } - logger.Info("readSendLoop", zap.Int("n", n)) + logger.Debug("readSendLoop", zap.Int("n", n)) if n == 0 && eof { return nil @@ -350,10 +340,10 @@ type buffer struct { var _ io.WriteCloser = (*buffer)(nil) -func newBuffer() *buffer { +func newBuffer(size int) *buffer { return &buffer{ mu: &sync.Mutex{}, - b: bytes.NewBuffer(make([]byte, 0, msgBufferSize)), + b: bytes.NewBuffer(make([]byte, 0, size)), closed: &atomic.Bool{}, close: make(chan struct{}), more: make(chan struct{}), @@ -390,12 +380,16 @@ func (b *buffer) Read(p []byte) (int, error) { b.mu.Unlock() if err != nil && errors.Is(err, io.EOF) && !b.closed.Load() { + if n > 0 { + return n, nil + } + select { case <-b.more: + return b.Read(p) case <-b.close: - return n, io.EOF + return 0, io.EOF } - return n, nil } return n, err diff --git a/internal/runnerv2service/execution2.go b/internal/runnerv2service/execution2.go new file mode 100644 index 000000000..b374d0e89 --- /dev/null +++ b/internal/runnerv2service/execution2.go @@ -0,0 +1,304 @@ +package runnerv2service + +import ( + "bytes" + "context" + "io" + "os" + "time" + + "go.uber.org/zap" + + "github.com/gabriel-vasile/mimetype" + "github.com/pkg/errors" + "github.com/stateful/runme/v3/internal/command" + "github.com/stateful/runme/v3/internal/rbuffer" + "github.com/stateful/runme/v3/internal/session" + runnerv2 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/runner/v2" + "github.com/stateful/runme/v3/pkg/project" +) + +const ( + // msgBufferSize limits the size of data chunks + // sent by the handler to clients. It's smaller + // intentionally as typically the messages are + // small. + // In the future, it might be worth to implement + // variable-sized buffers. + msgBufferSize = 4 * 1024 * 1024 // 4 MiB +) + +//lint:ignore U1000 Used in A/B testing +type execution2 struct { + Cmd command.Command + + knownName string + logger *zap.Logger + session *session.Session + storeStdoutInEnv bool + + stdinR, stdoutR, stderrR io.Reader + stdinW, stdoutW, stderrW io.WriteCloser +} + +//lint:ignore U1000 Used in A/B testing +func newExecution2( + cfg *command.ProgramConfig, + proj *project.Project, + session *session.Session, + logger *zap.Logger, + storeStdoutInEnv bool, +) (*execution2, error) { + logger = logger.Named("execution2") + + cmdFactory := command.NewFactory( + command.WithProject(proj), + command.WithLogger(logger), + ) + + stdinR, stdinW := io.Pipe() + stdoutR, stdoutW := io.Pipe() + stderrR, stderrW := io.Pipe() + + cmdOptions := command.CommandOptions{ + EnableEcho: true, + Session: session, + StdinWriter: stdinW, + Stdin: stdinR, + Stdout: stdoutW, + Stderr: stderrW, + } + + cmd, err := cmdFactory.Build(cfg, cmdOptions) + if err != nil { + return nil, err + } + + exec := &execution2{ + Cmd: cmd, + + knownName: cfg.GetKnownName(), + logger: logger, + session: session, + storeStdoutInEnv: storeStdoutInEnv, + + stdinR: stdinR, + stdinW: stdinW, + stdoutR: stdoutR, + stdoutW: stdoutW, + stderrR: stderrR, + stderrW: stderrW, + } + return exec, nil +} + +func (e *execution2) closeIO() { + err := e.stdinW.Close() + e.logger.Info("closed stdin writer", zap.Error(err)) + + err = e.stdoutW.Close() + e.logger.Info("closed stdout writer", zap.Error(err)) + + err = e.stderrW.Close() + e.logger.Info("closed stderr writer", zap.Error(err)) +} + +func (e *execution2) storeOutputInEnv(ctx context.Context, r io.Reader) { + b, err := io.ReadAll(r) + if err != nil { + e.logger.Warn("failed to read last output", zap.Error(err)) + return + } + + sanitized := bytes.ReplaceAll(b, []byte{'\000'}, nil) + env := command.CreateEnv(command.StoreStdoutEnvName, string(sanitized)) + if err := e.session.SetEnv(ctx, env); err != nil { + e.logger.Warn("failed to store last output", zap.Error(err)) + } + + if e.knownName != "" && matchesOpinionatedEnvVarNaming(e.knownName) { + if err := e.session.SetEnv(ctx, e.knownName+"="+string(sanitized)); err != nil { + e.logger.Warn("failed to store output under known name", zap.String("known_name", e.knownName), zap.Error(err)) + } + } +} + +func (e *execution2) Wait(ctx context.Context, sender runnerv2.RunnerService_ExecuteServer) (int, error) { + envStdout := io.Discard + if e.storeStdoutInEnv { + b := rbuffer.NewRingBuffer(session.MaxEnvSizeInBytes - len(command.StoreStdoutEnvName) - 1) + defer func() { + _ = b.Close() + e.storeOutputInEnv(ctx, b) + }() + envStdout = b + } + + readSendDone := make(chan error, 2) + go func() { + mimetypeDetected := false + + readSendDone <- e.readSendLoop( + sender, + e.stdoutR, + func(b []byte) *runnerv2.ExecuteResponse { + if _, err := envStdout.Write(b); err != nil { + e.logger.Warn("failed to write to envStdout writer", zap.Error(err)) + envStdout = io.Discard + } + + response := &runnerv2.ExecuteResponse{ + StdoutData: b, + } + + if !mimetypeDetected { + if detected := mimetype.Detect(response.StdoutData); detected != nil { + mimetypeDetected = true + response.MimeType = detected.String() + e.logger.Debug("detected MIME type", zap.String("mime", detected.String())) + } else { + e.logger.Debug("failed to detect MIME type") + } + } + + return response + }, + e.logger.Named("readSendLoop.stdout"), + ) + }() + go func() { + readSendDone <- e.readSendLoop( + sender, + e.stderrR, + func(b []byte) *runnerv2.ExecuteResponse { + return &runnerv2.ExecuteResponse{ + StderrData: b, + } + }, + e.logger.Named("readSendLoop.stderr"), + ) + }() + + waitErr := e.Cmd.Wait(ctx) + exitCode := exitCodeFromErr(waitErr) + e.logger.Info("command finished", zap.Int("exitCode", exitCode), zap.Error(waitErr)) + + e.closeIO() + + if waitErr != nil { + return exitCode, waitErr + } + + readSendLoopsFinished := 0 + +finalWait: + select { + case <-ctx.Done(): + e.logger.Info("context done", zap.Error(ctx.Err())) + return exitCode, ctx.Err() + case err := <-readSendDone: + if err != nil { + e.logger.Info("readSendCtx done", zap.Error(err)) + } + readSendLoopsFinished++ + if readSendLoopsFinished < 2 { + goto finalWait + } + return exitCode, err + } +} + +func (e *execution2) readSendLoop( + sender runnerv2.RunnerService_ExecuteServer, + src io.Reader, + cb func([]byte) *runnerv2.ExecuteResponse, + logger *zap.Logger, +) error { + const sendsPerSecond = 30 + + buf := newBuffer(msgBufferSize) + + // Copy from src to [buffer]. + go func() { + n, err := io.Copy(buf, src) + logger.Debug("copied from source to buffer", zap.Int64("count", n), zap.Error(err)) + _ = buf.Close() // always nil + }() + + data := make([]byte, msgBufferSize) + + for { + eof := false + n, err := buf.Read(data) + if err != nil { + if !errors.Is(err, io.EOF) { + return errors.WithStack(err) + } + eof = true + } + logger.Debug("read", zap.Int("n", n), zap.Bool("eof", eof)) + if n == 0 { + if eof { + return nil + } + continue + } + + readTime := time.Now() + + response := cb(data[:n]) + if err := sender.Send(response); err != nil { + return errors.WithStack(err) + } + + time.Sleep(time.Second/sendsPerSecond - time.Since(readTime)) + } +} + +func (e *execution2) Write(p []byte) (int, error) { + n, err := e.stdinW.Write(p) + + // Close stdin writer for non-interactive commands after handling the initial request. + // Non-interactive commands do not support sending data continuously and require that + // the stdin writer to be closed to finish processing the input. + if ok := e.Cmd.Interactive(); !ok { + if closeErr := e.stdinW.Close(); closeErr != nil { + e.logger.Info("failed to close native command stdin writer", zap.Error(closeErr)) + if err == nil { + err = closeErr + } + } + } + + return n, errors.WithStack(err) +} + +func (e *execution2) SetWinsize(size *runnerv2.Winsize) error { + if size == nil { + return nil + } + + return command.SetWinsize( + e.Cmd, + &command.Winsize{ + Rows: uint16(size.Rows), + Cols: uint16(size.Cols), + X: uint16(size.X), + Y: uint16(size.Y), + }, + ) +} + +func (e *execution2) Stop(stop runnerv2.ExecuteStop) (err error) { + switch stop { + case runnerv2.ExecuteStop_EXECUTE_STOP_UNSPECIFIED: + // continue + case runnerv2.ExecuteStop_EXECUTE_STOP_INTERRUPT: + err = e.Cmd.Signal(os.Interrupt) + case runnerv2.ExecuteStop_EXECUTE_STOP_KILL: + err = e.Cmd.Signal(os.Kill) + default: + err = errors.New("unknown stop signal") + } + return +} diff --git a/internal/runnerv2service/service_execute.go b/internal/runnerv2service/service_execute.go index 24d44e83c..e131ea2f2 100644 --- a/internal/runnerv2service/service_execute.go +++ b/internal/runnerv2service/service_execute.go @@ -16,8 +16,8 @@ import ( ) func (r *runnerService) Execute(srv runnerv2.RunnerService_ExecuteServer) error { - _id := ulid.GenerateID() - logger := r.logger.With(zap.String("id", _id)) + runID := ulid.GenerateID() + logger := r.logger.With(zap.String("id", runID)) // Get the initial request. req, err := srv.Recv() @@ -31,8 +31,10 @@ func (r *runnerService) Execute(srv runnerv2.RunnerService_ExecuteServer) error } logger.Info("received initial request", zap.Any("req", req)) - execInfo := getExecutionInfoFromExecutionRequest(_id, req) - ctx := rcontext.ContextWithExecutionInfo(srv.Context(), execInfo) + execInfo := getExecutionInfoFromExecutionRequest(req) + execInfo.RunID = runID + + ctx := rcontext.WithExecutionInfo(srv.Context(), execInfo) // Load the project. // TODO(adamb): this should come from the runme.yaml in the future. @@ -49,7 +51,6 @@ func (r *runnerService) Execute(srv runnerv2.RunnerService_ExecuteServer) error if !existed { r.sessions.Add(session) } - if err := session.SetEnv(ctx, req.Config.Env...); err != nil { return err } @@ -65,6 +66,17 @@ func (r *runnerService) Execute(srv runnerv2.RunnerService_ExecuteServer) error return err } + // exec, err := newExecution2( + // req.Config, + // proj, + // session, + // logger, + // req.StoreStdoutInEnv, + // ) + // if err != nil { + // return err + // } + // Start the command and send the initial response with PID. if err := exec.Cmd.Start(ctx); err != nil { return err @@ -82,14 +94,11 @@ func (r *runnerService) Execute(srv runnerv2.RunnerService_ExecuteServer) error req := initialReq for { - var err error - if err := exec.SetWinsize(req.Winsize); err != nil { logger.Info("failed to set winsize; ignoring", zap.Error(err)) } - _, err = exec.Write(req.InputData) - if err != nil { + if _, err := exec.Write(req.InputData); err != nil { logger.Info("failed to write to stdin; ignoring", zap.Error(err)) } @@ -97,7 +106,7 @@ func (r *runnerService) Execute(srv runnerv2.RunnerService_ExecuteServer) error logger.Info("failed to stop program; ignoring", zap.Error(err)) } - req, err = srv.Recv() + req, err := srv.Recv() logger.Info("received request", zap.Any("req", req), zap.Error(err)) switch { case err == nil: @@ -139,11 +148,10 @@ func (r *runnerService) Execute(srv runnerv2.RunnerService_ExecuteServer) error return waitErr } -func getExecutionInfoFromExecutionRequest(runID string, req *runnerv2.ExecuteRequest) *rcontext.ExecutionInfo { +func getExecutionInfoFromExecutionRequest(req *runnerv2.ExecuteRequest) *rcontext.ExecutionInfo { return &rcontext.ExecutionInfo{ ExecContext: "Execute", KnownID: req.GetConfig().GetKnownId(), KnownName: req.GetConfig().GetKnownName(), - RunID: runID, } } diff --git a/internal/runnerv2service/service_execute_test.go b/internal/runnerv2service/service_execute_test.go index e403f7cd1..8b11c15a1 100644 --- a/internal/runnerv2service/service_execute_test.go +++ b/internal/runnerv2service/service_execute_test.go @@ -129,8 +129,8 @@ func TestRunnerServiceServerExecute_Response(t *testing.T) { // Assert fourth response. resp, err = stream.Recv() assert.NoError(t, err) - assert.Equal(t, uint32(0), resp.ExitCode.Value) - assert.Nil(t, resp.Pid) + assert.Equal(t, uint32(0), resp.GetExitCode().GetValue()) + assert.Nil(t, resp.GetPid()) } func TestRunnerServiceServerExecute_StoreLastStdout(t *testing.T) { @@ -1086,7 +1086,10 @@ func startRunnerServiceServer(t *testing.T) (_ *bufconn.Listener, stop func()) { runnerService, err := NewRunnerService(factory, logger) require.NoError(t, err) - server := grpc.NewServer() + server := grpc.NewServer( + grpc.MaxRecvMsgSize(msgBufferSize*2), + grpc.MaxSendMsgSize(msgBufferSize*2), + ) runnerv2.RegisterRunnerServiceServer(server, runnerService) lis := bufconn.Listen(1 << 20) // 1 MB diff --git a/internal/runnerv2service/service_sessions.go b/internal/runnerv2service/service_sessions.go index 1d3df7847..3a8336491 100644 --- a/internal/runnerv2service/service_sessions.go +++ b/internal/runnerv2service/service_sessions.go @@ -114,7 +114,7 @@ type updateRequest interface { } func (r *runnerService) updateSession(ctx context.Context, sess *session.Session, req updateRequest) error { - ctx = rcontext.ContextWithExecutionInfo(ctx, &rcontext.ExecutionInfo{ + ctx = rcontext.WithExecutionInfo(ctx, &rcontext.ExecutionInfo{ ExecContext: "request", }) diff --git a/internal/server/server.go b/internal/server/server.go index a22cf6565..c02a4c195 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -23,9 +23,7 @@ import ( "github.com/stateful/runme/v3/pkg/document/editor/editorservice" ) -const ( - maxMsgSize = 4 * 1024 * 1024 // 4 MiB -) +const maxMsgSize = 32 * 1024 * 1024 // 32 MiB type Config struct { Address string diff --git a/internal/testutils/grpc.go b/internal/testutils/grpc.go index d28e92bcd..ac98fd445 100644 --- a/internal/testutils/grpc.go +++ b/internal/testutils/grpc.go @@ -42,6 +42,7 @@ func newGRPCClient[T any]( return lis.Dial() }), grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(8*1024*1024)), ) if err != nil { var result T From f1971ad9debc0d2e7f9fe2720b5c86fa89c894c8 Mon Sep 17 00:00:00 2001 From: Adam Babik Date: Sun, 22 Dec 2024 07:58:30 +0100 Subject: [PATCH 3/5] Optimize buffer size in execution2 --- internal/runnerv2service/execution.go | 2 ++ internal/runnerv2service/execution2.go | 6 ++++-- internal/runnerv2service/service_execute.go | 24 ++++++++++----------- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/internal/runnerv2service/execution.go b/internal/runnerv2service/execution.go index 3b654d7bb..9038a6050 100644 --- a/internal/runnerv2service/execution.go +++ b/internal/runnerv2service/execution.go @@ -28,6 +28,7 @@ func matchesOpinionatedEnvVarNaming(knownName string) bool { return opininatedEnvVarNamingRegexp.MatchString(knownName) } +//lint:ignore U1000 Used in A/B testing type execution struct { Cmd command.Command knownName string @@ -40,6 +41,7 @@ type execution struct { storeStdoutInEnv bool } +//lint:ignore U1000 Used in A/B testing func newExecution( cfg *command.ProgramConfig, proj *project.Project, diff --git a/internal/runnerv2service/execution2.go b/internal/runnerv2service/execution2.go index b374d0e89..1f9b57fe5 100644 --- a/internal/runnerv2service/execution2.go +++ b/internal/runnerv2service/execution2.go @@ -25,7 +25,7 @@ const ( // small. // In the future, it might be worth to implement // variable-sized buffers. - msgBufferSize = 4 * 1024 * 1024 // 4 MiB + msgBufferSize = 32 * 1024 * 1024 // 4 MiB ) //lint:ignore U1000 Used in A/B testing @@ -251,7 +251,9 @@ func (e *execution2) readSendLoop( return errors.WithStack(err) } - time.Sleep(time.Second/sendsPerSecond - time.Since(readTime)) + if n < msgBufferSize { + time.Sleep(time.Second/sendsPerSecond - time.Since(readTime)) + } } } diff --git a/internal/runnerv2service/service_execute.go b/internal/runnerv2service/service_execute.go index e131ea2f2..51f129f5e 100644 --- a/internal/runnerv2service/service_execute.go +++ b/internal/runnerv2service/service_execute.go @@ -55,18 +55,7 @@ func (r *runnerService) Execute(srv runnerv2.RunnerService_ExecuteServer) error return err } - exec, err := newExecution( - req.Config, - proj, - session, - logger, - req.StoreStdoutInEnv, - ) - if err != nil { - return err - } - - // exec, err := newExecution2( + // exec, err := newExecution( // req.Config, // proj, // session, @@ -77,6 +66,17 @@ func (r *runnerService) Execute(srv runnerv2.RunnerService_ExecuteServer) error // return err // } + exec, err := newExecution2( + req.Config, + proj, + session, + logger, + req.StoreStdoutInEnv, + ) + if err != nil { + return err + } + // Start the command and send the initial response with PID. if err := exec.Cmd.Start(ctx); err != nil { return err From d22741df93f6265d47c4ff1de687ebb7a9684803 Mon Sep 17 00:00:00 2001 From: Adam Babik Date: Sat, 4 Jan 2025 16:51:38 +0100 Subject: [PATCH 4/5] Merge execution2 and execution --- internal/runnerv2client/client.go | 4 +- internal/runnerv2service/buffer.go | 84 +++++ internal/runnerv2service/execution.go | 374 ++++++++------------ internal/runnerv2service/execution2.go | 306 ---------------- internal/runnerv2service/service_execute.go | 4 +- internal/server/server.go | 10 +- 6 files changed, 245 insertions(+), 537 deletions(-) create mode 100644 internal/runnerv2service/buffer.go delete mode 100644 internal/runnerv2service/execution2.go diff --git a/internal/runnerv2client/client.go b/internal/runnerv2client/client.go index e218ef70a..bd8c63da3 100644 --- a/internal/runnerv2client/client.go +++ b/internal/runnerv2client/client.go @@ -12,7 +12,7 @@ import ( runnerv2 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/runner/v2" ) -const maxMsgSize = 32 * 1024 * 1024 // 32 MiB +const MaxMsgSize = 32 * 1024 * 1024 // 32 MiB type Client struct { runnerv2.RunnerServiceClient @@ -24,7 +24,7 @@ func New(target string, logger *zap.Logger, opts ...grpc.DialOption) (*Client, e opts = append( // default options []grpc.DialOption{ - grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxMsgSize)), + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxMsgSize)), }, opts..., ) diff --git a/internal/runnerv2service/buffer.go b/internal/runnerv2service/buffer.go new file mode 100644 index 000000000..07bcf5d5e --- /dev/null +++ b/internal/runnerv2service/buffer.go @@ -0,0 +1,84 @@ +package runnerv2service + +import ( + "bytes" + "io" + "sync" + "sync/atomic" + + "github.com/pkg/errors" +) + +const ( + // msgBufferSize limits the size of data chunks + // sent by the handler to clients. It's smaller + // intentionally as typically the messages are + // small. + // In the future, it might be worth to implement + // variable-sized buffers. + msgBufferSize = 32 * 1024 * 1024 // 32 MiB +) + +// buffer is a thread-safe buffer that returns EOF +// only when it's closed. +type buffer struct { + mu *sync.Mutex + // +checklocks:mu + b *bytes.Buffer + closed *atomic.Bool + close chan struct{} + more chan struct{} +} + +var _ io.WriteCloser = (*buffer)(nil) + +func newBuffer(size int) *buffer { + return &buffer{ + mu: &sync.Mutex{}, + b: bytes.NewBuffer(make([]byte, 0, size)), + closed: &atomic.Bool{}, + close: make(chan struct{}), + more: make(chan struct{}), + } +} + +func (b *buffer) Write(p []byte) (int, error) { + if b.closed.Load() { + return 0, errors.New("closed") + } + + b.mu.Lock() + n, err := b.b.Write(p) + b.mu.Unlock() + + select { + case b.more <- struct{}{}: + default: + } + + return n, err +} + +func (b *buffer) Close() error { + if b.closed.CompareAndSwap(false, true) { + close(b.close) + } + return nil +} + +func (b *buffer) Read(p []byte) (int, error) { + b.mu.Lock() + n, err := b.b.Read(p) + b.mu.Unlock() + + if err != nil && errors.Is(err, io.EOF) && !b.closed.Load() { + select { + case <-b.more: + case <-b.close: + return n, io.EOF + } + return n, nil + } + + return n, err +} diff --git a/internal/runnerv2service/execution.go b/internal/runnerv2service/execution.go index 9038a6050..ee023e450 100644 --- a/internal/runnerv2service/execution.go +++ b/internal/runnerv2service/execution.go @@ -7,13 +7,13 @@ import ( "os" "os/exec" "regexp" - "sync" - "sync/atomic" "syscall" + "time" + + "go.uber.org/zap" "github.com/gabriel-vasile/mimetype" "github.com/pkg/errors" - "go.uber.org/zap" "github.com/stateful/runme/v3/internal/command" "github.com/stateful/runme/v3/internal/rbuffer" @@ -28,20 +28,18 @@ func matchesOpinionatedEnvVarNaming(knownName string) bool { return opininatedEnvVarNamingRegexp.MatchString(knownName) } -//lint:ignore U1000 Used in A/B testing type execution struct { - Cmd command.Command + Cmd command.Command + knownName string logger *zap.Logger session *session.Session - stdin io.Reader - stdinWriter io.WriteCloser - stdout *buffer - stderr *buffer storeStdoutInEnv bool + + stdinR, stdoutR, stderrR io.Reader + stdinW, stdoutW, stderrW io.WriteCloser } -//lint:ignore U1000 Used in A/B testing func newExecution( cfg *command.ProgramConfig, proj *project.Project, @@ -49,22 +47,24 @@ func newExecution( logger *zap.Logger, storeStdoutInEnv bool, ) (*execution, error) { + logger = logger.Named("execution") + cmdFactory := command.NewFactory( - command.WithLogger(logger), command.WithProject(proj), + command.WithLogger(logger), ) - stdin, stdinWriter := io.Pipe() - stdout := newBuffer(msgBufferSize) - stderr := newBuffer(msgBufferSize) + stdinR, stdinW := io.Pipe() + stdoutR, stdoutW := io.Pipe() + stderrR, stderrW := io.Pipe() cmdOptions := command.CommandOptions{ EnableEcho: true, Session: session, - StdinWriter: stdinWriter, - Stdin: stdin, - Stdout: stdout, - Stderr: stderr, + StdinWriter: stdinW, + Stdin: stdinR, + Stdout: stdoutW, + Stderr: stderrW, } cmd, err := cmdFactory.Build(cfg, cmdOptions) @@ -73,133 +73,198 @@ func newExecution( } exec := &execution{ - Cmd: cmd, + Cmd: cmd, + knownName: cfg.GetKnownName(), logger: logger, session: session, - stdin: stdin, - stdinWriter: stdinWriter, - stdout: stdout, - stderr: stderr, storeStdoutInEnv: storeStdoutInEnv, - } + stdinR: stdinR, + stdinW: stdinW, + stdoutR: stdoutR, + stdoutW: stdoutW, + stderrR: stderrR, + stderrW: stderrW, + } return exec, nil } +func (e *execution) closeIO() { + err := e.stdinW.Close() + e.logger.Info("closed stdin writer", zap.Error(err)) + + err = e.stdoutW.Close() + e.logger.Info("closed stdout writer", zap.Error(err)) + + err = e.stderrW.Close() + e.logger.Info("closed stderr writer", zap.Error(err)) +} + +func (e *execution) storeOutputInEnv(ctx context.Context, r io.Reader) { + b, err := io.ReadAll(r) + if err != nil { + e.logger.Warn("failed to read last output", zap.Error(err)) + return + } + + sanitized := bytes.ReplaceAll(b, []byte{'\000'}, nil) + env := command.CreateEnv(command.StoreStdoutEnvName, string(sanitized)) + if err := e.session.SetEnv(ctx, env); err != nil { + e.logger.Warn("failed to store last output", zap.Error(err)) + } + + if e.knownName != "" && matchesOpinionatedEnvVarNaming(e.knownName) { + if err := e.session.SetEnv(ctx, e.knownName+"="+string(sanitized)); err != nil { + e.logger.Warn("failed to store output under known name", zap.String("known_name", e.knownName), zap.Error(err)) + } + } +} + func (e *execution) Wait(ctx context.Context, sender runnerv2.RunnerService_ExecuteServer) (int, error) { - lastStdout := io.Discard + envStdout := io.Discard if e.storeStdoutInEnv { b := rbuffer.NewRingBuffer(session.MaxEnvSizeInBytes - len(command.StoreStdoutEnvName) - 1) defer func() { _ = b.Close() e.storeOutputInEnv(ctx, b) }() - lastStdout = b + envStdout = b } - firstStdoutSent := false - errc := make(chan error, 2) - + readSendDone := make(chan error, 2) go func() { - errc <- readSendLoop( - e.stdout, + mimetypeDetected := false + + readSendDone <- e.readSendLoop( sender, + e.stdoutR, func(b []byte) *runnerv2.ExecuteResponse { - if len(b) == 0 { - return nil + if _, err := envStdout.Write(b); err != nil { + e.logger.Warn("failed to write to envStdout writer", zap.Error(err)) + envStdout = io.Discard } - _, err := lastStdout.Write(b) - if err != nil { - e.logger.Warn("failed to write last output", zap.Error(err)) + response := &runnerv2.ExecuteResponse{ + StdoutData: b, } - resp := &runnerv2.ExecuteResponse{StdoutData: b} - - if !firstStdoutSent { - if detected := mimetype.Detect(b); detected != nil { - e.logger.Info("detected MIME type", zap.String("mime", detected.String())) - resp.MimeType = detected.String() + if !mimetypeDetected { + if detected := mimetype.Detect(response.StdoutData); detected != nil { + mimetypeDetected = true + response.MimeType = detected.String() + e.logger.Debug("detected MIME type", zap.String("mime", detected.String())) + } else { + e.logger.Debug("failed to detect MIME type") } } - firstStdoutSent = true - - e.logger.Debug("sending stdout data", zap.Int("len", len(resp.StdoutData))) - return resp + return response }, - e.logger.With(zap.String("source", "stdout")), + e.logger.Named("readSendLoop.stdout"), ) }() go func() { - errc <- readSendLoop( - e.stderr, + readSendDone <- e.readSendLoop( sender, + e.stderrR, func(b []byte) *runnerv2.ExecuteResponse { - if len(b) == 0 { - return nil + return &runnerv2.ExecuteResponse{ + StderrData: b, } - resp := &runnerv2.ExecuteResponse{StderrData: b} - e.logger.Debug("sending stderr data", zap.Any("resp", resp)) - return resp }, - e.logger.With(zap.String("source", "stderr")), + e.logger.Named("readSendLoop.stderr"), ) }() waitErr := e.Cmd.Wait(ctx) exitCode := exitCodeFromErr(waitErr) - e.logger.Info("command finished", zap.Int("exitCode", exitCode), zap.Error(waitErr)) e.closeIO() - // If waitErr is not nil, only log the errors but return waitErr. if waitErr != nil { - handlerErrors := 0 - - readSendHandlerForWaitErr: - select { - case err := <-errc: - handlerErrors++ - e.logger.Info("readSendLoop finished; ignoring any errors because there was a wait error", zap.Error(err)) - // Wait for both errors, or nils. - if handlerErrors < 2 { - goto readSendHandlerForWaitErr - } - case <-ctx.Done(): - e.logger.Info("context canceled while waiting for the readSendLoop finish; ignoring any errors because there was a wait error") - } return exitCode, waitErr } - // If waitErr is nil, wait for the readSendLoop to finish, - // or the context being canceled. + readSendLoopsFinished := 0 + +finalWait: select { - case err1 := <-errc: - // Wait for both errors, or nils. - select { - case err2 := <-errc: - if err2 != nil { - e.logger.Info("another error from readSendLoop; won't be returned", zap.Error(err2)) - } - case <-ctx.Done(): - } - return exitCode, err1 case <-ctx.Done(): + e.logger.Info("context done", zap.Error(ctx.Err())) return exitCode, ctx.Err() + case err := <-readSendDone: + if err != nil { + e.logger.Info("readSendCtx done", zap.Error(err)) + } + readSendLoopsFinished++ + if readSendLoopsFinished < 2 { + goto finalWait + } + return exitCode, err + } +} + +func (e *execution) readSendLoop( + sender runnerv2.RunnerService_ExecuteServer, + src io.Reader, + cb func([]byte) *runnerv2.ExecuteResponse, + logger *zap.Logger, +) error { + // Limit to 30 sends per second. This is typically quite enough + // for interactive commands and streaming the output. + const sendsPerSecond = 30 + + buf := newBuffer(msgBufferSize) + + // Copy from src to buffer. + go func() { + n, err := io.Copy(buf, src) + logger.Debug("copied from source to buffer", zap.Int64("count", n), zap.Error(err)) + _ = buf.Close() // always nil + }() + + data := make([]byte, msgBufferSize) + + for { + eof := false + n, err := buf.Read(data) + if err != nil { + if !errors.Is(err, io.EOF) { + return errors.WithStack(err) + } + eof = true + } + logger.Debug("read", zap.Int("n", n), zap.Bool("eof", eof)) + if n == 0 { + if eof { + return nil + } + continue + } + + readTime := time.Now() + + response := cb(data[:n]) + if err := sender.Send(response); err != nil { + return errors.WithStack(err) + } + + if n < msgBufferSize { + time.Sleep(time.Second/sendsPerSecond - time.Since(readTime)) + } } } func (e *execution) Write(p []byte) (int, error) { - n, err := e.stdinWriter.Write(p) + n, err := e.stdinW.Write(p) // Close stdin writer for non-interactive commands after handling the initial request. // Non-interactive commands do not support sending data continuously and require that // the stdin writer to be closed to finish processing the input. if ok := e.Cmd.Interactive(); !ok { - if closeErr := e.stdinWriter.Close(); closeErr != nil { + if closeErr := e.stdinW.Close(); closeErr != nil { e.logger.Info("failed to close native command stdin writer", zap.Error(closeErr)) if err == nil { err = closeErr @@ -227,6 +292,8 @@ func (e *execution) SetWinsize(size *runnerv2.Winsize) error { } func (e *execution) Stop(stop runnerv2.ExecuteStop) (err error) { + e.logger.Info("stopping program", zap.Any("stop", stop)) + switch stop { case runnerv2.ExecuteStop_EXECUTE_STOP_UNSPECIFIED: // continue @@ -240,77 +307,6 @@ func (e *execution) Stop(stop runnerv2.ExecuteStop) (err error) { return } -func (e *execution) closeIO() { - err := e.stdinWriter.Close() - e.logger.Info("closed stdin writer", zap.Error(err)) - - err = e.stdout.Close() - e.logger.Info("closed stdout writer", zap.Error(err)) - - err = e.stderr.Close() - e.logger.Info("closed stderr writer", zap.Error(err)) -} - -func (e *execution) storeOutputInEnv(ctx context.Context, r io.Reader) { - b, err := io.ReadAll(r) - if err != nil { - e.logger.Warn("failed to read last output", zap.Error(err)) - return - } - - sanitized := bytes.ReplaceAll(b, []byte{'\000'}, nil) - env := command.CreateEnv(command.StoreStdoutEnvName, string(sanitized)) - if err := e.session.SetEnv(ctx, env); err != nil { - e.logger.Warn("failed to store last output", zap.Error(err)) - } - - if e.knownName != "" && matchesOpinionatedEnvVarNaming(e.knownName) { - if err := e.session.SetEnv(ctx, e.knownName+"="+string(sanitized)); err != nil { - e.logger.Warn("failed to store output under known name", zap.String("known_name", e.knownName), zap.Error(err)) - } - } -} - -type sender interface { - Send(*runnerv2.ExecuteResponse) error -} - -func readSendLoop( - reader io.Reader, - sender sender, - fn func([]byte) *runnerv2.ExecuteResponse, - logger *zap.Logger, -) error { - buf := make([]byte, msgBufferSize) - - for { - eof := false - n, err := reader.Read(buf) - if err != nil { - if !errors.Is(err, io.EOF) { - return errors.WithStack(err) - } - eof = true - } - - logger.Debug("readSendLoop", zap.Int("n", n)) - - if n == 0 && eof { - return nil - } - - msg := fn(buf[:n]) - if msg == nil { - continue - } - - err = sender.Send(msg) - if err != nil { - return errors.WithStack(err) - } - } -} - func exitCodeFromErr(err error) int { if err == nil { return 0 @@ -330,69 +326,3 @@ func exitCodeFromErr(err error) int { } return -1 } - -type buffer struct { - mu *sync.Mutex - // +checklocks:mu - b *bytes.Buffer - closed *atomic.Bool - close chan struct{} - more chan struct{} -} - -var _ io.WriteCloser = (*buffer)(nil) - -func newBuffer(size int) *buffer { - return &buffer{ - mu: &sync.Mutex{}, - b: bytes.NewBuffer(make([]byte, 0, size)), - closed: &atomic.Bool{}, - close: make(chan struct{}), - more: make(chan struct{}), - } -} - -func (b *buffer) Write(p []byte) (int, error) { - if b.closed.Load() { - return 0, errors.New("closed") - } - - b.mu.Lock() - n, err := b.b.Write(p) - b.mu.Unlock() - - select { - case b.more <- struct{}{}: - default: - } - - return n, err -} - -func (b *buffer) Close() error { - if b.closed.CompareAndSwap(false, true) { - close(b.close) - } - return nil -} - -func (b *buffer) Read(p []byte) (int, error) { - b.mu.Lock() - n, err := b.b.Read(p) - b.mu.Unlock() - - if err != nil && errors.Is(err, io.EOF) && !b.closed.Load() { - if n > 0 { - return n, nil - } - - select { - case <-b.more: - return b.Read(p) - case <-b.close: - return 0, io.EOF - } - } - - return n, err -} diff --git a/internal/runnerv2service/execution2.go b/internal/runnerv2service/execution2.go deleted file mode 100644 index 1f9b57fe5..000000000 --- a/internal/runnerv2service/execution2.go +++ /dev/null @@ -1,306 +0,0 @@ -package runnerv2service - -import ( - "bytes" - "context" - "io" - "os" - "time" - - "go.uber.org/zap" - - "github.com/gabriel-vasile/mimetype" - "github.com/pkg/errors" - "github.com/stateful/runme/v3/internal/command" - "github.com/stateful/runme/v3/internal/rbuffer" - "github.com/stateful/runme/v3/internal/session" - runnerv2 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/runner/v2" - "github.com/stateful/runme/v3/pkg/project" -) - -const ( - // msgBufferSize limits the size of data chunks - // sent by the handler to clients. It's smaller - // intentionally as typically the messages are - // small. - // In the future, it might be worth to implement - // variable-sized buffers. - msgBufferSize = 32 * 1024 * 1024 // 4 MiB -) - -//lint:ignore U1000 Used in A/B testing -type execution2 struct { - Cmd command.Command - - knownName string - logger *zap.Logger - session *session.Session - storeStdoutInEnv bool - - stdinR, stdoutR, stderrR io.Reader - stdinW, stdoutW, stderrW io.WriteCloser -} - -//lint:ignore U1000 Used in A/B testing -func newExecution2( - cfg *command.ProgramConfig, - proj *project.Project, - session *session.Session, - logger *zap.Logger, - storeStdoutInEnv bool, -) (*execution2, error) { - logger = logger.Named("execution2") - - cmdFactory := command.NewFactory( - command.WithProject(proj), - command.WithLogger(logger), - ) - - stdinR, stdinW := io.Pipe() - stdoutR, stdoutW := io.Pipe() - stderrR, stderrW := io.Pipe() - - cmdOptions := command.CommandOptions{ - EnableEcho: true, - Session: session, - StdinWriter: stdinW, - Stdin: stdinR, - Stdout: stdoutW, - Stderr: stderrW, - } - - cmd, err := cmdFactory.Build(cfg, cmdOptions) - if err != nil { - return nil, err - } - - exec := &execution2{ - Cmd: cmd, - - knownName: cfg.GetKnownName(), - logger: logger, - session: session, - storeStdoutInEnv: storeStdoutInEnv, - - stdinR: stdinR, - stdinW: stdinW, - stdoutR: stdoutR, - stdoutW: stdoutW, - stderrR: stderrR, - stderrW: stderrW, - } - return exec, nil -} - -func (e *execution2) closeIO() { - err := e.stdinW.Close() - e.logger.Info("closed stdin writer", zap.Error(err)) - - err = e.stdoutW.Close() - e.logger.Info("closed stdout writer", zap.Error(err)) - - err = e.stderrW.Close() - e.logger.Info("closed stderr writer", zap.Error(err)) -} - -func (e *execution2) storeOutputInEnv(ctx context.Context, r io.Reader) { - b, err := io.ReadAll(r) - if err != nil { - e.logger.Warn("failed to read last output", zap.Error(err)) - return - } - - sanitized := bytes.ReplaceAll(b, []byte{'\000'}, nil) - env := command.CreateEnv(command.StoreStdoutEnvName, string(sanitized)) - if err := e.session.SetEnv(ctx, env); err != nil { - e.logger.Warn("failed to store last output", zap.Error(err)) - } - - if e.knownName != "" && matchesOpinionatedEnvVarNaming(e.knownName) { - if err := e.session.SetEnv(ctx, e.knownName+"="+string(sanitized)); err != nil { - e.logger.Warn("failed to store output under known name", zap.String("known_name", e.knownName), zap.Error(err)) - } - } -} - -func (e *execution2) Wait(ctx context.Context, sender runnerv2.RunnerService_ExecuteServer) (int, error) { - envStdout := io.Discard - if e.storeStdoutInEnv { - b := rbuffer.NewRingBuffer(session.MaxEnvSizeInBytes - len(command.StoreStdoutEnvName) - 1) - defer func() { - _ = b.Close() - e.storeOutputInEnv(ctx, b) - }() - envStdout = b - } - - readSendDone := make(chan error, 2) - go func() { - mimetypeDetected := false - - readSendDone <- e.readSendLoop( - sender, - e.stdoutR, - func(b []byte) *runnerv2.ExecuteResponse { - if _, err := envStdout.Write(b); err != nil { - e.logger.Warn("failed to write to envStdout writer", zap.Error(err)) - envStdout = io.Discard - } - - response := &runnerv2.ExecuteResponse{ - StdoutData: b, - } - - if !mimetypeDetected { - if detected := mimetype.Detect(response.StdoutData); detected != nil { - mimetypeDetected = true - response.MimeType = detected.String() - e.logger.Debug("detected MIME type", zap.String("mime", detected.String())) - } else { - e.logger.Debug("failed to detect MIME type") - } - } - - return response - }, - e.logger.Named("readSendLoop.stdout"), - ) - }() - go func() { - readSendDone <- e.readSendLoop( - sender, - e.stderrR, - func(b []byte) *runnerv2.ExecuteResponse { - return &runnerv2.ExecuteResponse{ - StderrData: b, - } - }, - e.logger.Named("readSendLoop.stderr"), - ) - }() - - waitErr := e.Cmd.Wait(ctx) - exitCode := exitCodeFromErr(waitErr) - e.logger.Info("command finished", zap.Int("exitCode", exitCode), zap.Error(waitErr)) - - e.closeIO() - - if waitErr != nil { - return exitCode, waitErr - } - - readSendLoopsFinished := 0 - -finalWait: - select { - case <-ctx.Done(): - e.logger.Info("context done", zap.Error(ctx.Err())) - return exitCode, ctx.Err() - case err := <-readSendDone: - if err != nil { - e.logger.Info("readSendCtx done", zap.Error(err)) - } - readSendLoopsFinished++ - if readSendLoopsFinished < 2 { - goto finalWait - } - return exitCode, err - } -} - -func (e *execution2) readSendLoop( - sender runnerv2.RunnerService_ExecuteServer, - src io.Reader, - cb func([]byte) *runnerv2.ExecuteResponse, - logger *zap.Logger, -) error { - const sendsPerSecond = 30 - - buf := newBuffer(msgBufferSize) - - // Copy from src to [buffer]. - go func() { - n, err := io.Copy(buf, src) - logger.Debug("copied from source to buffer", zap.Int64("count", n), zap.Error(err)) - _ = buf.Close() // always nil - }() - - data := make([]byte, msgBufferSize) - - for { - eof := false - n, err := buf.Read(data) - if err != nil { - if !errors.Is(err, io.EOF) { - return errors.WithStack(err) - } - eof = true - } - logger.Debug("read", zap.Int("n", n), zap.Bool("eof", eof)) - if n == 0 { - if eof { - return nil - } - continue - } - - readTime := time.Now() - - response := cb(data[:n]) - if err := sender.Send(response); err != nil { - return errors.WithStack(err) - } - - if n < msgBufferSize { - time.Sleep(time.Second/sendsPerSecond - time.Since(readTime)) - } - } -} - -func (e *execution2) Write(p []byte) (int, error) { - n, err := e.stdinW.Write(p) - - // Close stdin writer for non-interactive commands after handling the initial request. - // Non-interactive commands do not support sending data continuously and require that - // the stdin writer to be closed to finish processing the input. - if ok := e.Cmd.Interactive(); !ok { - if closeErr := e.stdinW.Close(); closeErr != nil { - e.logger.Info("failed to close native command stdin writer", zap.Error(closeErr)) - if err == nil { - err = closeErr - } - } - } - - return n, errors.WithStack(err) -} - -func (e *execution2) SetWinsize(size *runnerv2.Winsize) error { - if size == nil { - return nil - } - - return command.SetWinsize( - e.Cmd, - &command.Winsize{ - Rows: uint16(size.Rows), - Cols: uint16(size.Cols), - X: uint16(size.X), - Y: uint16(size.Y), - }, - ) -} - -func (e *execution2) Stop(stop runnerv2.ExecuteStop) (err error) { - switch stop { - case runnerv2.ExecuteStop_EXECUTE_STOP_UNSPECIFIED: - // continue - case runnerv2.ExecuteStop_EXECUTE_STOP_INTERRUPT: - err = e.Cmd.Signal(os.Interrupt) - case runnerv2.ExecuteStop_EXECUTE_STOP_KILL: - err = e.Cmd.Signal(os.Kill) - default: - err = errors.New("unknown stop signal") - } - return -} diff --git a/internal/runnerv2service/service_execute.go b/internal/runnerv2service/service_execute.go index 51f129f5e..10d76a768 100644 --- a/internal/runnerv2service/service_execute.go +++ b/internal/runnerv2service/service_execute.go @@ -17,7 +17,7 @@ import ( func (r *runnerService) Execute(srv runnerv2.RunnerService_ExecuteServer) error { runID := ulid.GenerateID() - logger := r.logger.With(zap.String("id", runID)) + logger := r.logger.Named("Execute").With(zap.String("id", runID)) // Get the initial request. req, err := srv.Recv() @@ -66,7 +66,7 @@ func (r *runnerService) Execute(srv runnerv2.RunnerService_ExecuteServer) error // return err // } - exec, err := newExecution2( + exec, err := newExecution( req.Config, proj, session, diff --git a/internal/server/server.go b/internal/server/server.go index c02a4c195..afbeb6182 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -23,7 +23,7 @@ import ( "github.com/stateful/runme/v3/pkg/document/editor/editorservice" ) -const maxMsgSize = 32 * 1024 * 1024 // 32 MiB +const MaxMsgSize = 32 * 1024 * 1024 // 32 MiB type Config struct { Address string @@ -53,11 +53,11 @@ func New( } } + var lis net.Listener + addr := c.Address protocol := "tcp" - var lis net.Listener - if strings.HasPrefix(addr, "unix://") { protocol = "unix" addr = strings.TrimPrefix(addr, "unix://") @@ -79,8 +79,8 @@ func New( logger.Info("server listening", zap.String("address", addr)) grpcServer := grpc.NewServer( - grpc.MaxRecvMsgSize(maxMsgSize), - grpc.MaxSendMsgSize(maxMsgSize), + grpc.MaxRecvMsgSize(MaxMsgSize), + grpc.MaxSendMsgSize(MaxMsgSize), ) // Register runme services. From 36a12febbd86b2b87e7e07c1c7d9d53f85f7681d Mon Sep 17 00:00:00 2001 From: Adam Babik Date: Mon, 6 Jan 2025 17:50:17 +0100 Subject: [PATCH 5/5] Refactor client and server --- internal/cmd/beta/server/server_start_cmd.go | 19 +-- internal/config/autoconfig/autoconfig.go | 87 +++++++--- internal/config/autoconfig/autoconfig_test.go | 155 +++++++++++++++++- internal/config/config.go | 4 +- internal/config/config.schema.json | 4 + internal/config/config_schema.go | 6 + internal/config/runme.default.yaml | 1 + internal/runnerv2client/client.go | 22 +-- internal/runnerv2client/client_test.go | 11 +- internal/runnerv2service/buffer.go | 2 +- internal/runnerv2service/service_test.go | 29 ---- internal/server/server.go | 144 +++++++++------- internal/server/server_test.go | 106 ------------ internal/server/server_unix_test.go | 36 ---- internal/testutils/grpc.go | 4 +- 15 files changed, 324 insertions(+), 306 deletions(-) delete mode 100644 internal/runnerv2service/service_test.go delete mode 100644 internal/server/server_test.go delete mode 100644 internal/server/server_unix_test.go diff --git a/internal/cmd/beta/server/server_start_cmd.go b/internal/cmd/beta/server/server_start_cmd.go index d1831a2c1..ca328ab4c 100644 --- a/internal/cmd/beta/server/server_start_cmd.go +++ b/internal/cmd/beta/server/server_start_cmd.go @@ -22,27 +22,14 @@ func serverStartCmd() *cobra.Command { return autoconfig.InvokeForCommand( func( cfg *config.Config, + server *server.Server, cmdFactory command.Factory, logger *zap.Logger, ) error { defer logger.Sync() - serverCfg := &server.Config{ - Address: cfg.Server.Address, - CertFile: *cfg.Server.Tls.CertFile, // guaranteed by autoconfig - KeyFile: *cfg.Server.Tls.KeyFile, // guaranteed by autoconfig - TLSEnabled: cfg.Server.Tls.Enabled, - } - _ = telemetry.ReportUnlessNoTracking(logger) - logger.Debug("server config", zap.Any("config", serverCfg)) - - s, err := server.New(serverCfg, cmdFactory, logger) - if err != nil { - return err - } - // When using a unix socket, we want to create a file with server's PID. if path := pidFileNameFromAddr(cfg.Server.Address); path != "" { logger.Debug("creating PID file", zap.String("path", path)) @@ -52,9 +39,7 @@ func serverStartCmd() *cobra.Command { defer os.Remove(cfg.Server.Address) } - logger.Debug("starting the server") - - return errors.WithStack(s.Serve()) + return errors.WithStack(server.Serve()) }, ) }, diff --git a/internal/config/autoconfig/autoconfig.go b/internal/config/autoconfig/autoconfig.go index 87333966f..2b4400852 100644 --- a/internal/config/autoconfig/autoconfig.go +++ b/internal/config/autoconfig/autoconfig.go @@ -24,8 +24,12 @@ import ( "github.com/stateful/runme/v3/internal/command" "github.com/stateful/runme/v3/internal/config" "github.com/stateful/runme/v3/internal/dockerexec" + "github.com/stateful/runme/v3/internal/project/projectservice" "github.com/stateful/runme/v3/internal/runnerv2client" + "github.com/stateful/runme/v3/internal/runnerv2service" + "github.com/stateful/runme/v3/internal/server" runmetls "github.com/stateful/runme/v3/internal/tls" + "github.com/stateful/runme/v3/pkg/document/editor/editorservice" "github.com/stateful/runme/v3/pkg/project" ) @@ -70,44 +74,27 @@ func init() { mustProvide(container.Provide(getCommandFactory)) mustProvide(container.Provide(getConfigLoader)) mustProvide(container.Provide(getDocker)) + mustProvide(container.Provide(getGRPCClient)) mustProvide(container.Provide(getLogger)) mustProvide(container.Provide(getProject)) mustProvide(container.Provide(getProjectFilters)) mustProvide(container.Provide(getRootConfig)) + mustProvide(container.Provide(getServer)) mustProvide(container.Provide(getUserConfigDir)) } -func getClient(cfg *config.Config, logger *zap.Logger) (*runnerv2client.Client, error) { - if cfg.Server == nil { - return nil, nil - } - - var opts []grpc.DialOption - - if cfg.Server.Tls != nil && cfg.Server.Tls.Enabled { - // It's ok to dereference TLS fields because they are checked in [getRootConfig]. - tlsConfig, err := runmetls.LoadClientConfig(*cfg.Server.Tls.CertFile, *cfg.Server.Tls.KeyFile) - if err != nil { - return nil, errors.WithStack(err) - } - creds := credentials.NewTLS(tlsConfig) - opts = append(opts, grpc.WithTransportCredentials(creds)) - } else { - opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) +func getClient(cfg *config.Config, clientConn *grpc.ClientConn, logger *zap.Logger) (*runnerv2client.Client, error) { + if clientConn == nil { + return nil, errors.New("client connection is not configured") } - - return runnerv2client.New( - cfg.Server.Address, - logger, - opts..., - ) + return runnerv2client.New(clientConn, logger), nil } type ClientFactory func() (*runnerv2client.Client, error) -func getClientFactory(cfg *config.Config, logger *zap.Logger) ClientFactory { +func getClientFactory(cfg *config.Config, clientConn *grpc.ClientConn, logger *zap.Logger) ClientFactory { return func() (*runnerv2client.Client, error) { - return getClient(cfg, logger) + return getClient(cfg, clientConn, logger) } } @@ -147,6 +134,35 @@ func getDocker(c *config.Config, logger *zap.Logger) (*dockerexec.Docker, error) return dockerexec.New(options) } +func getGRPCClient( + cfg *config.Config, + server *server.Server, + logger *zap.Logger, +) (*grpc.ClientConn, error) { + if cfg.Server == nil { + return nil, nil + } + + opts := []grpc.DialOption{ + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.Server.MaxMessageSize)), + } + + if tls := cfg.Server.Tls; tls != nil && tls.Enabled { + // It's ok to dereference TLS fields because they are checked in [getRootConfig]. + tlsConfig, err := runmetls.LoadClientConfig(*cfg.Server.Tls.CertFile, *cfg.Server.Tls.KeyFile) + if err != nil { + return nil, errors.WithStack(err) + } + creds := credentials.NewTLS(tlsConfig) + opts = append(opts, grpc.WithTransportCredentials(creds)) + } else { + opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + + conn, err := grpc.NewClient(server.Addr(), opts...) + return conn, errors.WithStack(err) +} + func getLogger(c *config.Config) (*zap.Logger, error) { if c == nil || c.Log == nil || !c.Log.Enabled { return zap.NewNop(), nil @@ -297,6 +313,27 @@ func getRootConfig(cfgLoader *config.Loader, userCfgDir UserConfigDir) (*config. return cfg, nil } +func getServer(cfg *config.Config, cmdFactory command.Factory, logger *zap.Logger) (*server.Server, error) { + if cfg.Server == nil { + return nil, nil + } + + parserService := editorservice.NewParserServiceServer(logger) + projectService := projectservice.NewProjectServiceServer(logger) + runnerService, err := runnerv2service.NewRunnerService(cmdFactory, logger) + if err != nil { + return nil, err + } + + return server.New( + cfg, + parserService, + projectService, + runnerService, + logger, + ) +} + type UserConfigDir string func getUserConfigDir() (UserConfigDir, error) { diff --git a/internal/config/autoconfig/autoconfig_test.go b/internal/config/autoconfig/autoconfig_test.go index 534d57576..d27ebc9b1 100644 --- a/internal/config/autoconfig/autoconfig_test.go +++ b/internal/config/autoconfig/autoconfig_test.go @@ -1,22 +1,28 @@ package autoconfig import ( + "context" "fmt" + "os" + "path/filepath" "testing" "testing/fstest" + "time" + "github.com/pkg/errors" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + healthv1 "google.golang.org/grpc/health/grpc_health_v1" "github.com/stateful/runme/v3/internal/config" + "github.com/stateful/runme/v3/internal/server" ) func TestInvokeForCommand_Config(t *testing.T) { // Create a fake filesystem and set it in [config.Loader]. err := InvokeForCommand(func(loader *config.Loader) error { fsys := fstest.MapFS{ - "README.md": { - Data: []byte("Hello, World!"), - }, "runme.yaml": { Data: []byte(fmt.Sprintf("version: v1alpha1\nproject:\n filename: %s\n", "README.md")), }, @@ -33,3 +39,146 @@ func TestInvokeForCommand_Config(t *testing.T) { }) require.NoError(t, err) } + +func TestInvokeForCommand_ServerClient(t *testing.T) { + tmp := t.TempDir() + readme := filepath.Join(tmp, "README.md") + err := os.WriteFile(readme, []byte("Hello, World!"), 0644) + require.NoError(t, err) + + t.Run("NoServerInConfig", func(t *testing.T) { + err := InvokeForCommand(func(loader *config.Loader) error { + fsys := fstest.MapFS{ + "runme.yaml": { + Data: []byte(fmt.Sprintf("version: v1alpha1\nproject:\n filename: %s\n", readme)), + }, + } + loader.SetConfigRootPath(fsys) + return nil + }) + require.NoError(t, err) + + err = InvokeForCommand(func( + server *server.Server, + client *grpc.ClientConn, + ) error { + require.Nil(t, server) + require.Nil(t, client) + return nil + }) + require.NoError(t, err) + }) + + t.Run("ServerInConfigWithoutTLS", func(t *testing.T) { + err := InvokeForCommand(func(loader *config.Loader) error { + fsys := fstest.MapFS{ + "runme.yaml": { + Data: []byte(`version: v1alpha1 +project: + filename: ` + readme + ` +server: + address: localhost:0 + tls: + enabled: false +`), + }, + } + loader.SetConfigRootPath(fsys) + return nil + }) + require.NoError(t, err) + + err = InvokeForCommand(func( + server *server.Server, + client *grpc.ClientConn, + ) error { + require.NotNil(t, server) + require.NotNil(t, client) + + var g errgroup.Group + + g.Go(func() error { + return server.Serve() + }) + + g.Go(func() error { + defer server.Shutdown() + return checkHealth(client) + }) + + return g.Wait() + }) + require.NoError(t, err) + }) + + t.Run("ServerInConfigWithTLS", func(t *testing.T) { + // Use a temp dir to store the TLS files. + err = DecorateRoot(func() (UserConfigDir, error) { + return UserConfigDir(tmp), nil + }) + require.NoError(t, err) + + err := InvokeForCommand(func(loader *config.Loader) error { + fsys := fstest.MapFS{ + "runme.yaml": { + Data: []byte(`version: v1alpha1 +project: + filename: ` + readme + ` +server: + address: 127.0.0.1:0 + tls: + enabled: true +`), + }, + } + loader.SetConfigRootPath(fsys) + return nil + }) + require.NoError(t, err) + + err = InvokeForCommand(func( + server *server.Server, + client *grpc.ClientConn, + ) error { + require.NotNil(t, server) + require.NotNil(t, client) + + var g errgroup.Group + + g.Go(func() error { + return server.Serve() + }) + + g.Go(func() error { + defer server.Shutdown() + return errors.WithMessage(checkHealth(client), "failed to check health") + }) + + return g.Wait() + }) + require.NoError(t, err) + }) +} + +func checkHealth(conn *grpc.ClientConn) error { + client := healthv1.NewHealthClient(conn) + + var ( + resp *healthv1.HealthCheckResponse + err error + ) + + for i := 0; i < 5; i++ { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + resp, err = client.Check(ctx, &healthv1.HealthCheckRequest{}) + if err != nil || resp.Status != healthv1.HealthCheckResponse_SERVING { + cancel() + time.Sleep(time.Second) + continue + } + cancel() + break + } + + return err +} diff --git a/internal/config/config.go b/internal/config/config.go index 676d67bfa..f4f43914f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -32,10 +32,8 @@ func Default() *Config { } // ParseYAML parses the given YAML items and returns a configuration object. -// Multiple items are merged into a single configuration. It uses a default -// configuration as a base. +// Multiple items are merged into a single configuration. func ParseYAML(items ...[]byte) (*Config, error) { - items = append([][]byte{defaultRunmeYAML}, items...) return parseYAML(items...) } diff --git a/internal/config/config.schema.json b/internal/config/config.schema.json index d8a21f4c0..75da9334f 100644 --- a/internal/config/config.schema.json +++ b/internal/config/config.schema.json @@ -116,6 +116,10 @@ "address": { "type": "string" }, + "max_message_size": { + "type": "integer", + "default": 33554432 + }, "tls": { "type": "object", "properties": { diff --git a/internal/config/config_schema.go b/internal/config/config_schema.go index 3e81b777f..07a6ab94c 100644 --- a/internal/config/config_schema.go +++ b/internal/config/config_schema.go @@ -295,6 +295,9 @@ type ConfigServer struct { // Address corresponds to the JSON schema field "address". Address string `json:"address" yaml:"address"` + // MaxMessageSize corresponds to the JSON schema field "max_message_size". + MaxMessageSize int `json:"max_message_size,omitempty" yaml:"max_message_size,omitempty"` + // Tls corresponds to the JSON schema field "tls". Tls *ConfigServerTls `json:"tls,omitempty" yaml:"tls,omitempty"` } @@ -342,6 +345,9 @@ func (j *ConfigServer) UnmarshalJSON(b []byte) error { if err := json.Unmarshal(b, &plain); err != nil { return err } + if v, ok := raw["max_message_size"]; !ok || v == nil { + plain.MaxMessageSize = 33554432.0 + } *j = ConfigServer(plain) return nil } diff --git a/internal/config/runme.default.yaml b/internal/config/runme.default.yaml index 877375263..59997e3f2 100644 --- a/internal/config/runme.default.yaml +++ b/internal/config/runme.default.yaml @@ -28,6 +28,7 @@ server: # If not specified, default paths will be used. # cert_file: "/path/to/cert.pem" # key_file: "/path/to/key.pem" + max_message_size: 33554432 # 32 MiB log: enabled: false diff --git a/internal/runnerv2client/client.go b/internal/runnerv2client/client.go index bd8c63da3..8ac51516e 100644 --- a/internal/runnerv2client/client.go +++ b/internal/runnerv2client/client.go @@ -20,25 +20,11 @@ type Client struct { logger *zap.Logger } -func New(target string, logger *zap.Logger, opts ...grpc.DialOption) (*Client, error) { - opts = append( - // default options - []grpc.DialOption{ - grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxMsgSize)), - }, - opts..., - ) - - client, err := grpc.NewClient(target, opts...) - if err != nil { - return nil, errors.WithStack(err) - } - serviceClient := &Client{ - RunnerServiceClient: runnerv2.NewRunnerServiceClient(client), - conn: client, - logger: logger, +func New(clientConn *grpc.ClientConn, logger *zap.Logger) *Client { + return &Client{ + RunnerServiceClient: runnerv2.NewRunnerServiceClient(clientConn), + logger: logger.Named("runnerv2client.Client"), } - return serviceClient, nil } func (c *Client) Close() error { diff --git a/internal/runnerv2client/client_test.go b/internal/runnerv2client/client_test.go index d2d8cb960..795963462 100644 --- a/internal/runnerv2client/client_test.go +++ b/internal/runnerv2client/client_test.go @@ -108,15 +108,18 @@ func TestClient_ExecuteProgram(t *testing.T) { func createClient(t *testing.T, lis *bufconn.Listener) *Client { t.Helper() - logger := zaptest.NewLogger(t) - client, err := New( + + clientConn, err := grpc.NewClient( "passthrough://bufconn", - logger, grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { return lis.Dial() }), grpc.WithTransportCredentials(insecure.NewCredentials()), ) require.NoError(t, err) - return client + + return New( + clientConn, + zaptest.NewLogger(t), + ) } diff --git a/internal/runnerv2service/buffer.go b/internal/runnerv2service/buffer.go index 07bcf5d5e..b3feb2da8 100644 --- a/internal/runnerv2service/buffer.go +++ b/internal/runnerv2service/buffer.go @@ -16,7 +16,7 @@ const ( // small. // In the future, it might be worth to implement // variable-sized buffers. - msgBufferSize = 32 * 1024 * 1024 // 32 MiB + msgBufferSize = 24 * 1024 * 1024 // 24 MiB ) // buffer is a thread-safe buffer that returns EOF diff --git a/internal/runnerv2service/service_test.go b/internal/runnerv2service/service_test.go deleted file mode 100644 index c661dd5ec..000000000 --- a/internal/runnerv2service/service_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package runnerv2service - -import ( - "testing/fstest" - - "github.com/stateful/runme/v3/internal/command" - "github.com/stateful/runme/v3/internal/config" - "github.com/stateful/runme/v3/internal/config/autoconfig" -) - -func init() { - command.SetEnvDumpCommandForTesting() - - // Server uses autoconfig to get necessary dependencies. - // One of them, implicit, is [config.Config]. With the default - // [config.Loader] it won't be found during testing, so - // we need to provide an override. - if err := autoconfig.DecorateRoot(func(loader *config.Loader) *config.Loader { - fsys := fstest.MapFS{ - "runme.yaml": { - Data: []byte("version: v1alpha1\n"), - }, - } - loader.SetConfigRootPath(fsys) - return loader - }); err != nil { - panic(err) - } -} diff --git a/internal/server/server.go b/internal/server/server.go index afbeb6182..355a14d18 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -6,90 +6,124 @@ import ( "os" "strings" - "github.com/pkg/errors" "go.uber.org/zap" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/health" healthv1 "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/reflection" - "github.com/stateful/runme/v3/internal/command" - "github.com/stateful/runme/v3/internal/project/projectservice" - "github.com/stateful/runme/v3/internal/runnerv2service" + "github.com/pkg/errors" + "github.com/stateful/runme/v3/internal/config" runmetls "github.com/stateful/runme/v3/internal/tls" parserv1 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/parser/v1" projectv1 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/project/v1" runnerv2 "github.com/stateful/runme/v3/pkg/api/gen/proto/go/runme/runner/v2" - "github.com/stateful/runme/v3/pkg/document/editor/editorservice" ) -const MaxMsgSize = 32 * 1024 * 1024 // 32 MiB - -type Config struct { - Address string - CertFile string - KeyFile string - TLSEnabled bool -} - type Server struct { - grpcServer *grpc.Server - lis net.Listener - logger *zap.Logger + gs *grpc.Server + lis net.Listener + logger *zap.Logger } func New( - c *Config, - cmdFactory command.Factory, + cfg *config.Config, + parserService parserv1.ParserServiceServer, + projectService projectv1.ProjectServiceServer, + runnerService runnerv2.RunnerServiceServer, logger *zap.Logger, -) (_ *Server, err error) { - var tlsConfig *tls.Config +) (*Server, error) { + tlsCfg, err := createTLSConfig(cfg, logger) + if err != nil { + return nil, err + } - if c.TLSEnabled { - // TODO(adamb): redesign runmetls API. - tlsConfig, err = runmetls.LoadOrGenerateConfig(c.CertFile, c.KeyFile, logger) - if err != nil { - return nil, err - } + lis, err := createListener(cfg, tlsCfg) + if err != nil { + return nil, err } - var lis net.Listener + grpcServer := createGRPCServer( + cfg, + tlsCfg, + parserService, + projectService, + runnerService, + ) + + s := Server{ + gs: grpcServer, + lis: lis, + logger: logger.Named("Server"), + } + + return &s, nil +} + +func (s *Server) Addr() string { + return s.lis.Addr().String() +} + +func (s *Server) Serve() error { + s.logger.Info("starting gRPC server", zap.String("address", s.Addr())) + return s.gs.Serve(s.lis) +} - addr := c.Address +func (s *Server) Shutdown() { + s.logger.Info("stopping gRPC server") + s.gs.GracefulStop() +} + +func createTLSConfig(cfg *config.Config, logger *zap.Logger) (*tls.Config, error) { + if tls := cfg.Server.Tls; tls != nil && tls.Enabled { + // TODO(adamb): redesign runmetls API. + return runmetls.LoadOrGenerateConfig( + *tls.CertFile, // guaranteed in [getRootConfig] + *tls.KeyFile, // guaranteed in [getRootConfig] + logger, + ) + } + return nil, nil +} + +func createListener(cfg *config.Config, tlsCfg *tls.Config) (net.Listener, error) { + addr := cfg.Server.Address protocol := "tcp" if strings.HasPrefix(addr, "unix://") { protocol = "unix" addr = strings.TrimPrefix(addr, "unix://") - if _, err := os.Stat(addr); !os.IsNotExist(err) { return nil, err } } - if tlsConfig == nil { - lis, err = net.Listen(protocol, addr) - } else { - lis, err = tls.Listen(protocol, addr, tlsConfig) - } - if err != nil { - return nil, errors.WithStack(err) + if tlsCfg != nil { + lis, err := tls.Listen(protocol, addr, tlsCfg) + return lis, errors.WithStack(err) } - logger.Info("server listening", zap.String("address", addr)) + lis, err := net.Listen(protocol, addr) + return lis, errors.WithStack(err) +} +func createGRPCServer( + cfg *config.Config, + tlsCfg *tls.Config, + parserService parserv1.ParserServiceServer, + projectService projectv1.ProjectServiceServer, + runnerService runnerv2.RunnerServiceServer, +) *grpc.Server { grpcServer := grpc.NewServer( - grpc.MaxRecvMsgSize(MaxMsgSize), - grpc.MaxSendMsgSize(MaxMsgSize), + grpc.MaxRecvMsgSize(cfg.Server.MaxMessageSize), + grpc.MaxSendMsgSize(cfg.Server.MaxMessageSize), + grpc.Creds(credentials.NewTLS(tlsCfg)), ) // Register runme services. - parserv1.RegisterParserServiceServer(grpcServer, editorservice.NewParserServiceServer(logger)) - projectv1.RegisterProjectServiceServer(grpcServer, projectservice.NewProjectServiceServer(logger)) - runnerService, err := runnerv2service.NewRunnerService(cmdFactory, logger) - if err != nil { - return nil, err - } + parserv1.RegisterParserServiceServer(grpcServer, parserService) + projectv1.RegisterProjectServiceServer(grpcServer, projectService) runnerv2.RegisterRunnerServiceServer(grpcServer, runnerService) // Register health service. @@ -101,21 +135,5 @@ func New( // Register reflection service. reflection.Register(grpcServer) - return &Server{ - lis: lis, - grpcServer: grpcServer, - logger: logger, - }, nil -} - -func (s *Server) Addr() string { - return s.lis.Addr().String() -} - -func (s *Server) Serve() error { - return s.grpcServer.Serve(s.lis) -} - -func (s *Server) Shutdown() { - s.grpcServer.GracefulStop() + return grpcServer } diff --git a/internal/server/server_test.go b/internal/server/server_test.go deleted file mode 100644 index 22f79b0b6..000000000 --- a/internal/server/server_test.go +++ /dev/null @@ -1,106 +0,0 @@ -package server - -import ( - "context" - "path/filepath" - "runtime" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/require" - "go.uber.org/zap/zaptest" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - healthv1 "google.golang.org/grpc/health/grpc_health_v1" - - "github.com/stateful/runme/v3/internal/command" - runmetls "github.com/stateful/runme/v3/internal/tls" -) - -func TestServer(t *testing.T) { - logger := zaptest.NewLogger(t) - factory := command.NewFactory(command.WithLogger(logger)) - - t.Run("tcp", func(t *testing.T) { - cfg := &Config{ - Address: "localhost:0", - } - s, err := New(cfg, factory, logger) - require.NoError(t, err) - errc := make(chan error, 1) - go func() { - errc <- s.Serve() - }() - - testConnectivity(t, s.Addr(), insecure.NewCredentials()) - - s.Shutdown() - require.NoError(t, <-errc) - }) - - t.Run("tcp with tls", func(t *testing.T) { - dir := t.TempDir() - cfg := &Config{ - Address: "localhost:0", - CertFile: filepath.Join(dir, "cert.pem"), - KeyFile: filepath.Join(dir, "key.pem"), - TLSEnabled: true, - } - s, err := New(cfg, factory, logger) - require.NoError(t, err) - errc := make(chan error, 1) - go func() { - errc <- s.Serve() - }() - - tlsConfig, err := runmetls.LoadClientConfig(cfg.CertFile, cfg.KeyFile) - require.NoError(t, err) - - addr := s.Addr() - if runtime.GOOS == "windows" { - addr = strings.TrimPrefix(addr, "unix://") - } - testConnectivity(t, addr, credentials.NewTLS(tlsConfig)) - - s.Shutdown() - require.NoError(t, <-errc) - }) -} - -func testConnectivity(t *testing.T, addr string, creds credentials.TransportCredentials) { - t.Helper() - - var err error - - for i := 0; i < 5; i++ { - var ( - conn *grpc.ClientConn - resp *healthv1.HealthCheckResponse - ) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - conn, err := grpc.NewClient( - addr, - grpc.WithTransportCredentials(creds), - ) - if err != nil { - goto wait - } - - resp, err = healthv1.NewHealthClient(conn).Check(ctx, &healthv1.HealthCheckRequest{}) - if err != nil || resp.Status != healthv1.HealthCheckResponse_SERVING { - goto wait - } - - cancel() - break - - wait: - cancel() - <-time.After(time.Millisecond * 100) - } - - require.NoError(t, err) -} diff --git a/internal/server/server_unix_test.go b/internal/server/server_unix_test.go deleted file mode 100644 index eee7f0865..000000000 --- a/internal/server/server_unix_test.go +++ /dev/null @@ -1,36 +0,0 @@ -//go:build !windows - -package server - -import ( - "path/filepath" - "testing" - - "github.com/stretchr/testify/require" - "go.uber.org/zap/zaptest" - "google.golang.org/grpc/credentials/insecure" - - "github.com/stateful/runme/v3/internal/command" -) - -func TestServerUnixSocket(t *testing.T) { - dir := t.TempDir() - sock := filepath.Join(dir, "runme.sock") - cfg := &Config{ - Address: "unix://" + sock, - } - logger := zaptest.NewLogger(t) - factory := command.NewFactory(command.WithLogger(logger)) - s, err := New(cfg, factory, logger) - require.NoError(t, err) - errc := make(chan error, 1) - go func() { - err := s.Serve() - errc <- err - }() - - testConnectivity(t, cfg.Address, insecure.NewCredentials()) - - s.Shutdown() - require.NoError(t, <-errc) -} diff --git a/internal/testutils/grpc.go b/internal/testutils/grpc.go index ac98fd445..3158c2e0a 100644 --- a/internal/testutils/grpc.go +++ b/internal/testutils/grpc.go @@ -10,6 +10,8 @@ import ( "google.golang.org/grpc/credentials/insecure" ) +const maxMsgSize = 32 * 1024 * 1024 // 32 MiB + func NewGRPCClientWithT[T any]( t *testing.T, lis interface{ Dial() (net.Conn, error) }, @@ -42,7 +44,7 @@ func newGRPCClient[T any]( return lis.Dial() }), grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(8*1024*1024)), + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxMsgSize)), ) if err != nil { var result T