Skip to content

Commit

Permalink
AppRole/Identity: Fix for race when creating an entity during login (#…
Browse files Browse the repository at this point in the history
…3932)

* possible fix for race in approle login while creating entity

* Add a test that hits the login request concurrently

* address review comments
  • Loading branch information
vishalnayak authored Feb 9, 2018
1 parent e47c7e8 commit 5bb8fa2
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 24 deletions.
86 changes: 86 additions & 0 deletions command/approle_concurrency_integ_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package command

import (
"sync"
"testing"

"github.com/hashicorp/vault/api"
credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
vaulthttp "github.com/hashicorp/vault/http"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/vault"
logxi "github.com/mgutz/logxi/v1"
)

func TestAppRole_Integ_ConcurrentLogins(t *testing.T) {
var err error
coreConfig := &vault.CoreConfig{
DisableMlock: true,
DisableCache: true,
Logger: logxi.NullLog,
CredentialBackends: map[string]logical.Factory{
"approle": credAppRole.Factory,
},
}

cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})

cluster.Start()
defer cluster.Cleanup()

cores := cluster.Cores

vault.TestWaitActive(t, cores[0].Core)

client := cores[0].Client

err = client.Sys().EnableAuthWithOptions("approle", &api.EnableAuthOptions{
Type: "approle",
})
if err != nil {
t.Fatal(err)
}

_, err = client.Logical().Write("auth/approle/role/role1", map[string]interface{}{
"bind_secret_id": "true",
"period": "300",
})
if err != nil {
t.Fatal(err)
}

secret, err := client.Logical().Write("auth/approle/role/role1/secret-id", nil)
if err != nil {
t.Fatal(err)
}
secretID := secret.Data["secret_id"].(string)

secret, err = client.Logical().Read("auth/approle/role/role1/role-id")
if err != nil {
t.Fatal(err)
}
roleID := secret.Data["role_id"].(string)

wg := &sync.WaitGroup{}

for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
secret, err = client.Logical().Write("auth/approle/login", map[string]interface{}{
"role_id": roleID,
"secret_id": secretID,
})
if err != nil {
t.Fatal(err)
}
if secret.Auth.ClientToken == "" {
t.Fatalf("expected a successful login")
}
}()

}
wg.Wait()
}
49 changes: 43 additions & 6 deletions vault/identity_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,27 @@ func (i *IdentityStore) entityByAliasFactors(mountAccessor, aliasName string, cl
return nil, fmt.Errorf("missing alias name")
}

alias, err := i.MemDBAliasByFactors(mountAccessor, aliasName, false, false)
txn := i.db.Txn(false)

return i.entityByAliasFactorsInTxn(txn, mountAccessor, aliasName, clone)
}

// entityByAlaisFactorsInTxn fetches the entity based on factors of alias, i.e
// mount accessor and the alias name.
func (i *IdentityStore) entityByAliasFactorsInTxn(txn *memdb.Txn, mountAccessor, aliasName string, clone bool) (*identity.Entity, error) {
if txn == nil {
return nil, fmt.Errorf("nil txn")
}

if mountAccessor == "" {
return nil, fmt.Errorf("missing mount accessor")
}

if aliasName == "" {
return nil, fmt.Errorf("missing alias name")
}

alias, err := i.MemDBAliasByFactorsInTxn(txn, mountAccessor, aliasName, false, false)
if err != nil {
return nil, err
}
Expand All @@ -258,12 +278,12 @@ func (i *IdentityStore) entityByAliasFactors(mountAccessor, aliasName string, cl
return nil, nil
}

return i.MemDBEntityByAliasID(alias.ID, clone)
return i.MemDBEntityByAliasIDInTxn(txn, alias.ID, clone)
}

// CreateEntity creates a new entity. This is used by core to
// CreateOrFetchEntity creates a new entity. This is used by core to
// associate each login attempt by an alias to a unified entity in Vault.
func (i *IdentityStore) CreateEntity(alias *logical.Alias) (*identity.Entity, error) {
func (i *IdentityStore) CreateOrFetchEntity(alias *logical.Alias) (*identity.Entity, error) {
var entity *identity.Entity
var err error

Expand All @@ -290,9 +310,24 @@ func (i *IdentityStore) CreateEntity(alias *logical.Alias) (*identity.Entity, er
return nil, err
}
if entity != nil {
return nil, fmt.Errorf("alias already belongs to a different entity")
return entity, nil
}

// Create a MemDB transaction to update both alias and entity
txn := i.db.Txn(true)
defer txn.Abort()

// Check if an entity was created before acquiring the lock
entity, err = i.entityByAliasFactorsInTxn(txn, alias.MountAccessor, alias.Name, false)
if err != nil {
return nil, err
}
if entity != nil {
return entity, nil
}

i.logger.Debug("identity: creating a new entity", "alias", alias)

entity = &identity.Entity{}

err = i.sanitizeEntity(entity)
Expand Down Expand Up @@ -320,10 +355,12 @@ func (i *IdentityStore) CreateEntity(alias *logical.Alias) (*identity.Entity, er
}

// Update MemDB and persist entity object
err = i.upsertEntity(entity, nil, true)
err = i.upsertEntityInTxn(txn, entity, nil, true, false)
if err != nil {
return nil, err
}

txn.Commit()

return entity, nil
}
22 changes: 16 additions & 6 deletions vault/identity_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ import (
"github.com/hashicorp/vault/logical"
)

func TestIdentityStore_CreateEntity(t *testing.T) {
func TestIdentityStore_CreateOrFetchEntity(t *testing.T) {
is, ghAccessor, _ := testIdentityStoreWithGithubAuth(t)
alias := &logical.Alias{
MountType: "github",
MountAccessor: ghAccessor,
Name: "githubuser",
}

entity, err := is.CreateEntity(alias)
entity, err := is.CreateOrFetchEntity(alias)
if err != nil {
t.Fatal(err)
}
Expand All @@ -33,10 +33,20 @@ func TestIdentityStore_CreateEntity(t *testing.T) {
t.Fatalf("bad: alias name; expected: %q, actual: %q", alias.Name, entity.Aliases[0].Name)
}

// Try recreating an entity with the same alias details. It should fail.
entity, err = is.CreateEntity(alias)
if err == nil {
t.Fatalf("expected an error")
entity, err = is.CreateOrFetchEntity(alias)
if err != nil {
t.Fatal(err)
}
if entity == nil {
t.Fatalf("expected a non-nil entity")
}

if len(entity.Aliases) != 1 {
t.Fatalf("bad: length of aliases; expected: 1, actual: %d", len(entity.Aliases))
}

if entity.Aliases[0].Name != alias.Name {
t.Fatalf("bad: alias name; expected: %q, actual: %q", alias.Name, entity.Aliases[0].Name)
}
}

Expand Down
19 changes: 18 additions & 1 deletion vault/identity_store_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -666,12 +666,29 @@ func (i *IdentityStore) MemDBAliasByFactors(mountAccessor, aliasName string, clo
return nil, fmt.Errorf("missing mount accessor")
}

txn := i.db.Txn(false)

return i.MemDBAliasByFactorsInTxn(txn, mountAccessor, aliasName, clone, groupAlias)
}

func (i *IdentityStore) MemDBAliasByFactorsInTxn(txn *memdb.Txn, mountAccessor, aliasName string, clone bool, groupAlias bool) (*identity.Alias, error) {
if txn == nil {
return nil, fmt.Errorf("nil txn")
}

if aliasName == "" {
return nil, fmt.Errorf("missing alias name")
}

if mountAccessor == "" {
return nil, fmt.Errorf("missing mount accessor")
}

tableName := entityAliasesTable
if groupAlias {
tableName = groupAliasesTable
}

txn := i.db.Txn(false)
aliasRaw, err := txn.First(tableName, "factors", mountAccessor, aliasName)
if err != nil {
return nil, fmt.Errorf("failed to fetch alias from memdb using factors: %v", err)
Expand Down
15 changes: 4 additions & 11 deletions vault/request_handling.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,22 +436,15 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re

var err error

// Check if an entity already exists for the given alias
entity, err = c.identityStore.entityByAliasFactors(auth.Alias.MountAccessor, auth.Alias.Name, false)
// Fetch the entity for the alias, or create an entity if one
// doesn't exist.
entity, err = c.identityStore.CreateOrFetchEntity(auth.Alias)
if err != nil {
return nil, nil, err
}

// If not, create one.
if entity == nil {
c.logger.Debug("core: creating a new entity", "alias", auth.Alias)
entity, err = c.identityStore.CreateEntity(auth.Alias)
if err != nil {
return nil, nil, err
}
if entity == nil {
return nil, nil, fmt.Errorf("failed to create an entity for the authenticated alias")
}
return nil, nil, fmt.Errorf("failed to create an entity for the authenticated alias")
}

auth.EntityID = entity.ID
Expand Down

0 comments on commit 5bb8fa2

Please sign in to comment.