Skip to content

Commit

Permalink
feat: add python language support for functions (#1063)
Browse files Browse the repository at this point in the history
* adding python language support and runtime_version attribute

* adding packages attribute for python / java

* updated function_acceptance_test

* testing by adding warehouse attribute

* testing by adding warehouse attribute

* updating null input and return behaviour

* updating test functions

* running go fmt

* changing runtime_version type from float64 to string

* updating function acceptance test

* updating function acceptance test

* adding sql in the list and updating the description of attributes

* updating function acceptance test

* minor changes

* testing

* testing
  • Loading branch information
sfc-gh-kumaurya authored Jun 18, 2022
1 parent d055d4c commit ee4c2c1
Show file tree
Hide file tree
Showing 6 changed files with 294 additions and 34 deletions.
11 changes: 7 additions & 4 deletions docs/resources/function.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,21 @@ description: |-
- `name` (String) Specifies the identifier for the function; does not have to be unique for the schema in which the function is created. Don't use the | character.
- `return_type` (String) The return type of the function
- `schema` (String) The schema in which to create the function. Don't use the | character.
- `statement` (String) Specifies the javascript / java / sql code used to create the function.
- `statement` (String) Specifies the javascript / java / sql / python code used to create the function.

### Optional

- `arguments` (Block List) List of the arguments for the function (see [below for nested schema](#nestedblock--arguments))
- `comment` (String) Specifies a comment for the function.
- `handler` (String) the handler method for Java function.
- `imports` (List of String) jar files to import for Java function.
- `handler` (String) The handler method for Java / Python function.
- `imports` (List of String) Imports for Java / Python functions. For Java this a list of jar files, for Python this is a list of Python files.
- `language` (String) The language of the statement
- `null_input_behavior` (String) Specifies the behavior of the function when called with null inputs.
- `packages` (List of String) List of package imports to use for Java / Python functions. For Java, package imports should be of the form: package_name:version_number, where package_name is snowflake_domain:package. For Python use it should be: ('numpy','pandas','xgboost==1.5.0').
- `return_behavior` (String) Specifies the behavior of the function when returning results
- `target_path` (String) the target path for compiled jar file for Java function.
- `runtime_version` (String) Required for Python functions. Specifies Python runtime version.
- `target_path` (String) The target path for the Java / Python functions. For Java, it is the path of compiled jar files and for the Python it is the path of the Python files.
- `warehouse` (String) The warehouse in which to create the function. Only for Python language.

### Read-Only

Expand Down
80 changes: 71 additions & 9 deletions pkg/resources/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,20 @@ import (
"github.com/pkg/errors"
)

var languages = []string{"javascript", "java", "sql"}
var languages = []string{"javascript", "java", "sql", "python"}

var functionSchema = map[string]*schema.Schema{
"name": {
Type: schema.TypeString,
Required: true,
Description: "Specifies the identifier for the function; does not have to be unique for the schema in which the function is created. Don't use the | character.",
},
"warehouse": {
Type: schema.TypeString,
Optional: true,
Description: "The warehouse in which to create the function. Only for Python language.",
ForceNew: true,
},
"database": {
Type: schema.TypeString,
Required: true,
Expand Down Expand Up @@ -68,7 +74,7 @@ var functionSchema = map[string]*schema.Schema{
"statement": {
Type: schema.TypeString,
Required: true,
Description: "Specifies the javascript / java / sql code used to create the function.",
Description: "Specifies the javascript / java / sql / python code used to create the function.",
ForceNew: true,
DiffSuppressFunc: DiffSuppressStatement,
},
Expand Down Expand Up @@ -102,26 +108,41 @@ var functionSchema = map[string]*schema.Schema{
Default: "user-defined function",
Description: "Specifies a comment for the function.",
},
"runtime_version": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
Description: "Required for Python functions. Specifies Python runtime version.",
},
"packages": {
Type: schema.TypeList,
Elem: &schema.Schema{
Type: schema.TypeString,
},
Optional: true,
ForceNew: true,
Description: "List of package imports to use for Java / Python functions. For Java, package imports should be of the form: package_name:version_number, where package_name is snowflake_domain:package. For Python use it should be: ('numpy','pandas','xgboost==1.5.0').",
},
"imports": {
Type: schema.TypeList,
Elem: &schema.Schema{
Type: schema.TypeString,
},
Optional: true,
ForceNew: true,
Description: "jar files to import for Java function.",
Description: "Imports for Java / Python functions. For Java this a list of jar files, for Python this is a list of Python files.",
},
"handler": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
Description: "the handler method for Java function.",
Description: "The handler method for Java / Python function.",
},
"target_path": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
Description: "the target path for compiled jar file for Java function.",
Description: "The target path for the Java / Python functions. For Java, it is the path of compiled jar files and for the Python it is the path of the Python files.",
},
}

Expand Down Expand Up @@ -179,11 +200,38 @@ func CreateFunction(d *schema.ResourceData, meta interface{}) error {
builder.WithLanguage(v.(string))
}

// Set optionals, runtime version for python
if v, ok := d.GetOk("runtime_version"); ok {
builder.WithRuntimeVersion(v.(string))
}

// Set optionals, warehouse in which to create the function
if v, ok := d.GetOk("warehouse"); ok {
builder.WithWarehouse(v.(string))
q, err := builder.UseWarehouse()
if err != nil {
return err
}
err = snowflake.Exec(db, q)
if err != nil {
return errors.Wrapf(err, "error using warehouse %v", v.(string))
}
}

if v, ok := d.GetOk("comment"); ok {
builder.WithComment(v.(string))
}

// Set optionals, imports for Java
// Set optionals, packages for Java / python
if _, ok := d.GetOk("packages"); ok {
packages := []string{}
for _, pack := range d.Get("packages").([]interface{}) {
packages = append(packages, pack.(string))
}
builder.WithPackages(packages)
}

// Set optionals, imports for Java / python
if _, ok := d.GetOk("imports"); ok {
imports := []string{}
for _, imp := range d.Get("imports").([]interface{}) {
Expand All @@ -192,12 +240,12 @@ func CreateFunction(d *schema.ResourceData, meta interface{}) error {
builder.WithImports(imports)
}

// handler for Java
// handler for Java / python
if v, ok := d.GetOk("handler"); ok {
builder.WithHandler(v.(string))
}

// target path for Java
// target path for Java / python
if v, ok := d.GetOk("target_path"); ok {
builder.WithTargetPath(v.(string))
}
Expand Down Expand Up @@ -303,6 +351,14 @@ func ReadFunction(d *schema.ResourceData, meta interface{}) error {
return err
}
}
case "packages":
packagesString := strings.ReplaceAll(strings.ReplaceAll(desc.Value.String, "[", ""), "]", "")
if packagesString != "" { // Do nothing for Java / Python functions without packages
packages := strings.Split(packagesString, ", ")
if err = d.Set("packages", packages); err != nil {
return err
}
}
case "imports":
importsString := strings.ReplaceAll(strings.ReplaceAll(desc.Value.String, "[", ""), "]", "")
if importsString != "" { // Do nothing for Java functions without imports
Expand All @@ -315,12 +371,18 @@ func ReadFunction(d *schema.ResourceData, meta interface{}) error {
if err = d.Set("handler", desc.Value.String); err != nil {
return err
}
case "warehouse":
if err = d.Set("warehouse", desc.Value.String); err != nil {
return err
}
case "target_path":
if err = d.Set("target_path", desc.Value.String); err != nil {
return err
}
case "runtime_version":
// runtime version for Java function. currently not used.
if err = d.Set("runtime_version", desc.Value.String); err != nil {
return err
}
default:
log.Printf("[WARN] unexpected function property %v returned from Snowflake", desc.Property.String)
}
Expand Down
43 changes: 40 additions & 3 deletions pkg/resources/function_acceptance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@ func TestAcc_Function(t *testing.T) {
dbName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))
schemaName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))
functName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))
warehouseName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))

expBody1 := "3.141592654::FLOAT"
expBody2 := "var X=3\nreturn X"
expBody3 := "select 1, 2\nunion all\nselect 3, 4\n"
expBody4 := `class CoolFunc {public static String test(int n) {return "hello!";}}`
expBody5 := "def add_py(i, j): return i+j"

resource.Test(t, resource.TestCase{
Providers: providers(),
CheckDestroy: nil,
Steps: []resource.TestStep{
{
Config: functionConfig(dbName, schemaName, functName),
Config: functionConfig(dbName, schemaName, functName, warehouseName),
Check: resource.ComposeTestCheckFunc(
resource.TestCheckResourceAttr("snowflake_function.test_funct", "name", functName),
resource.TestCheckResourceAttr("snowflake_function.test_funct", "comment", "Terraform acceptance test"),
Expand All @@ -54,13 +57,20 @@ func TestAcc_Function(t *testing.T) {
resource.TestCheckResourceAttr("snowflake_function.test_funct_java", "arguments.#", "1"),
resource.TestCheckResourceAttr("snowflake_function.test_funct_java", "arguments.0.name", "ARG1"),
resource.TestCheckResourceAttr("snowflake_function.test_funct_java", "arguments.0.type", "NUMBER"),

resource.TestCheckResourceAttr("snowflake_function.test_funct_python", "name", functName),
resource.TestCheckResourceAttr("snowflake_function.test_funct_python", "comment", "Terraform acceptance test for python"),
resource.TestCheckResourceAttr("snowflake_function.test_funct_python", "statement", expBody5),
resource.TestCheckResourceAttr("snowflake_function.test_funct_python", "arguments.#", "2"),
resource.TestCheckResourceAttr("snowflake_function.test_funct_python", "arguments.0.name", "ARG1"),
resource.TestCheckResourceAttr("snowflake_function.test_funct_python", "arguments.0.type", "NUMBER"),
),
},
},
})
}

func functionConfig(db, schema, name string) string {
func functionConfig(db, schema, name, warehouse string) string {
return fmt.Sprintf(`
resource "snowflake_database" "test_database" {
name = "%s"
Expand All @@ -73,6 +83,7 @@ func functionConfig(db, schema, name string) string {
comment = "Terraform acceptance test"
}
resource "snowflake_function" "test_funct_simple" {
name = "%s"
database = snowflake_database.test_database.name
Expand Down Expand Up @@ -110,6 +121,32 @@ func functionConfig(db, schema, name string) string {
statement = "class CoolFunc {public static String test(int n) {return \"hello!\";}}"
}
resource "snowflake_warehouse" "test_wh" {
name = "%s"
comment = "Warehouse for terraform acceptance test"
}
resource "snowflake_function" "test_funct_python" {
name = "%s"
database = snowflake_database.test_database.name
schema = snowflake_schema.test_schema.name
warehouse = snowflake_warehouse.test_wh.name
arguments {
name = "ARG1"
type = "NUMBER"
}
arguments {
name = "ARG2"
type = "NUMBER"
}
comment = "Terraform acceptance test for python"
return_type = "NUMBER(38,0)"
language = "python"
runtime_version = "3.8"
handler = "add_py"
statement = "def add_py(i, j): return i+j"
}
resource "snowflake_function" "test_funct_complex" {
name = "%s"
database = snowflake_database.test_database.name
Expand All @@ -130,5 +167,5 @@ union all
select 3, 4
EOT
}
`, db, schema, name, name, name, name)
`, db, schema, name, name, name, warehouse, name, name)
}
39 changes: 27 additions & 12 deletions pkg/resources/function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,24 @@ import (
"github.com/stretchr/testify/require"
)

const functionBody string = "hi"
const functionBody string = "def add_py(i, j): return i+j"

func prepDummyFunctionResource(t *testing.T) *schema.ResourceData {
argument1 := map[string]interface{}{"name": "data", "type": "varchar"}
argument2 := map[string]interface{}{"name": "event_dt", "type": "date"}
in := map[string]interface{}{
"name": "my_funct",
"database": "my_db",
"schema": "my_schema",
"arguments": []interface{}{argument1, argument2},
"return_type": "varchar",
"return_behavior": "IMMUTABLE",
"statement": functionBody, //var message = DATA + DATA;return message
"name": "my_funct",
"database": "my_db",
"schema": "my_schema",
"arguments": []interface{}{argument1, argument2},
"language": "PYTHON",
"null_input_behaviour": "CALLED ON NULL INPUT",
"return_behavior": "VOLATILE",
"runtime_version": "3.8",
"packages": []interface{}{"numpy", "pandas"},
"handler": "add_py",
"return_type": "varchar",
"statement": functionBody, //var message = DATA + DATA;return message
}
d := function(t, "my_db|my_schema|my_funct|VARCHAR-DATE", in)
return d
Expand All @@ -42,7 +47,7 @@ func TestFunctionCreate(t *testing.T) {
d := prepDummyFunctionResource(t)

WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectExec(`CREATE OR REPLACE FUNCTION "my_db"."my_schema"."my_funct"\(data VARCHAR, event_dt DATE\) RETURNS VARCHAR CALLED ON NULL INPUT IMMUTABLE COMMENT = 'user-defined function' AS \$\$hi\$\$`).WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec(`CREATE OR REPLACE FUNCTION "my_db"."my_schema"."my_funct"\(data VARCHAR, event_dt DATE\) RETURNS VARCHAR LANGUAGE PYTHON CALLED ON NULL INPUT VOLATILE RUNTIME_VERSION = '3.8' PACKAGES = \('numpy', 'pandas'\) COMMENT = 'user-defined function' HANDLER = 'add_py' AS \$\$def add_py\(i, j\)\: return i\+j\$\$`).WillReturnResult(sqlmock.NewResult(1, 1))
expectFunctionRead(mock)
err := resources.CreateFunction(d, db)
r.NoError(err)
Expand All @@ -60,9 +65,7 @@ func expectFunctionRead(mock sqlmock.Sqlmock) {
describeRows := sqlmock.NewRows([]string{"property", "value"}).
AddRow("signature", "(data VARCHAR, event_dt DATE)").
AddRow("returns", "VARCHAR(123456789)"). // This is how return type is stored in Snowflake DB
AddRow("language", "SQL").
AddRow("null handling", "CALLED ON NULL INPUT").
AddRow("volatility", "IMMUTABLE").
AddRow("language", "PYTHON").
AddRow("body", functionBody)

mock.ExpectQuery(`DESCRIBE FUNCTION "my_db"."my_schema"."my_funct"\(VARCHAR, DATE\)`).WillReturnRows(describeRows)
Expand All @@ -82,6 +85,18 @@ func TestFunctionRead(t *testing.T) {
r.Equal("user-defined function", d.Get("comment").(string))
r.Equal("VARCHAR", d.Get("return_type").(string))
r.Equal(functionBody, d.Get("statement").(string))
r.Equal("PYTHON", d.Get("language").(string))
r.Equal("3.8", d.Get("runtime_version").(string))
r.Equal("add_py", d.Get("handler").(string))
r.Equal("CALLED ON NULL INPUT", d.Get("null_input_behavior").(string))
r.Equal("VOLATILE", d.Get("return_behavior").(string))

pkgs := d.Get("packages").([]interface{})
r.Len(pkgs, 2)
test_funct_pkg1 := pkgs[0].(string)
test_funct_pkg2 := pkgs[1].(string)
r.Equal("numpy", test_funct_pkg1)
r.Equal("pandas", test_funct_pkg2)

args := d.Get("arguments").([]interface{})
r.Len(args, 2)
Expand Down
Loading

0 comments on commit ee4c2c1

Please sign in to comment.