Skip to content

Commit

Permalink
improvements
Browse files Browse the repository at this point in the history
- add sso login support
- add okta login support
- add private link support
- code restructure
- add debug messages
  • Loading branch information
Unravel-Andy committed Jun 24, 2023
1 parent 0c4a583 commit 6482847
Show file tree
Hide file tree
Showing 7 changed files with 254 additions and 123 deletions.
16 changes: 11 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,23 @@ The following arguments are required:
* `--target_role`: The name of the target role.

The following arguments are optional:
* `--source_login_method`: The login method for the source account. Possible options are password (default), oauth, or keypair.
* `--target_login_method`: The login method for the target account. Possible options are password (default), oauth, or keypair.
* `--source_login_method`: The login method for the source account. Possible options are password (default), oauth, sso, or keypair.
* `--target_login_method`: The login method for the target account. Possible options are password (default), oauth, sso, or keypair.
* `--source_private_link`: The private link for the source account e.g testaccount.us-east-1.privatelink.snowflakecomputing.com.
* `--target_private_link`: The private link for the target account e.g testaccount.us-east-1.privatelink.snowflakecomputing.com.
* `--source_okta_url`: The okta url for the source account e.g https://testaccount.okta.com.
* `--target_okta_url`: The okta url for the target account e.g https://testaccount.okta.com.
* `--source_passcode`: Your source Snowflake account MFA password.
* `--target_passcode`: Your target Snowflake account MFA password.
* `--stage`: The name of the stage. Default is `unravel_stage`.
* `--out`: The directory to save output files. Default is current directory.
* `--file_format`: The name of the file format. Default is `unravel_file_format`.
* `--debug`: This flag adds debug messages when set.
* `--debug`: Prints debug messages when set.
* `--save-sql`: This flag saves all queries as SQL files instead of running them.
* `--disable-cleanup`: This will skip the local temporary file cleanup process
* `--look-back-days`: The number of days to look back for account usage information. Default is 15 days.

If any of the required arguments are missing, you will be prompted to enter them.
**If any of the required arguments are missing, you will be prompted to enter them.**

The script will also replace `-` with `_` for the value of `--stage` argument.

Expand All @@ -67,7 +72,7 @@ xattr -d com.apple.quarantine <path_to_the_binary_directory>/snowflake-data-loa
# Linux login with private keypair
./snowflake-data-loader \
--source_login_method keypair \
--target_login_method keypair \
--target_login_method password \
--source_user <source_user> \
--private_key_path <private_key_path> \
--source_account <source_account> \
Expand All @@ -76,6 +81,7 @@ xattr -d com.apple.quarantine <path_to_the_binary_directory>/snowflake-data-loa
--source_schema <source_schema> \
--source_role <source_role> \
--target_user <target_user> \
--target_password <target_password> \
--target_account <target_account> \
--target_warehouse <target_warehouse> \
--target_database <target_database> \
Expand Down
209 changes: 118 additions & 91 deletions args.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

var (
allowedLoginMethods = []string{"password", "oauth", "keypair"}
allowedLoginMethods = []string{"password", "oauth", "keypair", "sso", "okta"}
)

type Args struct {
Expand All @@ -18,24 +18,28 @@ type Args struct {
SrcUser string
SrcPassword string
// MFA passcode
SrcPasscode string
SrcAccount string
SrcWarehouse string
SrcDatabase string
SrcSchema string
SrcRole string
SrcPasscode string
SrcAccount string
SrcWarehouse string
SrcDatabase string
SrcSchema string
SrcRole string
SrcPrivateLink string
SrcOktaURL string

// arguments for target snowflake account
TgtLoginMethod string
TgtUser string
TgtPassword string
// MFA passcode
TgtPasscode string
TgtAccount string
TgtWarehouse string
TgtDatabase string
TgtSchema string
TgtRole string
TgtPasscode string
TgtAccount string
TgtWarehouse string
TgtDatabase string
TgtSchema string
TgtRole string
TgtPrivateLink string
TgtOktaURL string

// Other arguments
Stage string
Expand All @@ -51,27 +55,31 @@ type Args struct {

func getArgs() Args {
// login method
srcLoginMethod := flag.String("source_login_method", "password", "source login method: password, oauth, or keypair")
tgtLoginMethod := flag.String("target_login_method", "password", "target login method: password, oauth, or keypair")
srcLoginMethod := flag.String("source_login_method", "password", "source login method: password, oauth, keypair, or sso")
tgtLoginMethod := flag.String("target_login_method", "password", "target login method: password, oauth, keypair, or sso")
// arguments for source snowflake account
srcUser := flag.String("source_user", "", "source Snowflake account username")
srcPassword := flag.String("source_password", "", "source Snowflake account password")
srcPassword := flag.String("source_password", "", "source Snowflake account password/oauth token")
srcPasscode := flag.String("source_passcode", "", "source Snowflake account MFA passcode")
srcAccount := flag.String("source_account", "", "source Snowflake account id")
srcWarehouse := flag.String("source_warehouse", "", "source warehouse")
srcDatabase := flag.String("source_database", "", "source database")
srcSchema := flag.String("source_schema", "", "source schema")
srcRole := flag.String("source_role", "", "source role")
srcPrivateLink := flag.String("source_private_link", "", "source account private link")
srcOktaURL := flag.String("source_okta_url", "", "source account okta url")

// arguments for target snowflake account
tgtUser := flag.String("target_user", "", "target Snowflake account username")
tgtPassword := flag.String("target_password", "", "target Snowflake account password")
tgtPassword := flag.String("target_password", "", "target Snowflake account password/oauth token")
tgtPasscode := flag.String("target_passcode", "", "target Snowflake account MFA passcode")
tgtAccount := flag.String("target_account", "", "target Snowflake account id")
tgtWarehouse := flag.String("target_warehouse", "", "target warehouse")
tgtDatabase := flag.String("target_database", "", "target database")
tgtSchema := flag.String("target_schema", "", "target schema")
tgtRole := flag.String("target_role", "", "target role")
tgtPrivateLink := flag.String("target_private_link", "", "target account private link")
tgtOktaURL := flag.String("target_okta_url", "", "target account okta url")

// Other arguments
actions := flag.String("actions", "download,upload", "actions to perform: download, upload")
Expand All @@ -85,107 +93,126 @@ func getArgs() Args {
lookBackDays := flag.Uint("look-back-days", 15, "number of days to look back for data to download to download all data set it to 0")
flag.Parse()

// prompt for missing args
if *srcLoginMethod == "" {
promptInput("Source login method: ", srcLoginMethod)
} else if !contains(allowedLoginMethods, *srcLoginMethod) {
log.Fatalf("Invalid source login method: %s must be %v", *srcLoginMethod, allowedLoginMethods)
args := Args{
SrcLoginMethod: *srcLoginMethod,
TgtLoginMethod: *tgtLoginMethod,
SrcUser: *srcUser,
SrcPassword: *srcPassword,
SrcPasscode: *srcPasscode,
SrcAccount: *srcAccount,
SrcWarehouse: *srcWarehouse,
SrcDatabase: *srcDatabase,
SrcSchema: *srcSchema,
SrcRole: *srcRole,
SrcPrivateLink: *srcPrivateLink,
SrcOktaURL: *srcOktaURL,
TgtUser: *tgtUser,
TgtPassword: *tgtPassword,
TgtPasscode: *tgtPasscode,
TgtAccount: *tgtAccount,
TgtWarehouse: *tgtWarehouse,
TgtDatabase: *tgtDatabase,
TgtSchema: *tgtSchema,
TgtRole: *tgtRole,
TgtPrivateLink: *tgtPrivateLink,
TgtOktaURL: *tgtOktaURL,
Stage: *stage,
Out: *out,
FileFormat: *fileFormat,
Debug: *debug,
SaveSql: *saveSql,
DisableCleanup: *disableCleanup,
Actions: strings.Split(*actions, ","),
PrivateKeyPath: *privateKeyPath,
LookBackDays: *lookBackDays,
}
if *actions == "" {
promptInput("Actions to perform: ", actions)
argsCheck(&args)
return args
}

func argsCheck(args *Args) {
// prompt for missing args
if args.SrcLoginMethod == "" {
args.SrcLoginMethod = promptInput("Source login method: ")
} else if !contains(allowedLoginMethods, args.SrcLoginMethod) {
log.Fatalf("Invalid source login method: %s must be %v", args.SrcLoginMethod, allowedLoginMethods)
} else if args.SrcLoginMethod == "okta" && args.SrcOktaURL == "" {
args.SrcOktaURL = promptInput("Source account okta url: ")
}
if *srcAccount == "" && *saveSql == false {
promptInput("Source Snowflake account ID: ", srcAccount)

if args.TgtLoginMethod == "" {
args.TgtLoginMethod = promptInput("Target login method: ")
} else if !contains(allowedLoginMethods, args.TgtLoginMethod) {
log.Fatalf("Invalid target login method: %s must be %v", args.TgtLoginMethod, allowedLoginMethods)
} else if args.TgtLoginMethod == "okta" && args.TgtOktaURL == "" {
args.TgtOktaURL = promptInput("Target account okta url: ")
}
if len(args.Actions) == 0 {
args.Actions = strings.Split(promptInput("Actions to perform: "), ",")
}
if args.SrcAccount == "" && args.SaveSql == false {
args.SrcAccount = promptInput("Source Snowflake account ID: ")

}
if *srcUser == "" && *saveSql == false {
promptInput("Source Snowflake account username: ", srcUser)
if args.SrcUser == "" && args.SaveSql == false {
args.SrcUser = promptInput("Source Snowflake account username: ")
}
if *srcPassword == "" && *saveSql == false && *srcLoginMethod == "password" {
promptSecureInput("Source password: ", srcPassword)
if args.SrcPassword == "" && args.SaveSql == false && (args.SrcLoginMethod == "password" || args.SrcLoginMethod == "oauth") {
args.SrcPassword = promptSecureInput("Source password/oauth token: ")
}
if *privateKeyPath == "" && (*srcLoginMethod == "keypair" || *tgtLoginMethod == "keypair") && *saveSql == false {
promptInput("Private key path: ", privateKeyPath)
if args.PrivateKeyPath == "" && (args.SrcLoginMethod == "keypair" || args.TgtLoginMethod == "keypair") && args.SaveSql == false {
args.PrivateKeyPath = promptInput("Private key path: ")
}
if *srcDatabase == "" && *saveSql == false {
promptInput("Source database: ", srcDatabase)
if args.SrcDatabase == "" && args.SaveSql == false {
args.SrcDatabase = promptInput("Source database: ")
}
if *srcSchema == "" && *saveSql == false {
promptInput("Source schema: ", srcSchema)
if args.SrcSchema == "" && args.SaveSql == false {
args.SrcSchema = promptInput("Source schema: ")
}
if *srcWarehouse == "" {
promptInput("Source warehouse: ", srcWarehouse)
if args.SrcWarehouse == "" {
args.SrcWarehouse = promptInput("Source warehouse: ")
}
if *srcRole == "" && *saveSql == false {
promptInput("Source role: ", srcRole)
if args.SrcRole == "" && args.SaveSql == false {
args.SrcRole = promptInput("Source role: ")
}

if *tgtAccount == "" && *saveSql == false {
promptInput("Target Snowflake account ID: ", tgtAccount)
if args.TgtAccount == "" && args.SaveSql == false {
args.TgtAccount = promptInput("Target Snowflake account ID: ")
}
if *tgtUser == "" && *saveSql == false {
promptInput("Target Snowflake account username: ", tgtUser)
if args.TgtUser == "" && args.SaveSql == false {
args.TgtUser = promptInput("Target Snowflake account username: ")
}
if *tgtPassword == "" && *saveSql == false && *tgtLoginMethod == "password" {
promptSecureInput("Target password: ", tgtPassword)
if args.TgtPassword == "" && args.SaveSql == false && (args.TgtLoginMethod == "password" || args.TgtLoginMethod == "oauth") {
args.TgtPassword = promptSecureInput("Target password/oauth token: ")
}
if *tgtDatabase == "" && *saveSql == false {
promptInput("Target database: ", tgtDatabase)
if args.TgtDatabase == "" && args.SaveSql == false {
args.TgtDatabase = promptInput("Target database: ")
}
if *tgtSchema == "" && *saveSql == false {
promptInput("Target schema: ", tgtSchema)
if args.TgtSchema == "" && args.SaveSql == false {
args.TgtSchema = promptInput("Target schema: ")
}
if *tgtWarehouse == "" {
promptInput("Target warehouse: ", tgtWarehouse)
if args.TgtWarehouse == "" {
args.TgtWarehouse = promptInput("Target warehouse: ")
}
if *tgtRole == "" && *saveSql == false {
promptInput("Target role: ", tgtRole)
if args.TgtRole == "" && args.SaveSql == false {
args.TgtRole = promptInput("Target role: ")
}
if *out == "" {
*out, _ = os.Getwd()
if args.Out == "" {
args.Out, _ = os.Getwd()
} else {
// ensure output directory exists and is directory
stat, err := os.Stat(*out)
stat, err := os.Stat(args.Out)
if os.IsNotExist(err) {
log.Fatalf("Output directory %s does not exist", *out)
log.Fatalf("Output directory %s does not exist", args.Out)
}
if !stat.IsDir() {
log.Fatalf("%s is not a directory", *out)
log.Fatalf("%s is not a directory", args.Out)
}
}
if *lookBackDays == 0 {
*lookBackDays = 365
if args.LookBackDays == 0 {
args.LookBackDays = 365
}
if *debug {
if args.Debug {
log.SetLevel(log.DebugLevel)
}
return Args{
SrcLoginMethod: *srcLoginMethod,
TgtLoginMethod: *tgtLoginMethod,
SrcUser: *srcUser,
SrcPassword: *srcPassword,
SrcPasscode: *srcPasscode,
SrcAccount: *srcAccount,
SrcWarehouse: *srcWarehouse,
SrcDatabase: *srcDatabase,
SrcSchema: *srcSchema,
SrcRole: *srcRole,
TgtUser: *tgtUser,
TgtPassword: *tgtPassword,
TgtPasscode: *tgtPasscode,
TgtAccount: *tgtAccount,
TgtWarehouse: *tgtWarehouse,
TgtDatabase: *tgtDatabase,
TgtSchema: *tgtSchema,
TgtRole: *tgtRole,
Stage: *stage,
Out: *out,
FileFormat: *fileFormat,
Debug: *debug,
SaveSql: *saveSql,
DisableCleanup: *disableCleanup,
Actions: strings.Split(*actions, ","),
PrivateKeyPath: *privateKeyPath,
LookBackDays: *lookBackDays,
}
}
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/unraveldata-org/snowflake-data-loader
go 1.20

require (
github.com/jedib0t/go-pretty/v6 v6.4.6
github.com/sirupsen/logrus v1.9.3
github.com/snowflakedb/gosnowflake v1.6.22
golang.org/x/term v0.9.0
Expand Down Expand Up @@ -44,11 +45,13 @@ require (
github.com/klauspost/asmfmt v1.3.2 // indirect
github.com/klauspost/compress v1.16.6 // indirect
github.com/klauspost/cpuid/v2 v2.2.5 // indirect
github.com/mattn/go-runewidth v0.0.13 // indirect
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 // indirect
github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 // indirect
github.com/mtibben/percent v0.2.1 // indirect
github.com/pierrec/lz4/v4 v4.1.18 // indirect
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/zeebo/xxh3 v1.0.2 // indirect
golang.org/x/crypto v0.10.0 // indirect
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect
Expand Down
Loading

0 comments on commit 6482847

Please sign in to comment.