diff --git a/pkg/resources/stream_acceptance_test.go b/pkg/resources/stream_acceptance_test.go index c9b05c1b14f..1f1d9085408 100644 --- a/pkg/resources/stream_acceptance_test.go +++ b/pkg/resources/stream_acceptance_test.go @@ -53,7 +53,7 @@ resource "snowflake_table" "test_stream_on_table" { } column { name = "column2" - type = "VARCHAR" + type = "VARCHAR(16777216)" } } diff --git a/pkg/resources/table.go b/pkg/resources/table.go index fe266e7910b..7fd92d8986d 100644 --- a/pkg/resources/table.go +++ b/pkg/resources/table.go @@ -138,12 +138,12 @@ func CreateTable(data *schema.ResourceData, meta interface{}) error { // This type conversion is due to the test framework in the terraform-plugin-sdk having limited support // for data types in the HCL2ValueFromConfigValue method. - columns := []map[string]string{} + columns := []snowflake.Column{} + for _, column := range data.Get("column").([]interface{}) { - columnDef := map[string]string{} - for key, val := range column.(map[string]interface{}) { - columnDef[key] = val.(string) - } + typed := column.(map[string]interface{}) + columnDef := snowflake.Column{} + columnDef.WithName(typed["name"].(string)).WithType(typed["type"].(string)) columns = append(columns, columnDef) } builder := snowflake.TableWithColumnDefinitions(name, database, schema, columns) @@ -180,28 +180,47 @@ func ReadTable(data *schema.ResourceData, meta interface{}) error { if err != nil { return err } + builder := snowflake.Table(tableID.TableName, tableID.DatabaseName, tableID.SchemaName) - dbName := tableID.DatabaseName - schema := tableID.SchemaName - name := tableID.TableName - - stmt := snowflake.Table(name, dbName, schema).Show() - row := snowflake.QueryRow(db, stmt) + row := snowflake.QueryRow(db, builder.Show()) table, err := snowflake.ScanTable(row) + // No rows then no table. Delete from state and end read + if err == sql.ErrNoRows { + data.SetId("") + return nil + } + // Check for other errors if err != nil { return err } - err = data.Set("name", table.TableName.String) + // Describe the table to read the cols + tableDescriptionRows, err := snowflake.Query(db, builder.ShowColumns()) if err != nil { return err } - err = data.Set("owner", table.Owner.String) + tableDescription, err := snowflake.ScanTableDescription(tableDescriptionRows) if err != nil { return err } + // Set the relevant data in the state + toSet := map[string]interface{}{ + "name": table.TableName.String, + "owner": table.Owner.String, + "database": tableID.DatabaseName, + "schema": tableID.SchemaName, + "comment": table.Comment.String, + "column": snowflake.NewColumns(tableDescription).Flatten(), + } + + for key, val := range toSet { + err = data.Set(key, val) + if err != nil { + return err + } + } return nil } diff --git a/pkg/resources/table_acceptance_test.go b/pkg/resources/table_acceptance_test.go index 6c85a3f4fea..f52c1d48406 100644 --- a/pkg/resources/table_acceptance_test.go +++ b/pkg/resources/table_acceptance_test.go @@ -51,7 +51,7 @@ resource "snowflake_table" "test_table" { } column { name = "column2" - type = "VARCHAR" + type = "VARCHAR(16777216)" } } ` diff --git a/pkg/resources/table_test.go b/pkg/resources/table_test.go index 880349f42aa..c8b31470e1b 100644 --- a/pkg/resources/table_test.go +++ b/pkg/resources/table_test.go @@ -41,6 +41,12 @@ func TestTableCreate(t *testing.T) { func expectTableRead(mock sqlmock.Sqlmock) { rows := sqlmock.NewRows([]string{"name", "type", "kind", "null?", "default", "primary key", "unique key", "check", "expression", "comment"}).AddRow("good_name", "VARCHAR()", "COLUMN", "Y", "NULL", "NULL", "N", "N", "NULL", "mock comment") mock.ExpectQuery(`SHOW TABLES LIKE 'good_name' IN SCHEMA "database_name"."schema_name"`).WillReturnRows(rows) + + describeRows := sqlmock.NewRows([]string{"name", "type", "kind"}). + AddRow("column1", "OBJECT", "COLUMN"). + AddRow("column2", "VARCHAR", "COLUMN") + + mock.ExpectQuery(`DESC TABLE "database_name"."schema_name"."good_name"`).WillReturnRows(describeRows) } func TestTableRead(t *testing.T) { diff --git a/pkg/snowflake/table.go b/pkg/snowflake/table.go index 15f0ae3e9a7..a9da0c744ec 100644 --- a/pkg/snowflake/table.go +++ b/pkg/snowflake/table.go @@ -8,12 +8,73 @@ import ( "github.com/jmoiron/sqlx" ) +type Column struct { + name string + _type string // type is reserved +} + +func (c *Column) WithName(name string) *Column { + c.name = name + return c +} +func (c *Column) WithType(t string) *Column { + c._type = t + return c +} + +func (c *Column) getColumnDefinition() string { + if c == nil { + return "" + } + return fmt.Sprintf(`"%v" %v`, EscapeString(c.name), EscapeString(c._type)) +} + +type Columns []Column + +// NewColumns generates columns from a table description +func NewColumns(tds []tableDescription) Columns { + cs := []Column{} + for _, td := range tds { + if td.Kind.String != "COLUMN" { + continue + } + cs = append(cs, Column{ + name: td.Name.String, + _type: td.Type.String, + }) + } + return Columns(cs) +} + +func (c Columns) Flatten() []interface{} { + flattened := []interface{}{} + for _, col := range c { + flat := map[string]interface{}{} + flat["name"] = col.name + flat["type"] = col._type + + flattened = append(flattened, flat) + } + return flattened +} + +func (c Columns) getColumnDefinitions() string { + // TODO(el): verify Snowflake reflects column order back in desc table calls + columnDefinitions := []string{} + for _, column := range c { + columnDefinitions = append(columnDefinitions, column.getColumnDefinition()) + } + + // NOTE: intentionally blank leading space + return fmt.Sprintf(" (%s)", strings.Join(columnDefinitions, ", ")) +} + // TableBuilder abstracts the creation of SQL queries for a Snowflake schema type TableBuilder struct { name string db string schema string - columns []map[string]string + columns Columns comment string } @@ -45,7 +106,7 @@ func (tb *TableBuilder) WithComment(c string) *TableBuilder { } // WithColumns sets the column definitions on the TableBuilder -func (tb *TableBuilder) WithColumns(c []map[string]string) *TableBuilder { +func (tb *TableBuilder) WithColumns(c Columns) *TableBuilder { tb.columns = c return tb } @@ -72,7 +133,7 @@ func Table(name, db, schema string) *TableBuilder { // - CREATE TABLE // // [Snowflake Reference](https://docs.snowflake.com/en/sql-reference/ddl-table.html) -func TableWithColumnDefinitions(name, db, schema string, columns []map[string]string) *TableBuilder { +func TableWithColumnDefinitions(name, db, schema string, columns Columns) *TableBuilder { return &TableBuilder{ name: name, db: db, @@ -85,14 +146,7 @@ func TableWithColumnDefinitions(name, db, schema string, columns []map[string]st func (tb *TableBuilder) Create() string { q := strings.Builder{} q.WriteString(fmt.Sprintf(`CREATE TABLE %v`, tb.QualifiedName())) - - q.WriteString(fmt.Sprintf(` (`)) - columnDefinitions := []string{} - for _, columnDefinition := range tb.columns { - columnDefinitions = append(columnDefinitions, fmt.Sprintf(`"%v" %v`, EscapeString(columnDefinition["name"]), EscapeString(columnDefinition["type"]))) - } - q.WriteString(strings.Join(columnDefinitions, ", ")) - q.WriteString(fmt.Sprintf(`)`)) + q.WriteString(tb.columns.getColumnDefinitions()) if tb.comment != "" { q.WriteString(fmt.Sprintf(` COMMENT = '%v'`, EscapeString(tb.comment))) @@ -121,6 +175,10 @@ func (tb *TableBuilder) Show() string { return fmt.Sprintf(`SHOW TABLES LIKE '%v' IN SCHEMA "%v"."%v"`, tb.name, tb.db, tb.schema) } +func (tb *TableBuilder) ShowColumns() string { + return fmt.Sprintf(`DESC TABLE %s`, tb.QualifiedName()) +} + type table struct { CreatedOn sql.NullString `db:"created_on"` TableName sql.NullString `db:"name"` @@ -142,3 +200,22 @@ func ScanTable(row *sqlx.Row) (*table, error) { e := row.StructScan(t) return t, e } + +type tableDescription struct { + Name sql.NullString `db:"name"` + Type sql.NullString `db:"type"` + Kind sql.NullString `db:"kind"` +} + +func ScanTableDescription(rows *sqlx.Rows) ([]tableDescription, error) { + tds := []tableDescription{} + for rows.Next() { + td := tableDescription{} + err := rows.StructScan(&td) + if err != nil { + return nil, err + } + tds = append(tds, td) + } + return tds, rows.Err() +} diff --git a/pkg/snowflake/table_test.go b/pkg/snowflake/table_test.go index 720248fb39f..010562e615e 100644 --- a/pkg/snowflake/table_test.go +++ b/pkg/snowflake/table_test.go @@ -9,7 +9,18 @@ import ( func TestTableCreate(t *testing.T) { r := require.New(t) s := Table("test_table", "test_db", "test_schema") - s.WithColumns([]map[string]string{{"name": "column1", "type": "OBJECT"}, {"name": "column2", "type": "VARCHAR"}}) + cols := []Column{ + { + name: "column1", + _type: "OBJECT", + }, + { + name: "column2", + _type: "VARCHAR", + }, + } + + s.WithColumns(Columns(cols)) r.Equal(s.QualifiedName(), `"test_db"."test_schema"."test_table"`) r.Equal(`CREATE TABLE "test_db"."test_schema"."test_table" ("column1" OBJECT, "column2" VARCHAR)`, s.Create())