Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AppRole/Identity: Fix for race when creating an entity during login #3932

Merged
merged 3 commits into from
Feb 9, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions command/approle_concurrency_integ_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package command

import (
"sync"
"testing"

"github.com/hashicorp/vault/api"
credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
logxi "github.com/mgutz/logxi/v1"
)

func TestAppRole_Integ_ConcurrentLogins(t *testing.T) {
var err error
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: logxi.NullLog,
CredentialBackends: map[string]logical.Factory{
"approle": credAppRole.Factory,
},
}

cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})

cluster.Start()
defer cluster.Cleanup()

cores := cluster.Cores

vault.TestWaitActive(t, cores[0].Core)

client := cores[0].Client

err = client.Sys().EnableAuthWithOptions("approle", &api.EnableAuthOptions{
Type: "approle",
})
if err != nil {
t.Fatal(err)
}

_, err = client.Logical().Write("auth/approle/role/role1", map[string]interface{}{
"bind_secret_id": "true",
"period": "300",
})
if err != nil {
t.Fatal(err)
}

secret, err := client.Logical().Write("auth/approle/role/role1/secret-id", nil)
if err != nil {
t.Fatal(err)
}
secretID := secret.Data["secret_id"].(string)

secret, err = client.Logical().Read("auth/approle/role/role1/role-id")
if err != nil {
t.Fatal(err)
}
roleID := secret.Data["role_id"].(string)

wg := &sync.WaitGroup{}

for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
secret, err = client.Logical().Write("auth/approle/login", map[string]interface{}{
"role_id": roleID,
"secret_id": secretID,
})
if err != nil {
t.Fatal(err)
}
if secret.Auth.ClientToken == "" {
t.Fatalf("expected a successful login")
}
}()

}
wg.Wait()
}
47 changes: 42 additions & 5 deletions vault/identity_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,27 @@ func (i *IdentityStore) entityByAliasFactors(mountAccessor, aliasName string, cl
return nil, fmt.Errorf("missing alias name")
}

alias, err := i.MemDBAliasByFactors(mountAccessor, aliasName, false, false)
txn := i.db.Txn(false)

return i.entityByAliasFactorsInTxn(txn, mountAccessor, aliasName, clone)
}

// entityByAlaisFactorsInTxn fetches the entity based on factors of alias, i.e
// mount accessor and the alias name.
func (i *IdentityStore) entityByAliasFactorsInTxn(txn *memdb.Txn, mountAccessor, aliasName string, clone bool) (*identity.Entity, error) {
if txn == nil {
return nil, fmt.Errorf("nil txn")
}

if mountAccessor == "" {
return nil, fmt.Errorf("missing mount accessor")
}

if aliasName == "" {
return nil, fmt.Errorf("missing alias name")
}

alias, err := i.MemDBAliasByFactorsInTxn(txn, mountAccessor, aliasName, false, false)
if err != nil {
return nil, err
}
Expand All @@ -258,12 +278,12 @@ func (i *IdentityStore) entityByAliasFactors(mountAccessor, aliasName string, cl
return nil, nil
}

return i.MemDBEntityByAliasID(alias.ID, clone)
return i.MemDBEntityByAliasIDInTxn(txn, alias.ID, clone)
}

// CreateEntity creates a new entity. This is used by core to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget to update this comment :) Small nit, but I prefer CreateOrFetchEntity.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// associate each login attempt by an alias to a unified entity in Vault.
func (i *IdentityStore) CreateEntity(alias *logical.Alias) (*identity.Entity, error) {
func (i *IdentityStore) FetchOrCreateEntity(alias *logical.Alias) (*identity.Entity, error) {
var entity *identity.Entity
var err error

Expand All @@ -290,9 +310,24 @@ func (i *IdentityStore) CreateEntity(alias *logical.Alias) (*identity.Entity, er
return nil, err
}
if entity != nil {
return nil, fmt.Errorf("alias already belongs to a different entity")
return entity, nil
}

// Create a MemDB transaction to update both alias and entity
txn := i.db.Txn(true)
defer txn.Abort()

// Check if an entity was created before acquiring the lock
entity, err = i.entityByAliasFactorsInTxn(txn, alias.MountAccessor, alias.Name, false)
if err != nil {
return nil, err
}
if entity != nil {
return entity, nil
}

i.logger.Debug("identity: creating a new entity", "alias", alias)

entity = &identity.Entity{}

err = i.sanitizeEntity(entity)
Expand Down Expand Up @@ -320,10 +355,12 @@ func (i *IdentityStore) CreateEntity(alias *logical.Alias) (*identity.Entity, er
}

// Update MemDB and persist entity object
err = i.upsertEntity(entity, nil, true)
err = i.upsertEntityInTxn(txn, entity, nil, true, false)
if err != nil {
return nil, err
}

txn.Commit()

return entity, nil
}
22 changes: 16 additions & 6 deletions vault/identity_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ import (
"github.com/hashicorp/vault/logical"
)

func TestIdentityStore_CreateEntity(t *testing.T) {
func TestIdentityStore_FetchOrCreateEntity(t *testing.T) {
is, ghAccessor, _ := testIdentityStoreWithGithubAuth(t)
alias := &logical.Alias{
MountType: "github",
MountAccessor: ghAccessor,
Name: "githubuser",
}

entity, err := is.CreateEntity(alias)
entity, err := is.FetchOrCreateEntity(alias)
if err != nil {
t.Fatal(err)
}
Expand All @@ -33,10 +33,20 @@ func TestIdentityStore_CreateEntity(t *testing.T) {
t.Fatalf("bad: alias name; expected: %q, actual: %q", alias.Name, entity.Aliases[0].Name)
}

// Try recreating an entity with the same alias details. It should fail.
entity, err = is.CreateEntity(alias)
if err == nil {
t.Fatalf("expected an error")
entity, err = is.FetchOrCreateEntity(alias)
if err != nil {
t.Fatal(err)
}
if entity == nil {
t.Fatalf("expected a non-nil entity")
}

if len(entity.Aliases) != 1 {
t.Fatalf("bad: length of aliases; expected: 1, actual: %d", len(entity.Aliases))
}

if entity.Aliases[0].Name != alias.Name {
t.Fatalf("bad: alias name; expected: %q, actual: %q", alias.Name, entity.Aliases[0].Name)
}
}

Expand Down
19 changes: 18 additions & 1 deletion vault/identity_store_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -666,12 +666,29 @@ func (i *IdentityStore) MemDBAliasByFactors(mountAccessor, aliasName string, clo
return nil, fmt.Errorf("missing mount accessor")
}

txn := i.db.Txn(false)

return i.MemDBAliasByFactorsInTxn(txn, mountAccessor, aliasName, clone, groupAlias)
}

func (i *IdentityStore) MemDBAliasByFactorsInTxn(txn *memdb.Txn, mountAccessor, aliasName string, clone bool, groupAlias bool) (*identity.Alias, error) {
if txn == nil {
return nil, fmt.Errorf("nil txn")
}

if aliasName == "" {
return nil, fmt.Errorf("missing alias name")
}

if mountAccessor == "" {
return nil, fmt.Errorf("missing mount accessor")
}

tableName := entityAliasesTable
if groupAlias {
tableName = groupAliasesTable
}

txn := i.db.Txn(false)
aliasRaw, err := txn.First(tableName, "factors", mountAccessor, aliasName)
if err != nil {
return nil, fmt.Errorf("failed to fetch alias from memdb using factors: %v", err)
Expand Down
15 changes: 4 additions & 11 deletions vault/request_handling.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,22 +436,15 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re

var err error

// Check if an entity already exists for the given alias
entity, err = c.identityStore.entityByAliasFactors(auth.Alias.MountAccessor, auth.Alias.Name, false)
// Fetch the entity for the alias, or create an entity if one
// doesn't exist.
entity, err = c.identityStore.FetchOrCreateEntity(auth.Alias)
if err != nil {
return nil, nil, err
}

// If not, create one.
if entity == nil {
c.logger.Debug("core: creating a new entity", "alias", auth.Alias)
entity, err = c.identityStore.CreateEntity(auth.Alias)
if err != nil {
return nil, nil, err
}
if entity == nil {
return nil, nil, fmt.Errorf("failed to create an entity for the authenticated alias")
}
return nil, nil, fmt.Errorf("failed to create an entity for the authenticated alias")
}

auth.EntityID = entity.ID
Expand Down