Skip to content

Commit

Permalink
GODRIVER-3215 Fix default auth source for auth specified via ClientOp…
Browse files Browse the repository at this point in the history
…tions.
  • Loading branch information
matthewdale committed Aug 21, 2024
1 parent a766876 commit 2422048
Show file tree
Hide file tree
Showing 11 changed files with 189 additions and 183 deletions.
49 changes: 4 additions & 45 deletions mongo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/auth"
"go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt"
mcopts "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
Expand Down Expand Up @@ -211,43 +210,16 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) {
clientOpt.SetMaxPoolSize(defaultMaxPoolSize)
}

if clientOpt.Auth != nil {
var oidcMachineCallback auth.OIDCCallback
if clientOpt.Auth.OIDCMachineCallback != nil {
oidcMachineCallback = func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) {
cred, err := clientOpt.Auth.OIDCMachineCallback(ctx, convertOIDCArgs(args))
return (*driver.OIDCCredential)(cred), err
}
}

var oidcHumanCallback auth.OIDCCallback
if clientOpt.Auth.OIDCHumanCallback != nil {
oidcHumanCallback = func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) {
cred, err := clientOpt.Auth.OIDCHumanCallback(ctx, convertOIDCArgs(args))
return (*driver.OIDCCredential)(cred), err
}
}

// Create an authenticator for the client
client.authenticator, err = auth.CreateAuthenticator(clientOpt.Auth.AuthMechanism, &auth.Cred{
Source: clientOpt.Auth.AuthSource,
Username: clientOpt.Auth.Username,
Password: clientOpt.Auth.Password,
PasswordSet: clientOpt.Auth.PasswordSet,
Props: clientOpt.Auth.AuthMechanismProperties,
OIDCMachineCallback: oidcMachineCallback,
OIDCHumanCallback: oidcHumanCallback,
}, clientOpt.HTTPClient)
if err != nil {
return nil, err
}
client.authenticator, err = topology.NewAuthenticator(clientOpt.Auth, clientOpt.HTTPClient)
if err != nil {
return nil, fmt.Errorf("error creating authenticator: %w", err)
}

cfg, err := topology.NewConfigWithAuthenticator(clientOpt, client.clock, client.authenticator)

if err != nil {
return nil, err
}

client.serverAPI = topology.ServerAPIFromServerOptions(cfg.ServerOpts)

if client.deployment == nil {
Expand All @@ -266,19 +238,6 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) {
return client, nil
}

// convertOIDCArgs converts the internal *driver.OIDCArgs into the equivalent
// public type *options.OIDCArgs.
func convertOIDCArgs(args *driver.OIDCArgs) *options.OIDCArgs {
if args == nil {
return nil
}
return &options.OIDCArgs{
Version: args.Version,
IDPInfo: (*options.IDPInfo)(args.IDPInfo),
RefreshToken: args.RefreshToken,
}
}

// Connect initializes the Client by starting background monitoring goroutines.
// If the Client was created using the NewClient function, this method must be called before a Client can be used.
//
Expand Down
76 changes: 0 additions & 76 deletions mongo/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,18 @@ import (
"errors"
"math"
"os"
"reflect"
"testing"
"time"

"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/internal/assert"
"go.mongodb.org/mongo-driver/internal/integtest"
"go.mongodb.org/mongo-driver/internal/require"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/tag"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
Expand Down Expand Up @@ -505,76 +502,3 @@ func TestClient(t *testing.T) {
}
})
}

// Test that convertOIDCArgs exhaustively copies all fields of a driver.OIDCArgs
// into an options.OIDCArgs.
func TestConvertOIDCArgs(t *testing.T) {
refreshToken := "test refresh token"

testCases := []struct {
desc string
args *driver.OIDCArgs
}{
{
desc: "populated args",
args: &driver.OIDCArgs{
Version: 9,
IDPInfo: &driver.IDPInfo{
Issuer: "test issuer",
ClientID: "test client ID",
RequestScopes: []string{"test scope 1", "test scope 2"},
},
RefreshToken: &refreshToken,
},
},
{
desc: "nil",
args: nil,
},
{
desc: "nil IDPInfo and RefreshToken",
args: &driver.OIDCArgs{
Version: 9,
IDPInfo: nil,
RefreshToken: nil,
},
},
}

for _, tc := range testCases {
tc := tc // Capture range variable.

t.Run(tc.desc, func(t *testing.T) {
t.Parallel()

got := convertOIDCArgs(tc.args)

if tc.args == nil {
assert.Nil(t, got, "expected nil when input is nil")
return
}

require.Equal(t,
3,
reflect.ValueOf(*tc.args).NumField(),
"expected the driver.OIDCArgs struct to have exactly 3 fields")
require.Equal(t,
3,
reflect.ValueOf(*got).NumField(),
"expected the options.OIDCArgs struct to have exactly 3 fields")

assert.Equal(t,
tc.args.Version,
got.Version,
"expected Version field to be equal")
assert.EqualValues(t,
tc.args.IDPInfo,
got.IDPInfo,
"expected IDPInfo field to be convertible to equal values")
assert.Equal(t,
tc.args.RefreshToken,
got.RefreshToken,
"expected RefreshToken field to be equal")
})
}
}
5 changes: 1 addition & 4 deletions x/mongo/driver/auth/mongodbaws.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ func newMongoDBAWSAuthenticator(cred *Cred, httpClient *http.Client) (Authentica
return nil, errors.New("httpClient must not be nil")
}
return &MongoDBAWSAuthenticator{
source: cred.Source,
credentials: &credproviders.StaticProvider{
Value: credentials.Value{
ProviderName: cred.Source,
AccessKeyID: cred.Username,
SecretAccessKey: cred.Password,
SessionToken: cred.Props["AWS_SESSION_TOKEN"],
Expand All @@ -43,7 +41,6 @@ func newMongoDBAWSAuthenticator(cred *Cred, httpClient *http.Client) (Authentica

// MongoDBAWSAuthenticator uses AWS-IAM credentials over SASL to authenticate a connection.
type MongoDBAWSAuthenticator struct {
source string
credentials *credproviders.StaticProvider
httpClient *http.Client
}
Expand All @@ -56,7 +53,7 @@ func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *Config) error {
credentials: providers.Cred,
},
}
err := ConductSaslConversation(ctx, cfg, a.source, adapter)
err := ConductSaslConversation(ctx, cfg, "$external", adapter)
if err != nil {
return newAuthError("sasl conversation error", err)
}
Expand Down
6 changes: 5 additions & 1 deletion x/mongo/driver/auth/mongodbcr.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ import (
const MONGODBCR = "MONGODB-CR"

func newMongoDBCRAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
source := cred.Source
if source == "" {
source = "admin"
}
return &MongoDBCRAuthenticator{
DB: cred.Source,
DB: source,
Username: cred.Username,
Password: cred.Password,
}, nil
Expand Down
3 changes: 3 additions & 0 deletions x/mongo/driver/auth/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ func (oa *OIDCAuthenticator) SetAccessToken(accessToken string) {
}

func newOIDCAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) {
if cred.Source != "" && cred.Source != "$external" {
return nil, newAuthError("MONGODB-OIDC source must be empty or $external", nil)
}
if cred.Password != "" {
return nil, fmt.Errorf("password cannot be specified for %q", MongoDBOIDC)
}
Expand Down
8 changes: 7 additions & 1 deletion x/mongo/driver/auth/plain.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,27 @@ import (
const PLAIN = "PLAIN"

func newPlainAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
source := cred.Source
if source == "" {
source = "$external"
}
return &PlainAuthenticator{
Username: cred.Username,
Password: cred.Password,
Source: source,
}, nil
}

// PlainAuthenticator uses the PLAIN algorithm over SASL to authenticate a connection.
type PlainAuthenticator struct {
Username string
Password string
Source string
}

// Auth authenticates the connection.
func (a *PlainAuthenticator) Auth(ctx context.Context, cfg *Config) error {
return ConductSaslConversation(ctx, cfg, "$external", &plainSaslClient{
return ConductSaslConversation(ctx, cfg, a.Source, &plainSaslClient{
username: a.Username,
password: a.Password,
})
Expand Down
12 changes: 10 additions & 2 deletions x/mongo/driver/auth/scram.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ var (
)

func newScramSHA1Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
source := cred.Source
if source == "" {
source = "admin"
}
passdigest := mongoPasswordDigest(cred.Username, cred.Password)
client, err := scram.SHA1.NewClientUnprepped(cred.Username, passdigest, "")
if err != nil {
Expand All @@ -46,12 +50,16 @@ func newScramSHA1Authenticator(cred *Cred, _ *http.Client) (Authenticator, error
client.WithMinIterations(4096)
return &ScramAuthenticator{
mechanism: SCRAMSHA1,
source: cred.Source,
source: source,
client: client,
}, nil
}

func newScramSHA256Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
source := cred.Source
if source == "" {
source = "admin"
}
passprep, err := stringprep.SASLprep.Prepare(cred.Password)
if err != nil {
return nil, newAuthError("error SASLprepping password", err)
Expand All @@ -63,7 +71,7 @@ func newScramSHA256Authenticator(cred *Cred, _ *http.Client) (Authenticator, err
client.WithMinIterations(4096)
return &ScramAuthenticator{
mechanism: SCRAMSHA256,
source: cred.Source,
source: source,
client: client,
}, nil
}
Expand Down
3 changes: 3 additions & 0 deletions x/mongo/driver/auth/x509.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ import (
const MongoDBX509 = "MONGODB-X509"

func newMongoDBX509Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) {
// TODO(GODRIVER-3309): Validate that cred.Source is either empty or
// "$external" to make validation uniform with other auth mechanisms that
// require Source to be "$external" (e.g. MONGODB-AWS, MONGODB-OIDC, etc).
return &MongoDBX509Authenticator{User: cred.Username}, nil
}

Expand Down
9 changes: 1 addition & 8 deletions x/mongo/driver/connstring/connstring.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ func (u *ConnString) setDefaultAuthParams(dbName string) error {
u.AuthMechanismProperties["SERVICE_NAME"] = "mongodb"
}
fallthrough
case "mongodb-aws", "mongodb-x509":
case "mongodb-aws", "mongodb-x509", "mongodb-oidc":
if u.AuthSource == "" {
u.AuthSource = "$external"
} else if u.AuthSource != "$external" {
Expand All @@ -313,13 +313,6 @@ func (u *ConnString) setDefaultAuthParams(dbName string) error {
u.AuthSource = "admin"
}
}
case "mongodb-oidc":
if u.AuthSource == "" {
u.AuthSource = dbName
if u.AuthSource == "" {
u.AuthSource = "$external"
}
}
case "":
// Only set auth source if there is a request for authentication via non-empty credentials.
if u.AuthSource == "" && (u.AuthMechanismProperties != nil || u.Username != "" || u.PasswordSet) {
Expand Down
Loading

0 comments on commit 2422048

Please sign in to comment.