From 3c63d6a12186373f4ef382c629ba733f10ba0576 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Cie=C5=9Blak?= Date: Thu, 12 Dec 2024 13:17:42 +0100 Subject: [PATCH 1/2] support table data type --- .../objectassert/function_snowflake_ext.go | 11 ++ .../objectassert/procedure_snowflake_ext.go | 11 ++ pkg/sdk/datatypes/data_types.go | 5 + pkg/sdk/datatypes/data_types_test.go | 90 +++++++++++++++ pkg/sdk/datatypes/legacy.go | 1 + pkg/sdk/datatypes/table.go | 107 ++++++++++++++++++ pkg/sdk/datatypes/vector.go | 1 + pkg/sdk/testint/functions_integration_test.go | 30 +++++ .../testint/procedures_integration_test.go | 25 ++-- 9 files changed, 269 insertions(+), 12 deletions(-) create mode 100644 pkg/sdk/datatypes/table.go diff --git a/pkg/acceptance/bettertestspoc/assert/objectassert/function_snowflake_ext.go b/pkg/acceptance/bettertestspoc/assert/objectassert/function_snowflake_ext.go index aa8d17a022..8836ff49d5 100644 --- a/pkg/acceptance/bettertestspoc/assert/objectassert/function_snowflake_ext.go +++ b/pkg/acceptance/bettertestspoc/assert/objectassert/function_snowflake_ext.go @@ -65,3 +65,14 @@ func (f *FunctionAssert) HasExactlySecrets(expectedSecrets map[string]sdk.Schema }) return f } + +func (f *FunctionAssert) HasArgumentsRawContains(substring string) *FunctionAssert { + f.AddAssertion(func(t *testing.T, o *sdk.Function) error { + t.Helper() + if !strings.Contains(o.ArgumentsRaw, substring) { + return fmt.Errorf("expected arguments raw contain: %v, to contain: %v", o.ArgumentsRaw, substring) + } + return nil + }) + return f +} diff --git a/pkg/acceptance/bettertestspoc/assert/objectassert/procedure_snowflake_ext.go b/pkg/acceptance/bettertestspoc/assert/objectassert/procedure_snowflake_ext.go index 12d5a384cf..4ce244f856 100644 --- a/pkg/acceptance/bettertestspoc/assert/objectassert/procedure_snowflake_ext.go +++ b/pkg/acceptance/bettertestspoc/assert/objectassert/procedure_snowflake_ext.go @@ -57,3 +57,14 @@ func (f *ProcedureAssert) HasExactlyExternalAccessIntegrations(integrations ...s }) return f } + +func (p *ProcedureAssert) HasArgumentsRawContains(substring string) *ProcedureAssert { + p.AddAssertion(func(t *testing.T, o *sdk.Procedure) error { + t.Helper() + if !strings.Contains(o.ArgumentsRaw, substring) { + return fmt.Errorf("expected arguments raw contain: %v, to contain: %v", o.ArgumentsRaw, substring) + } + return nil + }) + return p +} diff --git a/pkg/sdk/datatypes/data_types.go b/pkg/sdk/datatypes/data_types.go index be58f978f2..2371770a94 100644 --- a/pkg/sdk/datatypes/data_types.go +++ b/pkg/sdk/datatypes/data_types.go @@ -80,6 +80,9 @@ func ParseDataType(raw string) (DataType, error) { if idx := slices.IndexFunc(VectorDataTypeSynonyms, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 { return parseVectorDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, VectorDataTypeSynonyms[idx]}) } + if idx := slices.IndexFunc(TableDataTypeSynonyms, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 { + return parseTableDataTypeRaw(sanitizedDataTypeRaw{strings.TrimSpace(raw), TableDataTypeSynonyms[idx]}) + } return nil, fmt.Errorf("invalid data type: %s", raw) } @@ -118,6 +121,8 @@ func AreTheSame(a DataType, b DataType) bool { return castSuccessfully(v, b, areNumberDataTypesTheSame) case *ObjectDataType: return castSuccessfully(v, b, noArgsDataTypesAreTheSame) + case *TableDataType: + return castSuccessfully(v, b, areTableDataTypesTheSame) case *TextDataType: return castSuccessfully(v, b, areTextDataTypesTheSame) case *TimeDataType: diff --git a/pkg/sdk/datatypes/data_types_test.go b/pkg/sdk/datatypes/data_types_test.go index cfb3845ef1..33004c6997 100644 --- a/pkg/sdk/datatypes/data_types_test.go +++ b/pkg/sdk/datatypes/data_types_test.go @@ -6,6 +6,8 @@ import ( "strings" "testing" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -1095,6 +1097,88 @@ func Test_ParseDataType_Vector(t *testing.T) { } } +func Test_ParseDataType_Table(t *testing.T) { + type column struct { + name string + legacyType string + } + type test struct { + input string + expectedColumns []column + } + + positiveTestCases := []test{ + {input: "TABLE()", expectedColumns: []column{}}, + {input: "TABLE ()", expectedColumns: []column{}}, + {input: "TABLE(arg_name NUMBER)", expectedColumns: []column{{"arg_name", NumberLegacyDataType}}}, + {input: "TABLE(arg_name number, second float, third GEOGRAPHY)", expectedColumns: []column{{"arg_name", NumberLegacyDataType}, {"second", FloatLegacyDataType}, {"third", GeographyLegacyDataType}}}, + {input: "TABLE ( arg_name varchar, second date, third TIME )", expectedColumns: []column{{"arg_name", VarcharLegacyDataType}, {"second", DateLegacyDataType}, {"third", TimeLegacyDataType}}}, + // TODO: Support types with parameters (for now, only legacy types are supported because Snowflake returns only with this output), e.g. TABLE(ARG NUMBER(38, 0)) + // TODO: Support nested tables, e.g. TABLE(ARG NUMBER, NESTED TABLE(A VARCHAR, B GEOMETRY)) + } + + negativeTestCases := []test{ + {input: "TABLE(1, 2)"}, + {input: "TABLE(INT, INT)"}, + {input: "TABLE(a b)"}, + {input: "TABLE(1)"}, + {input: "TABLE(2, INT)"}, + {input: "TABLE"}, + {input: "TABLE(INT, 2, 3)"}, + {input: "TABLE(INT)"}, + {input: "TABLE(x, 2)"}, + {input: "TABLE("}, + {input: "TABLE)"}, + {input: "TA BLE"}, + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &TableDataType{}, parsed) + + assert.Equal(t, "TABLE", parsed.(*TableDataType).underlyingType) + assert.Equal(t, len(tc.expectedColumns), len(parsed.(*TableDataType).columns)) + for i, column := range tc.expectedColumns { + assert.Equal(t, column.name, parsed.(*TableDataType).columns[i].name) + assert.Equal(t, column.legacyType, parsed.(*TableDataType).columns[i].dataType.ToLegacyDataTypeSql()) + } + + legacyColumns := strings.Join(collections.Map(tc.expectedColumns, func(col column) string { + return fmt.Sprintf("%s %s", col.name, col.legacyType) + }), ", ") + assert.Equal(t, fmt.Sprintf("TABLE(%s)", legacyColumns), parsed.ToLegacyDataTypeSql()) + + canonicalColumns := strings.Join(collections.Map(tc.expectedColumns, func(col column) string { + parsedType, err := ParseDataType(col.legacyType) + require.NoError(t, err) + return fmt.Sprintf("%s %s", col.name, parsedType.Canonical()) + }), ", ") + assert.Equal(t, fmt.Sprintf("TABLE(%s)", canonicalColumns), parsed.Canonical()) + + columns := strings.Join(collections.Map(tc.expectedColumns, func(col column) string { + parsedType, err := ParseDataType(col.legacyType) + require.NoError(t, err) + return fmt.Sprintf("%s %s", col.name, parsedType.ToSql()) + }), ", ") + assert.Equal(t, fmt.Sprintf("TABLE(%s)", columns), parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run("negative: "+tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + func Test_AreTheSame(t *testing.T) { // empty d1/d2 means nil DataType input type test struct { @@ -1145,6 +1229,12 @@ func Test_AreTheSame(t *testing.T) { {d1: "TIME", d2: "TIME", expectedOutcome: true}, {d1: "TIME", d2: "TIME(5)", expectedOutcome: false}, {d1: "TIME", d2: fmt.Sprintf("TIME(%d)", DefaultTimePrecision), expectedOutcome: true}, + {d1: "TABLE()", d2: "TABLE()", expectedOutcome: true}, + {d1: "TABLE(A NUMBER)", d2: "TABLE(B NUMBER)", expectedOutcome: false}, + {d1: "TABLE(A NUMBER)", d2: "TABLE(a NUMBER)", expectedOutcome: false}, + {d1: "TABLE(A NUMBER)", d2: "TABLE(A VARCHAR)", expectedOutcome: false}, + {d1: "TABLE(A NUMBER, B VARCHAR)", d2: "TABLE(A NUMBER, B VARCHAR)", expectedOutcome: true}, + {d1: "TABLE(A NUMBER, B NUMBER)", d2: "TABLE(A NUMBER, B VARCHAR)", expectedOutcome: false}, } for _, tc := range testCases { diff --git a/pkg/sdk/datatypes/legacy.go b/pkg/sdk/datatypes/legacy.go index 5a0e249cd7..b8bd63040f 100644 --- a/pkg/sdk/datatypes/legacy.go +++ b/pkg/sdk/datatypes/legacy.go @@ -16,4 +16,5 @@ const ( TimestampNtzLegacyDataType = "TIMESTAMP_NTZ" TimestampTzLegacyDataType = "TIMESTAMP_TZ" VariantLegacyDataType = "VARIANT" + TableLegacyDataType = "TABLE" ) diff --git a/pkg/sdk/datatypes/table.go b/pkg/sdk/datatypes/table.go new file mode 100644 index 0000000000..dbe87307fe --- /dev/null +++ b/pkg/sdk/datatypes/table.go @@ -0,0 +1,107 @@ +package datatypes + +import ( + "fmt" + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" +) + +// TableDataType does not have synonyms. +// It consists of a list of column name + column type; may be empty. +type TableDataType struct { + columns []TableDataTypeColumn + underlyingType string +} + +type TableDataTypeColumn struct { + name string + dataType DataType +} + +func (c *TableDataTypeColumn) ColumnName() string { + return c.name +} + +func (c *TableDataTypeColumn) ColumnType() DataType { + return c.dataType +} + +func (t *TableDataType) ToSql() string { + columns := strings.Join(collections.Map(t.columns, func(col TableDataTypeColumn) string { + return fmt.Sprintf("%s %s", col.name, col.dataType.ToSql()) + }), ", ") + return fmt.Sprintf("%s (%s)", TableLegacyDataType, columns) +} + +func (t *TableDataType) ToLegacyDataTypeSql() string { + columns := strings.Join(collections.Map(t.columns, func(col TableDataTypeColumn) string { + return fmt.Sprintf("%s %s", col.name, col.dataType.ToLegacyDataTypeSql()) + }), ", ") + return fmt.Sprintf("%s (%s)", TableLegacyDataType, columns) +} + +func (t *TableDataType) Canonical() string { + columns := strings.Join(collections.Map(t.columns, func(col TableDataTypeColumn) string { + return fmt.Sprintf("%s %s", col.name, col.dataType.Canonical()) + }), ", ") + return fmt.Sprintf("%s (%s)", TableLegacyDataType, columns) +} + +func (t *TableDataType) Columns() []TableDataTypeColumn { + return t.columns +} + +func parseTableDataTypeRaw(raw sanitizedDataTypeRaw) (*TableDataType, error) { + r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType)) + if r == "()" { + return &TableDataType{ + columns: make([]TableDataTypeColumn, 0), + underlyingType: raw.matchedByType, + }, nil + } + if r == "" || (!strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")")) { + logging.DebugLogger.Printf(`table %s could not be parsed, use "%s(argName argType, ...)" format`, raw.raw, raw.matchedByType) + return nil, fmt.Errorf(`table %s could not be parsed, use "%s(argName argType, ...)" format`, raw.raw, raw.matchedByType) + } + onlyArgs := r[1 : len(r)-1] + columns, err := collections.MapErr(strings.Split(onlyArgs, ","), func(arg string) (TableDataTypeColumn, error) { + argParts := strings.Split(strings.TrimSpace(arg), " ") + if len(argParts) != 2 { + return TableDataTypeColumn{}, fmt.Errorf("could not parse table column: %s, it should contain the following format ` `; parser failure may be connected to the complex argument names", arg) + } + argDataType, err := ParseDataType(argParts[1]) + if err != nil { + return TableDataTypeColumn{}, err + } + return TableDataTypeColumn{ + name: argParts[0], + dataType: argDataType, + }, nil + }) + if err != nil { + return nil, err + } + return &TableDataType{ + columns: columns, + underlyingType: raw.matchedByType, + }, nil +} + +func areTableDataTypesTheSame(a, b *TableDataType) bool { + if len(a.columns) != len(b.columns) { + return false + } + + for i := range a.columns { + aColumn := a.columns[i] + bColumn := b.columns[i] + + if aColumn.name != bColumn.name || !AreTheSame(aColumn.dataType, bColumn.dataType) { + return false + } + } + + return true +} diff --git a/pkg/sdk/datatypes/vector.go b/pkg/sdk/datatypes/vector.go index 035249af64..d4fa9e9050 100644 --- a/pkg/sdk/datatypes/vector.go +++ b/pkg/sdk/datatypes/vector.go @@ -32,6 +32,7 @@ func (t *VectorDataType) Canonical() string { var ( VectorDataTypeSynonyms = []string{"VECTOR"} + TableDataTypeSynonyms = []string{"TABLE"} VectorAllowedInnerTypes = []string{"INT", "FLOAT"} ) diff --git a/pkg/sdk/testint/functions_integration_test.go b/pkg/sdk/testint/functions_integration_test.go index bb292cd627..4c9794ad72 100644 --- a/pkg/sdk/testint/functions_integration_test.go +++ b/pkg/sdk/testint/functions_integration_test.go @@ -1820,4 +1820,34 @@ func TestInt_Functions(t *testing.T) { assert.Equal(t, dataType.Canonical(), pairs["returns"]) }) } + + t.Run("create function for SQL - return table data type", func(t *testing.T) { + argName := "x" + + returnDataType, err := datatypes.ParseDataType(fmt.Sprintf("TABLE(ID %s, PRICE %s, THIRD %s)", datatypes.NumberLegacyDataType, datatypes.FloatLegacyDataType, datatypes.VarcharLegacyDataType)) + require.NoError(t, err) + + id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(datatypes.VarcharLegacyDataType) + + definition := ` SELECT 1, 2.2::float, 'abc');` + dt := sdk.NewFunctionReturnsResultDataTypeRequest(returnDataType) + returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) + argument := sdk.NewFunctionArgumentRequest(argName, nil).WithArgDataTypeOld(datatypes.VarcharLegacyDataType) + request := sdk.NewCreateForSQLFunctionRequestDefinitionWrapped(id.SchemaObjectId(), *returns, definition). + WithArguments([]sdk.FunctionArgumentRequest{*argument}) + + err = client.Functions.CreateForSQL(ctx, request) + require.NoError(t, err) + t.Cleanup(testClientHelper().Function.DropFunctionFunc(t, id)) + + function, err := client.Functions.ShowByID(ctx, id) + require.NoError(t, err) + + assertions.AssertThatObject(t, objectassert.FunctionFromObject(t, function). + HasCreatedOnNotEmpty(). + HasName(id.Name()). + HasSchemaName(id.SchemaName()). + HasArgumentsRawContains(fmt.Sprintf(`RETURN %s`, returnDataType.ToLegacyDataTypeSql())), + ) + }) } diff --git a/pkg/sdk/testint/procedures_integration_test.go b/pkg/sdk/testint/procedures_integration_test.go index 2a69ef42c2..d7b139dded 100644 --- a/pkg/sdk/testint/procedures_integration_test.go +++ b/pkg/sdk/testint/procedures_integration_test.go @@ -1483,22 +1483,21 @@ def filter_by_role(session, table_name, role): require.GreaterOrEqual(t, len(procedures), 1) }) - // TODO [SNOW-1348103]: adjust or remove t.Run("create procedure for SQL: returns table", func(t *testing.T) { - t.Skipf("Skipped for now; left as inspiration for resource rework as part of SNOW-1348103") - - name := "find_invoice_by_id" - id := testClientHelper().Ids.NewSchemaObjectIdentifierWithArguments(name, sdk.DataTypeVARCHAR) + id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(sdk.DataTypeVARCHAR) + column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld("INTEGER") + column2 := sdk.NewProcedureColumnRequest("price", nil).WithColumnDataTypeOld("double") + column3 := sdk.NewProcedureColumnRequest("third", nil).WithColumnDataTypeOld("Geometry") + returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2, *column3}) + expectedReturnDataType, err := datatypes.ParseDataType(fmt.Sprintf("TABLE(id %s, price %s, third %s)", datatypes.NumberLegacyDataType, datatypes.FloatLegacyDataType, datatypes.GeometryLegacyDataType)) + require.NoError(t, err) definition := ` DECLARE res RESULTSET DEFAULT (SELECT * FROM invoices WHERE id = :id); BEGIN RETURN TABLE(res); END;` - column1 := sdk.NewProcedureColumnRequest("id", nil).WithColumnDataTypeOld("INTEGER") - column2 := sdk.NewProcedureColumnRequest("price", nil).WithColumnDataTypeOld("NUMBER(12,2)") - returnsTable := sdk.NewProcedureReturnsTableRequest().WithColumns([]sdk.ProcedureColumnRequest{*column1, *column2}) returns := sdk.NewProcedureSQLReturnsRequest().WithTable(*returnsTable) argument := sdk.NewProcedureArgumentRequest("id", nil).WithArgDataTypeOld(sdk.DataTypeVARCHAR) request := sdk.NewCreateForSQLProcedureRequestDefinitionWrapped(id.SchemaObjectId(), *returns, definition). @@ -1506,13 +1505,15 @@ def filter_by_role(session, table_name, role): // SNOW-1051627 todo: uncomment once null input behavior working again // WithNullInputBehavior(sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorReturnsNullInput)). WithArguments([]sdk.ProcedureArgumentRequest{*argument}) - err := client.Procedures.CreateForSQL(ctx, request) + err = client.Procedures.CreateForSQL(ctx, request) require.NoError(t, err) t.Cleanup(testClientHelper().Procedure.DropProcedureFunc(t, id)) - procedures, err := client.Procedures.Show(ctx, sdk.NewShowProcedureRequest()) - require.NoError(t, err) - require.GreaterOrEqual(t, len(procedures), 1) + assertions.AssertThatObject(t, objectassert.Procedure(t, id). + HasCreatedOnNotEmpty(). + HasName(id.Name()). + HasSchemaName(id.SchemaName()). + HasArgumentsRawContains(fmt.Sprintf(`RETURN %s`, expectedReturnDataType.ToLegacyDataTypeSql()))) }) t.Run("show parameters", func(t *testing.T) { From 9526e410ea2341b992b949469c9258caacc4c947 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Cie=C5=9Blak?= Date: Thu, 12 Dec 2024 15:32:52 +0100 Subject: [PATCH 2/2] changes after review --- pkg/sdk/datatypes/data_types_test.go | 35 ++++++++++++------- pkg/sdk/datatypes/legacy.go | 3 +- pkg/sdk/datatypes/table.go | 25 +++++++------ pkg/sdk/datatypes/vector.go | 1 - pkg/sdk/testint/functions_integration_test.go | 8 +++-- .../testint/procedures_integration_test.go | 7 +++- 6 files changed, 51 insertions(+), 28 deletions(-) diff --git a/pkg/sdk/datatypes/data_types_test.go b/pkg/sdk/datatypes/data_types_test.go index 33004c6997..7e6382e63f 100644 --- a/pkg/sdk/datatypes/data_types_test.go +++ b/pkg/sdk/datatypes/data_types_test.go @@ -1099,8 +1099,8 @@ func Test_ParseDataType_Vector(t *testing.T) { func Test_ParseDataType_Table(t *testing.T) { type column struct { - name string - legacyType string + Name string + Type string } type test struct { input string @@ -1110,14 +1110,20 @@ func Test_ParseDataType_Table(t *testing.T) { positiveTestCases := []test{ {input: "TABLE()", expectedColumns: []column{}}, {input: "TABLE ()", expectedColumns: []column{}}, - {input: "TABLE(arg_name NUMBER)", expectedColumns: []column{{"arg_name", NumberLegacyDataType}}}, - {input: "TABLE(arg_name number, second float, third GEOGRAPHY)", expectedColumns: []column{{"arg_name", NumberLegacyDataType}, {"second", FloatLegacyDataType}, {"third", GeographyLegacyDataType}}}, - {input: "TABLE ( arg_name varchar, second date, third TIME )", expectedColumns: []column{{"arg_name", VarcharLegacyDataType}, {"second", DateLegacyDataType}, {"third", TimeLegacyDataType}}}, + {input: "TABLE ( )", expectedColumns: []column{}}, + {input: "TABLE(arg_name NUMBER)", expectedColumns: []column{{"arg_name", "NUMBER"}}}, + {input: "TABLE(arg_name double precision, arg_name_2 NUMBER)", expectedColumns: []column{{"arg_name", "double precision"}, {"arg_name_2", "NUMBER"}}}, + {input: "TABLE(arg_name NUMBER(38))", expectedColumns: []column{{"arg_name", "NUMBER(38)"}}}, + {input: "TABLE(arg_name NUMBER(38), arg_name_2 VARCHAR)", expectedColumns: []column{{"arg_name", "NUMBER(38)"}, {"arg_name_2", "VARCHAR"}}}, + {input: "TABLE(arg_name number, second float, third GEOGRAPHY)", expectedColumns: []column{{"arg_name", "number"}, {"second", "float"}, {"third", "GEOGRAPHY"}}}, + {input: "TABLE ( arg_name varchar, second date, third TIME )", expectedColumns: []column{{"arg_name", "varchar"}, {"second", "date"}, {"third", "time"}}}, // TODO: Support types with parameters (for now, only legacy types are supported because Snowflake returns only with this output), e.g. TABLE(ARG NUMBER(38, 0)) // TODO: Support nested tables, e.g. TABLE(ARG NUMBER, NESTED TABLE(A VARCHAR, B GEOMETRY)) + // TODO: Support complex argument names (with quotes / spaces / special characters / etc) } negativeTestCases := []test{ + {input: "TABLE())"}, {input: "TABLE(1, 2)"}, {input: "TABLE(INT, INT)"}, {input: "TABLE(a b)"}, @@ -1143,26 +1149,30 @@ func Test_ParseDataType_Table(t *testing.T) { assert.Equal(t, "TABLE", parsed.(*TableDataType).underlyingType) assert.Equal(t, len(tc.expectedColumns), len(parsed.(*TableDataType).columns)) for i, column := range tc.expectedColumns { - assert.Equal(t, column.name, parsed.(*TableDataType).columns[i].name) - assert.Equal(t, column.legacyType, parsed.(*TableDataType).columns[i].dataType.ToLegacyDataTypeSql()) + assert.Equal(t, column.Name, parsed.(*TableDataType).columns[i].name) + parsedType, err := ParseDataType(column.Type) + require.NoError(t, err) + assert.Equal(t, parsedType.ToLegacyDataTypeSql(), parsed.(*TableDataType).columns[i].dataType.ToLegacyDataTypeSql()) } legacyColumns := strings.Join(collections.Map(tc.expectedColumns, func(col column) string { - return fmt.Sprintf("%s %s", col.name, col.legacyType) + parsedType, err := ParseDataType(col.Type) + require.NoError(t, err) + return fmt.Sprintf("%s %s", col.Name, parsedType.ToLegacyDataTypeSql()) }), ", ") assert.Equal(t, fmt.Sprintf("TABLE(%s)", legacyColumns), parsed.ToLegacyDataTypeSql()) canonicalColumns := strings.Join(collections.Map(tc.expectedColumns, func(col column) string { - parsedType, err := ParseDataType(col.legacyType) + parsedType, err := ParseDataType(col.Type) require.NoError(t, err) - return fmt.Sprintf("%s %s", col.name, parsedType.Canonical()) + return fmt.Sprintf("%s %s", col.Name, parsedType.Canonical()) }), ", ") assert.Equal(t, fmt.Sprintf("TABLE(%s)", canonicalColumns), parsed.Canonical()) columns := strings.Join(collections.Map(tc.expectedColumns, func(col column) string { - parsedType, err := ParseDataType(col.legacyType) + parsedType, err := ParseDataType(col.Type) require.NoError(t, err) - return fmt.Sprintf("%s %s", col.name, parsedType.ToSql()) + return fmt.Sprintf("%s %s", col.Name, parsedType.ToSql()) }), ", ") assert.Equal(t, fmt.Sprintf("TABLE(%s)", columns), parsed.ToSql()) }) @@ -1235,6 +1245,7 @@ func Test_AreTheSame(t *testing.T) { {d1: "TABLE(A NUMBER)", d2: "TABLE(A VARCHAR)", expectedOutcome: false}, {d1: "TABLE(A NUMBER, B VARCHAR)", d2: "TABLE(A NUMBER, B VARCHAR)", expectedOutcome: true}, {d1: "TABLE(A NUMBER, B NUMBER)", d2: "TABLE(A NUMBER, B VARCHAR)", expectedOutcome: false}, + {d1: "TABLE()", d2: "TABLE(A NUMBER)", expectedOutcome: false}, } for _, tc := range testCases { diff --git a/pkg/sdk/datatypes/legacy.go b/pkg/sdk/datatypes/legacy.go index b8bd63040f..63f523779e 100644 --- a/pkg/sdk/datatypes/legacy.go +++ b/pkg/sdk/datatypes/legacy.go @@ -16,5 +16,6 @@ const ( TimestampNtzLegacyDataType = "TIMESTAMP_NTZ" TimestampTzLegacyDataType = "TIMESTAMP_TZ" VariantLegacyDataType = "VARIANT" - TableLegacyDataType = "TABLE" + // TableLegacyDataType was not a value of legacy data type in the old implementation. Left for now for an easier implementation. + TableLegacyDataType = "TABLE" ) diff --git a/pkg/sdk/datatypes/table.go b/pkg/sdk/datatypes/table.go index dbe87307fe..a05298c992 100644 --- a/pkg/sdk/datatypes/table.go +++ b/pkg/sdk/datatypes/table.go @@ -8,7 +8,8 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" ) -// TableDataType does not have synonyms. +// TableDataType is based on https://docs.snowflake.com/en/developer-guide/stored-procedure/stored-procedures-java#returning-tabular-data. +// It does not have synonyms. // It consists of a list of column name + column type; may be empty. type TableDataType struct { columns []TableDataTypeColumn @@ -20,6 +21,8 @@ type TableDataTypeColumn struct { dataType DataType } +var TableDataTypeSynonyms = []string{"TABLE"} + func (c *TableDataTypeColumn) ColumnName() string { return c.name } @@ -32,21 +35,21 @@ func (t *TableDataType) ToSql() string { columns := strings.Join(collections.Map(t.columns, func(col TableDataTypeColumn) string { return fmt.Sprintf("%s %s", col.name, col.dataType.ToSql()) }), ", ") - return fmt.Sprintf("%s (%s)", TableLegacyDataType, columns) + return fmt.Sprintf("%s(%s)", t.underlyingType, columns) } func (t *TableDataType) ToLegacyDataTypeSql() string { columns := strings.Join(collections.Map(t.columns, func(col TableDataTypeColumn) string { return fmt.Sprintf("%s %s", col.name, col.dataType.ToLegacyDataTypeSql()) }), ", ") - return fmt.Sprintf("%s (%s)", TableLegacyDataType, columns) + return fmt.Sprintf("%s(%s)", TableLegacyDataType, columns) } func (t *TableDataType) Canonical() string { columns := strings.Join(collections.Map(t.columns, func(col TableDataTypeColumn) string { return fmt.Sprintf("%s %s", col.name, col.dataType.Canonical()) }), ", ") - return fmt.Sprintf("%s (%s)", TableLegacyDataType, columns) + return fmt.Sprintf("%s(%s)", TableLegacyDataType, columns) } func (t *TableDataType) Columns() []TableDataTypeColumn { @@ -55,19 +58,19 @@ func (t *TableDataType) Columns() []TableDataTypeColumn { func parseTableDataTypeRaw(raw sanitizedDataTypeRaw) (*TableDataType, error) { r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType)) - if r == "()" { + if r == "" || (!strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")")) { + logging.DebugLogger.Printf(`table %s could not be parsed, use "%s(argName argType, ...)" format`, raw.raw, raw.matchedByType) + return nil, fmt.Errorf(`table %s could not be parsed, use "%s(argName argType, ...)" format`, raw.raw, raw.matchedByType) + } + onlyArgs := strings.TrimSpace(r[1 : len(r)-1]) + if onlyArgs == "" { return &TableDataType{ columns: make([]TableDataTypeColumn, 0), underlyingType: raw.matchedByType, }, nil } - if r == "" || (!strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")")) { - logging.DebugLogger.Printf(`table %s could not be parsed, use "%s(argName argType, ...)" format`, raw.raw, raw.matchedByType) - return nil, fmt.Errorf(`table %s could not be parsed, use "%s(argName argType, ...)" format`, raw.raw, raw.matchedByType) - } - onlyArgs := r[1 : len(r)-1] columns, err := collections.MapErr(strings.Split(onlyArgs, ","), func(arg string) (TableDataTypeColumn, error) { - argParts := strings.Split(strings.TrimSpace(arg), " ") + argParts := strings.SplitN(strings.TrimSpace(arg), " ", 2) if len(argParts) != 2 { return TableDataTypeColumn{}, fmt.Errorf("could not parse table column: %s, it should contain the following format ` `; parser failure may be connected to the complex argument names", arg) } diff --git a/pkg/sdk/datatypes/vector.go b/pkg/sdk/datatypes/vector.go index d4fa9e9050..035249af64 100644 --- a/pkg/sdk/datatypes/vector.go +++ b/pkg/sdk/datatypes/vector.go @@ -32,7 +32,6 @@ func (t *VectorDataType) Canonical() string { var ( VectorDataTypeSynonyms = []string{"VECTOR"} - TableDataTypeSynonyms = []string{"TABLE"} VectorAllowedInnerTypes = []string{"INT", "FLOAT"} ) diff --git a/pkg/sdk/testint/functions_integration_test.go b/pkg/sdk/testint/functions_integration_test.go index d7a0e386b4..349c17de83 100644 --- a/pkg/sdk/testint/functions_integration_test.go +++ b/pkg/sdk/testint/functions_integration_test.go @@ -2016,7 +2016,7 @@ func TestInt_Functions(t *testing.T) { id := testClientHelper().Ids.RandomSchemaObjectIdentifierWithArguments(datatypes.VarcharLegacyDataType) - definition := ` SELECT 1, 2.2::float, 'abc');` + definition := ` SELECT 1, 2.2::float, 'abc';` dt := sdk.NewFunctionReturnsResultDataTypeRequest(returnDataType) returns := sdk.NewFunctionReturnsRequest().WithResultDataType(*dt) argument := sdk.NewFunctionArgumentRequest(argName, nil).WithArgDataTypeOld(datatypes.VarcharLegacyDataType) @@ -2034,7 +2034,11 @@ func TestInt_Functions(t *testing.T) { HasCreatedOnNotEmpty(). HasName(id.Name()). HasSchemaName(id.SchemaName()). - HasArgumentsRawContains(fmt.Sprintf(`RETURN %s`, returnDataType.ToLegacyDataTypeSql())), + HasArgumentsRawContains(returnDataType.ToLegacyDataTypeSql()), + ) + + assertions.AssertThatObject(t, objectassert.FunctionDetails(t, id). + HasReturnDataType(returnDataType), ) }) } diff --git a/pkg/sdk/testint/procedures_integration_test.go b/pkg/sdk/testint/procedures_integration_test.go index dd357e11d8..3b4e4f041d 100644 --- a/pkg/sdk/testint/procedures_integration_test.go +++ b/pkg/sdk/testint/procedures_integration_test.go @@ -1764,7 +1764,12 @@ def filter_by_role(session, table_name, role): HasCreatedOnNotEmpty(). HasName(id.Name()). HasSchemaName(id.SchemaName()). - HasArgumentsRawContains(fmt.Sprintf(`RETURN %s`, expectedReturnDataType.ToLegacyDataTypeSql()))) + HasArgumentsRawContains(expectedReturnDataType.ToLegacyDataTypeSql()), + ) + + assertions.AssertThatObject(t, objectassert.ProcedureDetails(t, id). + HasReturnDataType(expectedReturnDataType), + ) }) t.Run("show parameters", func(t *testing.T) {