diff --git a/auth/store.go b/auth/store.go index 6b2c6c92312..6d218d3131e 100644 --- a/auth/store.go +++ b/auth/store.go @@ -182,6 +182,17 @@ type authStore struct { indexWaiter func(uint64) <-chan struct{} } +func newDeleterFunc(as *authStore) func(string) { + return func(t string) { + as.simpleTokensMu.Lock() + defer as.simpleTokensMu.Unlock() + if username, ok := as.simpleTokens[t]; ok { + plog.Infof("deleting token %s for user %s", t, username) + delete(as.simpleTokens, t) + } + } +} + func (as *authStore) AuthEnable() error { as.enabledMu.Lock() defer as.enabledMu.Unlock() @@ -210,15 +221,7 @@ func (as *authStore) AuthEnable() error { as.enabled = true - tokenDeleteFunc := func(t string) { - as.simpleTokensMu.Lock() - defer as.simpleTokensMu.Unlock() - if username, ok := as.simpleTokens[t]; ok { - plog.Infof("deleting token %s for user %s", t, username) - delete(as.simpleTokens, t) - } - } - as.simpleTokenKeeper = NewSimpleTokenTTLKeeper(tokenDeleteFunc) + as.simpleTokenKeeper = NewSimpleTokenTTLKeeper(newDeleterFunc(as)) as.rangePermCache = make(map[string]*unifiedRangePermissions) @@ -892,11 +895,25 @@ func NewAuthStore(be backend.Backend, indexWaiter func(uint64) <-chan struct{}) tx.UnsafeCreateBucket(authUsersBucketName) tx.UnsafeCreateBucket(authRolesBucketName) + enabled := false + _, vs := tx.UnsafeRange(authBucketName, enableFlagKey, nil, 0) + if len(vs) == 1 { + if bytes.Equal(vs[0], authEnabled) { + enabled = true + } + } + as := &authStore{ - be: be, - simpleTokens: make(map[string]string), - revision: 0, - indexWaiter: indexWaiter, + be: be, + simpleTokens: make(map[string]string), + revision: getRevision(tx), + indexWaiter: indexWaiter, + enabled: enabled, + rangePermCache: make(map[string]*unifiedRangePermissions), + } + + if enabled { + as.simpleTokenKeeper = NewSimpleTokenTTLKeeper(newDeleterFunc(as)) } as.commitRevision(tx) @@ -926,7 +943,8 @@ func (as *authStore) commitRevision(tx backend.BatchTx) { func getRevision(tx backend.BatchTx) uint64 { _, vs := tx.UnsafeRange(authBucketName, []byte(revisionKey), nil, 0) if len(vs) != 1 { - plog.Panicf("failed to get the key of auth store revision") + // this can happen in the initialization phase + return 0 } return binary.BigEndian.Uint64(vs[0]) diff --git a/auth/store_test.go b/auth/store_test.go index 72ea66b0758..3bc3a001456 100644 --- a/auth/store_test.go +++ b/auth/store_test.go @@ -496,6 +496,44 @@ func TestIsAdminPermitted(t *testing.T) { } } +func TestRecoverFromSnapshot(t *testing.T) { + as, _ := setupAuthStore(t) + + ua := &pb.AuthUserAddRequest{Name: "foo"} + _, err := as.UserAdd(ua) // add an existing user + if err == nil { + t.Fatalf("expected %v, got %v", ErrUserAlreadyExist, err) + } + if err != ErrUserAlreadyExist { + t.Fatalf("expected %v, got %v", ErrUserAlreadyExist, err) + } + + ua = &pb.AuthUserAddRequest{Name: ""} + _, err = as.UserAdd(ua) // add a user with empty name + if err != ErrUserEmpty { + t.Fatal(err) + } + + as.Close() + + as2 := NewAuthStore(as.be, dummyIndexWaiter) + defer func(a *authStore) { + a.Close() + }(as2) + + if !as2.isAuthEnabled() { + t.Fatal("recovering authStore from existing backend failed") + } + + ul, err := as.UserList(&pb.AuthUserListRequest{}) + if err != nil { + t.Fatal(err) + } + if !contains(ul.Users, "root") { + t.Errorf("expected %v in %v", "root", ul.Users) + } +} + func contains(array []string, str string) bool { for _, s := range array { if s == str {