forked from kolide/launcher
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add remote restart consumer to handle remote restart actions (kolide#…
- Loading branch information
1 parent
bc44bcb
commit c6fe8b7
Showing
4 changed files
with
298 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
130 changes: 130 additions & 0 deletions
130
ee/control/consumers/remoterestartconsumer/remoterestartconsumer.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
package remoterestartconsumer | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"log/slog" | ||
"time" | ||
|
||
"github.com/kolide/launcher/ee/agent/types" | ||
) | ||
|
||
const ( | ||
// RemoteRestartActorType identifies this action/actor type, which performs | ||
// a launcher restart when requested by the control server. This actor type | ||
// belongs to the action subsystem. | ||
RemoteRestartActorType = "remote_restart" | ||
|
||
// restartDelay is the delay after receiving action before triggering the restart. | ||
// We have a delay to allow the actionqueue. | ||
restartDelay = 15 * time.Second | ||
) | ||
|
||
var ( | ||
ErrRemoteRestartRequested = errors.New("need to reload launcher: remote restart requested") | ||
) | ||
|
||
type RemoteRestartConsumer struct { | ||
knapsack types.Knapsack | ||
slogger *slog.Logger | ||
signalRestart chan error | ||
interrupt chan struct{} | ||
interrupted bool | ||
} | ||
|
||
type remoteRestartAction struct { | ||
RunID string `json:"run_id"` // the run ID for the launcher run to restart | ||
} | ||
|
||
func New(knapsack types.Knapsack) *RemoteRestartConsumer { | ||
return &RemoteRestartConsumer{ | ||
knapsack: knapsack, | ||
slogger: knapsack.Slogger().With("component", "remote_restart_consumer"), | ||
signalRestart: make(chan error, 1), | ||
interrupt: make(chan struct{}, 1), | ||
} | ||
} | ||
|
||
// Do implements the `actionqueue.actor` interface, and allows the actionqueue | ||
// to pass `remote_restart` type actions to this consumer. The actionqueue validates | ||
// that this action has not already been performed and that this action is still | ||
// valid (i.e. not expired). `Do` additionally validates that the `run_id` given in | ||
// the action matches the current launcher run ID. | ||
func (r *RemoteRestartConsumer) Do(data io.Reader) error { | ||
var restartAction remoteRestartAction | ||
|
||
if err := json.NewDecoder(data).Decode(&restartAction); err != nil { | ||
return fmt.Errorf("decoding restart action: %w", err) | ||
} | ||
|
||
// The action's run ID indicates the current `runLauncher` that should be restarted. | ||
// If the action's run ID does not match the current run ID, we assume the restart | ||
// has already happened and does not need to happen again. | ||
if restartAction.RunID == "" { | ||
r.slogger.Log(context.TODO(), slog.LevelInfo, | ||
"received remote restart action with blank launcher run ID -- discarding", | ||
) | ||
return nil | ||
} | ||
if restartAction.RunID != r.knapsack.GetRunID() { | ||
r.slogger.Log(context.TODO(), slog.LevelInfo, | ||
"received remote restart action for incorrect (assuming past) launcher run ID -- discarding", | ||
"action_run_id", restartAction.RunID, | ||
) | ||
return nil | ||
} | ||
|
||
// Perform the restart by signaling actor shutdown, but delay slightly to give | ||
// the actionqueue a chance to process all actions and store their statuses. | ||
go func() { | ||
r.slogger.Log(context.TODO(), slog.LevelInfo, | ||
"received remote restart action for current launcher run ID -- signaling for restart shortly", | ||
"action_run_id", restartAction.RunID, | ||
"restart_delay", restartDelay.String(), | ||
) | ||
|
||
select { | ||
case <-r.interrupt: | ||
r.slogger.Log(context.TODO(), slog.LevelDebug, | ||
"received external interrupt before remote restart could be performed", | ||
) | ||
return | ||
case <-time.After(restartDelay): | ||
r.signalRestart <- ErrRemoteRestartRequested | ||
r.slogger.Log(context.TODO(), slog.LevelInfo, | ||
"signaled for restart after delay", | ||
"action_run_id", restartAction.RunID, | ||
) | ||
return | ||
} | ||
}() | ||
|
||
return nil | ||
} | ||
|
||
// Execute allows the remote restart consumer to run in the main launcher rungroup. | ||
// It waits until it receives a remote restart action from `Do`, or until it receives | ||
// a `Interrupt` request. | ||
func (r *RemoteRestartConsumer) Execute() (err error) { | ||
select { | ||
case <-r.interrupt: | ||
return nil | ||
case signalRestartErr := <-r.signalRestart: | ||
return signalRestartErr | ||
} | ||
} | ||
|
||
// Interrupt allows the remote restart consumer to run in the main launcher rungroup | ||
// and be shut down when the rungroup shuts down. | ||
func (r *RemoteRestartConsumer) Interrupt(_ error) { | ||
// Only perform shutdown tasks on first call to interrupt -- no need to repeat on potential extra calls. | ||
if r.interrupted { | ||
return | ||
} | ||
r.interrupted = true | ||
|
||
r.interrupt <- struct{}{} | ||
} |
160 changes: 160 additions & 0 deletions
160
ee/control/consumers/remoterestartconsumer/remoterestartconsumer_test.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
package remoterestartconsumer | ||
|
||
import ( | ||
"bytes" | ||
"encoding/json" | ||
"errors" | ||
"testing" | ||
"time" | ||
|
||
"github.com/kolide/kit/ulid" | ||
typesmocks "github.com/kolide/launcher/ee/agent/types/mocks" | ||
"github.com/kolide/launcher/pkg/log/multislogger" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestDo(t *testing.T) { | ||
t.Parallel() | ||
|
||
currentRunId := ulid.New() | ||
|
||
mockKnapsack := typesmocks.NewKnapsack(t) | ||
mockKnapsack.On("Slogger").Return(multislogger.NewNopLogger()) | ||
mockKnapsack.On("GetRunID").Return(currentRunId) | ||
|
||
remoteRestarter := New(mockKnapsack) | ||
|
||
testAction := remoteRestartAction{ | ||
RunID: currentRunId, | ||
} | ||
testActionRaw, err := json.Marshal(testAction) | ||
require.NoError(t, err) | ||
|
||
// We don't expect an error because we should process the action | ||
require.NoError(t, remoteRestarter.Do(bytes.NewReader(testActionRaw)), "expected no error processing valid remote restart action") | ||
|
||
// The restarter should delay before sending an error to `signalRestart` | ||
require.Len(t, remoteRestarter.signalRestart, 0, "expected restarter to delay before signal for restart but channel is already has item in it") | ||
time.Sleep(restartDelay + 2*time.Second) | ||
require.Len(t, remoteRestarter.signalRestart, 1, "expected restarter to signal for restart but channel is empty after delay") | ||
} | ||
|
||
func TestDo_DoesNotSignalRestartWhenRunIDDoesNotMatch(t *testing.T) { | ||
t.Parallel() | ||
|
||
currentRunId := ulid.New() | ||
|
||
mockKnapsack := typesmocks.NewKnapsack(t) | ||
mockKnapsack.On("Slogger").Return(multislogger.NewNopLogger()) | ||
mockKnapsack.On("GetRunID").Return(currentRunId) | ||
|
||
remoteRestarter := New(mockKnapsack) | ||
|
||
testAction := remoteRestartAction{ | ||
RunID: ulid.New(), // run ID will not match `currentRunId` | ||
} | ||
testActionRaw, err := json.Marshal(testAction) | ||
require.NoError(t, err) | ||
|
||
// We don't expect an error because we want to discard this action | ||
require.NoError(t, remoteRestarter.Do(bytes.NewReader(testActionRaw)), "should not return error for old run ID") | ||
|
||
// The restarter should not send an error to `signalRestart` | ||
time.Sleep(restartDelay + 2*time.Second) | ||
require.Len(t, remoteRestarter.signalRestart, 0, "restarter should not have signaled for a restart, but channel is not empty") | ||
} | ||
|
||
func TestDo_DoesNotSignalRestartWhenRunIDIsEmpty(t *testing.T) { | ||
t.Parallel() | ||
|
||
mockKnapsack := typesmocks.NewKnapsack(t) | ||
mockKnapsack.On("Slogger").Return(multislogger.NewNopLogger()) | ||
|
||
remoteRestarter := New(mockKnapsack) | ||
|
||
testAction := remoteRestartAction{ | ||
RunID: "", // run ID is empty | ||
} | ||
testActionRaw, err := json.Marshal(testAction) | ||
require.NoError(t, err) | ||
|
||
// We don't expect an error because we want to discard this action | ||
require.NoError(t, remoteRestarter.Do(bytes.NewReader(testActionRaw)), "should not return error for empty run ID") | ||
|
||
// The restarter should not send an error to `signalRestart` | ||
time.Sleep(restartDelay + 2*time.Second) | ||
require.Len(t, remoteRestarter.signalRestart, 0, "restarter should not have signaled for a restart, but channel is not empty") | ||
} | ||
|
||
func TestDo_DoesNotRestartIfInterruptedDuringDelay(t *testing.T) { | ||
t.Parallel() | ||
|
||
currentRunId := ulid.New() | ||
|
||
mockKnapsack := typesmocks.NewKnapsack(t) | ||
mockKnapsack.On("Slogger").Return(multislogger.NewNopLogger()) | ||
mockKnapsack.On("GetRunID").Return(currentRunId) | ||
|
||
remoteRestarter := New(mockKnapsack) | ||
|
||
testAction := remoteRestartAction{ | ||
RunID: currentRunId, | ||
} | ||
testActionRaw, err := json.Marshal(testAction) | ||
require.NoError(t, err) | ||
|
||
// We don't expect an error because the run ID is correct | ||
require.NoError(t, remoteRestarter.Do(bytes.NewReader(testActionRaw)), "expected no error processing valid remote restart action") | ||
|
||
// The restarter should delay before sending an error to `signalRestart` | ||
require.Len(t, remoteRestarter.signalRestart, 0, "expected restarter to delay before signal for restart but channel is already has item in it") | ||
|
||
// Now, send an interrupt | ||
remoteRestarter.Interrupt(errors.New("test error")) | ||
|
||
// Sleep beyond the interrupt delay, and confirm we don't try to do a restart when we're already shutting down | ||
time.Sleep(restartDelay + 2*time.Second) | ||
require.Len(t, remoteRestarter.signalRestart, 0, "restarter should not have tried to signal for restart when interrupted during restart delay") | ||
} | ||
|
||
func TestInterrupt_Multiple(t *testing.T) { | ||
t.Parallel() | ||
|
||
mockKnapsack := typesmocks.NewKnapsack(t) | ||
mockKnapsack.On("Slogger").Return(multislogger.NewNopLogger()) | ||
|
||
remoteRestarter := New(mockKnapsack) | ||
|
||
// Let the remote restarter run for a bit | ||
go remoteRestarter.Execute() | ||
time.Sleep(3 * time.Second) | ||
remoteRestarter.Interrupt(errors.New("test error")) | ||
|
||
// Confirm we can call Interrupt multiple times without blocking | ||
interruptComplete := make(chan struct{}) | ||
expectedInterrupts := 3 | ||
for i := 0; i < expectedInterrupts; i += 1 { | ||
go func() { | ||
remoteRestarter.Interrupt(nil) | ||
interruptComplete <- struct{}{} | ||
}() | ||
} | ||
|
||
receivedInterrupts := 0 | ||
for { | ||
if receivedInterrupts >= expectedInterrupts { | ||
break | ||
} | ||
|
||
select { | ||
case <-interruptComplete: | ||
receivedInterrupts += 1 | ||
continue | ||
case <-time.After(5 * time.Second): | ||
t.Errorf("could not call interrupt multiple times and return within 5 seconds -- received %d interrupts before timeout", receivedInterrupts) | ||
t.FailNow() | ||
} | ||
} | ||
|
||
require.Equal(t, expectedInterrupts, receivedInterrupts) | ||
} |