Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: supports collation of table column #2496

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions pkg/resources/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ var tableSchema = map[string]*schema.Schema{
Default: "",
Description: "Masking policy to apply on column. It has to be a fully qualified name.",
},
"collate": {
sfc-gh-asawicki marked this conversation as resolved.
Show resolved Hide resolved
Type: schema.TypeString,
Optional: true,
Default: "",
Description: "Column collation, e.g. utf8",
},
},
},
},
Expand Down Expand Up @@ -255,6 +261,7 @@ type column struct {
identity *columnIdentity
comment string
maskingPolicy string
collate string
}

type columns []column
Expand All @@ -268,13 +275,14 @@ type changedColumn struct {
dropedDefault bool
changedComment bool
changedMaskingPolicy bool
changedCollate bool
}

func (c columns) getChangedColumnProperties(new columns) (changed changedColumns) {
changed = changedColumns{}
for _, cO := range c {
for _, cN := range new {
changeColumn := changedColumn{cN, false, false, false, false, false}
changeColumn := changedColumn{cN, false, false, false, false, false, false}
if cO.name == cN.name && cO.dataType != cN.dataType {
changeColumn.changedDataType = true
}
Expand All @@ -293,6 +301,10 @@ func (c columns) getChangedColumnProperties(new columns) (changed changedColumns
changeColumn.changedMaskingPolicy = true
}

if cO.name == cN.name && cO.collate != cN.collate {
changeColumn.changedCollate = true
}

changed = append(changed, changeColumn)
}
}
Expand Down Expand Up @@ -363,6 +375,7 @@ func getColumn(from interface{}) (to column) {
_default: cd,
identity: id,
comment: c["comment"].(string),
collate: c["collate"].(string),
maskingPolicy: c["masking_policy"].(string),
}
}
Expand All @@ -388,7 +401,7 @@ func getTableColumnRequest(from interface{}) *sdk.TableColumnRequest {
if len(_default) == 1 {
if c, ok := _default[0].(map[string]interface{})["constant"]; ok {
if constant, ok := c.(string); ok && len(constant) > 0 {
if strings.Contains(_type, "CHAR") || _type == "STRING" || _type == "TEXT" {
if sdk.IsStringType(_type) {
expression = snowflake.EscapeSnowflakeString(constant)
} else {
expression = constant
Expand Down Expand Up @@ -423,6 +436,10 @@ func getTableColumnRequest(from interface{}) *sdk.TableColumnRequest {
request.WithMaskingPolicy(sdk.NewColumnMaskingPolicyRequest(sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(maskingPolicy)))
}

if sdk.IsStringType(_type) {
request.WithCollate(sdk.String(c["collate"].(string)))
}

return request.
WithNotNull(sdk.Bool(!c["nullable"].(bool))).
WithComment(sdk.String(c["comment"].(string)))
Expand Down Expand Up @@ -470,6 +487,10 @@ func toColumnConfig(descriptions []sdk.TableColumnDetails) []any {
flat["comment"] = *td.Comment
}

if td.Collation != nil {
flat["collate"] = *td.Collation
}

if td.PolicyName != nil {
// TODO [SNOW-867240]: SHOW TABLE returns last part of id without double quotes... we have to quote it again. Move it to SDK.
flat["masking_policy"] = sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(*td.PolicyName).FullyQualifiedName()
Expand Down Expand Up @@ -508,8 +529,7 @@ func toColumnDefaultConfig(td sdk.TableColumnDetails) map[string]any {
return def
}

columnType := strings.ToUpper(string(td.Type))
if strings.Contains(columnType, "CHAR") || columnType == "STRING" || columnType == "TEXT" {
if sdk.IsStringType(string(td.Type)) {
def["constant"] = snowflake.UnescapeSnowflakeString(defaultRaw)
return def
}
Expand Down Expand Up @@ -766,7 +786,7 @@ func UpdateTable(d *schema.ResourceData, meta interface{}) error {
return fmt.Errorf("failed to add column %v => Only adding a column as a constant is supported by Snowflake", cA.name)
}
var expression string
if strings.Contains(cA.dataType, "CHAR") || cA.dataType == "STRING" || cA.dataType == "TEXT" {
if sdk.IsStringType(cA.dataType) {
expression = snowflake.EscapeSnowflakeString(*cA._default.constant)
} else {
expression = *cA._default.constant
Expand All @@ -786,14 +806,18 @@ func UpdateTable(d *schema.ResourceData, meta interface{}) error {
addRequest.WithComment(sdk.String(cA.comment))
}

if cA.collate != "" && sdk.IsStringType(cA.dataType) {
addRequest.WithCollate(sdk.String(cA.collate))
}

err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(sdk.NewTableColumnActionRequest().WithAdd(addRequest)))
if err != nil {
return fmt.Errorf("error adding column: %w", err)
}
}
for _, cA := range changed {
if cA.changedDataType {
err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(sdk.NewTableColumnActionRequest().WithAlter([]sdk.TableColumnAlterActionRequest{*sdk.NewTableColumnAlterActionRequest(fmt.Sprintf("\"%s\"", cA.newColumn.name)).WithType(sdk.Pointer(sdk.DataType(cA.newColumn.dataType)))})))
if cA.changedDataType || cA.changedCollate {
err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(sdk.NewTableColumnActionRequest().WithAlter([]sdk.TableColumnAlterActionRequest{*sdk.NewTableColumnAlterActionRequest(fmt.Sprintf("\"%s\"", cA.newColumn.name)).WithType(sdk.Pointer(sdk.DataType(cA.newColumn.dataType))).WithCollate(sdk.String(cA.newColumn.collate))})))
sfc-gh-jcieslak marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return fmt.Errorf("error changing property on %v: err %w", d.Id(), err)
}
Expand Down
160 changes: 160 additions & 0 deletions pkg/resources/table_acceptance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"regexp"
"strings"
"testing"

Expand Down Expand Up @@ -1228,6 +1229,165 @@ resource "snowflake_table" "test_table" {
return fmt.Sprintf(s, name, databaseName, schemaName, name, databaseName, schemaName)
}

func TestAcc_TableCollate(t *testing.T) {
accName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))

resource.Test(t, resource.TestCase{
ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories,
PreCheck: func() { acc.TestAccPreCheck(t) },
TerraformVersionChecks: []tfversion.TerraformVersionCheck{
tfversion.RequireAbove(tfversion.Version1_5_0),
},
CheckDestroy: testAccCheckTableDestroy,
Steps: []resource.TestStep{
{
Config: tableColumnWithCollate(accName, acc.TestDatabaseName, acc.TestSchemaName),
Check: resource.ComposeTestCheckFunc(
resource.TestCheckResourceAttr("snowflake_table.test_table", "name", accName),
resource.TestCheckResourceAttr("snowflake_table.test_table", "database", acc.TestDatabaseName),
resource.TestCheckResourceAttr("snowflake_table.test_table", "schema", acc.TestSchemaName),
resource.TestCheckResourceAttr("snowflake_table.test_table", "comment", "Terraform acceptance test"),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.#", "3"),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.0.name", "column1"),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.0.collate", "en"),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.1.name", "column2"),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.1.collate", ""),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.name", "column3"),
sfc-gh-jcieslak marked this conversation as resolved.
Show resolved Hide resolved
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.collate", ""),
),
},
{
Config: alterTableColumnWithCollate(accName, acc.TestDatabaseName, acc.TestSchemaName),
Check: resource.ComposeTestCheckFunc(
resource.TestCheckResourceAttr("snowflake_table.test_table", "name", accName),
resource.TestCheckResourceAttr("snowflake_table.test_table", "database", acc.TestDatabaseName),
resource.TestCheckResourceAttr("snowflake_table.test_table", "schema", acc.TestSchemaName),
resource.TestCheckResourceAttr("snowflake_table.test_table", "comment", "Terraform acceptance test"),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.#", "4"),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.0.name", "column1"),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.0.collate", "en"),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.1.name", "column2"),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.1.collate", ""),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.name", "column3"),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.collate", ""),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.3.name", "column4"),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.3.collate", "utf8"),
),
},
{
Config: alterTableColumnWithIncompatibleCollate(accName, acc.TestDatabaseName, acc.TestSchemaName),
ExpectError: regexp.MustCompile("\"VARCHAR\\(200\\) COLLATE 'fr'\" because they have incompatible collations\\."),
Check: resource.ComposeTestCheckFunc(
resource.TestCheckResourceAttr("snowflake_table.test_table", "name", accName),
resource.TestCheckResourceAttr("snowflake_table.test_table", "database", acc.TestDatabaseName),
resource.TestCheckResourceAttr("snowflake_table.test_table", "schema", acc.TestSchemaName),
resource.TestCheckResourceAttr("snowflake_table.test_table", "comment", "Terraform acceptance test"),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.#", "4"),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.0.name", "column1"),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.0.collate", "en"),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.1.name", "column2"),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.1.collate", ""),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.name", "column3"),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.collate", ""),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.3.name", "column4"),
resource.TestCheckResourceAttr("snowflake_table.test_table", "column.3.collate", "utf8"),
),
},
},
})
}

func tableColumnWithCollate(name string, databaseName string, schemaName string) string {
s := `
resource "snowflake_table" "test_table" {
name = "%s"
database = "%s"
schema = "%s"
comment = "Terraform acceptance test"

column {
name = "column1"
type = "VARCHAR(100)"
collate = "en"
}
column {
name = "column2"
type = "VARCHAR(100)"
collate = ""
}
column {
name = "column3"
type = "VARCHAR(100)"
}
}
`
return fmt.Sprintf(s, name, databaseName, schemaName)
}

func alterTableColumnWithCollate(name string, databaseName string, schemaName string) string {
s := `
resource "snowflake_table" "test_table" {
name = "%s"
database = "%s"
schema = "%s"
comment = "Terraform acceptance test"

column {
name = "column1"
type = "VARCHAR(200)"
collate = "en"
}
column {
name = "column2"
type = "VARCHAR(200)"
collate = ""
}
column {
name = "column3"
type = "VARCHAR(200)"
}
column {
name = "column4"
type = "VARCHAR"
collate = "utf8"
}
}
`
return fmt.Sprintf(s, name, databaseName, schemaName)
}

func alterTableColumnWithIncompatibleCollate(name string, databaseName string, schemaName string) string {
s := `
resource "snowflake_table" "test_table" {
name = "%s"
database = "%s"
schema = "%s"
comment = "Terraform acceptance test"

column {
name = "column1"
type = "VARCHAR(200)"
collate = "fr"
}
column {
name = "column2"
type = "VARCHAR(200)"
collate = ""
}
column {
name = "column3"
type = "VARCHAR(200)"
}
column {
name = "column4"
type = "VARCHAR"
collate = "utf8"
}
}
`
return fmt.Sprintf(s, name, databaseName, schemaName)
}

func TestAcc_TableRename(t *testing.T) {
oldTableName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))
newTableName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))
Expand Down
10 changes: 10 additions & 0 deletions pkg/sdk/data_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,13 @@ func ToDataType(s string) (DataType, error) {

return "", fmt.Errorf("invalid data type: %s", s)
}

func IsStringType(_type string) bool {
t := strings.ToUpper(_type)
return strings.HasPrefix(t, "STRING") ||
strings.HasPrefix(t, "VARCHAR") ||
strings.HasPrefix(t, "CHAR") ||
strings.HasPrefix(t, "TEXT") ||
strings.HasPrefix(t, "NVARCHAR") ||
strings.HasPrefix(t, "NCHAR")
}
49 changes: 49 additions & 0 deletions pkg/sdk/data_types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,52 @@ func TestToDataType(t *testing.T) {
})
}
}

func TestIsStringType(t *testing.T) {
type test struct {
input string
want bool
}

tests := []test{
// case insensitive.
{input: "STRING", want: true},
{input: "string", want: true},
{input: "String", want: true},

// varchar types.
{input: "VARCHAR", want: true},
{input: "NVARCHAR", want: true},
{input: "NVARCHAR2", want: true},
{input: "CHAR", want: true},
{input: "NCHAR", want: true},
{input: "CHAR VARYING", want: true},
{input: "NCHAR VARYING", want: true},
{input: "TEXT", want: true},

// with length
{input: "VARCHAR(100)", want: true},
{input: "NVARCHAR(100)", want: true},
{input: "NVARCHAR2(100)", want: true},
{input: "CHAR(100)", want: true},
{input: "NCHAR(100)", want: true},
{input: "CHAR VARYING(100)", want: true},
{input: "NCHAR VARYING(100)", want: true},
{input: "TEXT(100)", want: true},

// binary is not string types.
{input: "binary", want: false},
{input: "varbinary", want: false},

// other types
{input: "boolean", want: false},
{input: "number", want: false},
}

for _, tc := range tests {
t.Run(tc.input, func(t *testing.T) {
got := IsStringType(tc.input)
require.Equal(t, tc.want, got)
})
}
}
Loading
Loading