Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADR 019 password rotation #523

Merged
merged 6 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 70 additions & 22 deletions neo4j/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ package auth

import (
"context"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/db"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/auth"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/collections"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/racing"
"reflect"
"time"
Expand Down Expand Up @@ -51,19 +53,23 @@ type TokenManager interface {
// The token returned must always belong to the same identity.
// Switching identities using the `TokenManager` is undefined behavior.
GetAuthToken(ctx context.Context) (auth.Token, error)
// OnTokenExpired is called by the driver when the provided token expires
// OnTokenExpired should invalidate the current token if it matches the provided one
OnTokenExpired(context.Context, auth.Token) error

// HandleSecurityException is called when the server returns any `Neo.ClientError.Security.*` error.
// It should return true if the error was handled, in which case the driver will mark the error as retryable.
HandleSecurityException(context.Context, auth.Token, *db.Neo4jError) (bool, error)
}

type authTokenProvider = func(context.Context) (auth.Token, error)

type authTokenWithExpirationProvider = func(context.Context) (auth.Token, *time.Time, error)

type expirationBasedTokenManager struct {
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
provider authTokenWithExpirationProvider
token *auth.Token
expiration *time.Time
mutex racing.Mutex
now *func() time.Time
provider authTokenWithExpirationProvider
token *auth.Token
expiration *time.Time
mutex racing.Mutex
now *func() time.Time
handledSecurityCodes collections.Set[string]
}

func (m *expirationBasedTokenManager) GetAuthToken(ctx context.Context) (auth.Token, error) {
Expand All @@ -83,34 +89,76 @@ func (m *expirationBasedTokenManager) GetAuthToken(ctx context.Context) (auth.To
return *m.token, nil
}

func (m *expirationBasedTokenManager) OnTokenExpired(ctx context.Context, token auth.Token) error {
func (m *expirationBasedTokenManager) HandleSecurityException(ctx context.Context, token auth.Token, securityException *db.Neo4jError) (bool, error) {
if !m.handledSecurityCodes.Contains(securityException.Code) {
return false, nil
}
if !m.mutex.TryLock(ctx) {
return racing.LockTimeoutError(
"could not acquire lock in time when handling token expiration in ExpirationBasedTokenManager")
return false, racing.LockTimeoutError(
"could not acquire lock in time when handling token expiration in expirationBasedTokenManager")
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
}
defer m.mutex.Unlock()
if m.token != nil && reflect.DeepEqual(token.Tokens, m.token.Tokens) {
m.token = nil
}
return nil
return true, nil
}

// ExpirationBasedTokenManager creates a token manager for potentially expiring auth info.
// Basic creates a TokenManager handling basic auth password rotation.
// The provider function returns basic auth information and is assumed to never expire.
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
//
// WARNING:
//
// The provider function *must not* interact with the driver in any way as this can cause deadlocks and undefined
// behaviour.
//
// The first and only argument is a provider function that returns auth information and an optional expiration time.
// If the expiration time is nil, the auth info is assumed to never expire.
// The provider function must only ever return auth information belonging to the same identity.
// Switching identities is undefined behavior.
//
// Basic is part of the re-authentication preview feature
// (see README on what it means in terms of support and compatibility guarantees)
func Basic(provider authTokenProvider) TokenManager {
now := time.Now
return &expirationBasedTokenManager{
provider: wrapWithNilExpiration(provider),
mutex: racing.NewMutex(),
now: &now,
handledSecurityCodes: collections.NewSet([]string{
"Neo.ClientError.Security.Unauthorized",
}),
}
}

// Bearer creates a TokenManager handling potentially expiring auth information.
// The provider function returns auth information and an optional expiration time.
// If the expiration time is nil, the auth information is assumed to never expire.
//
// WARNING:
//
// The provider function *must not* interact with the driver in any way as this can cause deadlocks and undefined
// behaviour.
// The provider function *must not* interact with the driver in any way as this can cause deadlocks and undefined
// behaviour.
//
// The provider function only ever return auth information belonging to the same identity.
// Switching identities is undefined behavior.
// The provider function must only ever return auth information belonging to the same identity.
// Switching identities is undefined behavior.
//
// ExpirationBasedTokenManager is part of the re-authentication preview feature
// Bearer is part of the re-authentication preview feature
// (see README on what it means in terms of support and compatibility guarantees)
func ExpirationBasedTokenManager(provider authTokenWithExpirationProvider) TokenManager {
func Bearer(provider authTokenWithExpirationProvider) TokenManager {
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
now := time.Now
return &expirationBasedTokenManager{provider: provider, mutex: racing.NewMutex(), now: &now}
return &expirationBasedTokenManager{
provider: provider,
mutex: racing.NewMutex(),
now: &now,
handledSecurityCodes: collections.NewSet([]string{
"Neo.ClientError.Security.TokenExpired",
"Neo.ClientError.Security.Unauthorized",
}),
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
}
}

func wrapWithNilExpiration(provider authTokenProvider) authTokenWithExpirationProvider {
return func(ctx context.Context) (auth.Token, *time.Time, error) {
token, err := provider(ctx)
return token, nil, err
}
}
24 changes: 20 additions & 4 deletions neo4j/auth/auth_examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,39 @@ import (
"time"
)

func ExampleExpirationBasedTokenManager() {
myProvider := func(ctx context.Context) (neo4j.AuthToken, *time.Time, error) {
func ExampleBasic() {
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
fetchBasicAuthToken := func(ctx context.Context) (neo4j.AuthToken, error) {
username, password, realm, err := getBasicAuth()
if err != nil {
return neo4j.AuthToken{}, err
}
return neo4j.BasicAuth(username, password, realm), nil
}

_, _ = neo4j.NewDriverWithContext(getUrl(), auth.Basic(fetchBasicAuthToken))
}

func ExampleBearer() {
fetchAuthTokenFromMyProvider := func(ctx context.Context) (neo4j.AuthToken, *time.Time, error) {
// some way to getting a token
token, err := getSsoToken(ctx)
if err != nil {
return neo4j.AuthToken{}, nil, err
}
// assume we know our tokens expire every 60 seconds

expiresIn := time.Now().Add(60 * time.Second)
// Include a little buffer so that we fetch a new token *before* the old one expires
expiresIn = expiresIn.Add(-10 * time.Second)
// or return nil instead of `&expiresIn` if we don't expect it to expire
return token, &expiresIn, nil
}

_, _ = neo4j.NewDriverWithContext(getUrl(), auth.ExpirationBasedTokenManager(myProvider))
_, _ = neo4j.NewDriverWithContext(getUrl(), auth.Bearer(fetchAuthTokenFromMyProvider))
}

func getBasicAuth() (username, password, realm string, error error) {
username, password, realm = "username", "password", "realm"
return
}

func getSsoToken(context.Context) (neo4j.AuthToken, error) {
Expand Down
4 changes: 4 additions & 0 deletions neo4j/db/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ func (e *Neo4jError) reclassify() {
}
}

func (e *Neo4jError) HasSecurityCode() bool {
return strings.HasPrefix(e.Code, "Neo.ClientError.Security.")
}

func (e *Neo4jError) IsAuthenticationFailed() bool {
return e.Code == "Neo.ClientError.Security.Unauthorized"
}
Expand Down
9 changes: 7 additions & 2 deletions neo4j/internal/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

package auth

import "context"
import (
"context"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/db"
)

type Token struct {
Tokens map[string]any
Expand All @@ -29,4 +32,6 @@ func (a Token) GetAuthToken(context.Context) (Token, error) {
return a, nil
}

func (a Token) OnTokenExpired(context.Context, Token) error { return nil }
func (a Token) HandleSecurityException(context.Context, Token, *db.Neo4jError) (bool, error) {
return false, nil
}
9 changes: 9 additions & 0 deletions neo4j/internal/collections/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,12 @@ func (set Set[T]) Copy() Set[T] {
}
return result
}

func (set Set[T]) Contains(value T) bool {
for _, a := range set.Values() {
if a == value {
return true
}
}
return false
}
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
11 changes: 11 additions & 0 deletions neo4j/internal/collections/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,17 @@ func TestSet(outer *testing.T) {
t.Error(err)
}
})

outer.Run("contains", func(t *testing.T) {
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
strings := collections.NewSet([]string{
"golang",
"neo4j",
})
expected := "golang"
if found := strings.Contains(expected); !found {
t.Errorf("Set does not contain %v", expected)
}
})
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
}

func containsExactlyOnce[T comparable](values collections.Set[T], search T) bool {
Expand Down
28 changes: 14 additions & 14 deletions neo4j/internal/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"context"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/config"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/db"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/auth"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/bolt"
idb "github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/db"
"github.com/neo4j/neo4j-go-driver/v5/neo4j/internal/errorutil"
Expand Down Expand Up @@ -453,35 +452,36 @@ func (p *Pool) Return(ctx context.Context, c idb.Connection) {
}

func (p *Pool) OnNeo4jError(ctx context.Context, connection idb.Connection, error *db.Neo4jError) error {
switch error.Code {
case "Neo.ClientError.Security.AuthorizationExpired":
if error.Code == "Neo.ClientError.Security.AuthorizationExpired" {
serverName := connection.ServerName()
p.serversMut.Lock()
defer p.serversMut.Unlock()
server := p.servers[serverName]
server.executeForAllConnections(func(c idb.Connection) {
c.ResetAuth()
})
case "Neo.ClientError.Security.TokenExpired":
}
if error.HasSecurityCode() {
manager, token := connection.GetCurrentAuth()
if manager != nil {
if err := manager.OnTokenExpired(ctx, token); err != nil {
handled, err := manager.HandleSecurityException(ctx, token, error)
if err != nil {
return err
}
if _, isStaticToken := manager.(auth.Token); !isStaticToken {
if handled {
error.MarkRetriable()
}
}
case "Neo.TransientError.General.DatabaseUnavailable":
}
if error.Code == "Neo.TransientError.General.DatabaseUnavailable" {
p.deactivate(ctx, connection.ServerName())
default:
if error.IsRetriableCluster() {
var database string
if dbSelector, ok := connection.(idb.DatabaseSelector); ok {
database = dbSelector.Database()
}
p.deactivateWriter(connection.ServerName(), database)
}
if error.IsRetriableCluster() {
var database string
if dbSelector, ok := connection.(idb.DatabaseSelector); ok {
database = dbSelector.Database()
}
p.deactivateWriter(connection.ServerName(), database)
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
}

return nil
Expand Down
38 changes: 20 additions & 18 deletions testkit-backend/backend.go
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ type sessionState struct {
}

type GenericTokenManager struct {
GetAuthTokenFunc func() neo4j.AuthToken
OnTokenExpiredFunc func(neo4j.AuthToken)
GetAuthTokenFunc func() neo4j.AuthToken
HandleSecurityExceptionFunc func(neo4j.AuthToken, *db.Neo4jError)
}

type AuthTokenAndExpiration struct {
Expand All @@ -98,9 +98,9 @@ func (g GenericTokenManager) GetAuthToken(_ context.Context) (neo4j.AuthToken, e
return g.GetAuthTokenFunc(), nil
}

func (g GenericTokenManager) OnTokenExpired(_ context.Context, token neo4j.AuthToken) error {
g.OnTokenExpiredFunc(token)
return nil
func (g GenericTokenManager) HandleSecurityException(_ context.Context, token neo4j.AuthToken, securityException *db.Neo4jError) (bool, error) {
g.HandleSecurityExceptionFunc(token, securityException)
return false, nil
StephenCathcart marked this conversation as resolved.
Show resolved Hide resolved
}

const (
Expand Down Expand Up @@ -1004,14 +1004,15 @@ func (b *backend) handleRequest(req map[string]any) {
}
}
},
OnTokenExpiredFunc: func(token neo4j.AuthToken) {
HandleSecurityExceptionFunc: func(token neo4j.AuthToken, error *db.Neo4jError) {
id := b.nextId()
b.writeResponse(
"AuthTokenManagerOnAuthExpiredRequest",
"AuthTokenManagerHandleSecurityExceptionRequest",
map[string]any{
"id": id,
"authTokenManagerId": managerId,
"auth": serializeAuth(token),
"errorCode": error.Code,
})
for {
b.process()
Expand All @@ -1032,35 +1033,36 @@ func (b *backend) handleRequest(req map[string]any) {
return
}
b.resolvedGetAuthTokens[id] = token
case "AuthTokenManagerOnAuthExpiredCompleted":
case "AuthTokenManagerHandleSecurityExceptionCompleted":
handled := data["handled"].(bool)
id := data["requestId"].(string)
b.resolvedOnTokenExpiries[id] = true
case "NewExpirationBasedAuthTokenManager":
b.resolvedOnTokenExpiries[id] = handled
case "NewBasicAuthTokenManager":
managerId := b.nextId()

manager := auth.ExpirationBasedTokenManager(
func(context.Context) (neo4j.AuthToken, *time.Time, error) {
manager := auth.Basic(
func(context.Context) (neo4j.AuthToken, error) {
id := b.nextId()
b.writeResponse(
"ExpirationBasedAuthTokenProviderRequest",
"BasicAuthTokenProviderRequest",
map[string]any{
"id": id,
"expirationBasedAuthTokenManagerId": managerId,
"id": id,
"basicAuthTokenManagerId": managerId,
})
for {
b.process()
if expiringToken, ok := b.resolvedExpiringTokens[id]; ok {
delete(b.resolvedExpiringTokens, id)
return expiringToken.token, expiringToken.expiration, nil
return expiringToken.token, nil
}
}
})
if b.timer != nil {
auth.SetTimer(manager, b.timer.Now)
}
b.authTokenManagers[managerId] = manager
b.writeResponse("ExpirationBasedAuthTokenManager", map[string]any{"id": managerId})
case "ExpirationBasedAuthTokenProviderCompleted":
b.writeResponse("BasicAuthTokenManager", map[string]any{"id": managerId})
case "BearerAuthTokenProviderCompleted":
id := data["requestId"].(string)
expiringToken := data["auth"].(map[string]any)["data"].(map[string]any)
token, err := getAuth(expiringToken["auth"].(map[string]any)["data"].(map[string]any))
Expand Down