Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: SHOW USERS output with insufficient privileges #2815

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,6 @@ dist/

# Test environment variables
test.env

# deps
vendor/
98 changes: 98 additions & 0 deletions pkg/sdk/snowflakesql/bool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package snowflakesql

import (
"database/sql/driver"
"fmt"
"reflect"
)

// Bool is inspired by sql.NullBool, but it will handle `"null"` passed as value, too
type Bool struct {
Bool bool
Valid bool // Valid is true if Bool is not NULL
}

// Scan implements the [Scanner] interface.
func (n *Bool) Scan(value any) error {
switch value := value.(type) {
case nil: // untyped nil
n.Bool, n.Valid = false, false
return nil
case bool:
return n.fromBool(&value)
case *bool:
return n.fromBool(value)
case string:
return n.fromString(&value)
case *string:
return n.fromString(value)
default:
return n.convertAny(value)
}
}

func (n *Bool) fromBool(value *bool) error {
if n.Valid = value != nil; n.Valid {
n.Bool = *value
} else {
n.Bool = false
}
return nil
}

func (n *Bool) fromString(value *string) error {
if value == nil {
n.Bool, n.Valid = false, false
return nil
}

str := *value
if str == "null" {
// Sadly, we have to do this, as Snowflake can return `"null"` for boolean fields.
// E.g., `disabled` field in `SHOW USERS` output.
n.Bool, n.Valid = false, false
return nil
}

return n.convertAny(str)
}

func (n *Bool) convertAny(value any) error {
v := reflect.ValueOf(value)
for v.Kind() == reflect.Pointer {
if v.IsNil() {
// nil pointer to some value
n.Bool, n.Valid = false, false
return nil
}
v = v.Elem()
}

if !v.CanInterface() {
// shouldn't be here, but fail without panic
n.Bool, n.Valid = false, false
return fmt.Errorf("can't convert %v (%T) into bool", value, value)
}

res, err := driver.Bool.ConvertValue(v.Interface())
if err != nil {
n.Bool, n.Valid = false, false
return err
}

n.Bool, n.Valid = res.(bool)
return nil
}

// Value implements the [driver.Valuer] interface.
func (n Bool) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Bool, nil
}

// BoolValue returns either the default bool (false) if the Bool.Valid != true, of the underlying Bool.Value.
func (n Bool) BoolValue() bool {
return n.Valid && n.Bool
}
97 changes: 97 additions & 0 deletions pkg/sdk/snowflakesql/bool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package snowflakesql

import (
"errors"
"fmt"
"io"
"testing"

"github.com/stretchr/testify/assert"
)

func TestBool_Scan(t *testing.T) {
type testCase struct {
from any
expected Bool
err error
}

for _, tc := range []testCase{
{
// passing nil will result in invalid Bool without errors
from: nil,
expected: Bool{},
err: nil,
},
{
from: "1",
expected: Bool{Valid: true, Bool: true},
},
{
from: &[]string{"1"}[0], // pointer to string
expected: Bool{Valid: true, Bool: true},
},
{
from: "2",
expected: Bool{},
err: errors.New("sql/driver: couldn't convert \"2\" into type bool"),
},
{
from: &[]string{"2"}[0], // pointer to string
expected: Bool{},
err: errors.New("sql/driver: couldn't convert \"2\" into type bool"),
},
{
from: "",
expected: Bool{},
err: errors.New("sql/driver: couldn't convert \"\" into type bool"),
},
{
from: &[]string{""}[0], // pointer to string
expected: Bool{},
err: errors.New("sql/driver: couldn't convert \"\" into type bool"),
},
{
from: true,
expected: Bool{Valid: true, Bool: true},
},
{
from: &[]bool{true}[0], // pointer to bool
expected: Bool{Valid: true, Bool: true},
},
{
from: false,
expected: Bool{Valid: true, Bool: false},
},
{
from: &[]bool{false}[0], // pointer to bool
expected: Bool{Valid: true, Bool: false},
},
{
from: int64(123),
expected: Bool{},
err: errors.New("sql/driver: couldn't convert 123 into type bool"),
},
{
from: &[]int64{123}[0], // pointer to int64
expected: Bool{},
err: errors.New("sql/driver: couldn't convert 123 into type bool"),
},
{
from: io.Copy,
expected: Bool{},
err: errors.New("(func(io.Writer, io.Reader) (int64, error)) into type bool"),
},
} {
t.Run(fmt.Sprint(tc.from), func(t *testing.T) {
var res Bool
err := res.Scan(tc.from)
if tc.err == nil {
assert.NoError(t, err)
} else {
assert.ErrorContains(t, err, tc.err.Error())
}
assert.Exactly(t, tc.expected, res)
})
}
}
15 changes: 15 additions & 0 deletions pkg/sdk/testint/users_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

func TestInt_UsersShow(t *testing.T) {
client := testClient(t)
secondaryClient := testSecondaryClient(t)
ctx := testContext(t)

userTest, userCleanup := testClientHelper().User.CreateUserWithName(t, "USER_FOO")
Expand Down Expand Up @@ -74,6 +75,20 @@ func TestInt_UsersShow(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, 1, len(users))
})

t.Run("with like options", func(t *testing.T) {
users, err := secondaryClient.Users.Show(ctx, nil)
require.NoError(t, err)
found := 0
// we can't compare via assert.Contains as not all the fields will be filled int
for _, u := range users {
if u.Name == userTest.Name || u.Name == userTest2.Name {
found++
}
}
assert.Equal(t, 2, found)
assert.Equal(t, 2, len(users))
})
}

func TestInt_UserCreate(t *testing.T) {
Expand Down
66 changes: 34 additions & 32 deletions pkg/sdk/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"errors"
"fmt"
"time"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/snowflakesql"
)

var (
Expand Down Expand Up @@ -60,51 +62,51 @@ type User struct {
HasRsaPublicKey bool
}
type userDBRow struct {
Name string `db:"name"`
CreatedOn time.Time `db:"created_on"`
LoginName string `db:"login_name"`
DisplayName sql.NullString `db:"display_name"`
FirstName sql.NullString `db:"first_name"`
LastName sql.NullString `db:"last_name"`
Email sql.NullString `db:"email"`
MinsToUnlock sql.NullString `db:"mins_to_unlock"`
DaysToExpiry sql.NullString `db:"days_to_expiry"`
Comment sql.NullString `db:"comment"`
Disabled bool `db:"disabled"`
MustChangePassword bool `db:"must_change_password"`
SnowflakeLock bool `db:"snowflake_lock"`
DefaultWarehouse sql.NullString `db:"default_warehouse"`
DefaultNamespace string `db:"default_namespace"`
DefaultRole string `db:"default_role"`
DefaultSecondaryRoles string `db:"default_secondary_roles"`
ExtAuthnDuo bool `db:"ext_authn_duo"`
ExtAuthnUid string `db:"ext_authn_uid"`
MinsToBypassMfa string `db:"mins_to_bypass_mfa"`
Owner string `db:"owner"`
LastSuccessLogin sql.NullTime `db:"last_success_login"`
ExpiresAtTime sql.NullTime `db:"expires_at_time"`
LockedUntilTime sql.NullTime `db:"locked_until_time"`
HasPassword bool `db:"has_password"`
HasRsaPublicKey bool `db:"has_rsa_public_key"`
Name string `db:"name"`
CreatedOn time.Time `db:"created_on"`
LoginName string `db:"login_name"`
DisplayName sql.NullString `db:"display_name"`
FirstName sql.NullString `db:"first_name"`
LastName sql.NullString `db:"last_name"`
Email sql.NullString `db:"email"`
MinsToUnlock sql.NullString `db:"mins_to_unlock"`
DaysToExpiry sql.NullString `db:"days_to_expiry"`
Comment sql.NullString `db:"comment"`
Disabled snowflakesql.Bool `db:"disabled"`
MustChangePassword snowflakesql.Bool `db:"must_change_password"`
SnowflakeLock snowflakesql.Bool `db:"snowflake_lock"`
DefaultWarehouse sql.NullString `db:"default_warehouse"`
DefaultNamespace string `db:"default_namespace"`
DefaultRole string `db:"default_role"`
DefaultSecondaryRoles string `db:"default_secondary_roles"`
ExtAuthnDuo snowflakesql.Bool `db:"ext_authn_duo"`
ExtAuthnUid string `db:"ext_authn_uid"`
MinsToBypassMfa string `db:"mins_to_bypass_mfa"`
Owner string `db:"owner"`
LastSuccessLogin sql.NullTime `db:"last_success_login"`
ExpiresAtTime sql.NullTime `db:"expires_at_time"`
LockedUntilTime sql.NullTime `db:"locked_until_time"`
HasPassword snowflakesql.Bool `db:"has_password"`
HasRsaPublicKey snowflakesql.Bool `db:"has_rsa_public_key"`
}

func (row userDBRow) convert() *User {
user := &User{
Name: row.Name,
CreatedOn: row.CreatedOn,
LoginName: row.LoginName,
Disabled: row.Disabled,
MustChangePassword: row.MustChangePassword,
SnowflakeLock: row.SnowflakeLock,
Disabled: row.Disabled.BoolValue(),
MustChangePassword: row.MustChangePassword.BoolValue(),
SnowflakeLock: row.SnowflakeLock.BoolValue(),
DefaultNamespace: row.DefaultNamespace,
DefaultRole: row.DefaultRole,
DefaultSecondaryRoles: row.DefaultSecondaryRoles,
ExtAuthnDuo: row.ExtAuthnDuo,
ExtAuthnDuo: row.ExtAuthnDuo.BoolValue(),
ExtAuthnUid: row.ExtAuthnUid,
MinsToBypassMfa: row.MinsToBypassMfa,
Owner: row.Owner,
HasPassword: row.HasPassword,
HasRsaPublicKey: row.HasRsaPublicKey,
HasPassword: row.HasPassword.BoolValue(),
HasRsaPublicKey: row.HasRsaPublicKey.BoolValue(),
}
if row.DisplayName.Valid {
user.DisplayName = row.DisplayName.String
Expand Down