From a3ceda39d52a3d3fd03e39a7d9ffb6ae95ce7e04 Mon Sep 17 00:00:00 2001 From: Rob Bruce Date: Wed, 11 Aug 2021 21:34:37 +0100 Subject: [PATCH] feat: Table Column Defaults (#631) --- docs/resources/sequence.md | 1 + docs/resources/table.md | 21 +++ .../resources/snowflake_table/resource.tf | 10 ++ pkg/resources/sequence.go | 13 +- pkg/resources/sequence_acceptance_test.go | 3 + pkg/resources/sequence_test.go | 1 + pkg/resources/table.go | 143 +++++++++++++++- pkg/resources/table_acceptance_test.go | 158 ++++++++++++++++++ pkg/resources/table_test.go | 45 ++++- pkg/snowflake/escaping.go | 36 +++- pkg/snowflake/escaping_test.go | 51 ++++++ pkg/snowflake/sequence.go | 4 + pkg/snowflake/table.go | 118 ++++++++++++- pkg/snowflake/table_test.go | 46 ++++- 14 files changed, 629 insertions(+), 21 deletions(-) diff --git a/docs/resources/sequence.md b/docs/resources/sequence.md index 95e020c51a..037b4d66aa 100644 --- a/docs/resources/sequence.md +++ b/docs/resources/sequence.md @@ -46,6 +46,7 @@ resource "snowflake_sequence" "test_sequence" { ### Read-Only +- **fully_qualified_name** (String) The fully qualified name of the sequence. - **next_value** (Number) The next value the sequence will provide. diff --git a/docs/resources/table.md b/docs/resources/table.md index 729c19222a..0971ced189 100644 --- a/docs/resources/table.md +++ b/docs/resources/table.md @@ -19,6 +19,12 @@ resource "snowflake_schema" "schema" { data_retention_days = 1 } +resource "snowflake_sequence" "sequence" { + database = snowflake_schema.schema.database + schema = snowflake_schema.schema.name + name = "sequence" +} + resource "snowflake_table" "table" { database = snowflake_schema.schema.database schema = snowflake_schema.schema.name @@ -32,6 +38,10 @@ resource "snowflake_table" "table" { name = "id" type = "int" nullable = true + + default { + sequence = snowflake_sequence.sequence.fully_qualified_name + } } column { @@ -92,8 +102,19 @@ Required: Optional: - **comment** (String) Column comment +- **default** (Block List, Max: 1) Defines the column default value; note due to limitations of Snowflake's ALTER TABLE ADD/MODIFY COLUMN updates to default will not be applied (see [below for nested schema](#nestedblock--column--default)) - **nullable** (Boolean) Whether this column can contain null values. **Note**: Depending on your Snowflake version, the default value will not suffice if this column is used in a primary key constraint. + +### Nested Schema for `column.default` + +Optional: + +- **constant** (String) The default constant value for the column +- **expression** (String) The default expression value for the column +- **sequence** (String) The default sequence to use for the column + + ### Nested Schema for `primary_key` diff --git a/examples/resources/snowflake_table/resource.tf b/examples/resources/snowflake_table/resource.tf index 653eb59019..bdae8d099d 100644 --- a/examples/resources/snowflake_table/resource.tf +++ b/examples/resources/snowflake_table/resource.tf @@ -4,6 +4,12 @@ resource "snowflake_schema" "schema" { data_retention_days = 1 } +resource "snowflake_sequence" "sequence" { + database = snowflake_schema.schema.database + schema = snowflake_schema.schema.name + name = "sequence" +} + resource "snowflake_table" "table" { database = snowflake_schema.schema.database schema = snowflake_schema.schema.name @@ -17,6 +23,10 @@ resource "snowflake_table" "table" { name = "id" type = "int" nullable = true + + default { + sequence = snowflake_sequence.sequence.fully_qualified_name + } } column { diff --git a/pkg/resources/sequence.go b/pkg/resources/sequence.go index 9ec2e44c3e..8a9c5d1d61 100644 --- a/pkg/resources/sequence.go +++ b/pkg/resources/sequence.go @@ -44,6 +44,11 @@ var sequenceSchema = map[string]*schema.Schema{ Description: "The next value the sequence will provide.", Computed: true, }, + "fully_qualified_name": { + Type: schema.TypeString, + Description: "The fully qualified name of the sequence.", + Computed: true, + }, } var sequenceProperties = []string{"comment", "data_retention_time_in_days"} @@ -95,7 +100,8 @@ func ReadSequence(d *schema.ResourceData, meta interface{}) error { schema := d.Get("schema").(string) name := d.Get("name").(string) - stmt := snowflake.Sequence(name, database, schema).Show() + seq := snowflake.Sequence(name, database, schema) + stmt := seq.Show() row := snowflake.QueryRow(db, stmt) sequence, err := snowflake.ScanSequence(row) @@ -145,6 +151,11 @@ func ReadSequence(d *schema.ResourceData, meta interface{}) error { return err } + err = d.Set("fully_qualified_name", seq.Address()) + if err != nil { + return err + } + d.SetId(fmt.Sprintf(`%v|%v|%v`, sequence.DBName.String, sequence.SchemaName.String, sequence.Name.String)) if err != nil { return err diff --git a/pkg/resources/sequence_acceptance_test.go b/pkg/resources/sequence_acceptance_test.go index a65bf7d1f0..763063b66e 100644 --- a/pkg/resources/sequence_acceptance_test.go +++ b/pkg/resources/sequence_acceptance_test.go @@ -22,6 +22,7 @@ func TestAcc_Sequence(t *testing.T) { Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_sequence.test_sequence", "name", accName), resource.TestCheckResourceAttr("snowflake_sequence.test_sequence", "next_value", "1"), + resource.TestCheckResourceAttr("snowflake_sequence.test_sequence", "fully_qualified_name", fmt.Sprintf(`%v.%v.%v`, accName, accName, accName)), ), }, // Set comment and rename @@ -31,6 +32,7 @@ func TestAcc_Sequence(t *testing.T) { resource.TestCheckResourceAttr("snowflake_sequence.test_sequence", "name", accRename), resource.TestCheckResourceAttr("snowflake_sequence.test_sequence", "comment", "look at me I am a comment"), resource.TestCheckResourceAttr("snowflake_sequence.test_sequence", "next_value", "1"), + resource.TestCheckResourceAttr("snowflake_sequence.test_sequence", "fully_qualified_name", fmt.Sprintf(`%v.%v.%v`, accName, accName, accRename)), ), }, // Unset comment and set increment @@ -41,6 +43,7 @@ func TestAcc_Sequence(t *testing.T) { resource.TestCheckResourceAttr("snowflake_sequence.test_sequence", "comment", ""), resource.TestCheckResourceAttr("snowflake_sequence.test_sequence", "next_value", "1"), resource.TestCheckResourceAttr("snowflake_sequence.test_sequence", "increment", "32"), + resource.TestCheckResourceAttr("snowflake_sequence.test_sequence", "fully_qualified_name", fmt.Sprintf(`%v.%v.%v`, accName, accName, accName)), ), }, }, diff --git a/pkg/resources/sequence_test.go b/pkg/resources/sequence_test.go index c0bcf84893..697b37a461 100644 --- a/pkg/resources/sequence_test.go +++ b/pkg/resources/sequence_test.go @@ -99,6 +99,7 @@ func TestSequenceRead(t *testing.T) { r.Equal(25, d.Get("increment").(int)) r.Equal(5, d.Get("next_value").(int)) r.Equal("database|schema|good_name", d.Id()) + r.Equal(`"database"."schema"."good_name"`, d.Get("fully_qualified_name").(string)) }) } diff --git a/pkg/resources/table.go b/pkg/resources/table.go index 5d51aa8d6a..abc63f72fd 100644 --- a/pkg/resources/table.go +++ b/pkg/resources/table.go @@ -66,6 +66,35 @@ var tableSchema = map[string]*schema.Schema{ Default: true, Description: "Whether this column can contain null values. **Note**: Depending on your Snowflake version, the default value will not suffice if this column is used in a primary key constraint.", }, + "default": { + Type: schema.TypeList, + Optional: true, + Description: "Defines the column default value; note due to limitations of Snowflake's ALTER TABLE ADD/MODIFY COLUMN updates to default will not be applied", + MinItems: 1, + MaxItems: 1, + Elem: &schema.Resource{ + Schema: map[string]*schema.Schema{ + "constant": { + Type: schema.TypeString, + Optional: true, + Description: "The default constant value for the column", + // ConflictsWith: []string{".expression", ".sequence"}, - can't use, nor ExactlyOneOf due to column type being TypeList + }, + "expression": { + Type: schema.TypeString, + Optional: true, + Description: "The default expression value for the column", + // ConflictsWith: []string{".constant", ".sequence"}, - can't use, nor ExactlyOneOf due to column type being TypeList + }, + "sequence": { + Type: schema.TypeString, + Optional: true, + Description: "The default sequence to use for the column", + // ConflictsWith: []string{".constant", ".expression"}, - can't use, nor ExactlyOneOf due to column type being TypeList + }, + }, + }, + }, "comment": { Type: schema.TypeString, Optional: true, @@ -183,16 +212,63 @@ func tableIDFromString(stringID string) (*tableID, error) { return tableResult, nil } +type columnDefault struct { + constant *string + expression *string + sequence *string +} + +func (cd *columnDefault) toSnowflakeColumnDefault() *snowflake.ColumnDefault { + if cd.constant != nil { + return snowflake.NewColumnDefaultWithConstant(*cd.constant) + } + + if cd.expression != nil { + return snowflake.NewColumnDefaultWithExpression(*cd.expression) + } + + if cd.sequence != nil { + return snowflake.NewColumnDefaultWithSequence(*cd.sequence) + } + + return nil +} + +func (cd *columnDefault) _type() string { + if cd.constant != nil { + return "constant" + } + + if cd.expression != nil { + return "expression" + } + + if cd.sequence != nil { + return "sequence" + } + + return "unknown" +} + type column struct { name string dataType string nullable bool + _default *columnDefault comment string } func (c column) toSnowflakeColumn() snowflake.Column { - sC := snowflake.Column{} - return *sC.WithName(c.name).WithType(c.dataType).WithNullable(c.nullable).WithComment(c.comment) + sC := &snowflake.Column{} + + if c._default != nil { + sC = sC.WithDefault(c._default.toSnowflakeColumnDefault()) + } + + return *sC.WithName(c.name). + WithType(c.dataType). + WithNullable(c.nullable). + WithComment(c.comment) } type columns []column @@ -228,6 +304,7 @@ type changedColumn struct { newColumn column //our new column changedDataType bool changedNullConstraint bool + dropedDefault bool changedComment bool } @@ -235,13 +312,17 @@ func (old columns) getChangedColumnProperties(new columns) (changed changedColum changed = changedColumns{} for _, cO := range old { for _, cN := range new { - changeColumn := changedColumn{cN, false, false, false} + changeColumn := changedColumn{cN, false, false, false, false} if cO.name == cN.name && cO.dataType != cN.dataType { changeColumn.changedDataType = true } if cO.name == cN.name && cO.nullable != cN.nullable { changeColumn.changedNullConstraint = true } + if cO.name == cN.name && cO._default != nil && cN._default == nil { + changeColumn.dropedDefault = true + } + if cO.name == cN.name && cO.comment != cN.comment { changeColumn.changedComment = true } @@ -256,12 +337,48 @@ func (old columns) diffs(new columns) (removed columns, added columns, changed c return old.getNewIn(new), new.getNewIn(old), old.getChangedColumnProperties(new) } +func getColumnDefault(def map[string]interface{}) *columnDefault { + if c, ok := def["constant"]; ok { + if constant, ok := c.(string); ok && len(constant) > 0 { + return &columnDefault{ + constant: &constant, + } + } + } + + if e, ok := def["expression"]; ok { + if expr, ok := e.(string); ok && len(expr) > 0 { + return &columnDefault{ + expression: &expr, + } + } + } + + if s, ok := def["sequence"]; ok { + if seq := s.(string); ok && len(seq) > 0 { + return &columnDefault{ + sequence: &seq, + } + } + } + + return nil +} + func getColumn(from interface{}) (to column) { c := from.(map[string]interface{}) + var cd *columnDefault + + _default := c["default"].([]interface{}) + if len(_default) == 1 { + cd = getColumnDefault(_default[0].(map[string]interface{})) + } + return column{ name: c["name"].(string), dataType: c["type"].(string), nullable: c["nullable"].(bool), + _default: cd, comment: c["comment"].(string), } } @@ -466,7 +583,17 @@ func UpdateTable(d *schema.ResourceData, meta interface{}) error { } } for _, cA := range added { - q := builder.AddColumn(cA.name, cA.dataType, cA.nullable, cA.comment) + var q string + if cA._default == nil { + q = builder.AddColumn(cA.name, cA.dataType, cA.nullable, nil, cA.comment) + } else { + if cA._default._type() != "constant" { + return fmt.Errorf("Failed to add column %v => Only adding a column as a constant is supported by Snowflake", cA.name) + } + + q = builder.AddColumn(cA.name, cA.dataType, cA.nullable, cA._default.toSnowflakeColumnDefault(), cA.comment) + } + err := snowflake.Exec(db, q) if err != nil { return errors.Wrapf(err, "error adding column on %v", d.Id()) @@ -492,6 +619,14 @@ func UpdateTable(d *schema.ResourceData, meta interface{}) error { } } + if cA.dropedDefault { + q := builder.DropColumnDefault(cA.newColumn.name) + err := snowflake.Exec(db, q) + if err != nil { + return errors.Wrapf(err, "error changing property on %v", d.Id()) + + } + } if cA.changedComment { q := builder.ChangeColumnComment(cA.newColumn.name, cA.newColumn.comment) err := snowflake.Exec(db, q) diff --git a/pkg/resources/table_acceptance_test.go b/pkg/resources/table_acceptance_test.go index dcb056d4dc..769e83f8be 100644 --- a/pkg/resources/table_acceptance_test.go +++ b/pkg/resources/table_acceptance_test.go @@ -811,3 +811,161 @@ resource "snowflake_table" "test_table" { ` return fmt.Sprintf(s, name, name, name) } + +func TestAcc_TableDefaults(t *testing.T) { + accName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) + + resource.ParallelTest(t, resource.TestCase{ + Providers: providers(), + Steps: []resource.TestStep{ + { + Config: tableColumnWithDefaults(accName), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr("snowflake_table.test_table", "name", accName), + resource.TestCheckResourceAttr("snowflake_table.test_table", "database", accName), + resource.TestCheckResourceAttr("snowflake_table.test_table", "schema", accName), + resource.TestCheckResourceAttr("snowflake_table.test_table", "data_retention_days", "1"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "change_tracking", "false"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "comment", "Terraform acceptance test"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.#", "3"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.0.name", "column1"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.0.type", "VARCHAR(16)"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.0.default.0.constant", "hello"), + resource.TestCheckNoResourceAttr("snowflake_table.test_table", "column.0.type.default.0.expression"), + resource.TestCheckNoResourceAttr("snowflake_table.test_table", "column.0.type.default.0.sequence"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.1.name", "column2"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.1.type", "TIMESTAMP_NTZ(9)"), + resource.TestCheckNoResourceAttr("snowflake_table.test_table", "column.1.type.default.0.constant"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.1.default.0.expression", "CURRENT_TIMESTAMP()"), + resource.TestCheckNoResourceAttr("snowflake_table.test_table", "column.1.type.default.0.sequence"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.name", "column3"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.type", "NUMBER(38,0)"), + resource.TestCheckNoResourceAttr("snowflake_table.test_table", "column.2.type.default.0.constant"), + resource.TestCheckNoResourceAttr("snowflake_table.test_table", "column.2.type.default.0.expression"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.default.0.sequence", fmt.Sprintf("%v.%v.%v", accName, accName, accName)), + resource.TestCheckNoResourceAttr("snowflake_table.test_table", "primary_key"), + ), + }, + { + Config: tableColumnWithoutDefaults(accName), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr("snowflake_table.test_table", "name", accName), + resource.TestCheckResourceAttr("snowflake_table.test_table", "database", accName), + resource.TestCheckResourceAttr("snowflake_table.test_table", "schema", accName), + resource.TestCheckResourceAttr("snowflake_table.test_table", "data_retention_days", "1"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "change_tracking", "false"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "comment", "Terraform acceptance test"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.#", "3"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.0.name", "column1"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.0.type", "VARCHAR(16)"), + resource.TestCheckNoResourceAttr("snowflake_table.test_table", "column.0.default"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.1.name", "column2"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.1.type", "TIMESTAMP_NTZ(9)"), + resource.TestCheckNoResourceAttr("snowflake_table.test_table", "column.1.type.default"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.name", "column3"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.type", "NUMBER(38,0)"), + resource.TestCheckNoResourceAttr("snowflake_table.test_table", "column.2.type.default.0.constant"), + resource.TestCheckNoResourceAttr("snowflake_table.test_table", "column.2.type.default.0.expression"), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.default.0.sequence", fmt.Sprintf("%v.%v.%v", accName, accName, accName)), + resource.TestCheckNoResourceAttr("snowflake_table.test_table", "primary_key"), + ), + }, + }, + }) +} + +func tableColumnWithDefaults(name string) string { + s := ` +resource "snowflake_database" "test_database" { + name = "%s" + comment = "Terraform acceptance test" +} + +resource "snowflake_schema" "test_schema" { + name = "%s" + database = snowflake_database.test_database.name + comment = "Terraform acceptance test" +} + +resource "snowflake_sequence" "test_seq" { + database = snowflake_database.test_database.name + schema = snowflake_schema.test_schema.name + name = "%s" +} + +resource "snowflake_table" "test_table" { + database = snowflake_database.test_database.name + schema = snowflake_schema.test_schema.name + name = "%s" + comment = "Terraform acceptance test" + + column { + name = "column1" + type = "VARCHAR(16)" + default { + constant = "hello" + } + } + column { + name = "column2" + type = "TIMESTAMP_NTZ(9)" + default { + expression = "CURRENT_TIMESTAMP()" + } + } + column { + name = "column3" + type = "NUMBER(38,0)" + default { + sequence = snowflake_sequence.test_seq.fully_qualified_name + } + } +} +` + return fmt.Sprintf(s, name, name, name, name) +} + +func tableColumnWithoutDefaults(name string) string { + s := ` +resource "snowflake_database" "test_database" { + name = "%s" + comment = "Terraform acceptance test" +} + +resource "snowflake_schema" "test_schema" { + name = "%s" + database = snowflake_database.test_database.name + comment = "Terraform acceptance test" +} + +resource "snowflake_sequence" "test_seq" { + database = snowflake_database.test_database.name + schema = snowflake_schema.test_schema.name + name = "%s" +} + +resource "snowflake_table" "test_table" { + database = snowflake_database.test_database.name + schema = snowflake_schema.test_schema.name + name = "%s" + comment = "Terraform acceptance test" + + column { + name = "column1" + type = "VARCHAR(16)" + } + column { + name = "column2" + type = "TIMESTAMP_NTZ(9)" + } + column { + name = "column3" + type = "NUMBER(38,0)" + default { + sequence = snowflake_sequence.test_seq.fully_qualified_name + } + } +} +` + return fmt.Sprintf(s, name, name, name, name) +} diff --git a/pkg/resources/table_test.go b/pkg/resources/table_test.go index dfd961e794..a06b16e500 100644 --- a/pkg/resources/table_test.go +++ b/pkg/resources/table_test.go @@ -38,20 +38,53 @@ func TestTableCreate(t *testing.T) { }, map[string]interface{}{ "name": "column3", - "type": "INT", + "type": "NUMBER(38,0)", "comment": "some comment", }, + map[string]interface{}{ + "name": "column4", + "type": "VARCHAR", + "nullable": false, + "default": []interface{}{ + map[string]interface{}{ + "constant": "hello", + }, + }, + }, }, "primary_key": []interface{}{map[string]interface{}{"name": "MY_KEY", "keys": []interface{}{"column1"}}}, } d := table(t, "database_name|schema_name|good_name", in) WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) { - mock.ExpectExec(`CREATE TABLE "database_name"."schema_name"."good_name" \("column1" OBJECT COMMENT '', "column2" VARCHAR NOT NULL COMMENT '', "column3" INT COMMENT 'some comment' ,CONSTRAINT "MY_KEY" PRIMARY KEY\("column1"\)\) COMMENT = 'great comment'`).WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec(`CREATE TABLE "database_name"."schema_name"."good_name" \("column1" OBJECT COMMENT '', "column2" VARCHAR NOT NULL COMMENT '', "column3" NUMBER\(38,0\) COMMENT 'some comment', "column4" VARCHAR NOT NULL DEFAULT 'hello' COMMENT '' ,CONSTRAINT "MY_KEY" PRIMARY KEY\("column1"\)\) COMMENT = 'great comment' DATA_RETENTION_TIME_IN_DAYS = 1 CHANGE_TRACKING = false`).WillReturnResult(sqlmock.NewResult(1, 1)) expectTableRead(mock) err := resources.CreateTable(d, db) r.NoError(err) r.Equal("good_name", d.Get("name").(string)) + columns := d.Get("column").([]interface{}) + r.Equal(4, len(columns)) + col1 := columns[0].(map[string]interface{}) + r.Equal("column1", col1["name"].(string)) + r.Equal("OBJECT", col1["type"].(string)) + r.Equal(true, col1["nullable"].(bool)) + col2 := columns[1].(map[string]interface{}) + r.Equal("column2", col2["name"].(string)) + r.Equal("VARCHAR", col2["type"].(string)) + r.Equal(false, col2["nullable"].(bool)) + col3 := columns[2].(map[string]interface{}) + r.Equal("column3", col3["name"].(string)) + r.Equal("NUMBER(38,0)", col3["type"].(string)) + r.Equal(true, col3["nullable"].(bool)) + r.Equal("some comment", col3["comment"].(string)) + col4 := columns[3].(map[string]interface{}) + r.Equal("column4", col4["name"].(string)) + r.Equal("VARCHAR", col4["type"].(string)) + r.NotNil(col4["default"]) + col4Default := col4["default"].([]interface{}) + r.Equal(1, len(col4Default)) + col4DefaultParams := col4Default[0].(map[string]interface{}) + r.Equal("hello", col4DefaultParams["constant"].(string)) }) } @@ -59,9 +92,11 @@ func expectTableRead(mock sqlmock.Sqlmock) { rows := sqlmock.NewRows([]string{"name", "type", "kind", "null?", "default", "primary key", "unique key", "check", "expression", "comment"}).AddRow("good_name", "VARCHAR()", "COLUMN", "Y", "NULL", "NULL", "N", "N", "NULL", "mock comment") mock.ExpectQuery(`SHOW TABLES LIKE 'good_name' IN SCHEMA "database_name"."schema_name"`).WillReturnRows(rows) - describeRows := sqlmock.NewRows([]string{"name", "type", "kind", "null?"}). - AddRow("column1", "OBJECT", "COLUMN", "Y"). - AddRow("column2", "VARCHAR", "COLUMN", "N") + describeRows := sqlmock.NewRows([]string{"name", "type", "kind", "null?", "default", "comment"}). + AddRow("column1", "OBJECT", "COLUMN", "Y", nil, nil). + AddRow("column2", "VARCHAR", "COLUMN", "N", nil, nil). + AddRow("column3", "NUMBER(38,0)", "COLUMN", "Y", nil, "some comment"). + AddRow("column4", "VARCHAR", "COLUMN", "N", "'hello'", nil) mock.ExpectQuery(`DESC TABLE "database_name"."schema_name"."good_name"`).WillReturnRows(describeRows) diff --git a/pkg/snowflake/escaping.go b/pkg/snowflake/escaping.go index 5d4a93e0b0..e3dd1f8960 100644 --- a/pkg/snowflake/escaping.go +++ b/pkg/snowflake/escaping.go @@ -1,6 +1,10 @@ package snowflake -import "strings" +import ( + "fmt" + "regexp" + "strings" +) // EscapeString will escape only the ' character. Would prefer a more robust OSS solution, but this should // prevent some dumb errors for now. @@ -16,3 +20,33 @@ func UnescapeString(in string) string { out = strings.Replace(out, `\'`, `'`, -1) return out } + +// EscapeSnowflakeString will escape single quotes with the SQL native double single quote +func EscapeSnowflakeString(in string) string { + out := strings.Replace(in, `'`, `''`, -1) + return fmt.Sprintf(`'%v'`, out) +} + +// UnescapeSnowflakeString reverses EscapeSnowflakeString +func UnescapeSnowflakeString(in string) string { + out := strings.TrimPrefix(in, `'`) + out = strings.TrimSuffix(out, `'`) + out = strings.Replace(out, `''`, `'`, -1) + return out +} + +// AddressEscape wraps a name inside double quotes only if required by Snowflake +func AddressEscape(in ...string) string { + quoteCheck := regexp.MustCompile(`[^A-Z0-9_]`) + address := make([]string, len(in)) + + for i, n := range in { + if quoteCheck.MatchString(n) { + address[i] = fmt.Sprintf(`"%s"`, strings.Replace(n, `"`, `\"`, -1)) + } else { + address[i] = n + } + } + + return strings.Join(address, ".") +} diff --git a/pkg/snowflake/escaping_test.go b/pkg/snowflake/escaping_test.go index a04421a9dd..d764edb69a 100644 --- a/pkg/snowflake/escaping_test.go +++ b/pkg/snowflake/escaping_test.go @@ -13,3 +13,54 @@ func TestEscapeString(t *testing.T) { r.Equal(`\'`, snowflake.EscapeString(`'`)) r.Equal(`\\\'`, snowflake.EscapeString(`\'`)) } + +func TestEscapeSnowflakeString(t *testing.T) { + r := require.New(t) + r.Equal(`'table''s quoted'`, snowflake.EscapeSnowflakeString(`table's quoted`)) +} + +func TestUnescapeSnowflakeString(t *testing.T) { + r := require.New(t) + r.Equal(`table's quoted`, snowflake.UnescapeSnowflakeString(`'table''s quoted'`)) +} + +func TestAddressEscape(t *testing.T) { + testCases := []struct { + id string + name []string + expected string + }{ + { + id: "single no escape", + name: []string{"HELLO"}, + expected: "HELLO", + }, + { + id: "multiple no escape", + name: []string{"HELLO", "WORLD"}, + expected: "HELLO.WORLD", + }, + { + id: "single escape", + name: []string{"hello"}, + expected: `"hello"`, + }, + { + id: "multiple escape", + name: []string{"hello", "world"}, + expected: `"hello"."world"`, + }, + { + id: "mixed escape", + name: []string{"hello", "world", "NOTHERE"}, + expected: `"hello"."world".NOTHERE`, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.id, func(t *testing.T) { + r := require.New(t) + r.Equal(testCase.expected, snowflake.AddressEscape(testCase.name...)) + }) + } +} diff --git a/pkg/snowflake/sequence.go b/pkg/snowflake/sequence.go index 4990ab49b0..a6852b2a50 100644 --- a/pkg/snowflake/sequence.go +++ b/pkg/snowflake/sequence.go @@ -85,6 +85,10 @@ func (sb *SequenceBuilder) QualifiedName() string { return fmt.Sprintf(`"%v"."%v"."%v"`, sb.db, sb.schema, sb.name) } +func (sb *SequenceBuilder) Address() string { + return AddressEscape(sb.db, sb.schema, sb.name) +} + func ScanSequence(row *sqlx.Row) (*sequence, error) { d := &sequence{} e := row.StructScan(d) diff --git a/pkg/snowflake/table.go b/pkg/snowflake/table.go index 90e9748d30..58802f33fe 100644 --- a/pkg/snowflake/table.go +++ b/pkg/snowflake/table.go @@ -30,12 +30,75 @@ func (pk *PrimaryKey) WithKeys(keys []string) *PrimaryKey { return pk } +type ColumnDefaultType int + +const ( + columnDefaultTypeConstant = iota + columnDefaultTypeSequence + columnDefaultTypeExpression +) + +type ColumnDefault struct { + _type ColumnDefaultType + expression string +} + +func NewColumnDefaultWithConstant(constant string) *ColumnDefault { + return &ColumnDefault{ + _type: columnDefaultTypeConstant, + expression: constant, + } +} + +func NewColumnDefaultWithExpression(expression string) *ColumnDefault { + return &ColumnDefault{ + _type: columnDefaultTypeExpression, + expression: expression, + } +} + +func NewColumnDefaultWithSequence(sequence string) *ColumnDefault { + return &ColumnDefault{ + _type: columnDefaultTypeSequence, + expression: sequence, + } +} + +func (d *ColumnDefault) String(columnType string) string { + columnType = strings.ToUpper(columnType) + + switch { + case d._type == columnDefaultTypeExpression: + return d.expression + + case d._type == columnDefaultTypeSequence: + return fmt.Sprintf(`%v.NEXTVAL`, d.expression) + + case d._type == columnDefaultTypeConstant && (strings.Contains(columnType, "CHAR") || columnType == "STRING" || columnType == "TEXT"): + return EscapeSnowflakeString(d.expression) + + default: + return d.expression + } +} + +func (d *ColumnDefault) UnescapeConstantSnowflakeString(columnType string) string { + columnType = strings.ToUpper(columnType) + + if d._type == columnDefaultTypeConstant && (strings.Contains(columnType, "CHAR") || columnType == "STRING" || columnType == "TEXT") { + return UnescapeSnowflakeString(d.expression) + } + + return d.expression +} + // Column structure that represents a table column type Column struct { name string _type string // type is reserved nullable bool - comment string // pointer as value is nullable + _default *ColumnDefault // default is reserved + comment string // pointer as value is nullable } // WithName set the column name @@ -56,6 +119,11 @@ func (c *Column) WithNullable(nullable bool) *Column { return c } +func (c *Column) WithDefault(cd *ColumnDefault) *Column { + c._default = cd + return c +} + // WithComment set the column comment func (c *Column) WithComment(comment string) *Column { c.comment = comment @@ -76,6 +144,10 @@ func (c *Column) getColumnDefinition(withInlineConstraints bool, withComment boo } } + if c._default != nil { + colDef.WriteString(fmt.Sprintf(` DEFAULT %v`, c._default.String(c._type))) + } + if withComment { colDef.WriteString(fmt.Sprintf(` COMMENT '%v'`, EscapeString(c.comment))) } @@ -133,10 +205,12 @@ func NewColumns(tds []tableDescription) Columns { if td.Kind.String != "COLUMN" { continue } + cs = append(cs, Column{ name: td.Name.String, _type: td.Type.String, nullable: td.IsNullable(), + _default: td.ColumnDefault(), comment: td.Comment.String, }) } @@ -152,6 +226,20 @@ func (c Columns) Flatten() []interface{} { flat["nullable"] = col.nullable flat["comment"] = col.comment + if col._default != nil { + def := map[string]interface{}{} + switch col._default._type { + case columnDefaultTypeConstant: + def["constant"] = col._default.UnescapeConstantSnowflakeString(col._type) + case columnDefaultTypeExpression: + def["expression"] = col._default.expression + case columnDefaultTypeSequence: + def["sequence"] = col._default.expression + } + + flat["default"] = []interface{}{def} + } + flattened = append(flattened, flat) } return flattened @@ -378,11 +466,12 @@ func (tb *TableBuilder) ChangeChangeTracking(changeTracking bool) string { } // AddColumn returns the SQL query that will add a new column to the table. -func (tb *TableBuilder) AddColumn(name string, dataType string, nullable bool, comment string) string { +func (tb *TableBuilder) AddColumn(name string, dataType string, nullable bool, _default *ColumnDefault, comment string) string { col := Column{ name: name, _type: dataType, nullable: nullable, + _default: _default, comment: comment, } return fmt.Sprintf(`ALTER TABLE %s ADD COLUMN %s`, tb.QualifiedName(), col.getColumnDefinition(true, true)) @@ -407,6 +496,10 @@ func (tb *TableBuilder) ChangeColumnComment(name string, comment string) string return fmt.Sprintf(`ALTER TABLE %s MODIFY COLUMN "%v" COMMENT '%v'`, tb.QualifiedName(), EscapeString(name), EscapeString(comment)) } +func (tb *TableBuilder) DropColumnDefault(name string) string { + return fmt.Sprintf(`ALTER TABLE %s MODIFY COLUMN "%v" DROP DEFAULT`, tb.QualifiedName(), EscapeString(name)) +} + // RemoveComment returns the SQL query that will remove the comment on the table. func (tb *TableBuilder) RemoveComment() string { return fmt.Sprintf(`ALTER TABLE %v UNSET COMMENT`, tb.QualifiedName()) @@ -484,6 +577,7 @@ type tableDescription struct { Type sql.NullString `db:"type"` Kind sql.NullString `db:"kind"` Nullable sql.NullString `db:"null?"` + Default sql.NullString `db:"default"` Comment sql.NullString `db:"comment"` } @@ -495,6 +589,26 @@ func (td *tableDescription) IsNullable() bool { } } +func (td *tableDescription) ColumnDefault() *ColumnDefault { + if !td.Default.Valid { + return nil + } + + if strings.HasSuffix(td.Default.String, ".NEXTVAL") { + return NewColumnDefaultWithSequence(strings.TrimSuffix(td.Default.String, ".NEXTVAL")) + } + + if strings.Contains(td.Default.String, "(") && strings.Contains(td.Default.String, ")") { + return NewColumnDefaultWithExpression(td.Default.String) + } + + if strings.Contains(td.Type.String, "CHAR") || td.Type.String == "STRING" || td.Type.String == "TEXT" { + return NewColumnDefaultWithConstant(UnescapeSnowflakeString(td.Default.String)) + } + + return NewColumnDefaultWithConstant(td.Default.String) +} + type primaryKeyDescription struct { ColumnName sql.NullString `db:"column_name"` KeySequence sql.NullString `db:"key_sequence"` diff --git a/pkg/snowflake/table_test.go b/pkg/snowflake/table_test.go index 5e9c0c7045..cd687ba6fd 100644 --- a/pkg/snowflake/table_test.go +++ b/pkg/snowflake/table_test.go @@ -21,27 +21,45 @@ func TestTableCreate(t *testing.T) { nullable: true, comment: "only populated when data is available", }, + { + name: "column3", + _type: "NUMBER(38,0)", + nullable: false, + _default: NewColumnDefaultWithSequence(`"test_db"."test_schema"."test_seq"`), + }, + { + name: "column4", + _type: "VARCHAR", + nullable: false, + _default: NewColumnDefaultWithConstant("test default's"), + }, + { + name: "column5", + _type: "TIMESTAMP_NTZ", + nullable: false, + _default: NewColumnDefaultWithExpression("CURRENT_TIMESTAMP()"), + }, } s.WithColumns(Columns(cols)) r.Equal(s.QualifiedName(), `"test_db"."test_schema"."test_table"`) - r.Equal(`CREATE TABLE "test_db"."test_schema"."test_table" ("column1" OBJECT COMMENT '', "column2" VARCHAR COMMENT 'only populated when data is available') DATA_RETENTION_TIME_IN_DAYS = 0 CHANGE_TRACKING = false`, s.Create()) + r.Equal(`CREATE TABLE "test_db"."test_schema"."test_table" ("column1" OBJECT COMMENT '', "column2" VARCHAR COMMENT 'only populated when data is available', "column3" NUMBER(38,0) NOT NULL DEFAULT "test_db"."test_schema"."test_seq".NEXTVAL COMMENT '', "column4" VARCHAR NOT NULL DEFAULT 'test default''s' COMMENT '', "column5" TIMESTAMP_NTZ NOT NULL DEFAULT CURRENT_TIMESTAMP() COMMENT '') DATA_RETENTION_TIME_IN_DAYS = 0 CHANGE_TRACKING = false`, s.Create()) s.WithComment("Test Comment") - r.Equal(s.Create(), `CREATE TABLE "test_db"."test_schema"."test_table" ("column1" OBJECT COMMENT '', "column2" VARCHAR COMMENT 'only populated when data is available') COMMENT = 'Test Comment' DATA_RETENTION_TIME_IN_DAYS = 0 CHANGE_TRACKING = false`) + r.Equal(s.Create(), `CREATE TABLE "test_db"."test_schema"."test_table" ("column1" OBJECT COMMENT '', "column2" VARCHAR COMMENT 'only populated when data is available', "column3" NUMBER(38,0) NOT NULL DEFAULT "test_db"."test_schema"."test_seq".NEXTVAL COMMENT '', "column4" VARCHAR NOT NULL DEFAULT 'test default''s' COMMENT '', "column5" TIMESTAMP_NTZ NOT NULL DEFAULT CURRENT_TIMESTAMP() COMMENT '') COMMENT = 'Test Comment' DATA_RETENTION_TIME_IN_DAYS = 0 CHANGE_TRACKING = false`) s.WithClustering([]string{"column1"}) - r.Equal(s.Create(), `CREATE TABLE "test_db"."test_schema"."test_table" ("column1" OBJECT COMMENT '', "column2" VARCHAR COMMENT 'only populated when data is available') COMMENT = 'Test Comment' CLUSTER BY LINEAR(column1) DATA_RETENTION_TIME_IN_DAYS = 0 CHANGE_TRACKING = false`) + r.Equal(s.Create(), `CREATE TABLE "test_db"."test_schema"."test_table" ("column1" OBJECT COMMENT '', "column2" VARCHAR COMMENT 'only populated when data is available', "column3" NUMBER(38,0) NOT NULL DEFAULT "test_db"."test_schema"."test_seq".NEXTVAL COMMENT '', "column4" VARCHAR NOT NULL DEFAULT 'test default''s' COMMENT '', "column5" TIMESTAMP_NTZ NOT NULL DEFAULT CURRENT_TIMESTAMP() COMMENT '') COMMENT = 'Test Comment' CLUSTER BY LINEAR(column1) DATA_RETENTION_TIME_IN_DAYS = 0 CHANGE_TRACKING = false`) s.WithPrimaryKey(PrimaryKey{name: "MY_KEY", keys: []string{"column1"}}) - r.Equal(s.Create(), `CREATE TABLE "test_db"."test_schema"."test_table" ("column1" OBJECT COMMENT '', "column2" VARCHAR COMMENT 'only populated when data is available' ,CONSTRAINT "MY_KEY" PRIMARY KEY("column1")) COMMENT = 'Test Comment' CLUSTER BY LINEAR(column1) DATA_RETENTION_TIME_IN_DAYS = 0 CHANGE_TRACKING = false`) + r.Equal(s.Create(), `CREATE TABLE "test_db"."test_schema"."test_table" ("column1" OBJECT COMMENT '', "column2" VARCHAR COMMENT 'only populated when data is available', "column3" NUMBER(38,0) NOT NULL DEFAULT "test_db"."test_schema"."test_seq".NEXTVAL COMMENT '', "column4" VARCHAR NOT NULL DEFAULT 'test default''s' COMMENT '', "column5" TIMESTAMP_NTZ NOT NULL DEFAULT CURRENT_TIMESTAMP() COMMENT '' ,CONSTRAINT "MY_KEY" PRIMARY KEY("column1")) COMMENT = 'Test Comment' CLUSTER BY LINEAR(column1) DATA_RETENTION_TIME_IN_DAYS = 0 CHANGE_TRACKING = false`) s.WithDataRetentionTimeInDays(10) - r.Equal(s.Create(), `CREATE TABLE "test_db"."test_schema"."test_table" ("column1" OBJECT COMMENT '', "column2" VARCHAR COMMENT 'only populated when data is available' ,CONSTRAINT "MY_KEY" PRIMARY KEY("column1")) COMMENT = 'Test Comment' CLUSTER BY LINEAR(column1) DATA_RETENTION_TIME_IN_DAYS = 10 CHANGE_TRACKING = false`) + r.Equal(s.Create(), `CREATE TABLE "test_db"."test_schema"."test_table" ("column1" OBJECT COMMENT '', "column2" VARCHAR COMMENT 'only populated when data is available', "column3" NUMBER(38,0) NOT NULL DEFAULT "test_db"."test_schema"."test_seq".NEXTVAL COMMENT '', "column4" VARCHAR NOT NULL DEFAULT 'test default''s' COMMENT '', "column5" TIMESTAMP_NTZ NOT NULL DEFAULT CURRENT_TIMESTAMP() COMMENT '' ,CONSTRAINT "MY_KEY" PRIMARY KEY("column1")) COMMENT = 'Test Comment' CLUSTER BY LINEAR(column1) DATA_RETENTION_TIME_IN_DAYS = 10 CHANGE_TRACKING = false`) s.WithChangeTracking(true) - r.Equal(s.Create(), `CREATE TABLE "test_db"."test_schema"."test_table" ("column1" OBJECT COMMENT '', "column2" VARCHAR COMMENT 'only populated when data is available' ,CONSTRAINT "MY_KEY" PRIMARY KEY("column1")) COMMENT = 'Test Comment' CLUSTER BY LINEAR(column1) DATA_RETENTION_TIME_IN_DAYS = 10 CHANGE_TRACKING = true`) + r.Equal(s.Create(), `CREATE TABLE "test_db"."test_schema"."test_table" ("column1" OBJECT COMMENT '', "column2" VARCHAR COMMENT 'only populated when data is available', "column3" NUMBER(38,0) NOT NULL DEFAULT "test_db"."test_schema"."test_seq".NEXTVAL COMMENT '', "column4" VARCHAR NOT NULL DEFAULT 'test default''s' COMMENT '', "column5" TIMESTAMP_NTZ NOT NULL DEFAULT CURRENT_TIMESTAMP() COMMENT '' ,CONSTRAINT "MY_KEY" PRIMARY KEY("column1")) COMMENT = 'Test Comment' CLUSTER BY LINEAR(column1) DATA_RETENTION_TIME_IN_DAYS = 10 CHANGE_TRACKING = true`) } func TestTableChangeComment(t *testing.T) { @@ -59,13 +77,19 @@ func TestTableRemoveComment(t *testing.T) { func TestTableAddColumn(t *testing.T) { r := require.New(t) s := Table("test_table", "test_db", "test_schema") - r.Equal(s.AddColumn("new_column", "VARIANT", true, ""), `ALTER TABLE "test_db"."test_schema"."test_table" ADD COLUMN "new_column" VARIANT COMMENT ''`) + r.Equal(s.AddColumn("new_column", "VARIANT", true, nil, ""), `ALTER TABLE "test_db"."test_schema"."test_table" ADD COLUMN "new_column" VARIANT COMMENT ''`) } func TestTableAddColumnWithComment(t *testing.T) { r := require.New(t) s := Table("test_table", "test_db", "test_schema") - r.Equal(s.AddColumn("new_column", "VARIANT", true, "some comment"), `ALTER TABLE "test_db"."test_schema"."test_table" ADD COLUMN "new_column" VARIANT COMMENT 'some comment'`) + r.Equal(s.AddColumn("new_column", "VARIANT", true, nil, "some comment"), `ALTER TABLE "test_db"."test_schema"."test_table" ADD COLUMN "new_column" VARIANT COMMENT 'some comment'`) +} + +func TestTableAddColumnWithDefault(t *testing.T) { + r := require.New(t) + s := Table("test_table", "test_db", "test_schema") + r.Equal(s.AddColumn("new_column", "NUMBER(38,0)", true, NewColumnDefaultWithConstant("1"), ""), `ALTER TABLE "test_db"."test_schema"."test_table" ADD COLUMN "new_column" NUMBER(38,0) DEFAULT 1 COMMENT ''`) } func TestTableDropColumn(t *testing.T) { @@ -86,6 +110,12 @@ func TestTableChangeColumnComment(t *testing.T) { r.Equal(s.ChangeColumnComment("old_column", "some comment"), `ALTER TABLE "test_db"."test_schema"."test_table" MODIFY COLUMN "old_column" COMMENT 'some comment'`) } +func TestTableDropColumnDefault(t *testing.T) { + r := require.New(t) + s := Table("test_table", "test_db", "test_schema") + r.Equal(s.DropColumnDefault("old_column"), `ALTER TABLE "test_db"."test_schema"."test_table" MODIFY COLUMN "old_column" DROP DEFAULT`) +} + func TestTableChangeClusterBy(t *testing.T) { r := require.New(t) s := Table("test_table", "test_db", "test_schema")