diff --git a/nomad/job_endpoint.go b/nomad/job_endpoint.go index 849e6e42698..9900bd7cde3 100644 --- a/nomad/job_endpoint.go +++ b/nomad/job_endpoint.go @@ -2,6 +2,7 @@ package nomad import ( "fmt" + "strings" "time" "github.com/armon/go-metrics" @@ -67,6 +68,42 @@ func (j *Job) Register(args *structs.JobRegisterRequest, reply *structs.JobRegis } } + // Ensure that the job has permissions for the requested Vault tokens + desiredPolicies := structs.VaultPoliciesSet(args.Job.VaultPolicies()) + if len(desiredPolicies) != 0 { + vconf := j.srv.config.VaultConfig + if !vconf.Enabled { + return fmt.Errorf("Vault not enabled and Vault policies requested") + } + + // Have to check if the user has permissions + if !vconf.AllowUnauthenticated { + if args.Job.VaultToken == "" { + return fmt.Errorf("Vault policies requested but missing Vault Token") + } + + vault := j.srv.vault + s, err := vault.LookupToken(args.Job.VaultToken) + if err != nil { + return err + } + + allowedPolicies, err := PoliciesFrom(s) + if err != nil { + return err + } + + subset, offending := structs.SliceStringIsSubset(allowedPolicies, desiredPolicies) + if !subset { + return fmt.Errorf("Passed Vault Token doesn't allow access to the following policies: %s", + strings.Join(offending, ", ")) + } + } + } + + // Clear the Vault token + args.Job.VaultToken = "" + // Commit this update via Raft _, index, err := j.srv.raftApply(structs.JobRegisterRequestType, args) if err != nil { diff --git a/nomad/job_endpoint_test.go b/nomad/job_endpoint_test.go index d05785438ab..470c02b939a 100644 --- a/nomad/job_endpoint_test.go +++ b/nomad/job_endpoint_test.go @@ -1,6 +1,7 @@ package nomad import ( + "fmt" "reflect" "strings" "testing" @@ -360,6 +361,189 @@ func TestJobEndpoint_Register_EnforceIndex(t *testing.T) { } } +func TestJobEndpoint_Register_Vault_Disabled(t *testing.T) { + s1 := testServer(t, func(c *Config) { + c.NumSchedulers = 0 // Prevent automatic dequeue + c.VaultConfig.Enabled = false + }) + defer s1.Shutdown() + codec := rpcClient(t, s1) + testutil.WaitForLeader(t, s1.RPC) + + // Create the register request with a job asking for a vault policy + job := mock.Job() + job.TaskGroups[0].Tasks[0].Vault = &structs.Vault{Policies: []string{"foo"}} + req := &structs.JobRegisterRequest{ + Job: job, + WriteRequest: structs.WriteRequest{Region: "global"}, + } + + // Fetch the response + var resp structs.JobRegisterResponse + err := msgpackrpc.CallWithCodec(codec, "Job.Register", req, &resp) + if err == nil || !strings.Contains(err.Error(), "Vault not enabled") { + t.Fatalf("expected Vault not enabled error: %v", err) + } +} + +func TestJobEndpoint_Register_Vault_AllowUnauthenticated(t *testing.T) { + s1 := testServer(t, func(c *Config) { + c.NumSchedulers = 0 // Prevent automatic dequeue + }) + defer s1.Shutdown() + codec := rpcClient(t, s1) + testutil.WaitForLeader(t, s1.RPC) + + // Enable vault and allow authenticated + s1.config.VaultConfig.Enabled = true + s1.config.VaultConfig.AllowUnauthenticated = true + + // Replace the Vault Client on the server + s1.vault = &TestVaultClient{} + + // Create the register request with a job asking for a vault policy + job := mock.Job() + job.TaskGroups[0].Tasks[0].Vault = &structs.Vault{Policies: []string{"foo"}} + req := &structs.JobRegisterRequest{ + Job: job, + WriteRequest: structs.WriteRequest{Region: "global"}, + } + + // Fetch the response + var resp structs.JobRegisterResponse + err := msgpackrpc.CallWithCodec(codec, "Job.Register", req, &resp) + if err != nil { + t.Fatalf("bad: %v", err) + } + + // Check for the job in the FSM + state := s1.fsm.State() + out, err := state.JobByID(job.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if out == nil { + t.Fatalf("expected job") + } + if out.CreateIndex != resp.JobModifyIndex { + t.Fatalf("index mis-match") + } +} + +func TestJobEndpoint_Register_Vault_NoToken(t *testing.T) { + s1 := testServer(t, func(c *Config) { + c.NumSchedulers = 0 // Prevent automatic dequeue + }) + defer s1.Shutdown() + codec := rpcClient(t, s1) + testutil.WaitForLeader(t, s1.RPC) + + // Enable vault + s1.config.VaultConfig.Enabled = true + s1.config.VaultConfig.AllowUnauthenticated = false + + // Replace the Vault Client on the server + s1.vault = &TestVaultClient{} + + // Create the register request with a job asking for a vault policy but + // don't send a Vault token + job := mock.Job() + job.TaskGroups[0].Tasks[0].Vault = &structs.Vault{Policies: []string{"foo"}} + req := &structs.JobRegisterRequest{ + Job: job, + WriteRequest: structs.WriteRequest{Region: "global"}, + } + + // Fetch the response + var resp structs.JobRegisterResponse + err := msgpackrpc.CallWithCodec(codec, "Job.Register", req, &resp) + if err == nil || !strings.Contains(err.Error(), "missing Vault Token") { + t.Fatalf("expected Vault not enabled error: %v", err) + } +} + +func TestJobEndpoint_Register_Vault_Policies(t *testing.T) { + s1 := testServer(t, func(c *Config) { + c.NumSchedulers = 0 // Prevent automatic dequeue + }) + defer s1.Shutdown() + codec := rpcClient(t, s1) + testutil.WaitForLeader(t, s1.RPC) + + // Enable vault + s1.config.VaultConfig.Enabled = true + s1.config.VaultConfig.AllowUnauthenticated = false + + // Replace the Vault Client on the server + tvc := &TestVaultClient{} + s1.vault = tvc + + // Add three tokens: one that allows the requesting policy, one that does + // not and one that returns an error + policy := "foo" + + badToken := structs.GenerateUUID() + badPolicies := []string{"a", "b", "c"} + tvc.SetLookupTokenAllowedPolicies(badToken, badPolicies) + + goodToken := structs.GenerateUUID() + goodPolicies := []string{"foo", "bar", "baz"} + tvc.SetLookupTokenAllowedPolicies(goodToken, goodPolicies) + + errToken := structs.GenerateUUID() + expectedErr := fmt.Errorf("return errors from vault") + tvc.SetLookupTokenError(errToken, expectedErr) + + // Create the register request with a job asking for a vault policy but + // send the bad Vault token + job := mock.Job() + job.VaultToken = badToken + job.TaskGroups[0].Tasks[0].Vault = &structs.Vault{Policies: []string{policy}} + req := &structs.JobRegisterRequest{ + Job: job, + WriteRequest: structs.WriteRequest{Region: "global"}, + } + + // Fetch the response + var resp structs.JobRegisterResponse + err := msgpackrpc.CallWithCodec(codec, "Job.Register", req, &resp) + if err == nil || !strings.Contains(err.Error(), + "doesn't allow access to the following policies: "+policy) { + t.Fatalf("expected permission denied error: %v", err) + } + + // Use the err token + job.VaultToken = errToken + err = msgpackrpc.CallWithCodec(codec, "Job.Register", req, &resp) + if err == nil || !strings.Contains(err.Error(), expectedErr.Error()) { + t.Fatalf("expected permission denied error: %v", err) + } + + // Use the good token + job.VaultToken = goodToken + + // Fetch the response + if err := msgpackrpc.CallWithCodec(codec, "Job.Register", req, &resp); err != nil { + t.Fatalf("bad: %v", err) + } + + // Check for the job in the FSM + state := s1.fsm.State() + out, err := state.JobByID(job.ID) + if err != nil { + t.Fatalf("err: %v", err) + } + if out == nil { + t.Fatalf("expected job") + } + if out.CreateIndex != resp.JobModifyIndex { + t.Fatalf("index mis-match") + } + if out.VaultToken != "" { + t.Fatalf("vault token not cleared") + } +} + func TestJobEndpoint_Evaluate(t *testing.T) { s1 := testServer(t, func(c *Config) { c.NumSchedulers = 0 // Prevent automatic dequeue diff --git a/nomad/structs/funcs.go b/nomad/structs/funcs.go index 836e0ce3d48..68f0af18c6e 100644 --- a/nomad/structs/funcs.go +++ b/nomad/structs/funcs.go @@ -229,3 +229,44 @@ func CopySliceConstraints(s []*Constraint) []*Constraint { } return c } + +// SliceStringIsSubset returns whether the smaller set of strings is a subset of +// the larger. If the smaller slice is not a subset, the offending elements are +// returned. +func SliceStringIsSubset(larger, smaller []string) (bool, []string) { + largerSet := make(map[string]struct{}, len(larger)) + for _, l := range larger { + largerSet[l] = struct{}{} + } + + subset := true + var offending []string + for _, s := range smaller { + if _, ok := largerSet[s]; !ok { + subset = false + offending = append(offending, s) + } + } + + return subset, offending +} + +// VaultPoliciesSet takes the structure returned by VaultPolicies and returns +// the set of required policies +func VaultPoliciesSet(policies map[string]map[string][]string) []string { + set := make(map[string]struct{}) + + for _, tgp := range policies { + for _, tp := range tgp { + for _, p := range tp { + set[p] = struct{}{} + } + } + } + + flattened := make([]string, 0, len(set)) + for p := range set { + flattened = append(flattened, p) + } + return flattened +} diff --git a/nomad/structs/funcs_test.go b/nomad/structs/funcs_test.go index cc0d574d68d..be34a830822 100644 --- a/nomad/structs/funcs_test.go +++ b/nomad/structs/funcs_test.go @@ -235,3 +235,18 @@ func TestGenerateUUID(t *testing.T) { } } } + +func TestSliceStringIsSubset(t *testing.T) { + l := []string{"a", "b", "c"} + s := []string{"d"} + + sub, offending := SliceStringIsSubset(l, l[:1]) + if !sub || len(offending) != 0 { + t.Fatalf("bad %v %v", sub, offending) + } + + sub, offending = SliceStringIsSubset(l, s) + if sub || len(offending) == 0 || offending[0] != "d" { + t.Fatalf("bad %v %v", sub, offending) + } +} diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index 6d1d1c52558..02d0fc080b8 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -1222,6 +1222,26 @@ func (j *Job) IsPeriodic() bool { return j.Periodic != nil } +// VaultPolicies returns the set of Vault policies per task group, per task +func (j *Job) VaultPolicies() map[string]map[string][]string { + policies := make(map[string]map[string][]string, len(j.TaskGroups)) + + for _, tg := range j.TaskGroups { + tgPolicies := make(map[string][]string, len(tg.Tasks)) + policies[tg.Name] = tgPolicies + + for _, task := range tg.Tasks { + if task.Vault == nil { + continue + } + + tgPolicies[task.Name] = task.Vault.Policies + } + } + + return policies +} + // JobListStub is used to return a subset of job information // for the job list type JobListStub struct { diff --git a/nomad/structs/structs_test.go b/nomad/structs/structs_test.go index 619702514c3..a1b8aa8f7b9 100644 --- a/nomad/structs/structs_test.go +++ b/nomad/structs/structs_test.go @@ -221,6 +221,86 @@ func TestJob_SystemJob_Validate(t *testing.T) { } } +func TestJob_VaultPolicies(t *testing.T) { + j0 := &Job{} + e0 := make(map[string]map[string][]string, 0) + + j1 := &Job{ + TaskGroups: []*TaskGroup{ + &TaskGroup{ + Name: "foo", + Tasks: []*Task{ + &Task{ + Name: "t1", + }, + &Task{ + Name: "t2", + Vault: &Vault{ + Policies: []string{ + "p1", + "p2", + }, + }, + }, + }, + }, + &TaskGroup{ + Name: "bar", + Tasks: []*Task{ + &Task{ + Name: "t3", + Vault: &Vault{ + Policies: []string{ + "p3", + "p4", + }, + }, + }, + &Task{ + Name: "t4", + Vault: &Vault{ + Policies: []string{ + "p5", + }, + }, + }, + }, + }, + }, + } + + e1 := map[string]map[string][]string{ + "foo": map[string][]string{ + "t2": []string{"p1", "p2"}, + }, + "bar": map[string][]string{ + "t3": []string{"p3", "p4"}, + "t4": []string{"p5"}, + }, + } + + cases := []struct { + Job *Job + Expected map[string]map[string][]string + }{ + { + Job: j0, + Expected: e0, + }, + { + Job: j1, + Expected: e1, + }, + } + + for i, c := range cases { + got := c.Job.VaultPolicies() + if !reflect.DeepEqual(got, c.Expected) { + t.Fatalf("case %d: got %#v; want %#v", i+1, got, c.Expected) + } + } +} + func TestTaskGroup_Validate(t *testing.T) { tg := &TaskGroup{ Count: -1, diff --git a/nomad/vault.go b/nomad/vault.go index 04d0cd939a7..a7c734562bc 100644 --- a/nomad/vault.go +++ b/nomad/vault.go @@ -418,6 +418,31 @@ func (v *vaultClient) CreateToken(a *structs.Allocation, task string) (*vapi.Sec return nil, nil } +// LookupToken takes a Vault token and does a lookup against Vault func (v *vaultClient) LookupToken(token string) (*vapi.Secret, error) { - return nil, nil + // Nothing to do + if !v.enabled { + return nil, fmt.Errorf("Vault integration disabled") + } + + // Check if we have established a connection with Vault + if !v.ConnectionEstablished() { + return nil, fmt.Errorf("Connection to Vault has not been established. Retry") + } + + // Lookup the token + return v.auth.Lookup(token) +} + +// PoliciesFrom parses the set of policies returned by a token lookup. +func PoliciesFrom(s *vapi.Secret) ([]string, error) { + if s == nil { + return nil, fmt.Errorf("cannot parse nil Vault secret") + } + var data tokenData + if err := mapstructure.WeakDecode(s.Data, &data); err != nil { + return nil, fmt.Errorf("failed to parse Vault token's data block: %v", err) + } + + return data.Policies, nil } diff --git a/nomad/vault_test.go b/nomad/vault_test.go index cc50b02a158..b007ca22e1c 100644 --- a/nomad/vault_test.go +++ b/nomad/vault_test.go @@ -4,10 +4,12 @@ import ( "encoding/json" "log" "os" + "reflect" "strings" "testing" "time" + "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs/config" "github.com/hashicorp/nomad/testutil" vapi "github.com/hashicorp/vault/api" @@ -64,11 +66,7 @@ func TestVaultClient_EstablishConnection(t *testing.T) { // Start Vault v.Start() - testutil.WaitForResult(func() (bool, error) { - return client.ConnectionEstablished(), nil - }, func(err error) { - t.Fatalf("Connection not established") - }) + waitForConnection(client, t) // Ensure that since we are using a root token that we haven started the // renewal loop. @@ -151,3 +149,101 @@ func parseTTLFromLookup(s *vapi.Secret, t *testing.T) int64 { return ttl } + +func TestVaultClient_LookupToken_Invalid(t *testing.T) { + conf := &config.VaultConfig{ + Enabled: false, + } + + logger := log.New(os.Stderr, "", log.LstdFlags) + client, err := NewVaultClient(conf, logger) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + + _, err = client.LookupToken("foo") + if err == nil || !strings.Contains(err.Error(), "disabled") { + t.Fatalf("Expected error because Vault is disabled: %v", err) + } + + // Enable vault but use a bad address so it never establishes a conn + conf.Enabled = true + conf.Addr = "http://foobar:12345" + conf.Token = structs.GenerateUUID() + client, err = NewVaultClient(conf, logger) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + + _, err = client.LookupToken("foo") + if err == nil || !strings.Contains(err.Error(), "established") { + t.Fatalf("Expected error because connection to Vault hasn't been made: %v", err) + } +} + +func waitForConnection(v *vaultClient, t *testing.T) { + testutil.WaitForResult(func() (bool, error) { + return v.ConnectionEstablished(), nil + }, func(err error) { + t.Fatalf("Connection not established") + }) +} + +func TestVaultClient_LookupToken(t *testing.T) { + v := testutil.NewTestVault(t).Start() + defer v.Stop() + + logger := log.New(os.Stderr, "", log.LstdFlags) + client, err := NewVaultClient(v.Config, logger) + if err != nil { + t.Fatalf("failed to build vault client: %v", err) + } + + waitForConnection(client, t) + + // Lookup ourselves + s, err := client.LookupToken(v.Config.Token) + if err != nil { + t.Fatalf("self lookup failed: %v", err) + } + + policies, err := PoliciesFrom(s) + if err != nil { + t.Fatalf("failed to parse policies: %v", err) + } + + expected := []string{"root"} + if !reflect.DeepEqual(policies, expected) { + t.Fatalf("Unexpected policies; got %v; want %v", policies, expected) + } + + // Create a token with a different set of policies + expected = []string{"default"} + req := vapi.TokenCreateRequest{ + Policies: expected, + } + s, err = v.Client.Auth().Token().Create(&req) + if err != nil { + t.Fatalf("failed to create child token: %v", err) + } + + // Get the client token + if s == nil || s.Auth == nil { + t.Fatalf("bad secret response: %+v", s) + } + + // Lookup new child + s, err = client.LookupToken(s.Auth.ClientToken) + if err != nil { + t.Fatalf("self lookup failed: %v", err) + } + + policies, err = PoliciesFrom(s) + if err != nil { + t.Fatalf("failed to parse policies: %v", err) + } + + if !reflect.DeepEqual(policies, expected) { + t.Fatalf("Unexpected policies; got %v; want %v", policies, expected) + } +} diff --git a/nomad/vault_testing.go b/nomad/vault_testing.go new file mode 100644 index 00000000000..f8558ea0c0d --- /dev/null +++ b/nomad/vault_testing.go @@ -0,0 +1,71 @@ +package nomad + +import ( + "github.com/hashicorp/nomad/nomad/structs" + vapi "github.com/hashicorp/vault/api" +) + +// TestVaultClient is a Vault client appropriate for use during testing. Its +// behavior is programmable such that endpoints can be tested under various +// circumstances. +type TestVaultClient struct { + // LookupTokenErrors maps a token to an error that will be returned by the + // LookupToken call + LookupTokenErrors map[string]error + + // LookupTokenSecret maps a token to the Vault secret that will be returned + // by the LookupToken call + LookupTokenSecret map[string]*vapi.Secret +} + +func (v *TestVaultClient) LookupToken(token string) (*vapi.Secret, error) { + var secret *vapi.Secret + var err error + + if v.LookupTokenSecret != nil { + secret = v.LookupTokenSecret[token] + } + if v.LookupTokenErrors != nil { + err = v.LookupTokenErrors[token] + } + + return secret, err +} + +// SetLookupTokenSecret sets the error that will be returned by the token +// lookup +func (v *TestVaultClient) SetLookupTokenError(token string, err error) { + if v.LookupTokenErrors == nil { + v.LookupTokenErrors = make(map[string]error) + } + + v.LookupTokenErrors[token] = err +} + +// SetLookupTokenSecret sets the secret that will be returned by the token +// lookup +func (v *TestVaultClient) SetLookupTokenSecret(token string, secret *vapi.Secret) { + if v.LookupTokenSecret == nil { + v.LookupTokenSecret = make(map[string]*vapi.Secret) + } + + v.LookupTokenSecret[token] = secret +} + +// SetLookupTokenAllowedPolicies is a helper that adds a secret that allows the +// given policies +func (v *TestVaultClient) SetLookupTokenAllowedPolicies(token string, policies []string) { + s := &vapi.Secret{ + Data: map[string]interface{}{ + "policies": policies, + }, + } + + v.SetLookupTokenSecret(token, s) +} + +func (v *TestVaultClient) CreateToken(a *structs.Allocation, task string) (*vapi.Secret, error) { + return nil, nil +} + +func (v *TestVaultClient) Stop() {}