From 0f4ddfe12c8cebc481ad9a7229f539bdce1602f6 Mon Sep 17 00:00:00 2001 From: Chris Hoffman Date: Tue, 4 Sep 2018 14:18:59 -0400 Subject: [PATCH] Fixing capabilities check for templated policies (#5250) * fixing capabilities check for templated policies * remove unnecessary change * formatting --- vault/capabilities.go | 29 +++----------- vault/capabilities_test.go | 77 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 23 deletions(-) diff --git a/vault/capabilities.go b/vault/capabilities.go index c3ffa9ee9af5..84212397956e 100644 --- a/vault/capabilities.go +++ b/vault/capabilities.go @@ -25,46 +25,29 @@ func (c *Core) Capabilities(ctx context.Context, token, path string) ([]string, return nil, &logical.StatusBadRequest{Err: "invalid token"} } - if te.Policies == nil { - return []string{DenyCapability}, nil - } - - var policies []*Policy - for _, tePolicy := range te.Policies { - policy, err := c.policyStore.GetPolicy(ctx, tePolicy, PolicyTypeToken) - if err != nil { - return nil, err - } - policies = append(policies, policy) - } + // Start with token entry policies + policies := te.Policies + // Fetch entity and entity group policies entity, derivedPolicies, err := c.fetchEntityAndDerivedPolicies(te.EntityID) if err != nil { return nil, err } - if entity != nil && entity.Disabled { c.logger.Warn("permission denied as the entity on the token is disabled") return nil, logical.ErrPermissionDenied } - if te != nil && te.EntityID != "" && entity == nil { + if te.EntityID != "" && entity == nil { c.logger.Warn("permission denied as the entity on the token is invalid") return nil, logical.ErrPermissionDenied } - - for _, item := range derivedPolicies { - policy, err := c.policyStore.GetPolicy(ctx, item, PolicyTypeToken) - if err != nil { - return nil, err - } - policies = append(policies, policy) - } + policies = append(policies, derivedPolicies...) if len(policies) == 0 { return []string{DenyCapability}, nil } - acl, err := NewACL(policies) + acl, err := c.policyStore.ACL(ctx, entity, policies...) if err != nil { return nil, err } diff --git a/vault/capabilities_test.go b/vault/capabilities_test.go index 368478838381..b03c545f170a 100644 --- a/vault/capabilities_test.go +++ b/vault/capabilities_test.go @@ -2,6 +2,7 @@ package vault import ( "context" + "fmt" "reflect" "sort" "testing" @@ -115,6 +116,82 @@ path "secret/sample" { } } +func TestCapabilities_TemplatedPolicies(t *testing.T) { + var resp *logical.Response + var err error + + i, _, c := testIdentityStoreWithGithubAuth(t) + + // Create an entity and assign policy1 to it + entityReq := &logical.Request{ + Path: "entity", + Operation: logical.UpdateOperation, + } + resp, err = i.HandleRequest(context.Background(), entityReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: resp: %#v\nerr: %#v\n", resp, err) + } + entityID := resp.Data["id"].(string) + + // Create a token for the entity and assign policy2 on the token + ent := &logical.TokenEntry{ + ID: "capabilitiestoken", + Path: "auth/token/create", + Policies: []string{"testpolicy"}, + EntityID: entityID, + TTL: time.Hour, + } + testMakeTokenDirectly(t, c.tokenStore, ent) + + tCases := []struct { + policy string + path string + expected []string + }{ + { + `name = "testpolicy" + path "secret/{{identity.entity.id}}/sample" { + capabilities = ["update", "create"] + } + `, + fmt.Sprintf("secret/%s/sample", entityID), + []string{"update", "create"}, + }, + { + `{"name": "testpolicy", "path": {"secret/{{identity.entity.id}}/sample": {"capabilities": ["read", "create"]}}}`, + fmt.Sprintf("secret/%s/sample", entityID), + []string{"read", "create"}, + }, + { + `{"name": "testpolicy", "path": {"secret/sample": {"capabilities": ["read"]}}}`, + "secret/sample", + []string{"read"}, + }, + } + + for _, tCase := range tCases { + // Create the above policies + policy, err := ParseACLPolicy(tCase.policy) + if err != nil { + t.Fatalf("err: %v", err) + } + err = c.policyStore.SetPolicy(context.Background(), policy) + if err != nil { + t.Fatalf("err: %v", err) + } + + actual, err := c.Capabilities(context.Background(), "capabilitiestoken", tCase.path) + if err != nil { + t.Fatalf("err: %v", err) + } + sort.Strings(actual) + sort.Strings(tCase.expected) + if !reflect.DeepEqual(actual, tCase.expected) { + t.Fatalf("bad: got\n%#v\nexpected\n%#v\n", actual, tCase.expected) + } + } +} + func TestCapabilities(t *testing.T) { c, _, token := TestCoreUnsealed(t)