Skip to content

Commit

Permalink
identity: Implement change_mode (#18943)
Browse files Browse the repository at this point in the history
* identity: support change_mode and change_signal

wip - just jobspec portion

* test struct

* cleanup some insignificant boogs

* actually implement change mode

* docs tweaks

* add changelog

* test identity.change_mode operations

* use more words in changelog

* job endpoint tests

* address comments from code review

---------

Co-authored-by: Tim Gross <[email protected]>
  • Loading branch information
schmichael and tgross authored Nov 1, 2023
1 parent d62213a commit e49ca3c
Show file tree
Hide file tree
Showing 10 changed files with 342 additions and 60 deletions.
3 changes: 3 additions & 0 deletions .changelog/18943.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:improvement
identity: Implement `change_mode` and `change_signal` for workload identities
```
14 changes: 8 additions & 6 deletions api/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -1162,12 +1162,14 @@ func (t *TaskCSIPluginConfig) Canonicalize() {
// WorkloadIdentity is the jobspec block which determines if and how a workload
// identity is exposed to tasks.
type WorkloadIdentity struct {
Name string `hcl:"name,optional"`
Audience []string `mapstructure:"aud" hcl:"aud,optional"`
Env bool `hcl:"env,optional"`
File bool `hcl:"file,optional"`
ServiceName string `hcl:"service_name,optional"`
TTL time.Duration `mapstructure:"ttl" hcl:"ttl,optional"`
Name string `hcl:"name,optional"`
Audience []string `mapstructure:"aud" hcl:"aud,optional"`
ChangeMode string `mapstructure:"change_mode" hcl:"change_mode,optional"`
ChangeSignal string `mapstructure:"change_signal" hcl:"change_signal,optional"`
Env bool `hcl:"env,optional"`
File bool `hcl:"file,optional"`
ServiceName string `hcl:"service_name,optional"`
TTL time.Duration `mapstructure:"ttl" hcl:"ttl,optional"`
}

type Action struct {
Expand Down
101 changes: 96 additions & 5 deletions client/allocrunner/taskrunner/identity_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ import (
"context"
"fmt"
"path/filepath"
"time"

"github.com/hashicorp/consul-template/signals"
log "github.com/hashicorp/go-hclog"

"github.com/hashicorp/nomad/client/allocrunner/interfaces"
ti "github.com/hashicorp/nomad/client/allocrunner/taskrunner/interfaces"
"github.com/hashicorp/nomad/client/taskenv"
"github.com/hashicorp/nomad/client/widmgr"
"github.com/hashicorp/nomad/helper/users"
Expand All @@ -37,6 +40,7 @@ type identityHook struct {
task *structs.Task
tokenDir string
envBuilder *taskenv.Builder
lifecycle ti.TaskLifecycle
ts tokenSetter
widmgr widmgr.IdentityManager
logger log.Logger
Expand All @@ -52,6 +56,7 @@ func newIdentityHook(tr *TaskRunner, logger log.Logger) *identityHook {
task: tr.Task(),
tokenDir: tr.taskDir.SecretsDir,
envBuilder: tr.envBuilder,
lifecycle: tr,
ts: tr,
widmgr: tr.widmgr,
stopCtx: stopCtx,
Expand All @@ -65,52 +70,138 @@ func (*identityHook) Name() string {
return "identity"
}

func (h *identityHook) Prestart(context.Context, *interfaces.TaskPrestartRequest, *interfaces.TaskPrestartResponse) error {
func (h *identityHook) Prestart(ctx context.Context, _ *interfaces.TaskPrestartRequest, _ *interfaces.TaskPrestartResponse) error {

// Handle default workload identity
if err := h.setDefaultToken(); err != nil {
return err
}

// Track first run signals from watchers
firstRunCh := make(chan struct{}, len(h.task.Identities))

// Start token watcher loops
for _, widspec := range h.task.Identities {
w := widspec
go h.watchIdentity(w)
go h.watchIdentity(w, firstRunCh)
}

// Don't block indefinitely for identities
deadlineTimer := time.NewTimer(time.Minute)
defer deadlineTimer.Stop()

// Wait until every watcher ticks the first run chan
for i := range h.task.Identities {
select {
case <-firstRunCh:
// Identity fetched, loop
case <-deadlineTimer.C:
h.logger.Warn("timed out waiting for initial identity tokens to be fetched",
"num_fetched", i, "num_total", len(h.task.Identities))
return nil
case <-ctx.Done():
h.logger.Debug("task prestart cancelled before initial identity tokens were fetched",
"num_fetched", i, "num_total", len(h.task.Identities))
return nil
case <-h.stopCtx.Done():
h.logger.Debug("task stopped before initial identity tokens were fetched",
"num_fetched", i, "num_total", len(h.task.Identities))
return nil
}
}

return nil
}

func (h *identityHook) watchIdentity(wid *structs.WorkloadIdentity) {
func (h *identityHook) watchIdentity(wid *structs.WorkloadIdentity, runCh chan struct{}) {
id := structs.WIHandle{WorkloadIdentifier: h.task.Name, IdentityName: wid.Name}
signedIdentitiesChan, stopWatching := h.widmgr.Watch(id)
defer stopWatching()

firstRun := true

for {
select {
case signedWID, ok := <-signedIdentitiesChan:
h.logger.Trace("receiving renewed identity", "identity_name", wid.Name)
h.logger.Trace("receiving renewed identity", "identity", wid.Name)
if !ok {
// Chan was closed, stop watching
h.logger.Trace("identity watch closed", "task", h.task.Name, "identity", wid.Name)
h.logger.Trace("identity watch closed", "identity", wid.Name)
return
}

if signedWID == nil {
// The only way to hit this should be a bug as it indicates the server
// did not sign an identity for a task on this alloc.
h.logger.Error("missing workload identity %q", wid.Name)
return
}

if err := h.setAltToken(wid, signedWID.JWT); err != nil {
h.logger.Error(err.Error())
}

// Skip ChangeMode on firstRun and notify caller it can proceed
if firstRun {
select {
case runCh <- struct{}{}:
default:
// Not great but not necessarily fatal
h.logger.Warn("task started before identity %q was fetched", wid.Name)
}

firstRun = false
continue
}

switch wid.ChangeMode {
case structs.WIChangeModeRestart:
const noFailure = false
err := h.lifecycle.Restart(h.stopCtx, structs.NewTaskEvent(structs.TaskRestartSignal).
SetDisplayMessage(fmt.Sprintf("Identity[%s]: new token acquired", wid.Name)), noFailure)
if err != nil {
// Ignore error from kill because if that fails there's really
// nothing to be done.
_ = h.lifecycle.Kill(h.stopCtx, structs.NewTaskEvent(structs.TaskKilling).
SetFailsTask().
SetDisplayMessage(fmt.Sprintf("Identity[%s]: failed to restart: %v", wid.Name, err)))
return
}

case structs.WIChangeModeSignal:
if err := h.signalTask(wid); err != nil {
h.logger.Error("failed to send signal", "identity", wid.Name, "signal", wid.ChangeSignal)
// Ignore error from kill because if that fails there's really
// nothing to be done.
_ = h.lifecycle.Kill(h.stopCtx, structs.NewTaskEvent(structs.TaskKilling).
SetFailsTask().
SetDisplayMessage(fmt.Sprintf("Identity[%s]: failed to send signal: %v", wid.Name, err)))
return
}

}

// Note: any code added here will not run on first run

case <-h.stopCtx.Done():
return
}
}
}

// signalTask sends the configured signal to a task or returns an error.
func (h *identityHook) signalTask(wid *structs.WorkloadIdentity) error {
s, err := signals.Parse(wid.ChangeSignal)
if err != nil {
return fmt.Errorf("failed to parse signal: %w", err)
}

event := structs.NewTaskEvent(structs.TaskSignaling).
SetTaskSignal(s).
SetDisplayMessage(fmt.Sprintf("Identity[%s]: new Identity token acquired", wid.Name))
return h.lifecycle.Signal(event, wid.ChangeSignal)
}

// setDefaultToken adds the Nomad token to the task's environment and writes it to a
// file if requested by the jobsepc.
func (h *identityHook) setDefaultToken() error {
Expand Down
39 changes: 29 additions & 10 deletions client/allocrunner/taskrunner/identity_hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
trtesting "github.com/hashicorp/nomad/client/allocrunner/taskrunner/testing"
cstate "github.com/hashicorp/nomad/client/state"
"github.com/hashicorp/nomad/client/taskenv"
"github.com/hashicorp/nomad/client/widmgr"
Expand Down Expand Up @@ -53,16 +54,19 @@ func TestIdentityHook_RenewAll(t *testing.T) {
task := alloc.LookupTask("web")
task.Identities = []*structs.WorkloadIdentity{
{
Name: "consul",
Audience: []string{"consul"},
Env: true,
TTL: ttl,
Name: "consul",
Audience: []string{"consul"},
Env: true,
TTL: ttl,
ChangeMode: "restart",
},
{
Name: "vault",
Audience: []string{"vault"},
File: true,
TTL: ttl,
Name: "vault",
Audience: []string{"vault"},
File: true,
TTL: ttl,
ChangeMode: "signal",
ChangeSignal: "SIGHUP",
},
}

Expand All @@ -79,13 +83,15 @@ func TestIdentityHook_RenewAll(t *testing.T) {
mockSigner := widmgr.NewMockWIDSigner(task.Identities)
mockWIDMgr := widmgr.NewWIDMgr(mockSigner, alloc, db, logger)
mockWIDMgr.SetMinWait(time.Second) // fast renewals, because the default is 10s
mockLifecycle := trtesting.NewMockTaskHooks()

h := &identityHook{
alloc: alloc,
task: task,
tokenDir: secretsDir,
envBuilder: taskenv.NewBuilder(node, alloc, task, alloc.Job.Region),
ts: mockTR,
lifecycle: mockLifecycle,
widmgr: mockWIDMgr,
logger: testlog.HCLogger(t),
stopCtx: stopCtx,
Expand All @@ -97,7 +103,6 @@ func TestIdentityHook_RenewAll(t *testing.T) {

start := time.Now()
must.NoError(t, h.Prestart(context.Background(), nil, nil))
time.Sleep(time.Second) // goroutines in the Prestart hook must run first before we Build the EnvMap
env := h.envBuilder.Build().EnvMap

// Assert initial tokens were set in Prestart
Expand All @@ -117,7 +122,21 @@ func TestIdentityHook_RenewAll(t *testing.T) {

// Stop renewal before checking to ensure stopping works
must.NoError(t, h.Stop(context.Background(), nil, nil))
time.Sleep(time.Second) // Stop is async so give renewal time to exit

// Ensure change_mode operations occurred
select {
case <-mockLifecycle.RestartCh:
h.logger.Trace("restart happened")
case <-time.After(10 * time.Second):
t.Fatalf("timed out waiting for restart")
}

select {
case <-mockLifecycle.SignalCh:
h.logger.Trace("signal happened")
case <-time.After(10 * time.Second):
t.Fatalf("timed out waiting for restart")
}

newConsul := h.envBuilder.Build().EnvMap["NOMAD_TOKEN_consul"]
must.StrContains(t, newConsul, ".") // ensure new token is JWTish
Expand Down
4 changes: 2 additions & 2 deletions client/allocrunner/taskrunner/vault_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,9 @@ OUTER:
const noFailure = false
h.lifecycle.Restart(h.ctx,
structs.NewTaskEvent(structs.TaskRestartSignal).
SetDisplayMessage("Vault: new Vault token acquired"), false)
SetDisplayMessage("Vault: new Vault token acquired"), noFailure)
case structs.VaultChangeModeNoop:
fallthrough
// True to its name, this is a noop!
default:
h.logger.Error("invalid Vault change mode", "mode", h.vaultBlock.ChangeMode)
}
Expand Down
33 changes: 11 additions & 22 deletions command/agent/job_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -1274,30 +1274,17 @@ func ApiTaskToStructsTask(job *structs.Job, group *structs.TaskGroup,
// Nomad 1.5 CLIs and JSON jobs may set the default identity parameters in
// the Task.Identity field, so if it is non-nil use it.
if id := apiTask.Identity; id != nil {
structsTask.Identity = &structs.WorkloadIdentity{
Name: id.Name,
Audience: slices.Clone(id.Audience),
Env: id.Env,
File: id.File,
TTL: id.TTL,
}
structsTask.Identity = apiWorkloadIdentityToStructs(id)
}

if ids := apiTask.Identities; len(ids) > 0 {
structsTask.Identities = make([]*structs.WorkloadIdentity, len(ids))
for i, id := range ids {
structsTask.Identities = make([]*structs.WorkloadIdentity, 0, len(ids))
for _, id := range ids {
if id == nil {
continue
}

structsTask.Identities[i] = &structs.WorkloadIdentity{
Name: id.Name,
Audience: slices.Clone(id.Audience),
Env: id.Env,
File: id.File,
TTL: id.TTL,
}

structsTask.Identities = append(structsTask.Identities, apiWorkloadIdentityToStructs(id))
}
}

Expand Down Expand Up @@ -1651,11 +1638,13 @@ func apiWorkloadIdentityToStructs(in *api.WorkloadIdentity) *structs.WorkloadIde
return nil
}
return &structs.WorkloadIdentity{
Name: in.Name,
Audience: in.Audience,
Env: in.Env,
File: in.File,
ServiceName: in.ServiceName,
Name: in.Name,
Audience: slices.Clone(in.Audience),
ChangeMode: in.ChangeMode,
ChangeSignal: in.ChangeSignal,
Env: in.Env,
File: in.File,
ServiceName: in.ServiceName,
}
}

Expand Down
Loading

0 comments on commit e49ca3c

Please sign in to comment.