Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix connection caching for MFA and external browser authenticators #705

Merged
merged 4 commits into from
Jan 24, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,12 @@ func authenticate(
}

sessionParameters[sessionClientValidateDefaultParameters] = sc.cfg.ValidateDefaultParameters != ConfigBoolFalse
if sc.cfg.ClientRequestMfaToken == ConfigBoolTrue {
sessionParameters[clientRequestMfaToken] = true
}
if sc.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue {
sessionParameters[clientStoreTemporaryCredential] = true
}
requestMain := authRequestData{
ClientAppID: clientType,
ClientAppVersion: SnowflakeGoDriverVersion,
Expand All @@ -314,6 +320,7 @@ func authenticate(
if sc.cfg.IDToken != "" {
requestMain.Authenticator = idTokenAuthenticator
requestMain.Token = sc.cfg.IDToken
requestMain.LoginName = sc.cfg.User
} else {
requestMain.ProofKey = string(proofKey)
requestMain.Token = string(samlResponse)
Expand Down Expand Up @@ -401,10 +408,10 @@ func authenticate(
if !respd.Success {
logger.Errorln("Authentication FAILED")
sc.rest.TokenAccessor.SetTokens("", "", -1)
if sessionParameters[clientRequestMfaToken] == "true" {
if sessionParameters[clientRequestMfaToken] == true {
deleteCredential(sc, mfaToken)
}
if sessionParameters[clientStoreTemporaryCredential] == "true" {
if sessionParameters[clientStoreTemporaryCredential] == true {
deleteCredential(sc, idToken)
}
code, err := strconv.Atoi(respd.Code)
Expand All @@ -420,11 +427,13 @@ func authenticate(
}
logger.Info("Authentication SUCCESS")
sc.rest.TokenAccessor.SetTokens(respd.Data.Token, respd.Data.MasterToken, respd.Data.SessionID)
if sc.isClientRequestMfaToken() {
setCredential(sc, mfaToken, respd.Data.MfaToken)
if sessionParameters[clientRequestMfaToken] == true {
token := respd.Data.MfaToken
setCredential(sc, mfaToken, token)
}
if sc.isClientStoreTemporaryCredential() {
setCredential(sc, idToken, respd.Data.IDToken)
if sessionParameters[clientStoreTemporaryCredential] == true {
token := respd.Data.IDToken
setCredential(sc, idToken, token)
}
return &respd.Data, nil
}
Expand Down Expand Up @@ -466,21 +475,20 @@ func authenticateWithConfig(sc *snowflakeConn) error {
var err error
//var consentCacheIdToken = true

paramBoolValue := "true"
if sc.cfg.Authenticator == AuthTypeExternalBrowser {
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
sc.cfg.Params[clientStoreTemporaryCredential] = &paramBoolValue
sc.cfg.ClientStoreTemporaryCredential = ConfigBoolTrue
}
if sc.isClientStoreTemporaryCredential() {
if sc.cfg.ClientStoreTemporaryCredential == ConfigBoolTrue {
fillCachedIDToken(sc)
}
}

if sc.cfg.Authenticator == AuthTypeUsernamePasswordMFA {
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
sc.cfg.Params[clientRequestMfaToken] = &paramBoolValue
sc.cfg.ClientRequestMfaToken = ConfigBoolTrue
}
if sc.isClientRequestMfaToken() {
if sc.cfg.ClientRequestMfaToken == ConfigBoolTrue {
fillCachedMfaToken(sc)
}
}
Expand Down
96 changes: 94 additions & 2 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/url"
"os"
"testing"
"time"

Expand Down Expand Up @@ -232,6 +234,9 @@ func postAuthCheckUsernamePasswordMfa(_ context.Context, _ *snowflakeRestful, _
return nil, err
}

if ar.Data.SessionParameters["CLIENT_REQUEST_MFA_TOKEN"] != true {
return nil, fmt.Errorf("expected client_request_mfa_token to be true but was %v", ar.Data.SessionParameters["CLIENT_REQUEST_MFA_TOKEN"])
}
return &authResponse{
Success: true,
Data: authResponseMain{
Expand All @@ -245,6 +250,28 @@ func postAuthCheckUsernamePasswordMfa(_ context.Context, _ *snowflakeRestful, _
}, nil
}

func postAuthCheckExternalBrowser(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, jsonBody []byte, _ time.Duration) (*authResponse, error) {
var ar authRequest
if err := json.Unmarshal(jsonBody, &ar); err != nil {
return nil, err
}

if ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"] != true {
return nil, fmt.Errorf("expected client_store_temporary_credential to be true but was %v", ar.Data.SessionParameters["CLIENT_STORE_TEMPORARY_CREDENTIAL"])
}
return &authResponse{
Success: true,
Data: authResponseMain{
Token: "t",
MasterToken: "m",
IDToken: "mockedIDToken",
SessionInfo: authResponseSessionInfo{
DatabaseName: "dbn",
},
},
}, nil
}

func getDefaultSnowflakeConn() *snowflakeConn {
cfg := Config{
Account: "a",
Expand Down Expand Up @@ -497,11 +524,76 @@ func TestUnitAuthenticateUsernamePasswordMfa(t *testing.T) {
}
sc := getDefaultSnowflakeConn()
sc.cfg.Authenticator = AuthTypeUsernamePasswordMFA
requestMfaToken := "true"
sc.cfg.Params[clientRequestMfaToken] = &requestMfaToken
sc.cfg.ClientRequestMfaToken = ConfigBoolTrue
sc.rest = sr
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}
}

func TestUnitAuthenticateExternalBrowser(t *testing.T) {
var err error
sr := &snowflakeRestful{
FuncPostAuth: postAuthCheckExternalBrowser,
TokenAccessor: getSimpleTokenAccessor(),
}
sc := getDefaultSnowflakeConn()
sc.cfg.Authenticator = AuthTypeExternalBrowser
sc.cfg.ClientStoreTemporaryCredential = ConfigBoolTrue
sc.rest = sr
_, err = authenticate(context.TODO(), sc, []byte{}, []byte{})
if err != nil {
t.Fatalf("failed to run. err: %v", err)
}
}

// To run this test you need to set environment variables in parameters.json to a user with MFA authentication enabled
// Set any other snowflake_test variables needed for database, schema, role for this user
func TestUsernamePasswordMfaCaching(t *testing.T) {
t.Skip("manual test for MFA token caching")

config, err := ParseDSN(dsn)
if err != nil {
t.Fatal("Failed to parse dsn")
}
// connect with MFA authentication
user := os.Getenv("SNOWFLAKE_TEST_MFA_USER")
password := os.Getenv("SNOWFLAKE_TEST_MFA_PASSWORD")
config.User = user
config.Password = password
config.Authenticator = AuthTypeUsernamePasswordMFA
connector := NewConnector(SnowflakeDriver{}, *config)
db := sql.OpenDB(connector)
for i := 0; i < 3; i++ {
// should only be prompted to authenticate first time around.
_, err := db.Query("select current_user()")
if err != nil {
t.Fatal(err)
}
}
}

// To run this test you need to set SNOWFLAKE_TEST_EXT_BROWSER_USER environment variable to an external browser user
// Set any other snowflake_test variables needed for database, schema, role for this user
func TestExternalBrowserCaching(t *testing.T) {
t.Skip("manual test for external browser token caching")

config, err := ParseDSN(dsn)
if err != nil {
t.Fatal("Failed to parse dsn")
}
// connect with external browser authentication
user := os.Getenv("SNOWFLAKE_TEST_EXT_BROWSER_USER")
config.User = user
config.Authenticator = AuthTypeExternalBrowser
connector := NewConnector(SnowflakeDriver{}, *config)
db := sql.OpenDB(connector)
for i := 0; i < 3; i++ {
// should only be prompted to authenticate first time around.
_, err := db.Query("select current_user()")
if err != nil {
t.Fatal(err)
}
}
}
16 changes: 0 additions & 16 deletions connection_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,6 @@ func (sc *snowflakeConn) isClientSessionKeepAliveEnabled() bool {
return strings.Compare(*v, "true") == 0
}

func (sc *snowflakeConn) isClientStoreTemporaryCredential() bool {
v, ok := sc.cfg.Params[clientStoreTemporaryCredential]
if !ok {
return false
}
return strings.Compare(*v, "true") == 0
}

func (sc *snowflakeConn) isClientRequestMfaToken() bool {
v, ok := sc.cfg.Params[clientRequestMfaToken]
if !ok {
return false
}
return strings.Compare(*v, "true") == 0
}

func (sc *snowflakeConn) startHeartBeat() {
if !sc.isClientSessionKeepAliveEnabled() {
return
Expand Down
6 changes: 4 additions & 2 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ type Config struct {

Tracing string // sets logging level

MfaToken string // Internally used to cache the MFA token
IDToken string // Internally used to cache the Id Token for external browser
MfaToken string // Internally used to cache the MFA token
IDToken string // Internally used to cache the Id Token for external browser
ClientRequestMfaToken ConfigBool // Internally used for MFa connection caching
ClientStoreTemporaryCredential ConfigBool // Internall used for ID token connection caching
}

// ocspMode returns the OCSP mode in string INSECURE, FAIL_OPEN, FAIL_CLOSED
Expand Down