Skip to content

Commit

Permalink
feat: Python support for functions (#1069)
Browse files Browse the repository at this point in the history
* adding python language support and runtime_version attribute

* adding packages attribute for python / java

* updated function_acceptance_test

* testing by adding warehouse attribute

* testing by adding warehouse attribute

* updating null input and return behaviour

* updating test functions

* running go fmt

* changing runtime_version type from float64 to string

* updating function acceptance test

* updating function acceptance test

* adding sql in the list and updating the description of attributes

* updating function acceptance test

* minor changes

* testing

* testing

* adding warehouse attribute to the provider config

* updating docs and adding examples

* fixing conflicts

* fixing docs

* updating description of warehouse attribute

Co-authored-by: Scott Winkler <[email protected]>
  • Loading branch information
sfc-gh-kumaurya and sfc-gh-swinkler authored Jun 28, 2022
1 parent 88f4d44 commit bab729a
Show file tree
Hide file tree
Showing 13 changed files with 165 additions and 77 deletions.
2 changes: 2 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ provider "snowflake" {
// optional
role = "..."
host = "..."
warehouse = "..."
}
```

Expand Down Expand Up @@ -64,6 +65,7 @@ provider "snowflake" {
- `private_key_path` (String, Sensitive) Path to a private key for using keypair authentication. Cannot be used with `browser_auth`, `oauth_access_token` or `password`. Can be source from `SNOWFLAKE_PRIVATE_KEY_PATH` environment variable.
- `region` (String) [Snowflake region](https://docs.snowflake.com/en/user-guide/intro-regions.html) to use. Can be source from the `SNOWFLAKE_REGION` environment variable.
- `role` (String) Snowflake role to use for operations. If left unset, default role for user will be used. Can come from the `SNOWFLAKE_ROLE` environment variable.
- `warehouse` (String) Sets the default warehouse. Optional. Can be sourced from SNOWFLAKE_WAREHOUSE enviornment variable.

## Authentication

Expand Down
70 changes: 68 additions & 2 deletions docs/resources/function.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,67 @@ description: |-




## Example Usage

```terraform
// Provider configuration
provider "snowflake" {
region = "REGION" // Default is "us-west-2"
username = "USERNAME"
account = "ACCOUNT"
password = "PASSWORD"
role = "MY_ROLE"
warehouse = "MY_WH" // Optional attribute, some resources (e.g. Python UDFs)' require a warehouse to create and can also be set optionally from the `SNOWFLAKE_WAREHOUSE` environment variable
}
// Create database
resource "snowflake_database" "db" {
name = "MY_DB"
data_retention_days = 1
}
// Create schema
resource "snowflake_schema" "schema" {
database = snowflake_database.db.name
name = "MY_SCHEMA"
data_retention_days = 1
}
// Example for Java language
resource "snowflake_function" "test_funct_java" {
name = "my_java_func"
database = "MY_DB"
schema = "MY_SCHEMA"
arguments {
name = "arg1"
type = "number"
}
comment = "Example for java language"
return_type = "varchar"
language = "java"
handler = "CoolFunc.test"
statement = "class CoolFunc {public static String test(int n) {return \"hello!\";}}"
}
// Example for Python language
resource "snowflake_function" "python_test" {
name = "MY_PYTHON_FUNC"
database = "MY_DB"
schema = "MY_SCHEMA"
arguments {
name = "arg1"
type = "number"
}
comment = "Example for Python language"
return_type = "NUMBER(38,0)"
null_input_behavior = "CALLED ON NULL INPUT"
return_behavior = "VOLATILE"
language = "python"
runtime_version = "3.8"
handler = "add_py"
statement = "def add_py(i): return i+1"
}
```

<!-- schema generated by tfplugindocs -->
## Schema
Expand All @@ -35,7 +95,6 @@ description: |-
- `return_behavior` (String) Specifies the behavior of the function when returning results
- `runtime_version` (String) Required for Python functions. Specifies Python runtime version.
- `target_path` (String) The target path for the Java / Python functions. For Java, it is the path of compiled jar files and for the Python it is the path of the Python files.
- `warehouse` (String) The warehouse in which to create the function. Only for Python language.

### Read-Only

Expand All @@ -49,4 +108,11 @@ Required:
- `name` (String) The argument name
- `type` (String) The argument type

## Import

Import is supported using the following syntax:

```shell
# format is database name | schema name | function name | <list of arg types, separated with '-'>
terraform import snowflake_function.example 'dbName|schemaName|functionName|varchar-varchar-varchar'
```
2 changes: 1 addition & 1 deletion docs/resources/procedure.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ description: |-
## Example Usage

```terraform
resource "snowflake_schema" "db" {
resource "snowflake_database" "db" {
name = "MYDB"
data_retention_days = 1
}
Expand Down
1 change: 1 addition & 0 deletions examples/provider/provider.tf
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ provider "snowflake" {
// optional
role = "..."
host = "..."
warehouse = "..."
}
2 changes: 2 additions & 0 deletions examples/resources/snowflake_function/import.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# format is database name | schema name | function name | <list of arg types, separated with '-'>
terraform import snowflake_function.example 'dbName|schemaName|functionName|varchar-varchar-varchar'
57 changes: 57 additions & 0 deletions examples/resources/snowflake_function/resource.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Provider configuration
provider "snowflake" {
region = "REGION" // Default is "us-west-2"
username = "USERNAME"
account = "ACCOUNT"
password = "PASSWORD"
role = "MY_ROLE"
warehouse = "MY_WH" // Optional attribute, some resources (e.g. Python UDFs)' require a warehouse to create and can also be set optionally from the `SNOWFLAKE_WAREHOUSE` environment variable
}

// Create database
resource "snowflake_database" "db" {
name = "MY_DB"
data_retention_days = 1
}

// Create schema
resource "snowflake_schema" "schema" {
database = snowflake_database.db.name
name = "MY_SCHEMA"
data_retention_days = 1
}

// Example for Java language
resource "snowflake_function" "test_funct_java" {
name = "my_java_func"
database = "MY_DB"
schema = "MY_SCHEMA"
arguments {
name = "arg1"
type = "number"
}
comment = "Example for java language"
return_type = "varchar"
language = "java"
handler = "CoolFunc.test"
statement = "class CoolFunc {public static String test(int n) {return \"hello!\";}}"
}

// Example for Python language
resource "snowflake_function" "python_test" {
name = "MY_PYTHON_FUNC"
database = "MY_DB"
schema = "MY_SCHEMA"
arguments {
name = "arg1"
type = "number"
}
comment = "Example for Python language"
return_type = "NUMBER(38,0)"
null_input_behavior = "CALLED ON NULL INPUT"
return_behavior = "VOLATILE"
language = "python"
runtime_version = "3.8"
handler = "add_py"
statement = "def add_py(i): return i+1"
}
2 changes: 1 addition & 1 deletion examples/resources/snowflake_procedure/resource.tf
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
resource "snowflake_schema" "db" {
resource "snowflake_database" "db" {
name = "MYDB"
data_retention_days = 1
}
Expand Down
19 changes: 16 additions & 3 deletions pkg/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ func Provider() *schema.Provider {
Optional: true,
DefaultFunc: schema.EnvDefaultFunc("SNOWFLAKE_HOST", nil),
},
"warehouse": {
Type: schema.TypeString,
Description: "Sets the default warehouse. Optional. Can be sourced from SNOWFLAKE_WAREHOUSE enviornment variable.",
Optional: true,
DefaultFunc: schema.EnvDefaultFunc("SNOWFLAKE_WAREHOUSE", nil),
},
},
ResourcesMap: getResources(),
DataSourcesMap: getDataSources(),
Expand Down Expand Up @@ -284,6 +290,7 @@ func ConfigureProvider(s *schema.ResourceData) (interface{}, error) {
oauthEndpoint := s.Get("oauth_endpoint").(string)
oauthRedirectURL := s.Get("oauth_redirect_url").(string)
host := s.Get("host").(string)
warehouse := s.Get("warehouse").(string)

if oauthRefreshToken != "" {
accessToken, err := GetOauthAccessToken(oauthEndpoint, oauthClientID, oauthClientSecret, GetOauthData(oauthRefreshToken, oauthRedirectURL))
Expand All @@ -305,6 +312,7 @@ func ConfigureProvider(s *schema.ResourceData) (interface{}, error) {
region,
role,
host,
warehouse,
)
if err != nil {
return nil, errors.Wrap(err, "could not build dsn for snowflake connection")
Expand All @@ -329,9 +337,10 @@ func DSN(
oauthAccessToken,
region,
role,
host string) (string, error) {
host,
warehouse string) (string, error) {

// us-west-2 is their default region, but if you actually specify that it won't trigger their default code
// us-west-2 is Snowflake's default region, but if you actually specify that it won't trigger the default code
// https://github.com/snowflakedb/gosnowflake/blob/52137ce8c32eaf93b0bd22fc5c7297beff339812/dsn.go#L61
if region == "us-west-2" {
region = ""
Expand All @@ -350,6 +359,11 @@ func DSN(
config.Host = host
}

// If warehouse is set
if warehouse != "" {
config.Warehouse = warehouse
}

if privateKeyPath != "" {
privateKeyBytes, err := ReadPrivateKeyFile(privateKeyPath)
if err != nil {
Expand Down Expand Up @@ -454,7 +468,6 @@ func GetOauthRequest(dataContent io.Reader, endPoint, clientId, clientSecret str
request.SetBasicAuth(clientId, clientSecret)
request.Header.Set("Content-Type", "application/x-www-form-urlencoded;charset=UTF-8")
return request, nil

}

func GetOauthAccessToken(
Expand Down
15 changes: 8 additions & 7 deletions pkg/provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,27 @@ func TestDSN(t *testing.T) {
browserAuth bool
region,
role,
host string
host,
warehouse string
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{"simple", args{"acct", "user", "pass", false, "region", "role", ""},
{"simple", args{"acct", "user", "pass", false, "region", "role", "", ""},
"user:[email protected]:443?ocspFailOpen=true&region=region&role=role&validateDefaultParameters=true", false},
{"us-west-2 special case", args{"acct2", "user2", "pass2", false, "us-west-2", "role2", ""},
{"us-west-2 special case", args{"acct2", "user2", "pass2", false, "us-west-2", "role2", "", ""},
"user2:[email protected]:443?ocspFailOpen=true&role=role2&validateDefaultParameters=true", false},
{"customhostwregion", args{"acct3", "user3", "pass3", false, "", "role3", "zha123.us-east-1.privatelink.snowflakecomputing.com"},
{"customhostwregion", args{"acct3", "user3", "pass3", false, "", "role3", "zha123.us-east-1.privatelink.snowflakecomputing.com", ""},
"user3:[email protected]:443?account=acct3&ocspFailOpen=true&role=role3&validateDefaultParameters=true", false},
{"customhostignoreregion", args{"acct4", "user4", "pass4", false, "fakeregion", "role4", "zha1234.us-east-1.privatelink.snowflakecomputing.com"},
{"customhostignoreregion", args{"acct4", "user4", "pass4", false, "fakeregion", "role4", "zha1234.us-east-1.privatelink.snowflakecomputing.com", ""},
"user4:[email protected]:443?account=acct4&ocspFailOpen=true&role=role4&validateDefaultParameters=true", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := provider.DSN(tt.args.account, tt.args.user, tt.args.password, tt.args.browserAuth, "", "", "", "", tt.args.region, tt.args.role, tt.args.host)
got, err := provider.DSN(tt.args.account, tt.args.user, tt.args.password, tt.args.browserAuth, "", "", "", "", tt.args.region, tt.args.role, tt.args.host, tt.args.warehouse)
if (err != nil) != tt.wantErr {
t.Errorf("DSN() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down Expand Up @@ -89,7 +90,7 @@ func TestOAuthDSN(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := provider.DSN(tt.args.account, tt.args.user, "", false, "", "", "", tt.args.oauthAccessToken, tt.args.region, tt.args.role, "")
got, err := provider.DSN(tt.args.account, tt.args.user, "", false, "", "", "", tt.args.oauthAccessToken, tt.args.region, tt.args.role, "", "")

if (err != nil) != tt.wantErr {
t.Errorf("DSN() error = %v, dsn = %v, wantErr %v", err, got, tt.wantErr)
Expand Down
33 changes: 5 additions & 28 deletions pkg/resources/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@ var functionSchema = map[string]*schema.Schema{
Required: true,
Description: "Specifies the identifier for the function; does not have to be unique for the schema in which the function is created. Don't use the | character.",
},
"warehouse": {
Type: schema.TypeString,
Optional: true,
Description: "The warehouse in which to create the function. Only for Python language.",
ForceNew: true,
},
"database": {
Type: schema.TypeString,
Required: true,
Expand Down Expand Up @@ -200,29 +194,16 @@ func CreateFunction(d *schema.ResourceData, meta interface{}) error {
builder.WithLanguage(v.(string))
}

// Set optionals, runtime version for python
// Set optionals, runtime version for Python
if v, ok := d.GetOk("runtime_version"); ok {
builder.WithRuntimeVersion(v.(string))
}

// Set optionals, warehouse in which to create the function
if v, ok := d.GetOk("warehouse"); ok {
builder.WithWarehouse(v.(string))
q, err := builder.UseWarehouse()
if err != nil {
return err
}
err = snowflake.Exec(db, q)
if err != nil {
return errors.Wrapf(err, "error using warehouse %v", v.(string))
}
}

if v, ok := d.GetOk("comment"); ok {
builder.WithComment(v.(string))
}

// Set optionals, packages for Java / python
// Set optionals, packages for Java / Python
if _, ok := d.GetOk("packages"); ok {
packages := []string{}
for _, pack := range d.Get("packages").([]interface{}) {
Expand All @@ -231,7 +212,7 @@ func CreateFunction(d *schema.ResourceData, meta interface{}) error {
builder.WithPackages(packages)
}

// Set optionals, imports for Java / python
// Set optionals, imports for Java / Python
if _, ok := d.GetOk("imports"); ok {
imports := []string{}
for _, imp := range d.Get("imports").([]interface{}) {
Expand All @@ -240,12 +221,12 @@ func CreateFunction(d *schema.ResourceData, meta interface{}) error {
builder.WithImports(imports)
}

// handler for Java / python
// handler for Java / Python
if v, ok := d.GetOk("handler"); ok {
builder.WithHandler(v.(string))
}

// target path for Java / python
// target path for Java / Python
if v, ok := d.GetOk("target_path"); ok {
builder.WithTargetPath(v.(string))
}
Expand Down Expand Up @@ -371,10 +352,6 @@ func ReadFunction(d *schema.ResourceData, meta interface{}) error {
if err = d.Set("handler", desc.Value.String); err != nil {
return err
}
case "warehouse":
if err = d.Set("warehouse", desc.Value.String); err != nil {
return err
}
case "target_path":
if err = d.Set("target_path", desc.Value.String); err != nil {
return err
Expand Down
Loading

0 comments on commit bab729a

Please sign in to comment.