From 24220485e9ada8d9d71132d6a495eb8a533ce8bf Mon Sep 17 00:00:00 2001 From: Matt Dale <9760375+matthewdale@users.noreply.github.com> Date: Tue, 20 Aug 2024 10:04:59 -0700 Subject: [PATCH] GODRIVER-3215 Fix default auth source for auth specified via ClientOptions. --- mongo/client.go | 49 +------ mongo/client_test.go | 76 ----------- x/mongo/driver/auth/mongodbaws.go | 5 +- x/mongo/driver/auth/mongodbcr.go | 6 +- x/mongo/driver/auth/oidc.go | 3 + x/mongo/driver/auth/plain.go | 8 +- x/mongo/driver/auth/scram.go | 12 +- x/mongo/driver/auth/x509.go | 3 + x/mongo/driver/connstring/connstring.go | 9 +- x/mongo/driver/topology/topology_options.go | 125 +++++++++++------- .../driver/topology/topology_options_test.go | 76 +++++++++++ 11 files changed, 189 insertions(+), 183 deletions(-) diff --git a/mongo/client.go b/mongo/client.go index 00f4f363aec..086b2459a01 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -26,7 +26,6 @@ import ( "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" - "go.mongodb.org/mongo-driver/x/mongo/driver/auth" "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt" mcopts "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" @@ -211,43 +210,16 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { clientOpt.SetMaxPoolSize(defaultMaxPoolSize) } - if clientOpt.Auth != nil { - var oidcMachineCallback auth.OIDCCallback - if clientOpt.Auth.OIDCMachineCallback != nil { - oidcMachineCallback = func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { - cred, err := clientOpt.Auth.OIDCMachineCallback(ctx, convertOIDCArgs(args)) - return (*driver.OIDCCredential)(cred), err - } - } - - var oidcHumanCallback auth.OIDCCallback - if clientOpt.Auth.OIDCHumanCallback != nil { - oidcHumanCallback = func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { - cred, err := clientOpt.Auth.OIDCHumanCallback(ctx, convertOIDCArgs(args)) - return (*driver.OIDCCredential)(cred), err - } - } - - // Create an authenticator for the client - client.authenticator, err = auth.CreateAuthenticator(clientOpt.Auth.AuthMechanism, &auth.Cred{ - Source: clientOpt.Auth.AuthSource, - Username: clientOpt.Auth.Username, - Password: clientOpt.Auth.Password, - PasswordSet: clientOpt.Auth.PasswordSet, - Props: clientOpt.Auth.AuthMechanismProperties, - OIDCMachineCallback: oidcMachineCallback, - OIDCHumanCallback: oidcHumanCallback, - }, clientOpt.HTTPClient) - if err != nil { - return nil, err - } + client.authenticator, err = topology.NewAuthenticator(clientOpt.Auth, clientOpt.HTTPClient) + if err != nil { + return nil, fmt.Errorf("error creating authenticator: %w", err) } cfg, err := topology.NewConfigWithAuthenticator(clientOpt, client.clock, client.authenticator) - if err != nil { return nil, err } + client.serverAPI = topology.ServerAPIFromServerOptions(cfg.ServerOpts) if client.deployment == nil { @@ -266,19 +238,6 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { return client, nil } -// convertOIDCArgs converts the internal *driver.OIDCArgs into the equivalent -// public type *options.OIDCArgs. -func convertOIDCArgs(args *driver.OIDCArgs) *options.OIDCArgs { - if args == nil { - return nil - } - return &options.OIDCArgs{ - Version: args.Version, - IDPInfo: (*options.IDPInfo)(args.IDPInfo), - RefreshToken: args.RefreshToken, - } -} - // Connect initializes the Client by starting background monitoring goroutines. // If the Client was created using the NewClient function, this method must be called before a Client can be used. // diff --git a/mongo/client_test.go b/mongo/client_test.go index 0a96e545011..013c1ae6bba 100644 --- a/mongo/client_test.go +++ b/mongo/client_test.go @@ -11,7 +11,6 @@ import ( "errors" "math" "os" - "reflect" "testing" "time" @@ -19,13 +18,11 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/integtest" - "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/tag" - "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt" "go.mongodb.org/mongo-driver/x/mongo/driver/session" "go.mongodb.org/mongo-driver/x/mongo/driver/topology" @@ -505,76 +502,3 @@ func TestClient(t *testing.T) { } }) } - -// Test that convertOIDCArgs exhaustively copies all fields of a driver.OIDCArgs -// into an options.OIDCArgs. -func TestConvertOIDCArgs(t *testing.T) { - refreshToken := "test refresh token" - - testCases := []struct { - desc string - args *driver.OIDCArgs - }{ - { - desc: "populated args", - args: &driver.OIDCArgs{ - Version: 9, - IDPInfo: &driver.IDPInfo{ - Issuer: "test issuer", - ClientID: "test client ID", - RequestScopes: []string{"test scope 1", "test scope 2"}, - }, - RefreshToken: &refreshToken, - }, - }, - { - desc: "nil", - args: nil, - }, - { - desc: "nil IDPInfo and RefreshToken", - args: &driver.OIDCArgs{ - Version: 9, - IDPInfo: nil, - RefreshToken: nil, - }, - }, - } - - for _, tc := range testCases { - tc := tc // Capture range variable. - - t.Run(tc.desc, func(t *testing.T) { - t.Parallel() - - got := convertOIDCArgs(tc.args) - - if tc.args == nil { - assert.Nil(t, got, "expected nil when input is nil") - return - } - - require.Equal(t, - 3, - reflect.ValueOf(*tc.args).NumField(), - "expected the driver.OIDCArgs struct to have exactly 3 fields") - require.Equal(t, - 3, - reflect.ValueOf(*got).NumField(), - "expected the options.OIDCArgs struct to have exactly 3 fields") - - assert.Equal(t, - tc.args.Version, - got.Version, - "expected Version field to be equal") - assert.EqualValues(t, - tc.args.IDPInfo, - got.IDPInfo, - "expected IDPInfo field to be convertible to equal values") - assert.Equal(t, - tc.args.RefreshToken, - got.RefreshToken, - "expected RefreshToken field to be equal") - }) - } -} diff --git a/x/mongo/driver/auth/mongodbaws.go b/x/mongo/driver/auth/mongodbaws.go index c5cebaa27f1..fdb3f2020cd 100644 --- a/x/mongo/driver/auth/mongodbaws.go +++ b/x/mongo/driver/auth/mongodbaws.go @@ -28,10 +28,8 @@ func newMongoDBAWSAuthenticator(cred *Cred, httpClient *http.Client) (Authentica return nil, errors.New("httpClient must not be nil") } return &MongoDBAWSAuthenticator{ - source: cred.Source, credentials: &credproviders.StaticProvider{ Value: credentials.Value{ - ProviderName: cred.Source, AccessKeyID: cred.Username, SecretAccessKey: cred.Password, SessionToken: cred.Props["AWS_SESSION_TOKEN"], @@ -43,7 +41,6 @@ func newMongoDBAWSAuthenticator(cred *Cred, httpClient *http.Client) (Authentica // MongoDBAWSAuthenticator uses AWS-IAM credentials over SASL to authenticate a connection. type MongoDBAWSAuthenticator struct { - source string credentials *credproviders.StaticProvider httpClient *http.Client } @@ -56,7 +53,7 @@ func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *Config) error { credentials: providers.Cred, }, } - err := ConductSaslConversation(ctx, cfg, a.source, adapter) + err := ConductSaslConversation(ctx, cfg, "$external", adapter) if err != nil { return newAuthError("sasl conversation error", err) } diff --git a/x/mongo/driver/auth/mongodbcr.go b/x/mongo/driver/auth/mongodbcr.go index a988011b36e..1861956b749 100644 --- a/x/mongo/driver/auth/mongodbcr.go +++ b/x/mongo/driver/auth/mongodbcr.go @@ -30,8 +30,12 @@ import ( const MONGODBCR = "MONGODB-CR" func newMongoDBCRAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) { + source := cred.Source + if source == "" { + source = "admin" + } return &MongoDBCRAuthenticator{ - DB: cred.Source, + DB: source, Username: cred.Username, Password: cred.Password, }, nil diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index 0b71533b738..09ce8d62ca4 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -109,6 +109,9 @@ func (oa *OIDCAuthenticator) SetAccessToken(accessToken string) { } func newOIDCAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) { + if cred.Source != "" && cred.Source != "$external" { + return nil, newAuthError("MONGODB-OIDC source must be empty or $external", nil) + } if cred.Password != "" { return nil, fmt.Errorf("password cannot be specified for %q", MongoDBOIDC) } diff --git a/x/mongo/driver/auth/plain.go b/x/mongo/driver/auth/plain.go index 9fce7ec3837..8f0e96c5b62 100644 --- a/x/mongo/driver/auth/plain.go +++ b/x/mongo/driver/auth/plain.go @@ -17,9 +17,14 @@ import ( const PLAIN = "PLAIN" func newPlainAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) { + source := cred.Source + if source == "" { + source = "$external" + } return &PlainAuthenticator{ Username: cred.Username, Password: cred.Password, + Source: source, }, nil } @@ -27,11 +32,12 @@ func newPlainAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) { type PlainAuthenticator struct { Username string Password string + Source string } // Auth authenticates the connection. func (a *PlainAuthenticator) Auth(ctx context.Context, cfg *Config) error { - return ConductSaslConversation(ctx, cfg, "$external", &plainSaslClient{ + return ConductSaslConversation(ctx, cfg, a.Source, &plainSaslClient{ username: a.Username, password: a.Password, }) diff --git a/x/mongo/driver/auth/scram.go b/x/mongo/driver/auth/scram.go index 8c04ce32cc9..0d7deaee0e1 100644 --- a/x/mongo/driver/auth/scram.go +++ b/x/mongo/driver/auth/scram.go @@ -38,6 +38,10 @@ var ( ) func newScramSHA1Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) { + source := cred.Source + if source == "" { + source = "admin" + } passdigest := mongoPasswordDigest(cred.Username, cred.Password) client, err := scram.SHA1.NewClientUnprepped(cred.Username, passdigest, "") if err != nil { @@ -46,12 +50,16 @@ func newScramSHA1Authenticator(cred *Cred, _ *http.Client) (Authenticator, error client.WithMinIterations(4096) return &ScramAuthenticator{ mechanism: SCRAMSHA1, - source: cred.Source, + source: source, client: client, }, nil } func newScramSHA256Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) { + source := cred.Source + if source == "" { + source = "admin" + } passprep, err := stringprep.SASLprep.Prepare(cred.Password) if err != nil { return nil, newAuthError("error SASLprepping password", err) @@ -63,7 +71,7 @@ func newScramSHA256Authenticator(cred *Cred, _ *http.Client) (Authenticator, err client.WithMinIterations(4096) return &ScramAuthenticator{ mechanism: SCRAMSHA256, - source: cred.Source, + source: source, client: client, }, nil } diff --git a/x/mongo/driver/auth/x509.go b/x/mongo/driver/auth/x509.go index 3e84f516f87..0aa603ce5dc 100644 --- a/x/mongo/driver/auth/x509.go +++ b/x/mongo/driver/auth/x509.go @@ -19,6 +19,9 @@ import ( const MongoDBX509 = "MONGODB-X509" func newMongoDBX509Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) { + // TODO(GODRIVER-3309): Validate that cred.Source is either empty or + // "$external" to make validation uniform with other auth mechanisms that + // require Source to be "$external" (e.g. MONGODB-AWS, MONGODB-OIDC, etc). return &MongoDBX509Authenticator{User: cred.Username}, nil } diff --git a/x/mongo/driver/connstring/connstring.go b/x/mongo/driver/connstring/connstring.go index 4a7a01f4fbc..081b57843c2 100644 --- a/x/mongo/driver/connstring/connstring.go +++ b/x/mongo/driver/connstring/connstring.go @@ -296,7 +296,7 @@ func (u *ConnString) setDefaultAuthParams(dbName string) error { u.AuthMechanismProperties["SERVICE_NAME"] = "mongodb" } fallthrough - case "mongodb-aws", "mongodb-x509": + case "mongodb-aws", "mongodb-x509", "mongodb-oidc": if u.AuthSource == "" { u.AuthSource = "$external" } else if u.AuthSource != "$external" { @@ -313,13 +313,6 @@ func (u *ConnString) setDefaultAuthParams(dbName string) error { u.AuthSource = "admin" } } - case "mongodb-oidc": - if u.AuthSource == "" { - u.AuthSource = dbName - if u.AuthSource == "" { - u.AuthSource = "$external" - } - } case "": // Only set auth source if there is a request for authentication via non-empty credentials. if u.AuthSource == "" && (u.AuthMechanismProperties != nil || u.Username != "" || u.PasswordSet) { diff --git a/x/mongo/driver/topology/topology_options.go b/x/mongo/driver/topology/topology_options.go index 0563e5524e7..e172aa4cad7 100644 --- a/x/mongo/driver/topology/topology_options.go +++ b/x/mongo/driver/topology/topology_options.go @@ -7,10 +7,10 @@ package topology import ( + "context" "crypto/tls" "fmt" "net/http" - "strings" "time" "go.mongodb.org/mongo-driver/event" @@ -71,31 +71,76 @@ func newLogger(opts *options.LoggerOptions) (*logger.Logger, error) { return log, nil } -// NewConfig will translate data from client options into a topology config for building non-default deployments. -func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config, error) { - // Auth & Database & Password & Username - if co.Auth != nil { - cred := &auth.Cred{ - Username: co.Auth.Username, - Password: co.Auth.Password, - PasswordSet: co.Auth.PasswordSet, - Props: co.Auth.AuthMechanismProperties, - Source: co.Auth.AuthSource, +// convertOIDCArgs converts the internal *driver.OIDCArgs into the equivalent +// public type *options.OIDCArgs. +func convertOIDCArgs(args *driver.OIDCArgs) *options.OIDCArgs { + if args == nil { + return nil + } + return &options.OIDCArgs{ + Version: args.Version, + IDPInfo: (*options.IDPInfo)(args.IDPInfo), + RefreshToken: args.RefreshToken, + } +} + +// NewAuthenticator returns a [driver.Authenticator] configured with the given +// credential and HTTP client. It returns nil if cred is nil. +func NewAuthenticator(cred *options.Credential, httpClient *http.Client) (driver.Authenticator, error) { + if cred == nil { + return nil, nil + } + + var oidcMachineCallback auth.OIDCCallback + if cred.OIDCMachineCallback != nil { + oidcMachineCallback = func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + cred, err := cred.OIDCMachineCallback(ctx, convertOIDCArgs(args)) + return (*driver.OIDCCredential)(cred), err } - mechanism := co.Auth.AuthMechanism - authenticator, err := auth.CreateAuthenticator(mechanism, cred, co.HTTPClient) - if err != nil { - return nil, err + } + + var oidcHumanCallback auth.OIDCCallback + if cred.OIDCHumanCallback != nil { + oidcHumanCallback = func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + cred, err := cred.OIDCHumanCallback(ctx, convertOIDCArgs(args)) + return (*driver.OIDCCredential)(cred), err } - return NewConfigWithAuthenticator(co, clock, authenticator) } - return NewConfigWithAuthenticator(co, clock, nil) + + // Create an authenticator for the client + return auth.CreateAuthenticator( + cred.AuthMechanism, + &auth.Cred{ + Source: cred.AuthSource, + Username: cred.Username, + Password: cred.Password, + PasswordSet: cred.PasswordSet, + Props: cred.AuthMechanismProperties, + OIDCMachineCallback: oidcMachineCallback, + OIDCHumanCallback: oidcHumanCallback, + }, + httpClient) +} + +// NewConfig will translate data from client options into a topology config for +// building non-default deployments. +func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config, error) { + authenticator, err := NewAuthenticator(co.Auth, co.HTTPClient) + if err != nil { + return nil, fmt.Errorf("error creating authenticator: %w", err) + } + return NewConfigWithAuthenticator(co, clock, authenticator) } -// NewConfigWithAuthenticator will translate data from client options into a topology config for building non-default deployments. -// Server and topology options are not honored if a custom deployment is used. It uses a passed in +// NewConfigWithAuthenticator will translate data from client options into a +// topology config for building non-default deployments. Server and topology +// options are not honored if a custom deployment is used. It uses a passed in // authenticator to authenticate the connection. -func NewConfigWithAuthenticator(co *options.ClientOptions, clock *session.ClusterClock, authenticator driver.Authenticator) (*Config, error) { +func NewConfigWithAuthenticator( + co *options.ClientOptions, + clock *session.ClusterClock, + authenticator driver.Authenticator, +) (*Config, error) { var serverAPI *driver.ServerAPIOptions if err := co.Validate(); err != nil { @@ -178,30 +223,8 @@ func NewConfigWithAuthenticator(co *options.ClientOptions, clock *session.Cluste } // Handshaker - var handshaker = func(driver.Handshaker) driver.Handshaker { - return operation.NewHello().AppName(appName).Compressors(comps).ClusterClock(clock). - ServerAPI(serverAPI).LoadBalanced(loadBalanced) - } - // Auth & Database & Password & Username - if co.Auth != nil { - cred := &auth.Cred{ - Username: co.Auth.Username, - Password: co.Auth.Password, - PasswordSet: co.Auth.PasswordSet, - Props: co.Auth.AuthMechanismProperties, - Source: co.Auth.AuthSource, - } - mechanism := co.Auth.AuthMechanism - - if len(cred.Source) == 0 { - switch strings.ToUpper(mechanism) { - case auth.MongoDBX509, auth.GSSAPI, auth.PLAIN: - cred.Source = "$external" - default: - cred.Source = "admin" - } - } - + var handshaker func(driver.Handshaker) driver.Handshaker + if authenticator != nil { handshakeOpts := &auth.HandshakeOptions{ AppName: appName, Authenticator: authenticator, @@ -211,9 +234,9 @@ func NewConfigWithAuthenticator(co *options.ClientOptions, clock *session.Cluste ClusterClock: clock, } - if mechanism == "" { + if co.Auth.AuthMechanism == "" { // Required for SASL mechanism negotiation during handshake - handshakeOpts.DBUser = cred.Source + "." + cred.Username + handshakeOpts.DBUser = co.Auth.AuthSource + "." + co.Auth.Username } if co.AuthenticateToAnything != nil && *co.AuthenticateToAnything { // Authenticate arbiters @@ -225,7 +248,17 @@ func NewConfigWithAuthenticator(co *options.ClientOptions, clock *session.Cluste handshaker = func(driver.Handshaker) driver.Handshaker { return auth.Handshaker(nil, handshakeOpts) } + } else { + handshaker = func(driver.Handshaker) driver.Handshaker { + return operation.NewHello(). + AppName(appName). + Compressors(comps). + ClusterClock(clock). + ServerAPI(serverAPI). + LoadBalanced(loadBalanced) + } } + connOpts = append(connOpts, WithHandshaker(handshaker)) // ConnectTimeout if co.ConnectTimeout != nil { diff --git a/x/mongo/driver/topology/topology_options_test.go b/x/mongo/driver/topology/topology_options_test.go index e57c75bcb00..1bd6a472769 100644 --- a/x/mongo/driver/topology/topology_options_test.go +++ b/x/mongo/driver/topology/topology_options_test.go @@ -9,11 +9,14 @@ package topology import ( "fmt" "net/url" + "reflect" "testing" "time" "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/x/mongo/driver" ) func TestDirectConnectionFromConnString(t *testing.T) { @@ -104,3 +107,76 @@ func TestTopologyNewConfig(t *testing.T) { assert.Equal(t, []string{"localhost:27018"}, cfg.SeedList) }) } + +// Test that convertOIDCArgs exhaustively copies all fields of a driver.OIDCArgs +// into an options.OIDCArgs. +func TestConvertOIDCArgs(t *testing.T) { + refreshToken := "test refresh token" + + testCases := []struct { + desc string + args *driver.OIDCArgs + }{ + { + desc: "populated args", + args: &driver.OIDCArgs{ + Version: 9, + IDPInfo: &driver.IDPInfo{ + Issuer: "test issuer", + ClientID: "test client ID", + RequestScopes: []string{"test scope 1", "test scope 2"}, + }, + RefreshToken: &refreshToken, + }, + }, + { + desc: "nil", + args: nil, + }, + { + desc: "nil IDPInfo and RefreshToken", + args: &driver.OIDCArgs{ + Version: 9, + IDPInfo: nil, + RefreshToken: nil, + }, + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + got := convertOIDCArgs(tc.args) + + if tc.args == nil { + assert.Nil(t, got, "expected nil when input is nil") + return + } + + require.Equal(t, + 3, + reflect.ValueOf(*tc.args).NumField(), + "expected the driver.OIDCArgs struct to have exactly 3 fields") + require.Equal(t, + 3, + reflect.ValueOf(*got).NumField(), + "expected the options.OIDCArgs struct to have exactly 3 fields") + + assert.Equal(t, + tc.args.Version, + got.Version, + "expected Version field to be equal") + assert.EqualValues(t, + tc.args.IDPInfo, + got.IDPInfo, + "expected IDPInfo field to be convertible to equal values") + assert.Equal(t, + tc.args.RefreshToken, + got.RefreshToken, + "expected RefreshToken field to be equal") + }) + } +}