From c8cf09b2af605dc373a138e6ca6863b5546303d5 Mon Sep 17 00:00:00 2001 From: Scott Winkler Date: Thu, 1 Feb 2024 04:48:25 -0800 Subject: [PATCH 1/5] feat: add external funcs to sdk (#2440) --- pkg/sdk/client.go | 2 + pkg/sdk/external_functions_def.go | 167 ++++++++++ .../external_functions_dto_builders_gen.go | 239 +++++++++++++ pkg/sdk/external_functions_dto_gen.go | 82 +++++ pkg/sdk/external_functions_gen.go | 150 +++++++++ pkg/sdk/external_functions_gen_test.go | 315 ++++++++++++++++++ pkg/sdk/external_functions_impl_gen.go | 211 ++++++++++++ pkg/sdk/external_functions_validations_gen.go | 74 ++++ pkg/sdk/poc/main.go | 1 + .../external_functions_integration_test.go | 278 ++++++++++++++++ pkg/sdk/testint/functions_integration_test.go | 1 - 11 files changed, 1519 insertions(+), 1 deletion(-) create mode 100644 pkg/sdk/external_functions_def.go create mode 100644 pkg/sdk/external_functions_dto_builders_gen.go create mode 100644 pkg/sdk/external_functions_dto_gen.go create mode 100644 pkg/sdk/external_functions_gen.go create mode 100644 pkg/sdk/external_functions_gen_test.go create mode 100644 pkg/sdk/external_functions_impl_gen.go create mode 100644 pkg/sdk/external_functions_validations_gen.go create mode 100644 pkg/sdk/testint/external_functions_integration_test.go diff --git a/pkg/sdk/client.go b/pkg/sdk/client.go index 213c898ebd..42c9738a60 100644 --- a/pkg/sdk/client.go +++ b/pkg/sdk/client.go @@ -48,6 +48,7 @@ type Client struct { DatabaseRoles DatabaseRoles Databases Databases DynamicTables DynamicTables + ExternalFunctions ExternalFunctions ExternalTables ExternalTables EventTables EventTables FailoverGroups FailoverGroups @@ -199,6 +200,7 @@ func (c *Client) initialize() { c.DatabaseRoles = &databaseRoles{client: c} c.Databases = &databases{client: c} c.DynamicTables = &dynamicTables{client: c} + c.ExternalFunctions = &externalFunctions{client: c} c.ExternalTables = &externalTables{client: c} c.EventTables = &eventTables{client: c} c.FailoverGroups = &failoverGroups{client: c} diff --git a/pkg/sdk/external_functions_def.go b/pkg/sdk/external_functions_def.go new file mode 100644 index 0000000000..093800ec8a --- /dev/null +++ b/pkg/sdk/external_functions_def.go @@ -0,0 +1,167 @@ +package sdk + +import g "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/poc/generator" + +//go:generate go run ./poc/main.go + +var externalFunctionArgument = g.NewQueryStruct("ExternalFunctionArgument"). + Text("ArgName", g.KeywordOptions().NoQuotes().Required()). + PredefinedQueryStructField("ArgDataType", g.KindOfT[DataType](), g.KeywordOptions().NoQuotes().Required()) + +var externalFunctionHeader = g.NewQueryStruct("ExternalFunctionHeader"). + Text("Name", g.KeywordOptions().SingleQuotes().Required()). + PredefinedQueryStructField("Value", g.KindOfT[string](), g.ParameterOptions().SingleQuotes().Required()) + +var externalFunctionContextHeader = g.NewQueryStruct("ExternalFunctionContextHeader").Text("ContextFunction", g.KeywordOptions().NoQuotes().Required()) + +var externalFunctionSet = g.NewQueryStruct("ExternalFunctionSet"). + OptionalIdentifier("ApiIntegration", g.KindOfTPointer[AccountObjectIdentifier](), g.IdentifierOptions().SQL("API_INTEGRATION =")). + ListQueryStructField( + "Headers", + externalFunctionHeader, + g.ParameterOptions().Parentheses().SQL("HEADERS"), + ). + ListQueryStructField( + "ContextHeaders", + externalFunctionContextHeader, + g.ParameterOptions().Parentheses().SQL("CONTEXT_HEADERS"), + ). + OptionalNumberAssignment("MAX_BATCH_ROWS", g.ParameterOptions()). + OptionalTextAssignment("COMPRESSION", g.ParameterOptions()). + OptionalIdentifier("RequestTranslator", g.KindOfTPointer[SchemaObjectIdentifier](), g.IdentifierOptions().SQL("REQUEST_TRANSLATOR =")). + OptionalIdentifier("ResponseTranslator", g.KindOfTPointer[SchemaObjectIdentifier](), g.IdentifierOptions().SQL("RESPONSE_TRANSLATOR =")). + WithValidation(g.ExactlyOneValueSet, "ApiIntegration", "Headers", "ContextHeaders", "MaxBatchRows", "Compression", "RequestTranslator", "ResponseTranslator") + +var externalFunctionUnset = g.NewQueryStruct("ExternalFunctionUnset"). + OptionalSQL("COMMENT"). + OptionalSQL("HEADERS"). + OptionalSQL("CONTEXT_HEADERS"). + OptionalSQL("MAX_BATCH_ROWS"). + OptionalSQL("COMPRESSION"). + OptionalSQL("SECURE"). + OptionalSQL("REQUEST_TRANSLATOR"). + OptionalSQL("RESPONSE_TRANSLATOR"). + WithValidation(g.AtLeastOneValueSet, "Comment", "Headers", "ContextHeaders", "MaxBatchRows", "Compression", "Secure", "RequestTranslator", "ResponseTranslator") + +var ExternalFunctionsDef = g.NewInterface( + "ExternalFunctions", + "ExternalFunction", + g.KindOfT[SchemaObjectIdentifier](), +).CreateOperation( + "https://docs.snowflake.com/en/sql-reference/sql/create-external-function", + g.NewQueryStruct("CreateExternalFunction"). + Create(). + OrReplace(). + OptionalSQL("SECURE"). + SQL("EXTERNAL FUNCTION"). + Name(). + ListQueryStructField( + "Arguments", + externalFunctionArgument, + g.ListOptions().MustParentheses()). + PredefinedQueryStructField("ResultDataType", g.KindOfT[DataType](), g.ParameterOptions().NoEquals().SQL("RETURNS").Required()). + PredefinedQueryStructField("ReturnNullValues", g.KindOfTPointer[ReturnNullValues](), g.KeywordOptions()). + PredefinedQueryStructField("NullInputBehavior", g.KindOfTPointer[NullInputBehavior](), g.KeywordOptions()). + PredefinedQueryStructField("ReturnResultsBehavior", g.KindOfTPointer[ReturnResultsBehavior](), g.KeywordOptions()). + OptionalTextAssignment("COMMENT", g.ParameterOptions().SingleQuotes()). + Identifier("ApiIntegration", g.KindOfTPointer[AccountObjectIdentifier](), g.IdentifierOptions().SQL("API_INTEGRATION =").Required()). + ListQueryStructField( + "Headers", + externalFunctionHeader, + g.ParameterOptions().Parentheses().SQL("HEADERS"), + ). + ListQueryStructField( + "ContextHeaders", + externalFunctionContextHeader, + g.ParameterOptions().Parentheses().SQL("CONTEXT_HEADERS"), + ). + OptionalNumberAssignment("MAX_BATCH_ROWS", g.ParameterOptions()). + OptionalTextAssignment("COMPRESSION", g.ParameterOptions()). + OptionalIdentifier("RequestTranslator", g.KindOfTPointer[SchemaObjectIdentifier](), g.IdentifierOptions().SQL("REQUEST_TRANSLATOR =")). + OptionalIdentifier("ResponseTranslator", g.KindOfTPointer[SchemaObjectIdentifier](), g.IdentifierOptions().SQL("RESPONSE_TRANSLATOR =")). + TextAssignment("AS", g.ParameterOptions().NoEquals().SingleQuotes().Required()). + WithValidation(g.ValidIdentifier, "name"). + WithValidation(g.ValidateValueSet, "ApiIntegration"). + WithValidation(g.ValidIdentifierIfSet, "RequestTranslator"). + WithValidation(g.ValidateValueSet, "As"). + WithValidation(g.ValidIdentifierIfSet, "ResponseTranslator"), +).AlterOperation( + "https://docs.snowflake.com/en/sql-reference/sql/alter-function", + g.NewQueryStruct("AlterExternalFunction"). + Alter(). + SQL("FUNCTION"). + IfExists(). + Name(). + PredefinedQueryStructField("ArgumentDataTypes", g.KindOfTSlice[DataType](), g.KeywordOptions().MustParentheses().Required()). + OptionalQueryStructField( + "Set", + externalFunctionSet, + g.KeywordOptions().SQL("SET"), + ). + OptionalQueryStructField( + "Unset", + externalFunctionUnset, + g.ListOptions().NoParentheses().SQL("UNSET"), + ). + WithValidation(g.ExactlyOneValueSet, "Set", "Unset"). + WithValidation(g.ValidIdentifier, "name"), +).ShowOperation( + "https://docs.snowflake.com/en/sql-reference/sql/show-external-functions", + g.DbStruct("externalFunctionRow"). + Field("created_on", "string"). + Field("name", "string"). + Field("schema_name", "sql.NullString"). + Field("is_builtin", "string"). + Field("is_aggregate", "string"). + Field("is_ansi", "string"). + Field("min_num_arguments", "int"). + Field("max_num_arguments", "int"). + Field("arguments", "string"). + Field("description", "string"). + Field("schema_name", "sql.NullString"). + Field("is_table_function", "string"). + Field("valid_for_clustering", "string"). + Field("is_secure", "sql.NullString"). + Field("is_external_function", "string"). + Field("language", "string"). + Field("is_memoizable", "sql.NullString"). + Field("is_data_metric", "sql.NullString"), + g.PlainStruct("ExternalFunction"). + Field("CreatedOn", "string"). + Field("Name", "string"). + Field("SchemaName", "string"). + Field("IsBuiltin", "bool"). + Field("IsAggregate", "bool"). + Field("IsAnsi", "bool"). + Field("MinNumArguments", "int"). + Field("MaxNumArguments", "int"). + Field("Arguments", "string"). + Field("Description", "string"). + Field("CatalogName", "string"). + Field("IsTableFunction", "bool"). + Field("ValidForClustering", "bool"). + Field("IsSecure", "bool"). + Field("IsExternalFunction", "bool"). + Field("Language", "string"). + Field("IsMemoizable", "bool"). + Field("IsDataMetric", "bool"), + g.NewQueryStruct("ShowFunctions"). + Show(). + SQL("EXTERNAL FUNCTIONS"). + OptionalLike(), +).ShowByIdOperation().DescribeOperation( + g.DescriptionMappingKindSlice, + "https://docs.snowflake.com/en/sql-reference/sql/desc-function", + g.DbStruct("externalFunctionPropertyRow"). + Field("property", "string"). + Field("value", "string"), + g.PlainStruct("ExternalFunctionProperty"). + Field("Property", "string"). + Field("Value", "string"), + g.NewQueryStruct("DescribeExternalFunction"). + Describe(). + SQL("FUNCTION"). + Name(). + PredefinedQueryStructField("ArgumentDataTypes", g.KindOfTSlice[DataType](), g.KeywordOptions().MustParentheses().Required()). + WithValidation(g.ValidIdentifier, "name"), +) diff --git a/pkg/sdk/external_functions_dto_builders_gen.go b/pkg/sdk/external_functions_dto_builders_gen.go new file mode 100644 index 0000000000..86c38b0bf2 --- /dev/null +++ b/pkg/sdk/external_functions_dto_builders_gen.go @@ -0,0 +1,239 @@ +// Code generated by dto builder generator; DO NOT EDIT. + +package sdk + +import () + +func NewCreateExternalFunctionRequest( + name SchemaObjectIdentifier, + ResultDataType DataType, + ApiIntegration *AccountObjectIdentifier, + As string, +) *CreateExternalFunctionRequest { + s := CreateExternalFunctionRequest{} + s.name = name + s.ResultDataType = ResultDataType + s.ApiIntegration = ApiIntegration + s.As = As + return &s +} + +func (s *CreateExternalFunctionRequest) WithOrReplace(OrReplace *bool) *CreateExternalFunctionRequest { + s.OrReplace = OrReplace + return s +} + +func (s *CreateExternalFunctionRequest) WithSecure(Secure *bool) *CreateExternalFunctionRequest { + s.Secure = Secure + return s +} + +func (s *CreateExternalFunctionRequest) WithArguments(Arguments []ExternalFunctionArgumentRequest) *CreateExternalFunctionRequest { + s.Arguments = Arguments + return s +} + +func (s *CreateExternalFunctionRequest) WithReturnNullValues(ReturnNullValues *ReturnNullValues) *CreateExternalFunctionRequest { + s.ReturnNullValues = ReturnNullValues + return s +} + +func (s *CreateExternalFunctionRequest) WithNullInputBehavior(NullInputBehavior *NullInputBehavior) *CreateExternalFunctionRequest { + s.NullInputBehavior = NullInputBehavior + return s +} + +func (s *CreateExternalFunctionRequest) WithReturnResultsBehavior(ReturnResultsBehavior *ReturnResultsBehavior) *CreateExternalFunctionRequest { + s.ReturnResultsBehavior = ReturnResultsBehavior + return s +} + +func (s *CreateExternalFunctionRequest) WithComment(Comment *string) *CreateExternalFunctionRequest { + s.Comment = Comment + return s +} + +func (s *CreateExternalFunctionRequest) WithHeaders(Headers []ExternalFunctionHeaderRequest) *CreateExternalFunctionRequest { + s.Headers = Headers + return s +} + +func (s *CreateExternalFunctionRequest) WithContextHeaders(ContextHeaders []ExternalFunctionContextHeaderRequest) *CreateExternalFunctionRequest { + s.ContextHeaders = ContextHeaders + return s +} + +func (s *CreateExternalFunctionRequest) WithMaxBatchRows(MaxBatchRows *int) *CreateExternalFunctionRequest { + s.MaxBatchRows = MaxBatchRows + return s +} + +func (s *CreateExternalFunctionRequest) WithCompression(Compression *string) *CreateExternalFunctionRequest { + s.Compression = Compression + return s +} + +func (s *CreateExternalFunctionRequest) WithRequestTranslator(RequestTranslator *SchemaObjectIdentifier) *CreateExternalFunctionRequest { + s.RequestTranslator = RequestTranslator + return s +} + +func (s *CreateExternalFunctionRequest) WithResponseTranslator(ResponseTranslator *SchemaObjectIdentifier) *CreateExternalFunctionRequest { + s.ResponseTranslator = ResponseTranslator + return s +} + +func NewExternalFunctionArgumentRequest( + ArgName string, + ArgDataType DataType, +) *ExternalFunctionArgumentRequest { + s := ExternalFunctionArgumentRequest{} + s.ArgName = ArgName + s.ArgDataType = ArgDataType + return &s +} + +func NewExternalFunctionHeaderRequest( + Name string, + Value string, +) *ExternalFunctionHeaderRequest { + s := ExternalFunctionHeaderRequest{} + s.Name = Name + s.Value = Value + return &s +} + +func NewExternalFunctionContextHeaderRequest( + ContextFunction string, +) *ExternalFunctionContextHeaderRequest { + s := ExternalFunctionContextHeaderRequest{} + s.ContextFunction = ContextFunction + return &s +} + +func NewAlterExternalFunctionRequest( + name SchemaObjectIdentifier, + ArgumentDataTypes []DataType, +) *AlterExternalFunctionRequest { + s := AlterExternalFunctionRequest{} + s.name = name + s.ArgumentDataTypes = ArgumentDataTypes + return &s +} + +func (s *AlterExternalFunctionRequest) WithIfExists(IfExists *bool) *AlterExternalFunctionRequest { + s.IfExists = IfExists + return s +} + +func (s *AlterExternalFunctionRequest) WithSet(Set *ExternalFunctionSetRequest) *AlterExternalFunctionRequest { + s.Set = Set + return s +} + +func (s *AlterExternalFunctionRequest) WithUnset(Unset *ExternalFunctionUnsetRequest) *AlterExternalFunctionRequest { + s.Unset = Unset + return s +} + +func NewExternalFunctionSetRequest() *ExternalFunctionSetRequest { + return &ExternalFunctionSetRequest{} +} + +func (s *ExternalFunctionSetRequest) WithApiIntegration(ApiIntegration *AccountObjectIdentifier) *ExternalFunctionSetRequest { + s.ApiIntegration = ApiIntegration + return s +} + +func (s *ExternalFunctionSetRequest) WithHeaders(Headers []ExternalFunctionHeaderRequest) *ExternalFunctionSetRequest { + s.Headers = Headers + return s +} + +func (s *ExternalFunctionSetRequest) WithContextHeaders(ContextHeaders []ExternalFunctionContextHeaderRequest) *ExternalFunctionSetRequest { + s.ContextHeaders = ContextHeaders + return s +} + +func (s *ExternalFunctionSetRequest) WithMaxBatchRows(MaxBatchRows *int) *ExternalFunctionSetRequest { + s.MaxBatchRows = MaxBatchRows + return s +} + +func (s *ExternalFunctionSetRequest) WithCompression(Compression *string) *ExternalFunctionSetRequest { + s.Compression = Compression + return s +} + +func (s *ExternalFunctionSetRequest) WithRequestTranslator(RequestTranslator *SchemaObjectIdentifier) *ExternalFunctionSetRequest { + s.RequestTranslator = RequestTranslator + return s +} + +func (s *ExternalFunctionSetRequest) WithResponseTranslator(ResponseTranslator *SchemaObjectIdentifier) *ExternalFunctionSetRequest { + s.ResponseTranslator = ResponseTranslator + return s +} + +func NewExternalFunctionUnsetRequest() *ExternalFunctionUnsetRequest { + return &ExternalFunctionUnsetRequest{} +} + +func (s *ExternalFunctionUnsetRequest) WithComment(Comment *bool) *ExternalFunctionUnsetRequest { + s.Comment = Comment + return s +} + +func (s *ExternalFunctionUnsetRequest) WithHeaders(Headers *bool) *ExternalFunctionUnsetRequest { + s.Headers = Headers + return s +} + +func (s *ExternalFunctionUnsetRequest) WithContextHeaders(ContextHeaders *bool) *ExternalFunctionUnsetRequest { + s.ContextHeaders = ContextHeaders + return s +} + +func (s *ExternalFunctionUnsetRequest) WithMaxBatchRows(MaxBatchRows *bool) *ExternalFunctionUnsetRequest { + s.MaxBatchRows = MaxBatchRows + return s +} + +func (s *ExternalFunctionUnsetRequest) WithCompression(Compression *bool) *ExternalFunctionUnsetRequest { + s.Compression = Compression + return s +} + +func (s *ExternalFunctionUnsetRequest) WithSecure(Secure *bool) *ExternalFunctionUnsetRequest { + s.Secure = Secure + return s +} + +func (s *ExternalFunctionUnsetRequest) WithRequestTranslator(RequestTranslator *bool) *ExternalFunctionUnsetRequest { + s.RequestTranslator = RequestTranslator + return s +} + +func (s *ExternalFunctionUnsetRequest) WithResponseTranslator(ResponseTranslator *bool) *ExternalFunctionUnsetRequest { + s.ResponseTranslator = ResponseTranslator + return s +} + +func NewShowExternalFunctionRequest() *ShowExternalFunctionRequest { + return &ShowExternalFunctionRequest{} +} + +func (s *ShowExternalFunctionRequest) WithLike(Like *Like) *ShowExternalFunctionRequest { + s.Like = Like + return s +} + +func NewDescribeExternalFunctionRequest( + name SchemaObjectIdentifier, + ArgumentDataTypes []DataType, +) *DescribeExternalFunctionRequest { + s := DescribeExternalFunctionRequest{} + s.name = name + s.ArgumentDataTypes = ArgumentDataTypes + return &s +} diff --git a/pkg/sdk/external_functions_dto_gen.go b/pkg/sdk/external_functions_dto_gen.go new file mode 100644 index 0000000000..1e8acd91bc --- /dev/null +++ b/pkg/sdk/external_functions_dto_gen.go @@ -0,0 +1,82 @@ +package sdk + +//go:generate go run ./dto-builder-generator/main.go + +var ( + _ optionsProvider[CreateExternalFunctionOptions] = new(CreateExternalFunctionRequest) + _ optionsProvider[AlterExternalFunctionOptions] = new(AlterExternalFunctionRequest) + _ optionsProvider[ShowExternalFunctionOptions] = new(ShowExternalFunctionRequest) + _ optionsProvider[DescribeExternalFunctionOptions] = new(DescribeExternalFunctionRequest) +) + +type CreateExternalFunctionRequest struct { + OrReplace *bool + Secure *bool + name SchemaObjectIdentifier // required + Arguments []ExternalFunctionArgumentRequest + ResultDataType DataType // required + ReturnNullValues *ReturnNullValues + NullInputBehavior *NullInputBehavior + ReturnResultsBehavior *ReturnResultsBehavior + Comment *string + ApiIntegration *AccountObjectIdentifier // required + Headers []ExternalFunctionHeaderRequest + ContextHeaders []ExternalFunctionContextHeaderRequest + MaxBatchRows *int + Compression *string + RequestTranslator *SchemaObjectIdentifier + ResponseTranslator *SchemaObjectIdentifier + As string // required +} + +type ExternalFunctionArgumentRequest struct { + ArgName string // required + ArgDataType DataType // required +} + +type ExternalFunctionHeaderRequest struct { + Name string // required + Value string // required +} + +type ExternalFunctionContextHeaderRequest struct { + ContextFunction string // required +} + +type AlterExternalFunctionRequest struct { + IfExists *bool + name SchemaObjectIdentifier // required + ArgumentDataTypes []DataType // required + Set *ExternalFunctionSetRequest + Unset *ExternalFunctionUnsetRequest +} + +type ExternalFunctionSetRequest struct { + ApiIntegration *AccountObjectIdentifier + Headers []ExternalFunctionHeaderRequest + ContextHeaders []ExternalFunctionContextHeaderRequest + MaxBatchRows *int + Compression *string + RequestTranslator *SchemaObjectIdentifier + ResponseTranslator *SchemaObjectIdentifier +} + +type ExternalFunctionUnsetRequest struct { + Comment *bool + Headers *bool + ContextHeaders *bool + MaxBatchRows *bool + Compression *bool + Secure *bool + RequestTranslator *bool + ResponseTranslator *bool +} + +type ShowExternalFunctionRequest struct { + Like *Like +} + +type DescribeExternalFunctionRequest struct { + name SchemaObjectIdentifier // required + ArgumentDataTypes []DataType // required +} diff --git a/pkg/sdk/external_functions_gen.go b/pkg/sdk/external_functions_gen.go new file mode 100644 index 0000000000..2e02563514 --- /dev/null +++ b/pkg/sdk/external_functions_gen.go @@ -0,0 +1,150 @@ +package sdk + +import ( + "context" + "database/sql" +) + +type ExternalFunctions interface { + Create(ctx context.Context, request *CreateExternalFunctionRequest) error + Alter(ctx context.Context, request *AlterExternalFunctionRequest) error + Show(ctx context.Context, request *ShowExternalFunctionRequest) ([]ExternalFunction, error) + ShowByID(ctx context.Context, id SchemaObjectIdentifier, arguments []DataType) (*ExternalFunction, error) + Describe(ctx context.Context, request *DescribeExternalFunctionRequest) ([]ExternalFunctionProperty, error) +} + +// CreateExternalFunctionOptions is based on https://docs.snowflake.com/en/sql-reference/sql/create-external-function. +type CreateExternalFunctionOptions struct { + create bool `ddl:"static" sql:"CREATE"` + OrReplace *bool `ddl:"keyword" sql:"OR REPLACE"` + Secure *bool `ddl:"keyword" sql:"SECURE"` + externalFunction bool `ddl:"static" sql:"EXTERNAL FUNCTION"` + name SchemaObjectIdentifier `ddl:"identifier"` + Arguments []ExternalFunctionArgument `ddl:"list,must_parentheses"` + ResultDataType DataType `ddl:"parameter,no_equals" sql:"RETURNS"` + ReturnNullValues *ReturnNullValues `ddl:"keyword"` + NullInputBehavior *NullInputBehavior `ddl:"keyword"` + ReturnResultsBehavior *ReturnResultsBehavior `ddl:"keyword"` + Comment *string `ddl:"parameter,single_quotes" sql:"COMMENT"` + ApiIntegration *AccountObjectIdentifier `ddl:"identifier" sql:"API_INTEGRATION ="` + Headers []ExternalFunctionHeader `ddl:"parameter,parentheses" sql:"HEADERS"` + ContextHeaders []ExternalFunctionContextHeader `ddl:"parameter,parentheses" sql:"CONTEXT_HEADERS"` + MaxBatchRows *int `ddl:"parameter" sql:"MAX_BATCH_ROWS"` + Compression *string `ddl:"parameter" sql:"COMPRESSION"` + RequestTranslator *SchemaObjectIdentifier `ddl:"identifier" sql:"REQUEST_TRANSLATOR ="` + ResponseTranslator *SchemaObjectIdentifier `ddl:"identifier" sql:"RESPONSE_TRANSLATOR ="` + As string `ddl:"parameter,single_quotes,no_equals" sql:"AS"` +} + +type ExternalFunctionArgument struct { + ArgName string `ddl:"keyword,no_quotes"` + ArgDataType DataType `ddl:"keyword,no_quotes"` +} + +type ExternalFunctionHeader struct { + Name string `ddl:"keyword,single_quotes"` + Value string `ddl:"parameter,single_quotes"` +} + +type ExternalFunctionContextHeader struct { + ContextFunction string `ddl:"keyword,no_quotes"` +} + +// AlterExternalFunctionOptions is based on https://docs.snowflake.com/en/sql-reference/sql/alter-function. +type AlterExternalFunctionOptions struct { + alter bool `ddl:"static" sql:"ALTER"` + function bool `ddl:"static" sql:"FUNCTION"` + IfExists *bool `ddl:"keyword" sql:"IF EXISTS"` + name SchemaObjectIdentifier `ddl:"identifier"` + ArgumentDataTypes []DataType `ddl:"keyword,must_parentheses"` + Set *ExternalFunctionSet `ddl:"keyword" sql:"SET"` + Unset *ExternalFunctionUnset `ddl:"list,no_parentheses" sql:"UNSET"` +} + +type ExternalFunctionSet struct { + ApiIntegration *AccountObjectIdentifier `ddl:"identifier" sql:"API_INTEGRATION ="` + Headers []ExternalFunctionHeader `ddl:"parameter,parentheses" sql:"HEADERS"` + ContextHeaders []ExternalFunctionContextHeader `ddl:"parameter,parentheses" sql:"CONTEXT_HEADERS"` + MaxBatchRows *int `ddl:"parameter" sql:"MAX_BATCH_ROWS"` + Compression *string `ddl:"parameter" sql:"COMPRESSION"` + RequestTranslator *SchemaObjectIdentifier `ddl:"identifier" sql:"REQUEST_TRANSLATOR ="` + ResponseTranslator *SchemaObjectIdentifier `ddl:"identifier" sql:"RESPONSE_TRANSLATOR ="` +} + +type ExternalFunctionUnset struct { + Comment *bool `ddl:"keyword" sql:"COMMENT"` + Headers *bool `ddl:"keyword" sql:"HEADERS"` + ContextHeaders *bool `ddl:"keyword" sql:"CONTEXT_HEADERS"` + MaxBatchRows *bool `ddl:"keyword" sql:"MAX_BATCH_ROWS"` + Compression *bool `ddl:"keyword" sql:"COMPRESSION"` + Secure *bool `ddl:"keyword" sql:"SECURE"` + RequestTranslator *bool `ddl:"keyword" sql:"REQUEST_TRANSLATOR"` + ResponseTranslator *bool `ddl:"keyword" sql:"RESPONSE_TRANSLATOR"` +} + +// ShowExternalFunctionOptions is based on https://docs.snowflake.com/en/sql-reference/sql/show-external-functions. +type ShowExternalFunctionOptions struct { + show bool `ddl:"static" sql:"SHOW"` + externalFunctions bool `ddl:"static" sql:"EXTERNAL FUNCTIONS"` + Like *Like `ddl:"keyword" sql:"LIKE"` +} + +type externalFunctionRow struct { + CreatedOn string `db:"created_on"` + Name string `db:"name"` + SchemaName sql.NullString `db:"schema_name"` + IsBuiltin string `db:"is_builtin"` + IsAggregate string `db:"is_aggregate"` + IsAnsi string `db:"is_ansi"` + MinNumArguments int `db:"min_num_arguments"` + MaxNumArguments int `db:"max_num_arguments"` + Arguments string `db:"arguments"` + Description string `db:"description"` + CatalogName sql.NullString `db:"catalog_name"` + IsTableFunction string `db:"is_table_function"` + ValidForClustering string `db:"valid_for_clustering"` + IsSecure sql.NullString `db:"is_secure"` + IsExternalFunction string `db:"is_external_function"` + Language string `db:"language"` + IsMemoizable sql.NullString `db:"is_memoizable"` + IsDataMetric sql.NullString `db:"is_data_metric"` +} + +type ExternalFunction struct { + CreatedOn string + Name string + SchemaName string + IsBuiltin bool + IsAggregate bool + IsAnsi bool + MinNumArguments int + MaxNumArguments int + Arguments string + Description string + CatalogName string + IsTableFunction bool + ValidForClustering bool + IsSecure bool + IsExternalFunction bool + Language string + IsMemoizable bool + IsDataMetric bool +} + +// DescribeExternalFunctionOptions is based on https://docs.snowflake.com/en/sql-reference/sql/desc-function. +type DescribeExternalFunctionOptions struct { + describe bool `ddl:"static" sql:"DESCRIBE"` + function bool `ddl:"static" sql:"FUNCTION"` + name SchemaObjectIdentifier `ddl:"identifier"` + ArgumentDataTypes []DataType `ddl:"keyword,must_parentheses"` +} + +type externalFunctionPropertyRow struct { + Property string `db:"property"` + Value string `db:"value"` +} + +type ExternalFunctionProperty struct { + Property string + Value string +} diff --git a/pkg/sdk/external_functions_gen_test.go b/pkg/sdk/external_functions_gen_test.go new file mode 100644 index 0000000000..bd9380045c --- /dev/null +++ b/pkg/sdk/external_functions_gen_test.go @@ -0,0 +1,315 @@ +package sdk + +import ( + "testing" +) + +func TestExternalFunctions_Create(t *testing.T) { + id := RandomSchemaObjectIdentifier() + + defaultOpts := func() *CreateExternalFunctionOptions { + return &CreateExternalFunctionOptions{ + name: id, + } + } + + t.Run("validation: nil options", func(t *testing.T) { + var opts *CreateExternalFunctionOptions = nil + assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) + }) + + t.Run("validation: incorrect identifier", func(t *testing.T) { + opts := defaultOpts() + opts.name = NewSchemaObjectIdentifier("", "", "") + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) + }) + + t.Run("validation: must options", func(t *testing.T) { + opts := defaultOpts() + + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateExternalFunctionOptions", "ApiIntegration")) + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateExternalFunctionOptions", "As")) + + opts = defaultOpts() + opts.As = "as" + integration := NewAccountObjectIdentifier("") + opts.ApiIntegration = &integration + rt := NewSchemaObjectIdentifier("", "", "") + opts.RequestTranslator = &rt + st := NewSchemaObjectIdentifier("", "", "") + opts.ResponseTranslator = &st + assertOptsInvalidJoinedErrors(t, opts, errNotSet("CreateExternalFunctionOptions", "ApiIntegration")) + assertOptsInvalidJoinedErrors(t, opts, errInvalidIdentifier("CreateExternalFunctionOptions", "RequestTranslator")) + assertOptsInvalidJoinedErrors(t, opts, errInvalidIdentifier("CreateExternalFunctionOptions", "ResponseTranslator")) + }) + + t.Run("all options", func(t *testing.T) { + opts := defaultOpts() + opts.OrReplace = Bool(true) + opts.Secure = Bool(true) + opts.Arguments = []ExternalFunctionArgument{ + { + ArgName: "id", + ArgDataType: DataTypeNumber, + }, + { + ArgName: "name", + ArgDataType: DataTypeVARCHAR, + }, + } + opts.ResultDataType = DataTypeVARCHAR + opts.ReturnNullValues = ReturnNullValuesPointer(ReturnNullValuesNotNull) + opts.NullInputBehavior = NullInputBehaviorPointer(NullInputBehaviorCalledOnNullInput) + opts.ReturnResultsBehavior = ReturnResultsBehaviorPointer(ReturnResultsBehaviorImmutable) + opts.Comment = String("comment") + integration := NewAccountObjectIdentifier("api_integration") + opts.ApiIntegration = &integration + opts.Headers = []ExternalFunctionHeader{ + { + Name: "header1", + Value: "value1", + }, + { + Name: "header2", + Value: "value2", + }, + } + opts.ContextHeaders = []ExternalFunctionContextHeader{ + { + ContextFunction: "CURRENT_ACCOUNT", + }, + { + ContextFunction: "CURRENT_USER", + }, + } + opts.MaxBatchRows = Int(100) + opts.Compression = String("GZIP") + rt := NewSchemaObjectIdentifier("db", "schema", "request_translator") + opts.RequestTranslator = &rt + rs := NewSchemaObjectIdentifier("db", "schema", "response_translator") + opts.ResponseTranslator = &rs + opts.As = "https://xyz.execute-api.us-west-2.amazonaws.com/prod/remote_echo" + assertOptsValidAndSQLEquals(t, opts, `CREATE OR REPLACE SECURE EXTERNAL FUNCTION %s (id NUMBER, name VARCHAR) RETURNS VARCHAR NOT NULL CALLED ON NULL INPUT IMMUTABLE COMMENT = 'comment' API_INTEGRATION = "api_integration" HEADERS = ('header1' = 'value1', 'header2' = 'value2') CONTEXT_HEADERS = (CURRENT_ACCOUNT, CURRENT_USER) MAX_BATCH_ROWS = 100 COMPRESSION = GZIP REQUEST_TRANSLATOR = %s RESPONSE_TRANSLATOR = %s AS 'https://xyz.execute-api.us-west-2.amazonaws.com/prod/remote_echo'`, id.FullyQualifiedName(), rt.FullyQualifiedName(), rs.FullyQualifiedName()) + }) +} + +func TestExternalFunctions_Alter(t *testing.T) { + id := RandomSchemaObjectIdentifier() + + defaultOpts := func() *AlterExternalFunctionOptions { + return &AlterExternalFunctionOptions{ + name: id, + IfExists: Bool(true), + ArgumentDataTypes: []DataType{DataTypeVARCHAR, DataTypeNumber}, + } + } + + t.Run("validation: nil options", func(t *testing.T) { + opts := (*AlterExternalFunctionOptions)(nil) + assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) + }) + + t.Run("validation: incorrect identifier", func(t *testing.T) { + opts := defaultOpts() + opts.name = NewSchemaObjectIdentifier("", "", "") + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) + }) + + t.Run("validation: at least one of the fields [opts.Unset.Comment opts.Unset.Headers opts.Unset.ContextHeaders opts.Unset.MaxBatchRows opts.Unset.Compression opts.Unset.Secure opts.Unset.RequestTranslator opts.Unset.ResponseTranslator] should be set", func(t *testing.T) { + opts := defaultOpts() + opts.Unset = &ExternalFunctionUnset{} + assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("AlterExternalFunctionOptions.Unset", "Comment", "Headers", "ContextHeaders", "MaxBatchRows", "Compression", "Secure", "RequestTranslator", "ResponseTranslator")) + }) + + t.Run("validation: exactly one field from [opts.Set opts.Unset] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Set = &ExternalFunctionSet{ + MaxBatchRows: Int(100), + } + opts.Unset = &ExternalFunctionUnset{ + MaxBatchRows: Bool(true), + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterExternalFunctionOptions", "Set", "Unset")) + }) + + t.Run("validation: exactly one field from [opts.Set.ApiIntegration opts.Set.Headers opts.Set.ContextHeaders opts.Set.MaxBatchRows opts.Set.Compression opts.Set.RequestTranslator opts.Set.ResponseTranslator] should be present", func(t *testing.T) { + opts := defaultOpts() + opts.Set = &ExternalFunctionSet{ + MaxBatchRows: Int(100), + Headers: []ExternalFunctionHeader{ + { + Name: "header1", + Value: "value1", + }, + { + Name: "header2", + Value: "value2", + }, + }, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterExternalFunctionOptions.Set", "ApiIntegration", "Headers", "ContextHeaders", "MaxBatchRows", "Compression", "RequestTranslator", "ResponseTranslator")) + }) + + t.Run("alter: set api integration", func(t *testing.T) { + opts := defaultOpts() + integration := NewAccountObjectIdentifier("api_integration") + opts.Set = &ExternalFunctionSet{ + ApiIntegration: &integration, + } + assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s (VARCHAR, NUMBER) SET API_INTEGRATION = "api_integration"`, id.FullyQualifiedName()) + }) + + t.Run("alter: set headers", func(t *testing.T) { + opts := defaultOpts() + opts.Set = &ExternalFunctionSet{ + Headers: []ExternalFunctionHeader{ + { + Name: "header1", + Value: "value1", + }, + { + Name: "header2", + Value: "value2", + }, + }, + } + assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s (VARCHAR, NUMBER) SET HEADERS = ('header1' = 'value1', 'header2' = 'value2')`, id.FullyQualifiedName()) + }) + + t.Run("alter: set max batch rows", func(t *testing.T) { + opts := defaultOpts() + opts.Set = &ExternalFunctionSet{ + MaxBatchRows: Int(100), + } + assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s (VARCHAR, NUMBER) SET MAX_BATCH_ROWS = 100`, id.FullyQualifiedName()) + }) + + t.Run("alter: set compression", func(t *testing.T) { + opts := defaultOpts() + opts.Set = &ExternalFunctionSet{ + Compression: String("GZIP"), + } + assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s (VARCHAR, NUMBER) SET COMPRESSION = GZIP`, id.FullyQualifiedName()) + }) + + t.Run("alter: set context headers", func(t *testing.T) { + opts := defaultOpts() + opts.Set = &ExternalFunctionSet{ + ContextHeaders: []ExternalFunctionContextHeader{ + { + ContextFunction: "CURRENT_ACCOUNT", + }, + { + ContextFunction: "CURRENT_USER", + }, + }, + } + assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s (VARCHAR, NUMBER) SET CONTEXT_HEADERS = (CURRENT_ACCOUNT, CURRENT_USER)`, id.FullyQualifiedName()) + }) + + t.Run("alter: set request translator", func(t *testing.T) { + opts := defaultOpts() + rt := RandomSchemaObjectIdentifier() + opts.Set = &ExternalFunctionSet{ + RequestTranslator: &rt, + } + assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s (VARCHAR, NUMBER) SET REQUEST_TRANSLATOR = %s`, id.FullyQualifiedName(), rt.FullyQualifiedName()) + }) + + t.Run("alter: set response translator", func(t *testing.T) { + opts := defaultOpts() + st := RandomSchemaObjectIdentifier() + opts.Set = &ExternalFunctionSet{ + ResponseTranslator: &st, + } + assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s (VARCHAR, NUMBER) SET RESPONSE_TRANSLATOR = %s`, id.FullyQualifiedName(), st.FullyQualifiedName()) + }) + + t.Run("alter: unset", func(t *testing.T) { + opts := defaultOpts() + opts.ArgumentDataTypes = []DataType{DataTypeVARCHAR, DataTypeNumber} + opts.Unset = &ExternalFunctionUnset{ + Comment: Bool(true), + Headers: Bool(true), + ContextHeaders: Bool(true), + MaxBatchRows: Bool(true), + Compression: Bool(true), + Secure: Bool(true), + RequestTranslator: Bool(true), + ResponseTranslator: Bool(true), + } + assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s (VARCHAR, NUMBER) UNSET COMMENT, HEADERS, CONTEXT_HEADERS, MAX_BATCH_ROWS, COMPRESSION, SECURE, REQUEST_TRANSLATOR, RESPONSE_TRANSLATOR`, id.FullyQualifiedName()) + }) + + t.Run("alter: unset with no arguments", func(t *testing.T) { + opts := defaultOpts() + opts.ArgumentDataTypes = nil + opts.Unset = &ExternalFunctionUnset{ + Comment: Bool(true), + Headers: Bool(true), + ContextHeaders: Bool(true), + MaxBatchRows: Bool(true), + Compression: Bool(true), + Secure: Bool(true), + RequestTranslator: Bool(true), + ResponseTranslator: Bool(true), + } + assertOptsValidAndSQLEquals(t, opts, `ALTER FUNCTION IF EXISTS %s () UNSET COMMENT, HEADERS, CONTEXT_HEADERS, MAX_BATCH_ROWS, COMPRESSION, SECURE, REQUEST_TRANSLATOR, RESPONSE_TRANSLATOR`, id.FullyQualifiedName()) + }) +} + +func TestExternalFunctions_Show(t *testing.T) { + defaultOpts := func() *ShowExternalFunctionOptions { + return &ShowExternalFunctionOptions{} + } + + t.Run("validation: nil options", func(t *testing.T) { + opts := (*ShowExternalFunctionOptions)(nil) + assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) + }) + + t.Run("show with empty options", func(t *testing.T) { + opts := defaultOpts() + assertOptsValidAndSQLEquals(t, opts, `SHOW EXTERNAL FUNCTIONS`) + }) + + t.Run("show with like", func(t *testing.T) { + opts := defaultOpts() + opts.Like = &Like{ + Pattern: String("pattern"), + } + assertOptsValidAndSQLEquals(t, opts, `SHOW EXTERNAL FUNCTIONS LIKE 'pattern'`) + }) +} + +func TestExternalFunctions_Describe(t *testing.T) { + id := RandomSchemaObjectIdentifier() + + defaultOpts := func() *DescribeExternalFunctionOptions { + return &DescribeExternalFunctionOptions{ + name: id, + } + } + + t.Run("validation: nil options", func(t *testing.T) { + opts := (*DescribeExternalFunctionOptions)(nil) + assertOptsInvalidJoinedErrors(t, opts, ErrNilOptions) + }) + + t.Run("validation: incorrect identifier", func(t *testing.T) { + opts := defaultOpts() + opts.name = NewSchemaObjectIdentifier("", "", "") + assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) + }) + + t.Run("no arguments", func(t *testing.T) { + opts := defaultOpts() + assertOptsValidAndSQLEquals(t, opts, `DESCRIBE FUNCTION %s ()`, id.FullyQualifiedName()) + }) + + t.Run("all options", func(t *testing.T) { + opts := defaultOpts() + opts.ArgumentDataTypes = []DataType{DataTypeVARCHAR, DataTypeNumber} + assertOptsValidAndSQLEquals(t, opts, `DESCRIBE FUNCTION %s (VARCHAR, NUMBER)`, id.FullyQualifiedName()) + }) +} diff --git a/pkg/sdk/external_functions_impl_gen.go b/pkg/sdk/external_functions_impl_gen.go new file mode 100644 index 0000000000..a360be4191 --- /dev/null +++ b/pkg/sdk/external_functions_impl_gen.go @@ -0,0 +1,211 @@ +package sdk + +import ( + "context" + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/internal/collections" +) + +var _ ExternalFunctions = (*externalFunctions)(nil) + +type externalFunctions struct { + client *Client +} + +func (v *externalFunctions) Create(ctx context.Context, request *CreateExternalFunctionRequest) error { + opts := request.toOpts() + return validateAndExec(v.client, ctx, opts) +} + +func (v *externalFunctions) Alter(ctx context.Context, request *AlterExternalFunctionRequest) error { + opts := request.toOpts() + return validateAndExec(v.client, ctx, opts) +} + +func (v *externalFunctions) Show(ctx context.Context, request *ShowExternalFunctionRequest) ([]ExternalFunction, error) { + opts := request.toOpts() + dbRows, err := validateAndQuery[externalFunctionRow](v.client, ctx, opts) + if err != nil { + return nil, err + } + resultList := convertRows[externalFunctionRow, ExternalFunction](dbRows) + return resultList, nil +} + +func (v *externalFunctions) ShowByID(ctx context.Context, id SchemaObjectIdentifier, arguments []DataType) (*ExternalFunction, error) { + externalFunctions, err := v.Show(ctx, NewShowExternalFunctionRequest().WithLike(&Like{Pattern: String(id.Name())})) + if err != nil { + return nil, err + } + return collections.FindOne(externalFunctions, func(r ExternalFunction) bool { + database := strings.Trim(r.CatalogName, `"`) + schema := strings.Trim(r.SchemaName, `"`) + if r.Name != id.Name() || database != id.DatabaseName() || schema != id.SchemaName() { + return false + } + var sb strings.Builder + sb.WriteString("(") + for i, argument := range arguments { + sb.WriteString(string(argument)) + if i < len(arguments)-1 { + sb.WriteString(", ") + } + } + sb.WriteString(")") + return strings.Contains(r.Arguments, sb.String()) + }) +} + +func (v *externalFunctions) Describe(ctx context.Context, request *DescribeExternalFunctionRequest) ([]ExternalFunctionProperty, error) { + opts := request.toOpts() + rows, err := validateAndQuery[externalFunctionPropertyRow](v.client, ctx, opts) + if err != nil { + return nil, err + } + return convertRows[externalFunctionPropertyRow, ExternalFunctionProperty](rows), nil +} + +func (r *CreateExternalFunctionRequest) toOpts() *CreateExternalFunctionOptions { + opts := &CreateExternalFunctionOptions{ + OrReplace: r.OrReplace, + Secure: r.Secure, + name: r.name, + + ResultDataType: r.ResultDataType, + ReturnNullValues: r.ReturnNullValues, + NullInputBehavior: r.NullInputBehavior, + ReturnResultsBehavior: r.ReturnResultsBehavior, + Comment: r.Comment, + ApiIntegration: r.ApiIntegration, + + MaxBatchRows: r.MaxBatchRows, + Compression: r.Compression, + RequestTranslator: r.RequestTranslator, + ResponseTranslator: r.ResponseTranslator, + As: r.As, + } + if r.Arguments != nil { + s := make([]ExternalFunctionArgument, len(r.Arguments)) + for i, v := range r.Arguments { + s[i] = ExternalFunctionArgument(v) + } + opts.Arguments = s + } + if r.Headers != nil { + s := make([]ExternalFunctionHeader, len(r.Headers)) + for i, v := range r.Headers { + s[i] = ExternalFunctionHeader(v) + } + opts.Headers = s + } + if r.ContextHeaders != nil { + s := make([]ExternalFunctionContextHeader, len(r.ContextHeaders)) + for i, v := range r.ContextHeaders { + s[i] = ExternalFunctionContextHeader(v) + } + opts.ContextHeaders = s + } + return opts +} + +func (r *AlterExternalFunctionRequest) toOpts() *AlterExternalFunctionOptions { + opts := &AlterExternalFunctionOptions{ + IfExists: r.IfExists, + name: r.name, + ArgumentDataTypes: r.ArgumentDataTypes, + } + if r.Set != nil { + opts.Set = &ExternalFunctionSet{ + ApiIntegration: r.Set.ApiIntegration, + + MaxBatchRows: r.Set.MaxBatchRows, + Compression: r.Set.Compression, + RequestTranslator: r.Set.RequestTranslator, + ResponseTranslator: r.Set.ResponseTranslator, + } + if r.Set.Headers != nil { + s := make([]ExternalFunctionHeader, len(r.Set.Headers)) + for i, v := range r.Set.Headers { + s[i] = ExternalFunctionHeader(v) + } + opts.Set.Headers = s + } + if r.Set.ContextHeaders != nil { + s := make([]ExternalFunctionContextHeader, len(r.Set.ContextHeaders)) + for i, v := range r.Set.ContextHeaders { + s[i] = ExternalFunctionContextHeader(v) + } + opts.Set.ContextHeaders = s + } + } + if r.Unset != nil { + opts.Unset = &ExternalFunctionUnset{ + Comment: r.Unset.Comment, + Headers: r.Unset.Headers, + ContextHeaders: r.Unset.ContextHeaders, + MaxBatchRows: r.Unset.MaxBatchRows, + Compression: r.Unset.Compression, + Secure: r.Unset.Secure, + RequestTranslator: r.Unset.RequestTranslator, + ResponseTranslator: r.Unset.ResponseTranslator, + } + } + return opts +} + +func (r *ShowExternalFunctionRequest) toOpts() *ShowExternalFunctionOptions { + opts := &ShowExternalFunctionOptions{ + Like: r.Like, + } + return opts +} + +func (r externalFunctionRow) convert() *ExternalFunction { + e := &ExternalFunction{ + CreatedOn: r.CreatedOn, + Name: r.Name, + IsBuiltin: r.IsBuiltin == "Y", + IsAggregate: r.IsAggregate == "Y", + IsAnsi: r.IsAnsi == "Y", + MinNumArguments: r.MinNumArguments, + MaxNumArguments: r.MaxNumArguments, + Arguments: r.Arguments, + Description: r.Description, + IsTableFunction: r.IsTableFunction == "Y", + ValidForClustering: r.ValidForClustering == "Y", + IsExternalFunction: r.IsExternalFunction == "Y", + Language: r.Language, + } + if r.SchemaName.Valid { + e.SchemaName = r.SchemaName.String + } + if r.CatalogName.Valid { + e.CatalogName = r.CatalogName.String + } + if r.IsSecure.Valid { + e.IsSecure = r.IsSecure.String == "Y" + } + if r.IsMemoizable.Valid { + e.IsMemoizable = r.IsMemoizable.String == "Y" + } + if r.IsDataMetric.Valid { + e.IsDataMetric = r.IsDataMetric.String == "Y" + } + return e +} + +func (r *DescribeExternalFunctionRequest) toOpts() *DescribeExternalFunctionOptions { + opts := &DescribeExternalFunctionOptions{ + name: r.name, + ArgumentDataTypes: r.ArgumentDataTypes, + } + return opts +} + +func (r externalFunctionPropertyRow) convert() *ExternalFunctionProperty { + return &ExternalFunctionProperty{ + Property: r.Property, + Value: r.Value, + } +} diff --git a/pkg/sdk/external_functions_validations_gen.go b/pkg/sdk/external_functions_validations_gen.go new file mode 100644 index 0000000000..433da9ddd3 --- /dev/null +++ b/pkg/sdk/external_functions_validations_gen.go @@ -0,0 +1,74 @@ +package sdk + +var ( + _ validatable = new(CreateExternalFunctionOptions) + _ validatable = new(AlterExternalFunctionOptions) + _ validatable = new(ShowExternalFunctionOptions) + _ validatable = new(DescribeExternalFunctionOptions) +) + +func (opts *CreateExternalFunctionOptions) validate() error { + if opts == nil { + return ErrNilOptions + } + var errs []error + if !ValidObjectIdentifier(opts.name) { + errs = append(errs, ErrInvalidObjectIdentifier) + } + if !valueSet(opts.ApiIntegration) { + errs = append(errs, errNotSet("CreateExternalFunctionOptions", "ApiIntegration")) + } + if opts.RequestTranslator != nil && !ValidObjectIdentifier(opts.RequestTranslator) { + errs = append(errs, errInvalidIdentifier("CreateExternalFunctionOptions", "RequestTranslator")) + } + if !valueSet(opts.As) { + errs = append(errs, errNotSet("CreateExternalFunctionOptions", "As")) + } + if opts.ResponseTranslator != nil && !ValidObjectIdentifier(opts.ResponseTranslator) { + errs = append(errs, errInvalidIdentifier("CreateExternalFunctionOptions", "ResponseTranslator")) + } + return JoinErrors(errs...) +} + +func (opts *AlterExternalFunctionOptions) validate() error { + if opts == nil { + return ErrNilOptions + } + var errs []error + if !exactlyOneValueSet(opts.Set, opts.Unset) { + errs = append(errs, errExactlyOneOf("AlterExternalFunctionOptions", "Set", "Unset")) + } + if !ValidObjectIdentifier(opts.name) { + errs = append(errs, ErrInvalidObjectIdentifier) + } + if valueSet(opts.Set) { + if !exactlyOneValueSet(opts.Set.ApiIntegration, opts.Set.Headers, opts.Set.ContextHeaders, opts.Set.MaxBatchRows, opts.Set.Compression, opts.Set.RequestTranslator, opts.Set.ResponseTranslator) { + errs = append(errs, errExactlyOneOf("AlterExternalFunctionOptions.Set", "ApiIntegration", "Headers", "ContextHeaders", "MaxBatchRows", "Compression", "RequestTranslator", "ResponseTranslator")) + } + } + if valueSet(opts.Unset) { + if !anyValueSet(opts.Unset.Comment, opts.Unset.Headers, opts.Unset.ContextHeaders, opts.Unset.MaxBatchRows, opts.Unset.Compression, opts.Unset.Secure, opts.Unset.RequestTranslator, opts.Unset.ResponseTranslator) { + errs = append(errs, errAtLeastOneOf("AlterExternalFunctionOptions.Unset", "Comment", "Headers", "ContextHeaders", "MaxBatchRows", "Compression", "Secure", "RequestTranslator", "ResponseTranslator")) + } + } + return JoinErrors(errs...) +} + +func (opts *ShowExternalFunctionOptions) validate() error { + if opts == nil { + return ErrNilOptions + } + var errs []error + return JoinErrors(errs...) +} + +func (opts *DescribeExternalFunctionOptions) validate() error { + if opts == nil { + return ErrNilOptions + } + var errs []error + if !ValidObjectIdentifier(opts.name) { + errs = append(errs, ErrInvalidObjectIdentifier) + } + return JoinErrors(errs...) +} diff --git a/pkg/sdk/poc/main.go b/pkg/sdk/poc/main.go index 132bccd330..4e3cf2a88c 100644 --- a/pkg/sdk/poc/main.go +++ b/pkg/sdk/poc/main.go @@ -36,6 +36,7 @@ var definitionMapping = map[string]*generator.Interface{ "materialized_views_def.go": sdk.MaterializedViewsDef, "api_integrations_def.go": sdk.ApiIntegrationsDef, "notification_integrations_def.go": sdk.NotificationIntegrationsDef, + "external_functions_def.go": sdk.ExternalFunctionsDef, "streamlits_def.go": sdk.StreamlitsDef, } diff --git a/pkg/sdk/testint/external_functions_integration_test.go b/pkg/sdk/testint/external_functions_integration_test.go new file mode 100644 index 0000000000..4d363fc57b --- /dev/null +++ b/pkg/sdk/testint/external_functions_integration_test.go @@ -0,0 +1,278 @@ +package testint + +import ( + "context" + "fmt" + "testing" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/internal/random" + "github.com/stretchr/testify/require" +) + +func TestInt_ExternalFunctions(t *testing.T) { + client := testClient(t) + ctx := context.Background() + + defaultDataTypes := []sdk.DataType{sdk.DataTypeVARCHAR} + + databaseTest, schemaTest := testDb(t), testSchema(t) + + cleanupExternalFuncionHandle := func(id sdk.SchemaObjectIdentifier, dts []sdk.DataType) func() { + return func() { + err := client.Functions.Drop(ctx, sdk.NewDropFunctionRequest(id, dts).WithIfExists(sdk.Bool(true))) + require.NoError(t, err) + } + } + + assertExternalFunction := func(t *testing.T, id sdk.SchemaObjectIdentifier, secure bool, dts []sdk.DataType) { + t.Helper() + + e, err := client.ExternalFunctions.ShowByID(ctx, id, dts) + require.NoError(t, err) + + require.NotEmpty(t, e.CreatedOn) + require.Equal(t, id.Name(), e.Name) + require.Equal(t, fmt.Sprintf(`"%v"`, id.SchemaName()), e.SchemaName) + require.Equal(t, false, e.IsBuiltin) + require.Equal(t, false, e.IsAggregate) + require.Equal(t, false, e.IsAnsi) + if len(dts) > 0 { + require.Equal(t, 1, e.MinNumArguments) + require.Equal(t, 1, e.MaxNumArguments) + } else { + require.Equal(t, 0, e.MinNumArguments) + require.Equal(t, 0, e.MaxNumArguments) + } + require.NotEmpty(t, e.Arguments) + require.NotEmpty(t, e.Description) + require.NotEmpty(t, e.CatalogName) + require.Equal(t, false, e.IsTableFunction) + require.Equal(t, false, e.ValidForClustering) + require.Equal(t, secure, e.IsSecure) + require.Equal(t, true, e.IsExternalFunction) + require.Equal(t, "EXTERNAL", e.Language) + require.Equal(t, false, e.IsMemoizable) + require.Equal(t, false, e.IsDataMetric) + } + + createApiIntegrationHandle := func(t *testing.T, id sdk.AccountObjectIdentifier) { + t.Helper() + + _, err := client.ExecForTests(ctx, fmt.Sprintf(`CREATE API INTEGRATION %s API_PROVIDER = aws_api_gateway API_AWS_ROLE_ARN = 'arn:aws:iam::123456789012:role/hello_cloud_account_role' API_ALLOWED_PREFIXES = ('https://xyz.execute-api.us-west-2.amazonaws.com/production') ENABLED = true`, id.FullyQualifiedName())) + require.NoError(t, err) + t.Cleanup(func() { + _, err = client.ExecForTests(ctx, fmt.Sprintf(`DROP API INTEGRATION %s`, id.FullyQualifiedName())) + require.NoError(t, err) + }) + } + + createExternalFunction := func(t *testing.T, dt sdk.DataType) *sdk.ExternalFunction { + t.Helper() + + integration := sdk.NewAccountObjectIdentifier(random.AlphaN(4)) + createApiIntegrationHandle(t, integration) + + id := sdk.NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, random.StringN(4)) + argument := sdk.NewExternalFunctionArgumentRequest("x", dt) + as := "https://xyz.execute-api.us-west-2.amazonaws.com/production/remote_echo" + request := sdk.NewCreateExternalFunctionRequest(id, sdk.DataTypeVariant, &integration, as). + WithOrReplace(sdk.Bool(true)). + WithSecure(sdk.Bool(true)). + WithArguments([]sdk.ExternalFunctionArgumentRequest{*argument}) + err := client.ExternalFunctions.Create(ctx, request) + require.NoError(t, err) + t.Cleanup(cleanupExternalFuncionHandle(id, []sdk.DataType{sdk.DataTypeVariant})) + + e, err := client.ExternalFunctions.ShowByID(ctx, id, defaultDataTypes) + require.NoError(t, err) + return e + } + + t.Run("create external function", func(t *testing.T) { + integration := sdk.NewAccountObjectIdentifier(random.AlphaN(4)) + createApiIntegrationHandle(t, integration) + + id := sdk.NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, random.StringN(4)) + argument := sdk.NewExternalFunctionArgumentRequest("x", sdk.DataTypeVARCHAR) + headers := []sdk.ExternalFunctionHeaderRequest{ + { + Name: "measure", + Value: "kilometers", + }, + } + ch := []sdk.ExternalFunctionContextHeaderRequest{ + { + ContextFunction: "CURRENT_DATE", + }, + { + ContextFunction: "CURRENT_TIMESTAMP", + }, + } + as := "https://xyz.execute-api.us-west-2.amazonaws.com/production/remote_echo" + request := sdk.NewCreateExternalFunctionRequest(id, sdk.DataTypeVariant, &integration, as). + WithOrReplace(sdk.Bool(true)). + WithSecure(sdk.Bool(true)). + WithArguments([]sdk.ExternalFunctionArgumentRequest{*argument}). + WithNullInputBehavior(sdk.NullInputBehaviorPointer(sdk.NullInputBehaviorCalledOnNullInput)). + WithHeaders(headers). + WithContextHeaders(ch). + WithMaxBatchRows(sdk.Int(10)). + WithCompression(sdk.String("GZIP")) + err := client.ExternalFunctions.Create(ctx, request) + require.NoError(t, err) + t.Cleanup(cleanupExternalFuncionHandle(id, []sdk.DataType{sdk.DataTypeVariant})) + + assertExternalFunction(t, id, true, defaultDataTypes) + }) + + t.Run("create external function without arguments", func(t *testing.T) { + integration := sdk.NewAccountObjectIdentifier(random.AlphaN(4)) + createApiIntegrationHandle(t, integration) + + id := sdk.NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, random.StringN(4)) + as := "https://xyz.execute-api.us-west-2.amazonaws.com/production/remote_echo" + request := sdk.NewCreateExternalFunctionRequest(id, sdk.DataTypeVariant, &integration, as) + err := client.ExternalFunctions.Create(ctx, request) + require.NoError(t, err) + t.Cleanup(cleanupExternalFuncionHandle(id, nil)) + + assertExternalFunction(t, id, false, nil) + }) + + t.Run("alter external function: set api integration", func(t *testing.T) { + e := createExternalFunction(t, sdk.DataTypeVARCHAR) + + integration := sdk.NewAccountObjectIdentifier(random.AlphaN(5)) + createApiIntegrationHandle(t, integration) + + id := sdk.NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, e.Name) + set := sdk.NewExternalFunctionSetRequest(). + WithApiIntegration(&integration) + request := sdk.NewAlterExternalFunctionRequest(id, defaultDataTypes).WithSet(set) + err := client.ExternalFunctions.Alter(ctx, request) + require.NoError(t, err) + + assertExternalFunction(t, id, true, defaultDataTypes) + }) + + t.Run("alter external function: set headers", func(t *testing.T) { + e := createExternalFunction(t, sdk.DataTypeVARCHAR) + + id := sdk.NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, e.Name) + headers := []sdk.ExternalFunctionHeaderRequest{ + { + Name: "measure", + Value: "kilometers", + }, + } + set := sdk.NewExternalFunctionSetRequest().WithHeaders(headers) + request := sdk.NewAlterExternalFunctionRequest(id, defaultDataTypes).WithSet(set) + err := client.ExternalFunctions.Alter(ctx, request) + require.NoError(t, err) + assertExternalFunction(t, id, true, defaultDataTypes) + }) + + t.Run("alter external function: set context headers", func(t *testing.T) { + e := createExternalFunction(t, sdk.DataTypeVARCHAR) + + id := sdk.NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, e.Name) + ch := []sdk.ExternalFunctionContextHeaderRequest{ + { + ContextFunction: "CURRENT_DATE", + }, + { + ContextFunction: "CURRENT_TIMESTAMP", + }, + } + set := sdk.NewExternalFunctionSetRequest().WithContextHeaders(ch) + request := sdk.NewAlterExternalFunctionRequest(id, defaultDataTypes).WithSet(set) + err := client.ExternalFunctions.Alter(ctx, request) + require.NoError(t, err) + assertExternalFunction(t, id, true, defaultDataTypes) + }) + + t.Run("alter external function: set compression", func(t *testing.T) { + e := createExternalFunction(t, sdk.DataTypeVARCHAR) + + id := sdk.NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, e.Name) + set := sdk.NewExternalFunctionSetRequest().WithCompression(sdk.String("AUTO")) + request := sdk.NewAlterExternalFunctionRequest(id, defaultDataTypes).WithSet(set) + err := client.ExternalFunctions.Alter(ctx, request) + require.NoError(t, err) + assertExternalFunction(t, id, true, defaultDataTypes) + }) + + t.Run("alter external function: set max batch rows", func(t *testing.T) { + e := createExternalFunction(t, sdk.DataTypeVARCHAR) + + id := sdk.NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, e.Name) + set := sdk.NewExternalFunctionSetRequest().WithMaxBatchRows(sdk.Int(20)) + request := sdk.NewAlterExternalFunctionRequest(id, defaultDataTypes).WithSet(set) + err := client.ExternalFunctions.Alter(ctx, request) + require.NoError(t, err) + assertExternalFunction(t, id, true, defaultDataTypes) + }) + + t.Run("alter external function: unset", func(t *testing.T) { + e := createExternalFunction(t, sdk.DataTypeVARCHAR) + + id := sdk.NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, e.Name) + unset := sdk.NewExternalFunctionUnsetRequest(). + WithComment(sdk.Bool(true)). + WithHeaders(sdk.Bool(true)) + request := sdk.NewAlterExternalFunctionRequest(id, defaultDataTypes).WithUnset(unset) + err := client.ExternalFunctions.Alter(ctx, request) + require.NoError(t, err) + + assertExternalFunction(t, id, true, defaultDataTypes) + }) + + t.Run("show external function: with like", func(t *testing.T) { + e1 := createExternalFunction(t, sdk.DataTypeVARCHAR) + e2 := createExternalFunction(t, sdk.DataTypeVARCHAR) + + es, err := client.ExternalFunctions.Show(ctx, sdk.NewShowExternalFunctionRequest().WithLike(&sdk.Like{Pattern: sdk.String(e1.Name)})) + require.NoError(t, err) + + require.Equal(t, 1, len(es)) + require.Contains(t, es, *e1) + require.NotContains(t, es, *e2) + }) + + t.Run("show external function: no matches", func(t *testing.T) { + es, err := client.ExternalFunctions.Show(ctx, sdk.NewShowExternalFunctionRequest().WithLike(&sdk.Like{Pattern: sdk.String(random.String())})) + require.NoError(t, err) + require.Equal(t, 0, len(es)) + }) + + t.Run("show external function by id", func(t *testing.T) { + e := createExternalFunction(t, sdk.DataTypeVARCHAR) + + id := sdk.NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, e.Name) + es, err := client.ExternalFunctions.ShowByID(ctx, id, []sdk.DataType{sdk.DataTypeVARCHAR}) + require.NoError(t, err) + require.Equal(t, *e, *es) + + _, err = client.ExternalFunctions.ShowByID(ctx, id, nil) + require.Error(t, err, sdk.ErrObjectNotExistOrAuthorized) + }) + + t.Run("describe external function", func(t *testing.T) { + e := createExternalFunction(t, sdk.DataTypeVARCHAR) + id := sdk.NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, e.Name) + + request := sdk.NewDescribeExternalFunctionRequest(id, []sdk.DataType{sdk.DataTypeVARCHAR}) + details, err := client.ExternalFunctions.Describe(ctx, request) + require.NoError(t, err) + pairs := make(map[string]string) + for _, detail := range details { + pairs[detail.Property] = detail.Value + } + require.Equal(t, "EXTERNAL", pairs["language"]) + require.Equal(t, "VARIANT", pairs["returns"]) + require.Equal(t, "VOLATILE", pairs["volatility"]) + require.Equal(t, "AUTO", pairs["compression"]) + require.Equal(t, "(X VARCHAR)", pairs["signature"]) + }) +} diff --git a/pkg/sdk/testint/functions_integration_test.go b/pkg/sdk/testint/functions_integration_test.go index e4645b6661..d8861f2976 100644 --- a/pkg/sdk/testint/functions_integration_test.go +++ b/pkg/sdk/testint/functions_integration_test.go @@ -403,7 +403,6 @@ func TestInt_OtherFunctions(t *testing.T) { functions, err := client.Functions.Show(ctx, sdk.NewShowFunctionRequest()) require.NoError(t, err) - require.Equal(t, 2, len(functions)) require.Contains(t, functions, *f1) require.Contains(t, functions, *f2) }) From 77de5694f73e5ad1443bb99407d2e8aec9a87320 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Cie=C5=9Blak?= Date: Thu, 1 Feb 2024 14:04:30 +0100 Subject: [PATCH 2/5] chore: add missing deprecation message (#2451) --- pkg/resources/database_grant.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/resources/database_grant.go b/pkg/resources/database_grant.go index 06efe42b7f..cf80147dc3 100644 --- a/pkg/resources/database_grant.go +++ b/pkg/resources/database_grant.go @@ -86,7 +86,8 @@ func DatabaseGrant() *TerraformGrantResource { Delete: DeleteDatabaseGrant, Update: UpdateDatabaseGrant, - Schema: databaseGrantSchema, + DeprecationMessage: "This resource is deprecated and will be removed in a future major version release. Please use snowflake_grant_privileges_to_account_role instead.", + Schema: databaseGrantSchema, Importer: &schema.ResourceImporter{ StateContext: func(ctx context.Context, d *schema.ResourceData, m interface{}) ([]*schema.ResourceData, error) { parts := strings.Split(d.Id(), helpers.IDDelimiter) From ca8ca635c4c31a43028c6258960239657a3b5875 Mon Sep 17 00:00:00 2001 From: "snowflake-release-please[bot]" <105954990+snowflake-release-please[bot]@users.noreply.github.com> Date: Fri, 2 Feb 2024 11:01:19 +0100 Subject: [PATCH 3/5] chore(main): release 0.85.0 (#2401) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit :robot: I have created a release *beep* *boop* --- ## [0.85.0](https://github.com/Snowflake-Labs/terraform-provider-snowflake/compare/v0.84.1...v0.85.0) (2024-02-01) ### 🎉 **What's new:** * Add API integration to the SDK ([#2409](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2409)) ([23acda5](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/23acda5dba9c8378f3b5631446d380a27cf1732c)) * add application to sdk ([#2350](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2350)) ([de97ad8](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/de97ad84db925b62ab10046e0893a5c285a26d67)) * add external funcs to sdk ([#2440](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2440)) ([c8cf09b](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/c8cf09b2af605dc373a138e6ca6863b5546303d5)) * Add grant privileges to share resource ([#2447](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2447)) ([d8241a5](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/d8241a5cc76ea7b929abdada81cf6929b5f6ad9e)) * Add materialized view to the SDK ([#2403](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2403)) ([a5ce699](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/a5ce69920328cce899260249d319ff7726ae3911)) * Add notification integration to the SDK ([#2412](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2412)) ([d84240c](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/d84240cda369ed9106c7cb3e3eedf85b8d1fa944)) * add sequences to sdk ([#2351](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2351)) ([d2e5ffd](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/d2e5ffd5405f10ff30c5ad9f7cd58bd54a5cc028)) * add snowflake grant privileges to account role ([#2365](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2365)) ([e3d086e](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/e3d086eddc05e0d4963234f82e09e174a018bb08)) * add streamlits to sdk ([#2400](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2400)) ([129d24c](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/129d24c00fa244d1401cb2169b5b7fb0ba6c465c)) * add-call-with to sdk ([#2337](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2337)) ([ebcd1bc](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/ebcd1bc40d554abe6863b67d2ab76f2d992dfb32)) * stages migration follow-up ([#2372](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2372)) ([3939dbe](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/3939dbe2f9189968c087a883ed97dd3b7350787f)) * Use API integration from SDK ([#2429](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2429)) ([1ccc864](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/1ccc8641106a3ceb4de813ce7c0e5077ead5272e)) * Use managed account from the SDK ([#2420](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2420)) ([3aaa080](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/3aaa08071a14f820e08751cc7b1e8bef5db16e30)) * Use materialized views and views from SDK ([#2448](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2448)) ([dc66d02](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/dc66d02304a99a7cb152e91a8e942587cab7e60f)) * Use notification integration from sdk ([#2445](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2445)) ([e8915cc](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/e8915ccb99eeec1f0ac5777fe80be7ef443d8f5c)) * use roles from the SDK ([#2405](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2405)) ([c645b4d](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/c645b4d0e2036d932766480e9c1e0334ef79c16e)) * Use row access policy from SDK ([#2428](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2428)) ([119af5e](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/119af5ea74bb219ae822962096e6220ed00f5910)) * Use SDK in the storage integration ([#2380](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2380)) ([ce0741c](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/ce0741ce226be9464407b549e90cb179b0fe5880)) * use sequence from sdk and add ordering attr ([#2419](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2419)) ([973b8f7](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/973b8f76a8ed1540bfd948ba8cb57c212c0d4abc)), closes [#2387](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2387) * Use stage from sdk ([#2427](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2427)) ([c17effd](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/c17effd16ccd77ba4c5d45f43dcc53a9f11601c6)) ### 🔧 **Misc** * add missing deprecation message ([#2451](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2451)) ([77de569](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/77de5694f73e5ad1443bb99407d2e8aec9a87320)) ### 🐛 **Bug fixes:** * account role test ([#2422](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2422)) ([c1b47d1](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/c1b47d1ade4b198b5bf14dc32162d34797a3b344)) * Adjust tests after Snowflake behavior change ([#2404](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2404)) ([8c03ffb](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/8c03ffb0430445c903168da9706e1ce2630675da)) * app-pkg unset ([#2399](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2399)) ([fedb1df](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/fedb1df2a731d139d68d2284bf3be47fcc4d0115)) * Fix some bugs ([#2421](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2421)) ([dec7cd9](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/dec7cd9e199ac8658f5c939f811686ba9f5e2e21)), closes [#2358](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2358) [#2369](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2369) [#2329](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2329) * snowflake_grant_privileges_to_role read ([#2424](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2424)) ([5385cec](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/5385cec3e5c03d2dbff762b63523bdddee8632d3)) --- This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please). Co-authored-by: snowflake-release-please[bot] <105954990+snowflake-release-please[bot]@users.noreply.github.com> --- CHANGELOG.md | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ae30fb6ac9..14096e1a2f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,45 @@ # Changelog +## [0.85.0](https://github.com/Snowflake-Labs/terraform-provider-snowflake/compare/v0.84.1...v0.85.0) (2024-02-01) + + +### 🎉 **What's new:** + +* Add API integration to the SDK ([#2409](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2409)) ([23acda5](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/23acda5dba9c8378f3b5631446d380a27cf1732c)) +* add application to sdk ([#2350](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2350)) ([de97ad8](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/de97ad84db925b62ab10046e0893a5c285a26d67)) +* add external funcs to sdk ([#2440](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2440)) ([c8cf09b](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/c8cf09b2af605dc373a138e6ca6863b5546303d5)) +* Add grant privileges to share resource ([#2447](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2447)) ([d8241a5](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/d8241a5cc76ea7b929abdada81cf6929b5f6ad9e)) +* Add materialized view to the SDK ([#2403](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2403)) ([a5ce699](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/a5ce69920328cce899260249d319ff7726ae3911)) +* Add notification integration to the SDK ([#2412](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2412)) ([d84240c](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/d84240cda369ed9106c7cb3e3eedf85b8d1fa944)) +* add sequences to sdk ([#2351](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2351)) ([d2e5ffd](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/d2e5ffd5405f10ff30c5ad9f7cd58bd54a5cc028)) +* add snowflake grant privileges to account role ([#2365](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2365)) ([e3d086e](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/e3d086eddc05e0d4963234f82e09e174a018bb08)) +* add streamlits to sdk ([#2400](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2400)) ([129d24c](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/129d24c00fa244d1401cb2169b5b7fb0ba6c465c)) +* add-call-with to sdk ([#2337](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2337)) ([ebcd1bc](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/ebcd1bc40d554abe6863b67d2ab76f2d992dfb32)) +* stages migration follow-up ([#2372](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2372)) ([3939dbe](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/3939dbe2f9189968c087a883ed97dd3b7350787f)) +* Use API integration from SDK ([#2429](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2429)) ([1ccc864](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/1ccc8641106a3ceb4de813ce7c0e5077ead5272e)) +* Use managed account from the SDK ([#2420](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2420)) ([3aaa080](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/3aaa08071a14f820e08751cc7b1e8bef5db16e30)) +* Use materialized views and views from SDK ([#2448](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2448)) ([dc66d02](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/dc66d02304a99a7cb152e91a8e942587cab7e60f)) +* Use notification integration from sdk ([#2445](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2445)) ([e8915cc](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/e8915ccb99eeec1f0ac5777fe80be7ef443d8f5c)) +* use roles from the SDK ([#2405](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2405)) ([c645b4d](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/c645b4d0e2036d932766480e9c1e0334ef79c16e)) +* Use row access policy from SDK ([#2428](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2428)) ([119af5e](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/119af5ea74bb219ae822962096e6220ed00f5910)) +* Use SDK in the storage integration ([#2380](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2380)) ([ce0741c](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/ce0741ce226be9464407b549e90cb179b0fe5880)) +* use sequence from sdk and add ordering attr ([#2419](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2419)) ([973b8f7](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/973b8f76a8ed1540bfd948ba8cb57c212c0d4abc)), closes [#2387](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2387) +* Use stage from sdk ([#2427](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2427)) ([c17effd](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/c17effd16ccd77ba4c5d45f43dcc53a9f11601c6)) + + +### 🔧 **Misc** + +* add missing deprecation message ([#2451](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2451)) ([77de569](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/77de5694f73e5ad1443bb99407d2e8aec9a87320)) + + +### 🐛 **Bug fixes:** + +* account role test ([#2422](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2422)) ([c1b47d1](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/c1b47d1ade4b198b5bf14dc32162d34797a3b344)) +* Adjust tests after Snowflake behavior change ([#2404](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2404)) ([8c03ffb](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/8c03ffb0430445c903168da9706e1ce2630675da)) +* app-pkg unset ([#2399](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2399)) ([fedb1df](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/fedb1df2a731d139d68d2284bf3be47fcc4d0115)) +* Fix some bugs ([#2421](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2421)) ([dec7cd9](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/dec7cd9e199ac8658f5c939f811686ba9f5e2e21)), closes [#2358](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2358) [#2369](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2369) [#2329](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2329) +* snowflake_grant_privileges_to_role read ([#2424](https://github.com/Snowflake-Labs/terraform-provider-snowflake/issues/2424)) ([5385cec](https://github.com/Snowflake-Labs/terraform-provider-snowflake/commit/5385cec3e5c03d2dbff762b63523bdddee8632d3)) + ## [0.85.0](https://github.com/Snowflake-Labs/terraform-provider-snowflake/compare/v0.84.1...v0.85.0) (2024-01-22) From fdb4f88ade09c9a1d63029b1937f1ef87528db8d Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Fri, 2 Feb 2024 14:45:12 +0100 Subject: [PATCH 4/5] feat: Use tables from SDK (#2453) - Migrate tables resource and datasource - Remove old implementation - Fix issue with masking policies update (check #2186) --- docs/resources/table.md | 2 +- pkg/datasources/tables.go | 44 +- pkg/datasources/tables_acceptance_test.go | 9 +- pkg/resources/materialized_view.go | 6 +- pkg/resources/stream.go | 8 +- pkg/resources/table.go | 676 ++++++++++-------- pkg/resources/table_acceptance_test.go | 194 ++++-- pkg/resources/table_internal_test.go | 72 -- pkg/resources/tag.go | 27 - pkg/sdk/tables.go | 45 +- pkg/sdk/tables_dto.go | 14 +- pkg/sdk/tables_dto_generated.go | 43 +- pkg/sdk/tables_impl.go | 2 + pkg/sdk/tables_test.go | 55 +- pkg/sdk/tables_validations.go | 6 - pkg/sdk/testint/tables_integration_test.go | 48 +- pkg/snowflake/table.go | 760 +-------------------- 17 files changed, 714 insertions(+), 1297 deletions(-) delete mode 100644 pkg/resources/table_internal_test.go diff --git a/docs/resources/table.md b/docs/resources/table.md index fee4fc3856..3f24c433cf 100644 --- a/docs/resources/table.md +++ b/docs/resources/table.md @@ -118,7 +118,7 @@ Optional: - `comment` (String) Column comment - `default` (Block List, Max: 1) Defines the column default value; note due to limitations of Snowflake's ALTER TABLE ADD/MODIFY COLUMN updates to default will not be applied (see [below for nested schema](#nestedblock--column--default)) - `identity` (Block List, Max: 1) Defines the identity start/step values for a column. **Note** Identity/default are mutually exclusive. (see [below for nested schema](#nestedblock--column--identity)) -- `masking_policy` (String) Masking policy to apply on column +- `masking_policy` (String) Masking policy to apply on column. It has to be a fully qualified name. - `nullable` (Boolean) Whether this column can contain null values. **Note**: Depending on your Snowflake version, the default value will not suffice if this column is used in a primary key constraint. diff --git a/pkg/datasources/tables.go b/pkg/datasources/tables.go index 44b9992dc4..a04d804527 100644 --- a/pkg/datasources/tables.go +++ b/pkg/datasources/tables.go @@ -1,12 +1,12 @@ package datasources import ( + "context" "database/sql" - "errors" - "fmt" "log" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) @@ -58,38 +58,38 @@ func Tables() *schema.Resource { func ReadTables(d *schema.ResourceData, meta interface{}) error { db := meta.(*sql.DB) + ctx := context.Background() + client := sdk.NewClientFromDB(db) databaseName := d.Get("database").(string) schemaName := d.Get("schema").(string) - currentTables, err := snowflake.ListTables(databaseName, schemaName, db) - if errors.Is(err, sql.ErrNoRows) { - // If not found, mark resource to be removed from state file during apply or refresh - log.Printf("[DEBUG] tables in schema (%s) not found", d.Id()) - d.SetId("") - return nil - } else if err != nil { - log.Printf("[DEBUG] unable to parse tables in schema (%s)", d.Id()) + schemaId := sdk.NewDatabaseObjectIdentifier(databaseName, schemaName) + extractedTables, err := client.Tables.Show(ctx, sdk.NewShowTableRequest().WithIn( + &sdk.In{Schema: schemaId}, + )) + if err != nil { + log.Printf("[DEBUG] failed when searching tables in schema (%s), err = %s", schemaId.FullyQualifiedName(), err.Error()) d.SetId("") return nil } - tables := []map[string]interface{}{} + tables := make([]map[string]any, 0) - for _, table := range currentTables { - tableMap := map[string]interface{}{} - - if table.IsExternal.String == "Y" { + for _, extractedTable := range extractedTables { + if extractedTable.IsExternal { continue } - tableMap["name"] = table.TableName.String - tableMap["database"] = table.DatabaseName.String - tableMap["schema"] = table.SchemaName.String - tableMap["comment"] = table.Comment.String + table := map[string]any{ + "name": extractedTable.Name, + "database": extractedTable.DatabaseName, + "schema": extractedTable.SchemaName, + "comment": extractedTable.Comment, + } - tables = append(tables, tableMap) + tables = append(tables, table) } - d.SetId(fmt.Sprintf(`%v|%v`, databaseName, schemaName)) + d.SetId(helpers.EncodeSnowflakeID(databaseName, schemaName)) return d.Set("tables", tables) } diff --git a/pkg/datasources/tables_acceptance_test.go b/pkg/datasources/tables_acceptance_test.go index 62aa32daae..9205e3a033 100644 --- a/pkg/datasources/tables_acceptance_test.go +++ b/pkg/datasources/tables_acceptance_test.go @@ -5,8 +5,11 @@ import ( "strings" "testing" + acc "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance" + "github.com/hashicorp/terraform-plugin-testing/helper/acctest" "github.com/hashicorp/terraform-plugin-testing/helper/resource" + "github.com/hashicorp/terraform-plugin-testing/tfversion" ) func TestAcc_Tables(t *testing.T) { @@ -16,7 +19,11 @@ func TestAcc_Tables(t *testing.T) { stageName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) externalTableName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) resource.ParallelTest(t, resource.TestCase{ - Providers: providers(), + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, CheckDestroy: nil, Steps: []resource.TestStep{ { diff --git a/pkg/resources/materialized_view.go b/pkg/resources/materialized_view.go index fe7060393c..91ce9bd080 100644 --- a/pkg/resources/materialized_view.go +++ b/pkg/resources/materialized_view.go @@ -206,13 +206,13 @@ func UpdateMaterializedView(d *schema.ResourceData, meta interface{}) error { unsetRequest := sdk.NewMaterializedViewUnsetRequest() if d.HasChange("comment") { - comment := d.Get("comment") - if c := comment.(string); c == "" { + comment := d.Get("comment").(string) + if comment == "" { runUnsetStatement = true unsetRequest.WithComment(sdk.Bool(true)) } else { runSetStatement = true - setRequest.WithComment(sdk.String(d.Get("comment").(string))) + setRequest.WithComment(sdk.String(comment)) } } if d.HasChange("is_secure") { diff --git a/pkg/resources/stream.go b/pkg/resources/stream.go index c1ffb59d0f..7cd22ef87f 100644 --- a/pkg/resources/stream.go +++ b/pkg/resources/stream.go @@ -9,8 +9,6 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" - - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) @@ -131,14 +129,12 @@ func CreateStream(d *schema.ResourceData, meta interface{}) error { } tableId := tableObjectIdentifier.(sdk.SchemaObjectIdentifier) - tq := snowflake.NewTableBuilder(tableId.Name(), tableId.DatabaseName(), tableId.SchemaName()).Show() - tableRow := snowflake.QueryRow(db, tq) - t, err := snowflake.ScanTable(tableRow) + table, err := client.Tables.ShowByID(ctx, tableId) if err != nil { return err } - if t.IsExternal.String == "Y" { + if table.IsExternal { req := sdk.NewCreateStreamOnExternalTableRequest(id, tableId) if insertOnly { req.WithInsertOnly(sdk.Bool(true)) diff --git a/pkg/resources/table.go b/pkg/resources/table.go index d71dfafe52..00a2c3fe5b 100644 --- a/pkg/resources/table.go +++ b/pkg/resources/table.go @@ -1,24 +1,23 @@ package resources import ( - "bytes" + "context" "database/sql" - "encoding/csv" - "errors" "fmt" "log" "slices" + "strconv" "strings" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" ) -const ( - tableIDDelimiter = '|' -) - +// TODO [SNOW-867235]: old implementation was quoting every column, SDK is not quoting them, therefore they are quoted here: decide if we quote columns or not +// TODO [SNOW-1031688]: move data manipulation logic to the SDK - SQL generation or builders part (e.g. different default types/identity) var tableSchema = map[string]*schema.Schema{ "name": { Type: schema.TypeString, @@ -79,6 +78,7 @@ var tableSchema = map[string]*schema.Schema{ MinItems: 1, MaxItems: 1, Elem: &schema.Resource{ + // TODO [SNOW-867235]: there is no such separation on SDK level. Should we keep it in V1? Schema: map[string]*schema.Schema{ "constant": { Type: schema.TypeString, @@ -137,7 +137,7 @@ var tableSchema = map[string]*schema.Schema{ Type: schema.TypeString, Optional: true, Default: "", - Description: "Masking policy to apply on column", + Description: "Masking policy to apply on column. It has to be a fully qualified name.", }, }, }, @@ -220,73 +220,12 @@ func Table() *schema.Resource { } } -type tableID struct { - DatabaseName string - SchemaName string - TableName string -} - -// String() takes in a tableID object and returns a pipe-delimited string: -// DatabaseName|SchemaName|TableName. -func (si *tableID) String() (string, error) { - var buf bytes.Buffer - csvWriter := csv.NewWriter(&buf) - csvWriter.Comma = tableIDDelimiter - dataIdentifiers := [][]string{{si.DatabaseName, si.SchemaName, si.TableName}} - if err := csvWriter.WriteAll(dataIdentifiers); err != nil { - return "", err - } - strTableID := strings.TrimSpace(buf.String()) - return strTableID, nil -} - -// tableIDFromString() takes in a pipe-delimited string: DatabaseName|SchemaName|TableName -// and returns a tableID object. -func tableIDFromString(stringID string) (*tableID, error) { - reader := csv.NewReader(strings.NewReader(stringID)) - reader.Comma = tableIDDelimiter - lines, err := reader.ReadAll() - if err != nil { - return nil, fmt.Errorf("not CSV compatible") - } - - if len(lines) != 1 { - return nil, fmt.Errorf("1 line at a time") - } - if len(lines[0]) != 3 { - return nil, fmt.Errorf("3 fields allowed") - } - - tableResult := &tableID{ - DatabaseName: lines[0][0], - SchemaName: lines[0][1], - TableName: lines[0][2], - } - return tableResult, nil -} - type columnDefault struct { constant *string expression *string sequence *string } -func (cd *columnDefault) toSnowflakeColumnDefault() *snowflake.ColumnDefault { - if cd.constant != nil { - return snowflake.NewColumnDefaultWithConstant(*cd.constant) - } - - if cd.expression != nil { - return snowflake.NewColumnDefaultWithExpression(*cd.expression) - } - - if cd.sequence != nil { - return snowflake.NewColumnDefaultWithSequence(*cd.sequence) - } - - return nil -} - func (cd *columnDefault) _type() string { if cd.constant != nil { return "constant" @@ -308,11 +247,6 @@ type columnIdentity struct { stepNum int } -func (identity *columnIdentity) toSnowflakeColumnIdentity() *snowflake.ColumnIdentity { - snowIdentity := snowflake.ColumnIdentity{} - return snowIdentity.WithStartNum(identity.startNum).WithStep(identity.stepNum) -} - type column struct { name string dataType string @@ -323,34 +257,8 @@ type column struct { maskingPolicy string } -func (c column) toSnowflakeColumn() snowflake.Column { - sC := &snowflake.Column{} - - if c._default != nil { - sC = sC.WithDefault(c._default.toSnowflakeColumnDefault()) - } - - if c.identity != nil { - sC = sC.WithIdentity(c.identity.toSnowflakeColumnIdentity()) - } - - return *sC.WithName(c.name). - WithType(c.dataType). - WithNullable(c.nullable). - WithComment(c.comment). - WithMaskingPolicy(c.maskingPolicy) -} - type columns []column -func (c columns) toSnowflakeColumns() []snowflake.Column { - sC := make([]snowflake.Column, len(c)) - for i, col := range c { - sC[i] = col.toSnowflakeColumn() - } - return sC -} - type changedColumns []changedColumn type changedColumn struct { @@ -468,6 +376,67 @@ func getColumns(from interface{}) (to columns) { return to } +func getTableColumnRequest(from interface{}) *sdk.TableColumnRequest { + c := from.(map[string]interface{}) + _type := c["type"].(string) + + nameInQuotes := fmt.Sprintf(`"%v"`, snowflake.EscapeString(c["name"].(string))) + request := sdk.NewTableColumnRequest(nameInQuotes, sdk.DataType(_type)) + + _default := c["default"].([]interface{}) + var expression string + if len(_default) == 1 { + if c, ok := _default[0].(map[string]interface{})["constant"]; ok { + if constant, ok := c.(string); ok && len(constant) > 0 { + if strings.Contains(_type, "CHAR") || _type == "STRING" || _type == "TEXT" { + expression = snowflake.EscapeSnowflakeString(constant) + } else { + expression = constant + } + } + } + + if e, ok := _default[0].(map[string]interface{})["expression"]; ok { + if expr, ok := e.(string); ok && len(expr) > 0 { + expression = expr + } + } + + if s, ok := _default[0].(map[string]interface{})["sequence"]; ok { + if seq := s.(string); ok && len(seq) > 0 { + expression = fmt.Sprintf(`%v.NEXTVAL`, seq) + } + } + request.WithDefaultValue(sdk.NewColumnDefaultValueRequest().WithExpression(sdk.String(expression))) + } + + identity := c["identity"].([]interface{}) + if len(identity) == 1 { + identityProp := identity[0].(map[string]interface{}) + startNum := identityProp["start_num"].(int) + stepNum := identityProp["step_num"].(int) + request.WithDefaultValue(sdk.NewColumnDefaultValueRequest().WithIdentity(sdk.NewColumnIdentityRequest(startNum, stepNum))) + } + + maskingPolicy := c["masking_policy"].(string) + if maskingPolicy != "" { + request.WithMaskingPolicy(sdk.NewColumnMaskingPolicyRequest(sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(maskingPolicy))) + } + + return request. + WithNotNull(sdk.Bool(!c["nullable"].(bool))). + WithComment(sdk.String(c["comment"].(string))) +} + +func getTableColumnRequests(from interface{}) []sdk.TableColumnRequest { + cols := from.([]interface{}) + to := make([]sdk.TableColumnRequest, len(cols)) + for i, c := range cols { + to[i] = *getTableColumnRequest(c) + } + return to +} + type primarykey struct { name string keys []string @@ -485,66 +454,158 @@ func getPrimaryKey(from interface{}) (to primarykey) { return to } -func (pk primarykey) toSnowflakePrimaryKey() snowflake.PrimaryKey { - snowPk := snowflake.PrimaryKey{} - return *snowPk.WithName(pk.name).WithKeys(pk.keys) +func toColumnConfig(descriptions []sdk.TableColumnDetails) []any { + flattened := make([]any, 0) + for _, td := range descriptions { + if td.Kind != "COLUMN" { + continue + } + + flat := map[string]any{} + flat["name"] = td.Name + flat["type"] = string(td.Type) + flat["nullable"] = td.IsNullable + + if td.Comment != nil { + flat["comment"] = *td.Comment + } + + if td.PolicyName != nil { + // TODO [SNOW-867240]: SHOW TABLE returns last part of id without double quotes... we have to quote it again. Move it to SDK. + flat["masking_policy"] = sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(*td.PolicyName).FullyQualifiedName() + } + + identity := toColumnIdentityConfig(td) + if identity != nil { + flat["identity"] = []any{identity} + } else { + def := toColumnDefaultConfig(td) + if def != nil { + flat["default"] = []any{def} + } + } + flattened = append(flattened, flat) + } + return flattened +} + +func toColumnDefaultConfig(td sdk.TableColumnDetails) map[string]any { + if td.Default == nil { + return nil + } + + defaultRaw := *td.Default + def := map[string]any{} + if strings.HasSuffix(defaultRaw, ".NEXTVAL") { + // TODO [SNOW-867240]: SHOW TABLE returns last part of id without double quotes... we have to quote it again. Move it to SDK. + sequenceIdRaw := strings.TrimSuffix(defaultRaw, ".NEXTVAL") + def["sequence"] = sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(sequenceIdRaw).FullyQualifiedName() + return def + } + + if strings.Contains(defaultRaw, "(") && strings.Contains(defaultRaw, ")") { + def["expression"] = defaultRaw + return def + } + + columnType := strings.ToUpper(string(td.Type)) + if strings.Contains(columnType, "CHAR") || columnType == "STRING" || columnType == "TEXT" { + def["constant"] = snowflake.UnescapeSnowflakeString(defaultRaw) + return def + } + + def["constant"] = defaultRaw + return def +} + +func toColumnIdentityConfig(td sdk.TableColumnDetails) map[string]any { + // if autoincrement is used this is reflected back IDENTITY START 1 INCREMENT 1 + if td.Default == nil { + return nil + } + + defaultRaw := *td.Default + + if strings.Contains(defaultRaw, "IDENTITY") { + identity := map[string]any{} + + split := strings.Split(defaultRaw, " ") + start, err := strconv.Atoi(split[2]) + if err == nil { + identity["start_num"] = start + } + step, err := strconv.Atoi(split[4]) + if err == nil { + identity["step_num"] = step + } + + return identity + } + return nil } // CreateTable implements schema.CreateFunc. func CreateTable(d *schema.ResourceData, meta interface{}) error { db := meta.(*sql.DB) - database := d.Get("database").(string) - schema := d.Get("schema").(string) + ctx := context.Background() + client := sdk.NewClientFromDB(db) + + databaseName := d.Get("database").(string) + schemaName := d.Get("schema").(string) name := d.Get("name").(string) + id := sdk.NewSchemaObjectIdentifier(databaseName, schemaName, name) - columns := getColumns(d.Get("column").([]interface{})) + tableColumnRequests := getTableColumnRequests(d.Get("column").([]interface{})) - builder := snowflake.NewTableWithColumnDefinitionsBuilder(name, database, schema, columns.toSnowflakeColumns()) + createRequest := sdk.NewCreateTableRequest(id, tableColumnRequests) - // Set optionals if v, ok := d.GetOk("comment"); ok { - builder.WithComment(v.(string)) + createRequest.WithComment(sdk.String(v.(string))) } if v, ok := d.GetOk("cluster_by"); ok { - builder.WithClustering(expandStringList(v.([]interface{}))) + createRequest.WithClusterBy(expandStringList(v.([]interface{}))) } if v, ok := d.GetOk("primary_key"); ok { - pk := getPrimaryKey(v.([]interface{})) - builder.WithPrimaryKey(pk.toSnowflakePrimaryKey()) + keysList := v.([]any) + if len(keysList) > 0 { + keys := expandStringList(keysList[0].(map[string]any)["keys"].([]interface{})) + constraintRequest := sdk.NewOutOfLineConstraintRequest(sdk.ColumnConstraintTypePrimaryKey).WithColumns(snowflake.QuoteStringList(keys)) + + keyName, isPresent := keysList[0].(map[string]any)["name"] + if isPresent && keyName != "" { + constraintRequest.WithName(sdk.String(keyName.(string))) + } + } } if v, ok := d.GetOk("data_retention_days"); ok { - builder.WithDataRetentionTimeInDays(v.(int)) + createRequest.WithDataRetentionTimeInDays(sdk.Int(v.(int))) } else if v, ok := d.GetOk("data_retention_time_in_days"); ok { - builder.WithDataRetentionTimeInDays(v.(int)) + createRequest.WithDataRetentionTimeInDays(sdk.Int(v.(int))) } if v, ok := d.GetOk("change_tracking"); ok { - builder.WithChangeTracking(v.(bool)) - } - - if v, ok := d.GetOk("tag"); ok { - tags := getTags(v) - builder.WithTags(tags.toSnowflakeTagValues()) + createRequest.WithChangeTracking(sdk.Bool(v.(bool))) } - stmt := builder.Create() - if err := snowflake.Exec(db, stmt); err != nil { - return fmt.Errorf("error creating table %v", name) + var tagAssociationRequests []sdk.TagAssociationRequest + if _, ok := d.GetOk("tag"); ok { + tagAssociations := getPropertyTags(d, "tag") + tagAssociationRequests = make([]sdk.TagAssociationRequest, len(tagAssociations)) + for i, t := range tagAssociations { + tagAssociationRequests[i] = *sdk.NewTagAssociationRequest(t.Name, t.Value) + } + createRequest.WithTags(tagAssociationRequests) } - tableID := &tableID{ - DatabaseName: database, - SchemaName: schema, - TableName: name, - } - dataIDInput, err := tableID.String() + err := client.Tables.Create(ctx, createRequest) if err != nil { - return err + return fmt.Errorf("error creating table %v err = %w", name, err) } - d.SetId(dataIDInput) + + d.SetId(helpers.EncodeSnowflakeID(id)) return ReadTable(d, meta) } @@ -552,59 +613,33 @@ func CreateTable(d *schema.ResourceData, meta interface{}) error { // ReadTable implements schema.ReadFunc. func ReadTable(d *schema.ResourceData, meta interface{}) error { db := meta.(*sql.DB) - tableID, err := tableIDFromString(d.Id()) - if err != nil { - return err - } - builder := snowflake.NewTableBuilder(tableID.TableName, tableID.DatabaseName, tableID.SchemaName) + ctx := context.Background() + client := sdk.NewClientFromDB(db) + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) - row := snowflake.QueryRow(db, builder.Show()) - table, err := snowflake.ScanTable(row) - if errors.Is(err, sql.ErrNoRows) { - // If not found, mark resource to be removed from state file during apply or refresh + table, err := client.Tables.ShowByID(ctx, id) + if err != nil { log.Printf("[DEBUG] table (%s) not found", d.Id()) d.SetId("") return nil } - if err != nil { - return err - } - // Describe the table to read the cols - tableDescriptionRows, err := snowflake.Query(db, builder.ShowColumns()) + tableDescription, err := client.Tables.DescribeColumns(ctx, sdk.NewDescribeTableColumnsRequest(id)) if err != nil { return err } - tableDescription, err := snowflake.ScanTableDescription(tableDescriptionRows) - if err != nil { - return err - } - - /* - deprecated as it conflicts with the new table_constraint resource - showPkrows, err := snowflake.Query(db, builder.ShowPrimaryKeys()) - if err != nil { - return err - } - - pkDescription, err := snowflake.ScanPrimaryKeyDescription(showPkrows) - if err != nil { - return err - }*/ - // Set the relevant data in the state toSet := map[string]interface{}{ - "name": table.TableName.String, - "owner": table.Owner.String, - "database": tableID.DatabaseName, - "schema": tableID.SchemaName, - "comment": table.Comment.String, - "column": snowflake.NewColumns(tableDescription).Flatten(), - "cluster_by": snowflake.ClusterStatementToList(table.ClusterBy.String), - // "primary_key": snowflake.FlattenTablePrimaryKey(pkDescription), - "change_tracking": (table.ChangeTracking.String == "ON"), - "qualified_name": fmt.Sprintf(`"%s"."%s"."%s"`, tableID.DatabaseName, tableID.SchemaName, table.TableName.String), + "name": table.Name, + "owner": table.Owner, + "database": table.DatabaseName, + "schema": table.SchemaName, + "comment": table.Comment, + "column": toColumnConfig(tableDescription), + "cluster_by": table.GetClusterByKeys(), + "change_tracking": table.ChangeTracking, + "qualified_name": id.FullyQualifiedName(), } var dataRetentionKey string if _, ok := d.GetOk("data_retention_time_in_days"); ok { @@ -613,7 +648,7 @@ func ReadTable(d *schema.ResourceData, meta interface{}) error { dataRetentionKey = "data_retention_days" } if dataRetentionKey != "" { - toSet[dataRetentionKey] = table.RetentionTime.Int32 + toSet[dataRetentionKey] = table.RetentionTime } for key, val := range toSet { @@ -626,170 +661,241 @@ func ReadTable(d *schema.ResourceData, meta interface{}) error { // UpdateTable implements schema.UpdateFunc. func UpdateTable(d *schema.ResourceData, meta interface{}) error { - tid, err := tableIDFromString(d.Id()) - if err != nil { - return err + db := meta.(*sql.DB) + ctx := context.Background() + client := sdk.NewClientFromDB(db) + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) + + if d.HasChange("name") { + newName := d.Get("name").(string) + + newId := sdk.NewSchemaObjectIdentifier(id.DatabaseName(), id.SchemaName(), newName) + + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithNewName(&newId)) + if err != nil { + return fmt.Errorf("error renaming table %v err = %w", d.Id(), err) + } + + d.SetId(helpers.EncodeSnowflakeID(newId)) + id = newId } - dbName := tid.DatabaseName - schema := tid.SchemaName - tableName := tid.TableName + var runSetStatement bool + var runUnsetStatement bool + setRequest := sdk.NewTableSetRequest() + unsetRequest := sdk.NewTableUnsetRequest() + + if d.HasChange("comment") { + comment := d.Get("comment").(string) + if comment == "" { + runUnsetStatement = true + unsetRequest.WithComment(true) + } else { + runSetStatement = true + setRequest.WithComment(sdk.String(comment)) + } + } - builder := snowflake.NewTableBuilder(tableName, dbName, schema) + if d.HasChange("change_tracking") { + changeTracking := d.Get("change_tracking").(bool) + runSetStatement = true + setRequest.WithChangeTracking(sdk.Bool(changeTracking)) + } - db := meta.(*sql.DB) - if d.HasChange("name") { - name := d.Get("name") - q := builder.Rename(name.(string)) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error updating table name on %v", d.Id()) + checkChangeForDataRetention := func(key string) { + if d.HasChange(key) { + dataRetentionDays := d.Get(key).(int) + runSetStatement = true + setRequest.WithDataRetentionTimeInDays(sdk.Int(dataRetentionDays)) } - tableID := &tableID{ - DatabaseName: dbName, - SchemaName: schema, - TableName: name.(string), + } + checkChangeForDataRetention("data_retention_days") + checkChangeForDataRetention("data_retention_time_in_days") + + if runSetStatement { + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithSet(setRequest)) + if err != nil { + return fmt.Errorf("error updating table: %w", err) } - dataIDInput, err := tableID.String() + } + + if runUnsetStatement { + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithUnset(unsetRequest)) if err != nil { - return err + return fmt.Errorf("error updating table: %w", err) } - d.SetId(dataIDInput) } - if d.HasChange("comment") { - comment := d.Get("comment") - q := builder.ChangeComment(comment.(string)) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error updating table comment on %v", d.Id()) + + if d.HasChange("cluster_by") { + cb := expandStringList(d.Get("cluster_by").([]interface{})) + + if len(cb) != 0 { + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithClusteringAction(sdk.NewTableClusteringActionRequest().WithClusterBy(cb))) + if err != nil { + return fmt.Errorf("error updating table: %w", err) + } + } else { + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithClusteringAction(sdk.NewTableClusteringActionRequest().WithDropClusteringKey(sdk.Bool(true)))) + if err != nil { + return fmt.Errorf("error updating table: %w", err) + } } } + if d.HasChange("column") { t, n := d.GetChange("column") removed, added, changed := getColumns(t).diffs(getColumns(n)) - for _, cA := range removed { - q := builder.DropColumn(cA.name) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error dropping column on %v", d.Id()) + + if len(removed) > 0 { + removedColumnNames := make([]string, len(removed)) + for i, r := range removed { + removedColumnNames[i] = r.name + } + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(sdk.NewTableColumnActionRequest().WithDropColumns(snowflake.QuoteStringList(removedColumnNames)))) + if err != nil { + return fmt.Errorf("error updating table: %w", err) } } + for _, cA := range added { - var q string + addRequest := sdk.NewTableColumnAddActionRequest(fmt.Sprintf("\"%s\"", cA.name), sdk.DataType(cA.dataType)). + WithInlineConstraint(sdk.NewTableColumnAddInlineConstraintRequest().WithNotNull(sdk.Bool(!cA.nullable))) - if cA.identity == nil && cA._default == nil { //nolint:gocritic // todo: please fix this to pass gocritic - q = builder.AddColumn(cA.name, cA.dataType, cA.nullable, nil, nil, cA.comment, cA.maskingPolicy) - } else if cA.identity != nil { - q = builder.AddColumn(cA.name, cA.dataType, cA.nullable, nil, cA.identity.toSnowflakeColumnIdentity(), cA.comment, cA.maskingPolicy) - } else { + if cA._default != nil { if cA._default._type() != "constant" { return fmt.Errorf("failed to add column %v => Only adding a column as a constant is supported by Snowflake", cA.name) } + var expression string + if strings.Contains(cA.dataType, "CHAR") || cA.dataType == "STRING" || cA.dataType == "TEXT" { + expression = snowflake.EscapeSnowflakeString(*cA._default.constant) + } else { + expression = *cA._default.constant + } + addRequest.WithDefaultValue(sdk.NewColumnDefaultValueRequest().WithExpression(sdk.String(expression))) + } + + if cA.identity != nil { + addRequest.WithDefaultValue(sdk.NewColumnDefaultValueRequest().WithIdentity(sdk.NewColumnIdentityRequest(cA.identity.startNum, cA.identity.stepNum))) + } - q = builder.AddColumn(cA.name, cA.dataType, cA.nullable, cA._default.toSnowflakeColumnDefault(), nil, cA.comment, cA.maskingPolicy) + if cA.maskingPolicy != "" { + addRequest.WithMaskingPolicy(sdk.NewColumnMaskingPolicyRequest(sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(cA.maskingPolicy))) } - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error adding column on %v", d.Id()) + if cA.comment != "" { + addRequest.WithComment(sdk.String(cA.comment)) + } + + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(sdk.NewTableColumnActionRequest().WithAdd(addRequest))) + if err != nil { + return fmt.Errorf("error adding column: %w", err) } } for _, cA := range changed { if cA.changedDataType { - q := builder.ChangeColumnType(cA.newColumn.name, cA.newColumn.dataType) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error changing property on %v", d.Id()) + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(sdk.NewTableColumnActionRequest().WithAlter([]sdk.TableColumnAlterActionRequest{*sdk.NewTableColumnAlterActionRequest(fmt.Sprintf("\"%s\"", cA.newColumn.name)).WithType(sdk.Pointer(sdk.DataType(cA.newColumn.dataType)))}))) + if err != nil { + return fmt.Errorf("error changing property on %v: err %w", d.Id(), err) } } if cA.changedNullConstraint { - q := builder.ChangeNullConstraint(cA.newColumn.name, cA.newColumn.nullable) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error changing property on %v", d.Id()) + nullabilityRequest := sdk.NewTableColumnNotNullConstraintRequest() + if !cA.newColumn.nullable { + nullabilityRequest.WithSet(sdk.Bool(true)) + } else { + nullabilityRequest.WithDrop(sdk.Bool(true)) + } + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(sdk.NewTableColumnActionRequest().WithAlter([]sdk.TableColumnAlterActionRequest{*sdk.NewTableColumnAlterActionRequest(fmt.Sprintf("\"%s\"", cA.newColumn.name)).WithNotNullConstraint(nullabilityRequest)}))) + if err != nil { + return fmt.Errorf("error changing property on %v: err %w", d.Id(), err) } } if cA.dropedDefault { - q := builder.DropColumnDefault(cA.newColumn.name) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error changing property on %v", d.Id()) + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(sdk.NewTableColumnActionRequest().WithAlter([]sdk.TableColumnAlterActionRequest{*sdk.NewTableColumnAlterActionRequest(fmt.Sprintf("\"%s\"", cA.newColumn.name)).WithDropDefault(sdk.Bool(true))}))) + if err != nil { + return fmt.Errorf("error changing property on %v: err %w", d.Id(), err) } } if cA.changedComment { - q := builder.ChangeColumnComment(cA.newColumn.name, cA.newColumn.comment) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error changing property on %v", d.Id()) + columnAlterActionRequest := sdk.NewTableColumnAlterActionRequest(fmt.Sprintf("\"%s\"", cA.newColumn.name)) + if cA.newColumn.comment == "" { + columnAlterActionRequest.WithUnsetComment(sdk.Bool(true)) + } else { + columnAlterActionRequest.WithComment(sdk.String(cA.newColumn.comment)) + } + + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(sdk.NewTableColumnActionRequest().WithAlter([]sdk.TableColumnAlterActionRequest{*columnAlterActionRequest}))) + if err != nil { + return fmt.Errorf("error changing property on %v: err %w", d.Id(), err) } } if cA.changedMaskingPolicy { - q := builder.ChangeColumnMaskingPolicy(cA.newColumn.name, cA.newColumn.maskingPolicy) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error changing property on %v", d.Id()) + columnAction := sdk.NewTableColumnActionRequest() + if strings.TrimSpace(cA.newColumn.maskingPolicy) == "" { + columnAction.WithUnsetMaskingPolicy(sdk.NewTableColumnAlterUnsetMaskingPolicyActionRequest(fmt.Sprintf("\"%s\"", cA.newColumn.name))) + } else { + columnAction.WithSetMaskingPolicy(sdk.NewTableColumnAlterSetMaskingPolicyActionRequest(fmt.Sprintf("\"%s\"", cA.newColumn.name), sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(cA.newColumn.maskingPolicy), []string{}).WithForce(sdk.Bool(true))) + } + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithColumnAction(columnAction)) + if err != nil { + return fmt.Errorf("error changing property on %v: err %w", d.Id(), err) } } } } - if d.HasChange("cluster_by") { - cb := expandStringList(d.Get("cluster_by").([]interface{})) - - var q string - if len(cb) != 0 { - builder.WithClustering(cb) - q = builder.ChangeClusterBy(builder.GetClusterKeyString()) - } else { - q = builder.DropClustering() - } - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error updating table clustering on %v", d.Id()) - } - } if d.HasChange("primary_key") { - opk, npk := d.GetChange("primary_key") + o, n := d.GetChange("primary_key") - newpk := getPrimaryKey(npk) - oldpk := getPrimaryKey(opk) + newKey := getPrimaryKey(n) + oldKey := getPrimaryKey(o) - if len(oldpk.keys) > 0 || len(newpk.keys) == 0 { + if len(oldKey.keys) > 0 || len(newKey.keys) == 0 { // drop our pk if there was an old primary key, or pk has been removed - q := builder.DropPrimaryKey() - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error changing primary key first on %v", d.Id()) + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithConstraintAction( + sdk.NewTableConstraintActionRequest(). + WithDrop(sdk.NewTableConstraintDropActionRequest().WithPrimaryKey(sdk.Bool(true))), + )) + if err != nil { + return fmt.Errorf("error updating table: %w", err) } } - if len(newpk.keys) > 0 { - // add our new pk - q := builder.ChangePrimaryKey(newpk.toSnowflakePrimaryKey()) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error changing property on %v", d.Id()) + if len(newKey.keys) > 0 { + constraint := sdk.NewOutOfLineConstraintRequest(sdk.ColumnConstraintTypePrimaryKey).WithColumns(snowflake.QuoteStringList(newKey.keys)) + if newKey.name != "" { + constraint.WithName(sdk.String(newKey.name)) + } + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithConstraintAction( + sdk.NewTableConstraintActionRequest().WithAdd(constraint), + )) + if err != nil { + return fmt.Errorf("error updating table: %w", err) } } } - updateDataRetention := func(key string) error { - if d.HasChange(key) { - ndr := d.Get(key) - q := builder.ChangeDataRetention(ndr.(int)) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error changing property on %v", d.Id()) + + if d.HasChange("tag") { + unsetTags, setTags := GetTagsDiff(d, "tag") + + if len(unsetTags) > 0 { + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithUnsetTags(unsetTags)) + if err != nil { + return fmt.Errorf("error setting tags on %v, err = %w", d.Id(), err) } } - return nil - } - err = updateDataRetention("data_retention_days") - if err != nil { - return err - } - err = updateDataRetention("data_retention_time_in_days") - if err != nil { - return err - } - if d.HasChange("change_tracking") { - nct := d.Get("change_tracking") - q := builder.ChangeChangeTracking(nct.(bool)) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error changing property on %v", d.Id()) + if len(setTags) > 0 { + tagAssociationRequests := make([]sdk.TagAssociationRequest, len(setTags)) + for i, t := range setTags { + tagAssociationRequests[i] = *sdk.NewTagAssociationRequest(t.Name, t.Value) + } + err := client.Tables.Alter(ctx, sdk.NewAlterTableRequest(id).WithSetTags(tagAssociationRequests)) + if err != nil { + return fmt.Errorf("error setting tags on %v, err = %w", d.Id(), err) + } } } - tagChangeErr := handleTagChanges(db, d, builder) - if tagChangeErr != nil { - return tagChangeErr - } return ReadTable(d, meta) } @@ -797,21 +903,15 @@ func UpdateTable(d *schema.ResourceData, meta interface{}) error { // DeleteTable implements schema.DeleteFunc. func DeleteTable(d *schema.ResourceData, meta interface{}) error { db := meta.(*sql.DB) - tableID, err := tableIDFromString(d.Id()) + ctx := context.Background() + client := sdk.NewClientFromDB(db) + id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier) + + err := client.Tables.Drop(ctx, sdk.NewDropTableRequest(id)) if err != nil { return err } - dbName := tableID.DatabaseName - schemaName := tableID.SchemaName - tableName := tableID.TableName - - q := snowflake.NewTableBuilder(tableName, dbName, schemaName).Drop() - - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error deleting pipe %v err = %w", d.Id(), err) - } - d.SetId("") return nil diff --git a/pkg/resources/table_acceptance_test.go b/pkg/resources/table_acceptance_test.go index c0c5d72c96..cddc2b1dee 100644 --- a/pkg/resources/table_acceptance_test.go +++ b/pkg/resources/table_acceptance_test.go @@ -1,26 +1,31 @@ package resources_test import ( + "context" + "database/sql" "fmt" - "os" "strings" "testing" acc "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/hashicorp/terraform-plugin-testing/helper/acctest" "github.com/hashicorp/terraform-plugin-testing/helper/resource" + "github.com/hashicorp/terraform-plugin-testing/terraform" + "github.com/hashicorp/terraform-plugin-testing/tfversion" ) func TestAcc_TableWithSeparateDataRetentionObjectParameterWithoutLifecycle(t *testing.T) { - if _, ok := os.LookupEnv("SKIP_TABLE_DATA_RETENTION_TESTS"); ok { - t.Skip("Skipping TestAcc_TableWithSeparateDataRetentionObjectParameterWithoutLifecycle") - } - accName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - resource.ParallelTest(t, resource.TestCase{ - Providers: acc.TestAccProviders(), - PreCheck: func() { acc.TestAccPreCheck(t) }, - CheckDestroy: nil, + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: testAccCheckTableDestroy, Steps: []resource.TestStep{ { Config: tableConfig(accName, acc.TestDatabaseName, acc.TestSchemaName), @@ -61,15 +66,15 @@ func TestAcc_TableWithSeparateDataRetentionObjectParameterWithoutLifecycle(t *te } func TestAcc_TableWithSeparateDataRetentionObjectParameterWithLifecycle(t *testing.T) { - if _, ok := os.LookupEnv("SKIP_TABLE_DATA_RETENTION_TESTS"); ok { - t.Skip("Skipping TestAcc_TableWithSeparateDataRetentionObjectParameterWithLifecycle") - } - accName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - resource.ParallelTest(t, resource.TestCase{ - Providers: acc.TestAccProviders(), - PreCheck: func() { acc.TestAccPreCheck(t) }, - CheckDestroy: nil, + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: testAccCheckTableDestroy, Steps: []resource.TestStep{ { Config: tableConfig(accName, acc.TestDatabaseName, acc.TestSchemaName), @@ -131,10 +136,13 @@ func TestAcc_Table(t *testing.T) { table2Name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) table3Name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - resource.ParallelTest(t, resource.TestCase{ - Providers: acc.TestAccProviders(), - PreCheck: func() { acc.TestAccPreCheck(t) }, - CheckDestroy: nil, + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: testAccCheckTableDestroy, Steps: []resource.TestStep{ { Config: tableConfig(accName, acc.TestDatabaseName, acc.TestSchemaName), @@ -670,7 +678,6 @@ resource "snowflake_table" "test_table" { nullable = false } primary_key { - name = "" keys = ["column2"] } } @@ -868,10 +875,13 @@ resource "snowflake_table" "test_table" { func TestAcc_TableDefaults(t *testing.T) { accName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - resource.ParallelTest(t, resource.TestCase{ - Providers: acc.TestAccProviders(), - PreCheck: func() { acc.TestAccPreCheck(t) }, - CheckDestroy: nil, + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: testAccCheckTableDestroy, Steps: []resource.TestStep{ { Config: tableColumnWithDefaults(accName, acc.TestDatabaseName, acc.TestSchemaName), @@ -896,7 +906,7 @@ func TestAcc_TableDefaults(t *testing.T) { resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.type", "NUMBER(38,0)"), resource.TestCheckNoResourceAttr("snowflake_table.test_table", "column.2.type.default.0.constant"), resource.TestCheckNoResourceAttr("snowflake_table.test_table", "column.2.type.default.0.expression"), - resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.default.0.sequence", fmt.Sprintf(`"%v"."%v".%v`, acc.TestDatabaseName, acc.TestSchemaName, accName)), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.default.0.sequence", fmt.Sprintf(`"%v"."%v"."%v"`, acc.TestDatabaseName, acc.TestSchemaName, accName)), resource.TestCheckNoResourceAttr("snowflake_table.test_table", "primary_key.0"), ), }, @@ -919,7 +929,7 @@ func TestAcc_TableDefaults(t *testing.T) { resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.type", "NUMBER(38,0)"), resource.TestCheckNoResourceAttr("snowflake_table.test_table", "column.2.type.default.0.constant"), resource.TestCheckNoResourceAttr("snowflake_table.test_table", "column.2.type.default.0.expression"), - resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.default.0.sequence", fmt.Sprintf(`"%v"."%v".%v`, acc.TestDatabaseName, acc.TestSchemaName, accName)), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.2.default.0.sequence", fmt.Sprintf(`"%v"."%v"."%v"`, acc.TestDatabaseName, acc.TestSchemaName, accName)), resource.TestCheckNoResourceAttr("snowflake_table.test_table", "primary_key.0"), ), }, @@ -1005,10 +1015,14 @@ func TestAcc_TableTags(t *testing.T) { accName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) tagName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) tag2Name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - resource.ParallelTest(t, resource.TestCase{ - Providers: acc.TestAccProviders(), - PreCheck: func() { acc.TestAccPreCheck(t) }, - CheckDestroy: nil, + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: testAccCheckTableDestroy, Steps: []resource.TestStep{ { Config: tableWithTags(accName, tagName, tag2Name, acc.TestDatabaseName, acc.TestSchemaName), @@ -1080,10 +1094,13 @@ resource "snowflake_table" "test_table" { func TestAcc_TableIdentity(t *testing.T) { accName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - resource.ParallelTest(t, resource.TestCase{ - Providers: acc.TestAccProviders(), - PreCheck: func() { acc.TestAccPreCheck(t) }, - CheckDestroy: nil, + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: testAccCheckTableDestroy, Steps: []resource.TestStep{ { Config: tableColumnWithIdentityDefault(accName, acc.TestDatabaseName, acc.TestSchemaName), @@ -1214,10 +1231,14 @@ resource "snowflake_table" "test_table" { func TestAcc_TableRename(t *testing.T) { oldTableName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) newTableName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) - resource.ParallelTest(t, resource.TestCase{ - Providers: acc.TestAccProviders(), - PreCheck: func() { acc.TestAccPreCheck(t) }, - CheckDestroy: nil, + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: testAccCheckTableDestroy, Steps: []resource.TestStep{ { Config: tableConfigWithName(oldTableName, acc.TestDatabaseName, acc.TestSchemaName), @@ -1263,3 +1284,96 @@ resource "snowflake_table" "test_table" { ` return fmt.Sprintf(s, tableName, databaseName, schemaName) } + +func TestAcc_Table_MaskingPolicy(t *testing.T) { + accName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) + + resource.Test(t, resource.TestCase{ + ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories, + PreCheck: func() { acc.TestAccPreCheck(t) }, + TerraformVersionChecks: []tfversion.TerraformVersionCheck{ + tfversion.RequireAbove(tfversion.Version1_5_0), + }, + CheckDestroy: testAccCheckTableDestroy, + Steps: []resource.TestStep{ + { + Config: tableWithMaskingPolicy(accName, acc.TestDatabaseName, acc.TestSchemaName, "policy1"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr("snowflake_table.test_table", "name", accName), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.0.masking_policy", sdk.NewSchemaObjectIdentifier(acc.TestDatabaseName, acc.TestSchemaName, fmt.Sprintf("%s1", accName)).FullyQualifiedName()), + ), + }, + // this step proves https://github.com/Snowflake-Labs/terraform-provider-snowflake/pull/2186 + { + Config: tableWithMaskingPolicy(accName, acc.TestDatabaseName, acc.TestSchemaName, "policy2"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr("snowflake_table.test_table", "name", accName), + resource.TestCheckResourceAttr("snowflake_table.test_table", "column.0.masking_policy", sdk.NewSchemaObjectIdentifier(acc.TestDatabaseName, acc.TestSchemaName, fmt.Sprintf("%s2", accName)).FullyQualifiedName()), + ), + }, + }, + }) +} + +func tableWithMaskingPolicy(name string, databaseName string, schemaName string, policy string) string { + s := ` +resource "snowflake_masking_policy" "policy1" { + name = "%[1]s1" + database = "%[2]s" + schema = "%[3]s" + signature { + column { + name = "val" + type = "VARCHAR" + } + } + masking_expression = "case when current_role() in ('ANALYST') then val else sha2(val, 512) end" + return_data_type = "VARCHAR(16777216)" +} + +resource "snowflake_masking_policy" "policy2" { + name = "%[1]s2" + database = "%[2]s" + schema = "%[3]s" + signature { + column { + name = "val" + type = "VARCHAR" + } + } + masking_expression = "case when current_role() in ('ANALYST') then val else sha2(val, 512) end" + return_data_type = "VARCHAR(16777216)" +} + +resource "snowflake_table" "test_table" { + name = "%[1]s" + database = "%[2]s" + schema = "%[3]s" + comment = "Terraform acceptance test" + + column { + name = "column1" + type = "VARCHAR(16)" + masking_policy = snowflake_masking_policy.%[4]s.qualified_name + } +} +` + return fmt.Sprintf(s, name, databaseName, schemaName, policy) +} + +func testAccCheckTableDestroy(s *terraform.State) error { + db := acc.TestAccProvider.Meta().(*sql.DB) + client := sdk.NewClientFromDB(db) + for _, rs := range s.RootModule().Resources { + if rs.Type != "snowflake_table" { + continue + } + ctx := context.Background() + id := sdk.NewSchemaObjectIdentifier(rs.Primary.Attributes["database"], rs.Primary.Attributes["schema"], rs.Primary.Attributes["name"]) + existingTable, err := client.Tables.ShowByID(ctx, id) + if err == nil { + return fmt.Errorf("table %v still exists", existingTable.ID().FullyQualifiedName()) + } + } + return nil +} diff --git a/pkg/resources/table_internal_test.go b/pkg/resources/table_internal_test.go deleted file mode 100644 index f35c953025..0000000000 --- a/pkg/resources/table_internal_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package resources - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestTableIDFromString(t *testing.T) { - r := require.New(t) - // Vanilla - id := "database_name|schema_name|table" - table, err := tableIDFromString(id) - r.NoError(err) - r.Equal("database_name", table.DatabaseName) - r.Equal("schema_name", table.SchemaName) - r.Equal("table", table.TableName) - - // Bad ID -- not enough fields - id = "database" - _, err = tableIDFromString(id) - r.Equal(fmt.Errorf("3 fields allowed"), err) - - // Bad ID - id = "||" - _, err = tableIDFromString(id) - r.NoError(err) - - // 0 lines - id = "" - _, err = tableIDFromString(id) - r.Equal(fmt.Errorf("1 line at a time"), err) - - // 2 lines - id = `database_name|schema_name|table - database_name|schema_name|table` - _, err = tableIDFromString(id) - r.Equal(fmt.Errorf("1 line at a time"), err) -} - -func TestTableStruct(t *testing.T) { - r := require.New(t) - - // Vanilla - table := &tableID{ - DatabaseName: "database_name", - SchemaName: "schema_name", - TableName: "table", - } - sID, err := table.String() - r.NoError(err) - r.Equal("database_name|schema_name|table", sID) - - // Empty grant - table = &tableID{} - sID, err = table.String() - r.NoError(err) - r.Equal("||", sID) - - // Grant with extra delimiters - table = &tableID{ - DatabaseName: "database|name", - TableName: "table|name", - } - sID, err = table.String() - r.NoError(err) - newTable, err := tableIDFromString(sID) - r.NoError(err) - r.Equal("database|name", newTable.DatabaseName) - r.Equal("table|name", newTable.TableName) -} diff --git a/pkg/resources/tag.go b/pkg/resources/tag.go index 01359edf34..9f8276da52 100644 --- a/pkg/resources/tag.go +++ b/pkg/resources/tag.go @@ -95,33 +95,6 @@ type TagBuilder interface { ChangeTag(snowflake.TagValue) string } -func handleTagChanges(db *sql.DB, d *schema.ResourceData, builder TagBuilder) error { - if d.HasChange("tag") { - o, n := d.GetChange("tag") - removed, added, changed := getTags(o).diffs(getTags(n)) - for _, tA := range removed { - q := builder.UnsetTag(tA.toSnowflakeTagValue()) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error dropping tag on %v", d.Id()) - } - } - for _, tA := range added { - q := builder.AddTag(tA.toSnowflakeTagValue()) - - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error adding column on %v", d.Id()) - } - } - for _, tA := range changed { - q := builder.ChangeTag(tA.toSnowflakeTagValue()) - if err := snowflake.Exec(db, q); err != nil { - return fmt.Errorf("error changing property on %v", d.Id()) - } - } - } - return nil -} - // String() takes in a schemaID object and returns a pipe-delimited string: // DatabaseName|SchemaName|TagName. func (ti *TagID) String() (string, error) { diff --git a/pkg/sdk/tables.go b/pkg/sdk/tables.go index 7712b7a4ba..2e3b124786 100644 --- a/pkg/sdk/tables.go +++ b/pkg/sdk/tables.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "strings" ) var _ convertibleRow[Table] = new(tableDBRow) @@ -171,7 +172,7 @@ type ColumnMaskingPolicy struct { // OutOfLineConstraint is based on https://docs.snowflake.com/en/sql-reference/sql/create-table-constraint#out-of-line-unique-primary-foreign-key. type OutOfLineConstraint struct { - Name string `ddl:"parameter,no_equals" sql:"CONSTRAINT"` + Name *string `ddl:"parameter,no_equals" sql:"CONSTRAINT"` Type ColumnConstraintType `ddl:"keyword"` Columns []string `ddl:"keyword,parentheses"` ForeignKey *OutOfLineForeignKey `ddl:"keyword"` @@ -269,11 +270,12 @@ type TableColumnAddAction struct { InlineConstraint *TableColumnAddInlineConstraint `ddl:"keyword"` MaskingPolicy *ColumnMaskingPolicy `ddl:"keyword"` Tags []TagAssociation `ddl:"keyword,parentheses" sql:"TAG"` + Comment *string `ddl:"parameter,no_equals,single_quotes" sql:"COMMENT"` } type TableColumnAddInlineConstraint struct { NotNull *bool `ddl:"keyword" sql:"NOT NULL"` - Name string `ddl:"parameter,no_equals" sql:"CONSTRAINT"` + Name *string `ddl:"parameter,no_equals" sql:"CONSTRAINT"` Type ColumnConstraintType `ddl:"keyword"` ForeignKey *ColumnAddForeignKey `ddl:"keyword"` } @@ -370,15 +372,14 @@ type TableConstraintAlterAction struct { Unique *bool `ddl:"keyword" sql:"UNIQUE"` ForeignKey *bool `ddl:"keyword" sql:"FOREIGN KEY"` - Columns []string `ddl:"keyword,parentheses"` - // Optional - Enforced *bool `ddl:"keyword" sql:"ENFORCED"` - NotEnforced *bool `ddl:"keyword" sql:"NOT ENFORCED"` - Validate *bool `ddl:"keyword" sql:"VALIDATE"` - NoValidate *bool `ddl:"keyword" sql:"NOVALIDATE"` - Rely *bool `ddl:"keyword" sql:"RELY"` - NoRely *bool `ddl:"keyword" sql:"NORELY"` + Columns []string `ddl:"keyword,parentheses"` + Enforced *bool `ddl:"keyword" sql:"ENFORCED"` + NotEnforced *bool `ddl:"keyword" sql:"NOT ENFORCED"` + Validate *bool `ddl:"keyword" sql:"VALIDATE"` + NoValidate *bool `ddl:"keyword" sql:"NOVALIDATE"` + Rely *bool `ddl:"keyword" sql:"RELY"` + NoRely *bool `ddl:"keyword" sql:"NORELY"` } type TableConstraintDropAction struct { @@ -388,11 +389,10 @@ type TableConstraintDropAction struct { Unique *bool `ddl:"keyword" sql:"UNIQUE"` ForeignKey *bool `ddl:"keyword" sql:"FOREIGN KEY"` - Columns []string `ddl:"keyword,parentheses"` - // Optional - Cascade *bool `ddl:"keyword" sql:"CASCADE"` - Restrict *bool `ddl:"keyword" sql:"RESTRICT"` + Columns []string `ddl:"keyword,parentheses"` + Cascade *bool `ddl:"keyword" sql:"CASCADE"` + Restrict *bool `ddl:"keyword" sql:"RESTRICT"` } type TableUnsetTags struct { @@ -412,6 +412,7 @@ type TableExternalTableColumnAddAction struct { Name string `ddl:"keyword"` Type DataType `ddl:"keyword"` Expression []string `ddl:"parameter,no_equals,parentheses" sql:"AS"` + Comment *string `ddl:"parameter,no_equals,single_quotes" sql:"COMMENT"` } type TableExternalTableColumnRenameAction struct { @@ -553,6 +554,22 @@ type Table struct { Budget *string } +// GetClusterByKeys converts the SHOW TABLES result for ClusterBy and converts it to list of keys. +func (v *Table) GetClusterByKeys() []string { + if v.ClusterBy == "" { + return nil + } + + statementWithoutLinear := strings.TrimSuffix(strings.Replace(v.ClusterBy, "LINEAR(", "", 1), ")") + keysRaw := strings.Split(statementWithoutLinear, ",") + keysClean := make([]string, 0, len(keysRaw)) + for _, key := range keysRaw { + keysClean = append(keysClean, strings.TrimSpace(key)) + } + + return keysClean +} + func (row tableDBRow) convert() *Table { table := Table{ CreatedOn: row.CreatedOn, diff --git a/pkg/sdk/tables_dto.go b/pkg/sdk/tables_dto.go index 49fb03da58..504cab4fe5 100644 --- a/pkg/sdk/tables_dto.go +++ b/pkg/sdk/tables_dto.go @@ -132,7 +132,7 @@ type ColumnInlineConstraintRequest struct { } type OutOfLineConstraintRequest struct { - Name string // required + Name *string Type ColumnConstraintType // required Columns []string ForeignKey *OutOfLineForeignKeyRequest @@ -370,11 +370,12 @@ type TableColumnAddActionRequest struct { MaskingPolicy *ColumnMaskingPolicyRequest With *bool Tags []TagAssociation + Comment *string } type TableColumnAddInlineConstraintRequest struct { NotNull *bool - Name string + Name *string Type ColumnConstraintType ForeignKey *ColumnAddForeignKey } @@ -390,8 +391,7 @@ type TableColumnRenameActionRequest struct { } type TableColumnAlterActionRequest struct { - Column bool // required - Name string // required + Name string // required // One of DropDefault *bool @@ -447,8 +447,8 @@ type TableConstraintAlterActionRequest struct { Unique *bool ForeignKey *bool - Columns []string // required // Optional + Columns []string Enforced *bool NotEnforced *bool Validate *bool @@ -464,9 +464,8 @@ type TableConstraintDropActionRequest struct { Unique *bool ForeignKey *bool - Columns []string // required - // Optional + Columns []string Cascade *bool Restrict *bool } @@ -500,6 +499,7 @@ type TableExternalTableColumnAddActionRequest struct { Name string Type DataType Expression string + Comment *string } type TableExternalTableColumnRenameActionRequest struct { diff --git a/pkg/sdk/tables_dto_generated.go b/pkg/sdk/tables_dto_generated.go index 3573a85be8..17ec7218a7 100644 --- a/pkg/sdk/tables_dto_generated.go +++ b/pkg/sdk/tables_dto_generated.go @@ -425,15 +425,18 @@ func (s *ColumnInlineConstraintRequest) WithNoRely(noRely *bool) *ColumnInlineCo } func NewOutOfLineConstraintRequest( - name string, constraintType ColumnConstraintType, ) *OutOfLineConstraintRequest { s := OutOfLineConstraintRequest{} - s.Name = name s.Type = constraintType return &s } +func (s *OutOfLineConstraintRequest) WithName(name *string) *OutOfLineConstraintRequest { + s.Name = name + return s +} + func (s *OutOfLineConstraintRequest) WithColumns(columns []string) *OutOfLineConstraintRequest { s.Columns = columns return s @@ -1184,6 +1187,11 @@ func (s *TableColumnAddActionRequest) WithTags(tags []TagAssociation) *TableColu return s } +func (s *TableColumnAddActionRequest) WithComment(comment *string) *TableColumnAddActionRequest { + s.Comment = comment + return s +} + func NewTableColumnAddInlineConstraintRequest() *TableColumnAddInlineConstraintRequest { return &TableColumnAddInlineConstraintRequest{} } @@ -1193,7 +1201,7 @@ func (s *TableColumnAddInlineConstraintRequest) WithNotNull(notNull *bool) *Tabl return s } -func (s *TableColumnAddInlineConstraintRequest) WithName(name string) *TableColumnAddInlineConstraintRequest { +func (s *TableColumnAddInlineConstraintRequest) WithName(name *string) *TableColumnAddInlineConstraintRequest { s.Name = name return s } @@ -1233,11 +1241,9 @@ func NewTableColumnRenameActionRequest( } func NewTableColumnAlterActionRequest( - column bool, name string, ) *TableColumnAlterActionRequest { s := TableColumnAlterActionRequest{} - s.Column = column s.Name = name return &s } @@ -1369,10 +1375,8 @@ func (s *TableConstraintRenameActionRequest) WithNewName(newName string) *TableC return s } -func NewTableConstraintAlterActionRequest(columns []string) *TableConstraintAlterActionRequest { - return &TableConstraintAlterActionRequest{ - Columns: columns, - } +func NewTableConstraintAlterActionRequest() *TableConstraintAlterActionRequest { + return &TableConstraintAlterActionRequest{} } func (s *TableConstraintAlterActionRequest) WithConstraintName(constraintName *string) *TableConstraintAlterActionRequest { @@ -1395,6 +1399,11 @@ func (s *TableConstraintAlterActionRequest) WithForeignKey(foreignKey *bool) *Ta return s } +func (s *TableConstraintAlterActionRequest) WithColumns(columns []string) *TableConstraintAlterActionRequest { + s.Columns = columns + return s +} + func (s *TableConstraintAlterActionRequest) WithEnforced(enforced *bool) *TableConstraintAlterActionRequest { s.Enforced = enforced return s @@ -1425,10 +1434,8 @@ func (s *TableConstraintAlterActionRequest) WithNoRely(noRely *bool) *TableConst return s } -func NewTableConstraintDropActionRequest(columns []string) *TableConstraintDropActionRequest { - return &TableConstraintDropActionRequest{ - Columns: columns, - } +func NewTableConstraintDropActionRequest() *TableConstraintDropActionRequest { + return &TableConstraintDropActionRequest{} } func (s *TableConstraintDropActionRequest) WithConstraintName(constraintName *string) *TableConstraintDropActionRequest { @@ -1451,6 +1458,11 @@ func (s *TableConstraintDropActionRequest) WithForeignKey(foreignKey *bool) *Tab return s } +func (s *TableConstraintDropActionRequest) WithColumns(columns []string) *TableConstraintDropActionRequest { + s.Columns = columns + return s +} + func (s *TableConstraintDropActionRequest) WithCascade(cascade *bool) *TableConstraintDropActionRequest { s.Cascade = cascade return s @@ -1562,6 +1574,11 @@ func (s *TableExternalTableColumnAddActionRequest) WithExpression(expression str return s } +func (s *TableExternalTableColumnAddActionRequest) WithComment(comment *string) *TableExternalTableColumnAddActionRequest { + s.Comment = comment + return s +} + func NewTableExternalTableColumnRenameActionRequest() *TableExternalTableColumnRenameActionRequest { return &TableExternalTableColumnRenameActionRequest{} } diff --git a/pkg/sdk/tables_impl.go b/pkg/sdk/tables_impl.go index b3ca029285..5a34f0ac8e 100644 --- a/pkg/sdk/tables_impl.go +++ b/pkg/sdk/tables_impl.go @@ -254,6 +254,7 @@ func (r *TableExternalTableActionRequest) toOpts() *TableExternalTableAction { Name: r.Add.Name, Type: r.Add.Type, Expression: []string{r.Add.Expression}, + Comment: r.Add.Comment, }, } } @@ -392,6 +393,7 @@ func (r *TableColumnActionRequest) toOpts() *TableColumnAction { Type: r.Add.Type, DefaultValue: defaultValue, InlineConstraint: inlineConstraint, + Comment: r.Add.Comment, }, } } diff --git a/pkg/sdk/tables_test.go b/pkg/sdk/tables_test.go index 27574b9ffe..644c9594a7 100644 --- a/pkg/sdk/tables_test.go +++ b/pkg/sdk/tables_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/internal/random" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -406,7 +407,7 @@ func TestTableCreate(t *testing.T) { } require.NoError(t, err) outOfLineConstraint1 := OutOfLineConstraint{ - Name: "OUT_OF_LINE_CONSTRAINT", + Name: String("OUT_OF_LINE_CONSTRAINT"), Type: ColumnConstraintTypeForeignKey, Columns: []string{"COLUMN_1", "COLUMN_2"}, ForeignKey: &OutOfLineForeignKey{ @@ -475,7 +476,7 @@ func TestTableCreate(t *testing.T) { Comment: &tableComment, } assertOptsValidAndSQLEquals(t, opts, - `CREATE TABLE %s (%s %s CONSTRAINT INLINE_CONSTRAINT PRIMARY KEY NOT NULL COLLATE 'de' IDENTITY START 10 INCREMENT 1 ORDER MASKING POLICY %s USING (FOO, BAR) TAG ("db"."schema"."column_tag1" = 'v1', "db"."schema"."column_tag2" = 'v2') COMMENT '%s', CONSTRAINT OUT_OF_LINE_CONSTRAINT FOREIGN KEY (COLUMN_1, COLUMN_2) REFERENCES %s (COLUMN_3, COLUMN_4) MATCH FULL ON UPDATE SET NULL ON DELETE RESTRICT, CONSTRAINT UNIQUE (COLUMN_1) ENFORCED DEFERRABLE INITIALLY DEFERRED ENABLE RELY) CLUSTER BY (COLUMN_1, COLUMN_2) ENABLE_SCHEMA_EVOLUTION = true STAGE_FILE_FORMAT = (TYPE = CSV COMPRESSION = AUTO) STAGE_COPY_OPTIONS = (ON_ERROR = SKIP_FILE) DATA_RETENTION_TIME_IN_DAYS = 10 MAX_DATA_EXTENSION_TIME_IN_DAYS = 100 CHANGE_TRACKING = true DEFAULT_DDL_COLLATION = 'en' COPY GRANTS ROW ACCESS POLICY %s ON (COLUMN_1, COLUMN_2) TAG ("db"."schema"."table_tag1" = 'v1', "db"."schema"."table_tag2" = 'v2') COMMENT = '%s'`, + `CREATE TABLE %s (%s %s CONSTRAINT INLINE_CONSTRAINT PRIMARY KEY NOT NULL COLLATE 'de' IDENTITY START 10 INCREMENT 1 ORDER MASKING POLICY %s USING (FOO, BAR) TAG ("db"."schema"."column_tag1" = 'v1', "db"."schema"."column_tag2" = 'v2') COMMENT '%s', CONSTRAINT OUT_OF_LINE_CONSTRAINT FOREIGN KEY (COLUMN_1, COLUMN_2) REFERENCES %s (COLUMN_3, COLUMN_4) MATCH FULL ON UPDATE SET NULL ON DELETE RESTRICT, UNIQUE (COLUMN_1) ENFORCED DEFERRABLE INITIALLY DEFERRED ENABLE RELY) CLUSTER BY (COLUMN_1, COLUMN_2) ENABLE_SCHEMA_EVOLUTION = true STAGE_FILE_FORMAT = (TYPE = CSV COMPRESSION = AUTO) STAGE_COPY_OPTIONS = (ON_ERROR = SKIP_FILE) DATA_RETENTION_TIME_IN_DAYS = 10 MAX_DATA_EXTENSION_TIME_IN_DAYS = 100 CHANGE_TRACKING = true DEFAULT_DDL_COLLATION = 'en' COPY GRANTS ROW ACCESS POLICY %s ON (COLUMN_1, COLUMN_2) TAG ("db"."schema"."table_tag1" = 'v1', "db"."schema"."table_tag2" = 'v2') COMMENT = '%s'`, id.FullyQualifiedName(), columnName, columnType, @@ -813,17 +814,6 @@ func TestTableAlter(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("TableConstraintAlterAction", "ConstraintName", "PrimaryKey", "Unique", "ForeignKey", "Columns")) }) - t.Run("validation: constraint alter action - no columns", func(t *testing.T) { - opts := defaultOpts() - opts.ConstraintAction = &TableConstraintAction{ - Alter: &TableConstraintAlterAction{ - ConstraintName: String("constraint"), - Columns: []string{}, - }, - } - assertOptsInvalidJoinedErrors(t, opts, errNotSet("TableConstraintAlterAction", "Columns")) - }) - t.Run("validation: constraint alter action - two options present", func(t *testing.T) { opts := defaultOpts() opts.ConstraintAction = &TableConstraintAction{ @@ -843,17 +833,6 @@ func TestTableAlter(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("TableConstraintDropAction", "ConstraintName", "PrimaryKey", "Unique", "ForeignKey", "Columns")) }) - t.Run("validation: constraint drop action - no columns", func(t *testing.T) { - opts := defaultOpts() - opts.ConstraintAction = &TableConstraintAction{ - Drop: &TableConstraintDropAction{ - ConstraintName: String("constraint"), - Columns: []string{}, - }, - } - assertOptsInvalidJoinedErrors(t, opts, errNotSet("TableConstraintDropAction", "Columns")) - }) - t.Run("validation: constraint drop action - two options present", func(t *testing.T) { opts := defaultOpts() opts.ConstraintAction = &TableConstraintAction{ @@ -1191,7 +1170,7 @@ func TestTableAlter(t *testing.T) { t.Run("alter constraint: add", func(t *testing.T) { outOfLineConstraint := OutOfLineConstraint{ - Name: "OUT_OF_LINE_CONSTRAINT", + Name: String("OUT_OF_LINE_CONSTRAINT"), Type: ColumnConstraintTypeForeignKey, Columns: []string{"COLUMN_1", "COLUMN_2"}, ForeignKey: &OutOfLineForeignKey{ @@ -1571,3 +1550,29 @@ func TestTableDescribeStage(t *testing.T) { assertOptsValidAndSQLEquals(t, opts, `DESCRIBE TABLE %s TYPE = STAGE`, id.FullyQualifiedName()) }) } + +func TestTable_GetClusterByKeys(t *testing.T) { + t.Run("empty", func(t *testing.T) { + table := Table{ClusterBy: ""} + + assert.Nil(t, table.GetClusterByKeys()) + }) + + t.Run("one param", func(t *testing.T) { + table := Table{ClusterBy: "LINEAR(abc)"} + + assert.Equal(t, []string{"abc"}, table.GetClusterByKeys()) + }) + + t.Run("more params", func(t *testing.T) { + table := Table{ClusterBy: "LINEAR(abc,def)"} + + assert.Equal(t, []string{"abc", "def"}, table.GetClusterByKeys()) + }) + + t.Run("white space", func(t *testing.T) { + table := Table{ClusterBy: " LINEAR( abc , def )"} + + assert.Equal(t, []string{"abc", "def"}, table.GetClusterByKeys()) + }) +} diff --git a/pkg/sdk/tables_validations.go b/pkg/sdk/tables_validations.go index 6c6211bf9d..e1b2467753 100644 --- a/pkg/sdk/tables_validations.go +++ b/pkg/sdk/tables_validations.go @@ -226,9 +226,6 @@ func (opts *alterTableOptions) validate() error { ); !ok { errs = append(errs, errExactlyOneOf("TableConstraintAlterAction", "ConstraintName", "PrimaryKey", "Unique", "ForeignKey", "Columns")) } - if len(alterAction.Columns) == 0 { - errs = append(errs, errNotSet("TableConstraintAlterAction", "Columns")) - } } if dropAction := constraintAction.Drop; valueSet(dropAction) { if ok := exactlyOneValueSet( @@ -239,9 +236,6 @@ func (opts *alterTableOptions) validate() error { ); !ok { errs = append(errs, errExactlyOneOf("TableConstraintDropAction", "ConstraintName", "PrimaryKey", "Unique", "ForeignKey", "Columns")) } - if len(dropAction.Columns) == 0 { - errs = append(errs, errNotSet("TableConstraintDropAction", "Columns")) - } } if addAction := constraintAction.Add; valueSet(addAction) { if err := addAction.validate(); err != nil { diff --git a/pkg/sdk/testint/tables_integration_test.go b/pkg/sdk/testint/tables_integration_test.go index 2b4b850792..0228c3e732 100644 --- a/pkg/sdk/testint/tables_integration_test.go +++ b/pkg/sdk/testint/tables_integration_test.go @@ -124,7 +124,8 @@ func TestInt_Table(t *testing.T) { WithNotNull(sdk.Bool(true)), *sdk.NewTableColumnRequest("COLUMN_2", sdk.DataTypeNumber).WithDefaultValue(sdk.NewColumnDefaultValueRequest().WithIdentity(sdk.NewColumnIdentityRequest(1, 1))), } - outOfLineConstraint := sdk.NewOutOfLineConstraintRequest("OUT_OF_LINE_CONSTRAINT", sdk.ColumnConstraintTypeForeignKey). + outOfLineConstraint := sdk.NewOutOfLineConstraintRequest(sdk.ColumnConstraintTypeForeignKey). + WithName(sdk.String("OUT_OF_LINE_CONSTRAINT")). WithColumns([]string{"COLUMN_1"}). WithForeignKey(sdk.NewOutOfLineForeignKeyRequest(sdk.NewSchemaObjectIdentifier(database.Name, schema.Name, table2.Name), []string{"id"}). WithMatch(sdk.Pointer(sdk.FullMatchType)). @@ -476,7 +477,7 @@ func TestInt_Table(t *testing.T) { alterRequest := sdk.NewAlterTableRequest(id). WithColumnAction(sdk.NewTableColumnActionRequest(). - WithAdd(sdk.NewTableColumnAddActionRequest("COLUMN_3", sdk.DataTypeVARCHAR))) + WithAdd(sdk.NewTableColumnAddActionRequest("COLUMN_3", sdk.DataTypeVARCHAR).WithComment(sdk.String("some comment")))) err = client.Tables.Alter(ctx, alterRequest) require.NoError(t, err) @@ -670,7 +671,7 @@ func TestInt_Table(t *testing.T) { alterRequest := sdk.NewAlterTableRequest(id). WithConstraintAction(sdk.NewTableConstraintActionRequest(). - WithAdd(sdk.NewOutOfLineConstraintRequest("OUT_OF_LINE_CONSTRAINT", sdk.ColumnConstraintTypeForeignKey).WithColumns([]string{"COLUMN_1"}). + WithAdd(sdk.NewOutOfLineConstraintRequest(sdk.ColumnConstraintTypeForeignKey).WithName(sdk.String("OUT_OF_LINE_CONSTRAINT")).WithColumns([]string{"COLUMN_1"}). WithForeignKey(sdk.NewOutOfLineForeignKeyRequest(sdk.NewSchemaObjectIdentifier(database.Name, schema.Name, secondTableName), []string{"COLUMN_3"})))) err = client.Tables.Alter(ctx, alterRequest) require.NoError(t, err) @@ -685,7 +686,7 @@ func TestInt_Table(t *testing.T) { *sdk.NewTableColumnRequest("COLUMN_2", sdk.DataTypeVARCHAR), } oldConstraintName := "OUT_OF_LINE_CONSTRAINT" - outOfLineConstraint := sdk.NewOutOfLineConstraintRequest(oldConstraintName, sdk.ColumnConstraintTypePrimaryKey).WithColumns([]string{"COLUMN_1"}) + outOfLineConstraint := sdk.NewOutOfLineConstraintRequest(sdk.ColumnConstraintTypePrimaryKey).WithName(sdk.String(oldConstraintName)).WithColumns([]string{"COLUMN_1"}) err := client.Tables.Create(ctx, sdk.NewCreateTableRequest(id, columns).WithOutOfLineConstraint(*outOfLineConstraint)) require.NoError(t, err) @@ -703,7 +704,6 @@ func TestInt_Table(t *testing.T) { // TODO [SNOW-1007542]: check altered constraint t.Run("alter constraint: alter", func(t *testing.T) { - t.Skip("Test is failing: generated statement is not compiling but it is aligned with Snowflake docs https://docs.snowflake.com/en/sql-reference/sql/alter-table#syntax. Requires further investigation.") name := random.String() id := sdk.NewSchemaObjectIdentifier(database.Name, schema.Name, name) columns := []sdk.TableColumnRequest{ @@ -711,21 +711,20 @@ func TestInt_Table(t *testing.T) { *sdk.NewTableColumnRequest("COLUMN_2", sdk.DataTypeVARCHAR), } constraintName := "OUT_OF_LINE_CONSTRAINT" - outOfLineConstraint := sdk.NewOutOfLineConstraintRequest(constraintName, sdk.ColumnConstraintTypePrimaryKey).WithColumns([]string{"COLUMN_1"}) + outOfLineConstraint := sdk.NewOutOfLineConstraintRequest(sdk.ColumnConstraintTypePrimaryKey).WithName(sdk.String(constraintName)).WithColumns([]string{"COLUMN_1"}) err := client.Tables.Create(ctx, sdk.NewCreateTableRequest(id, columns).WithOutOfLineConstraint(*outOfLineConstraint)) require.NoError(t, err) t.Cleanup(cleanupTableProvider(id)) alterRequest := sdk.NewAlterTableRequest(id). - WithConstraintAction(sdk.NewTableConstraintActionRequest().WithAlter(sdk.NewTableConstraintAlterActionRequest([]string{"COLUMN_1"}).WithConstraintName(sdk.String(constraintName)).WithEnforced(sdk.Bool(true)))) + WithConstraintAction(sdk.NewTableConstraintActionRequest().WithAlter(sdk.NewTableConstraintAlterActionRequest().WithConstraintName(sdk.String(constraintName)).WithEnforced(sdk.Bool(true)))) err = client.Tables.Alter(ctx, alterRequest) require.NoError(t, err) }) // TODO [SNOW-1007542]: check dropped constraint - t.Run("alter constraint: drop", func(t *testing.T) { - t.Skip("Test is failing: generated statement is not compiling but it is aligned with Snowflake docs https://docs.snowflake.com/en/sql-reference/sql/alter-table#syntax. Requires further investigation.") + t.Run("alter constraint: drop constraint with name", func(t *testing.T) { name := random.String() id := sdk.NewSchemaObjectIdentifier(database.Name, schema.Name, name) columns := []sdk.TableColumnRequest{ @@ -733,19 +732,37 @@ func TestInt_Table(t *testing.T) { *sdk.NewTableColumnRequest("COLUMN_2", sdk.DataTypeVARCHAR), } constraintName := "OUT_OF_LINE_CONSTRAINT" - outOfLineConstraint := sdk.NewOutOfLineConstraintRequest(constraintName, sdk.ColumnConstraintTypePrimaryKey).WithColumns([]string{"COLUMN_1"}) + outOfLineConstraint := sdk.NewOutOfLineConstraintRequest(sdk.ColumnConstraintTypePrimaryKey).WithName(sdk.String(constraintName)).WithColumns([]string{"COLUMN_1"}) err := client.Tables.Create(ctx, sdk.NewCreateTableRequest(id, columns).WithOutOfLineConstraint(*outOfLineConstraint)) require.NoError(t, err) t.Cleanup(cleanupTableProvider(id)) alterRequest := sdk.NewAlterTableRequest(id). - WithConstraintAction(sdk.NewTableConstraintActionRequest().WithDrop(sdk.NewTableConstraintDropActionRequest([]string{"COLUMN_1"}).WithConstraintName(sdk.String(constraintName)))) + WithConstraintAction(sdk.NewTableConstraintActionRequest().WithDrop(sdk.NewTableConstraintDropActionRequest().WithConstraintName(sdk.String(constraintName)))) err = client.Tables.Alter(ctx, alterRequest) require.NoError(t, err) }) - t.Run("external table: add", func(t *testing.T) { + t.Run("alter constraint: drop primary key without constraint name", func(t *testing.T) { + name := random.String() + id := sdk.NewSchemaObjectIdentifier(database.Name, schema.Name, name) + columns := []sdk.TableColumnRequest{ + *sdk.NewTableColumnRequest("COLUMN_1", sdk.DataTypeVARCHAR), + } + outOfLineConstraint := sdk.NewOutOfLineConstraintRequest(sdk.ColumnConstraintTypePrimaryKey).WithColumns([]string{"COLUMN_1"}) + + err := client.Tables.Create(ctx, sdk.NewCreateTableRequest(id, columns).WithOutOfLineConstraint(*outOfLineConstraint)) + require.NoError(t, err) + t.Cleanup(cleanupTableProvider(id)) + + alterRequest := sdk.NewAlterTableRequest(id). + WithConstraintAction(sdk.NewTableConstraintActionRequest().WithDrop(sdk.NewTableConstraintDropActionRequest().WithPrimaryKey(sdk.Bool(true)))) + err = client.Tables.Alter(ctx, alterRequest) + require.NoError(t, err) + }) + + t.Run("external table: add column", func(t *testing.T) { name := random.String() id := sdk.NewSchemaObjectIdentifier(database.Name, schema.Name, name) columns := []sdk.TableColumnRequest{ @@ -758,7 +775,12 @@ func TestInt_Table(t *testing.T) { t.Cleanup(cleanupTableProvider(id)) alterRequest := sdk.NewAlterTableRequest(id). - WithExternalTableAction(sdk.NewTableExternalTableActionRequest().WithAdd(sdk.NewTableExternalTableColumnAddActionRequest().WithName("COLUMN_3").WithType(sdk.DataTypeNumber).WithExpression("1 + 1"))) + WithExternalTableAction(sdk.NewTableExternalTableActionRequest().WithAdd(sdk.NewTableExternalTableColumnAddActionRequest(). + WithName("COLUMN_3"). + WithType(sdk.DataTypeNumber). + WithExpression("1 + 1"). + WithComment(sdk.String("some comment")), + )) err = client.Tables.Alter(ctx, alterRequest) require.NoError(t, err) diff --git a/pkg/snowflake/table.go b/pkg/snowflake/table.go index 7aa6f4c855..58ba303894 100644 --- a/pkg/snowflake/table.go +++ b/pkg/snowflake/table.go @@ -1,420 +1,10 @@ package snowflake import ( - "database/sql" - "errors" "fmt" - "log" - "sort" - "strconv" - "strings" - - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" - - "github.com/jmoiron/sqlx" -) - -// PrimaryKey structure that represents a tables primary key. -type PrimaryKey struct { - name string - keys []string -} - -// WithName set the primary key name. -func (pk *PrimaryKey) WithName(name string) *PrimaryKey { - pk.name = name - return pk -} - -// WithKeys set the primary key keys. -func (pk *PrimaryKey) WithKeys(keys []string) *PrimaryKey { - pk.keys = keys - return pk -} - -type ColumnDefaultType int - -const ( - columnDefaultTypeConstant = iota - columnDefaultTypeSequence - columnDefaultTypeExpression ) -type ColumnDefault struct { - _type ColumnDefaultType - expression string -} - -type ColumnIdentity struct { - startNum int - stepNum int -} - -func (id *ColumnIdentity) WithStartNum(start int) *ColumnIdentity { - id.startNum = start - return id -} - -func (id *ColumnIdentity) WithStep(step int) *ColumnIdentity { - id.stepNum = step - return id -} - -func NewColumnDefaultWithConstant(constant string) *ColumnDefault { - return &ColumnDefault{ - _type: columnDefaultTypeConstant, - expression: constant, - } -} - -func NewColumnDefaultWithExpression(expression string) *ColumnDefault { - return &ColumnDefault{ - _type: columnDefaultTypeExpression, - expression: expression, - } -} - -func NewColumnDefaultWithSequence(sequence string) *ColumnDefault { - return &ColumnDefault{ - _type: columnDefaultTypeSequence, - expression: sequence, - } -} - -func (d *ColumnDefault) String(columnType string) string { - columnType = strings.ToUpper(columnType) - - switch { - case d._type == columnDefaultTypeExpression: - return d.expression - - case d._type == columnDefaultTypeSequence: - return fmt.Sprintf(`%v.NEXTVAL`, d.expression) - - case d._type == columnDefaultTypeConstant && (strings.Contains(columnType, "CHAR") || columnType == "STRING" || columnType == "TEXT"): - return EscapeSnowflakeString(d.expression) - - default: - return d.expression - } -} - -func (d *ColumnDefault) UnescapeConstantSnowflakeString(columnType string) string { - columnType = strings.ToUpper(columnType) - - if d._type == columnDefaultTypeConstant && (strings.Contains(columnType, "CHAR") || columnType == "STRING" || columnType == "TEXT") { - return UnescapeSnowflakeString(d.expression) - } - - return d.expression -} - -// Column structure that represents a table column. -type Column struct { - name string - _type string // type is reserved - nullable bool - _default *ColumnDefault // default is reserved - identity *ColumnIdentity - comment string // pointer as value is nullable - maskingPolicy string -} - -// WithName set the column name. -func (c *Column) WithName(name string) *Column { - c.name = name - return c -} - -// WithType set the column type. -func (c *Column) WithType(t string) *Column { - c._type = t - return c -} - -// WithNullable set if the column is nullable. -func (c *Column) WithNullable(nullable bool) *Column { - c.nullable = nullable - return c -} - -func (c *Column) WithDefault(cd *ColumnDefault) *Column { - c._default = cd - return c -} - -// WithComment set the column comment. -func (c *Column) WithComment(comment string) *Column { - c.comment = comment - return c -} - -func (c *Column) WithMaskingPolicy(maskingPolicy string) *Column { - c.maskingPolicy = maskingPolicy - return c -} - -func (c *Column) WithIdentity(id *ColumnIdentity) *Column { - c.identity = id - return c -} - -func (c *Column) getColumnDefinition(withInlineConstraints bool, withComment bool) string { - if c == nil { - return "" - } - var colDef strings.Builder - colDef.WriteString(fmt.Sprintf(`"%v" %v`, EscapeString(c.name), EscapeString(c._type))) - - if withInlineConstraints { - if !c.nullable { - colDef.WriteString(` NOT NULL`) - } - } - - if c._default != nil { - colDef.WriteString(fmt.Sprintf(` DEFAULT %v`, c._default.String(c._type))) - } - - if c.identity != nil { - colDef.WriteString(fmt.Sprintf(` IDENTITY(%v, %v)`, c.identity.startNum, c.identity.stepNum)) - } - - if strings.TrimSpace(c.maskingPolicy) != "" { - colDef.WriteString(fmt.Sprintf(` WITH MASKING POLICY %v`, EscapeString(c.maskingPolicy))) - } - - if withComment { - colDef.WriteString(fmt.Sprintf(` COMMENT '%v'`, EscapeString(c.comment))) - } - - return colDef.String() -} - -func FlattenTablePrimaryKey(pkds []PrimaryKeyDescription) []interface{} { - flattened := []interface{}{} - if len(pkds) == 0 { - return flattened - } - - sort.SliceStable(pkds, func(i, j int) bool { - num1, _ := strconv.Atoi(pkds[i].KeySequence.String) - num2, _ := strconv.Atoi(pkds[j].KeySequence.String) - return num1 < num2 - }) - // sort our keys on the key sequence - - flat := map[string]interface{}{} - keys := make([]string, 0, len(pkds)) - var name string - var nameSet bool - - for _, pk := range pkds { - // set as empty string, sys_constraint means it was an unnnamed constraint - if strings.Contains(pk.ConstraintName.String, "SYS_CONSTRAINT") && !nameSet { - name = "" - nameSet = true - } - if !nameSet { - name = pk.ConstraintName.String - nameSet = true - } - - keys = append(keys, pk.ColumnName.String) - } - - flat["name"] = name - flat["keys"] = keys - flattened = append(flattened, flat) - return flattened -} - -type Columns []Column - -// NewColumns generates columns from a table description. -func NewColumns(tds []TableDescription) Columns { - cs := []Column{} - for _, td := range tds { - if td.Kind.String != "COLUMN" { - continue - } - - cs = append(cs, Column{ - name: td.Name.String, - _type: td.Type.String, - nullable: td.IsNullable(), - _default: td.ColumnDefault(), - identity: td.ColumnIdentity(), - comment: td.Comment.String, - maskingPolicy: td.MaskingPolicy.String, - }) - } - return Columns(cs) -} - -func (c Columns) Flatten() []interface{} { - flattened := []interface{}{} - for _, col := range c { - flat := map[string]interface{}{} - flat["name"] = col.name - flat["type"] = col._type - flat["nullable"] = col.nullable - flat["comment"] = col.comment - flat["masking_policy"] = col.maskingPolicy - - if col._default != nil { - def := map[string]interface{}{} - switch col._default._type { - case columnDefaultTypeConstant: - def["constant"] = col._default.UnescapeConstantSnowflakeString(col._type) - case columnDefaultTypeExpression: - def["expression"] = col._default.expression - case columnDefaultTypeSequence: - def["sequence"] = col._default.expression - } - - flat["default"] = []interface{}{def} - } - - if col.identity != nil { - id := map[string]interface{}{} - id["start_num"] = col.identity.startNum - id["step_num"] = col.identity.stepNum - flat["identity"] = []interface{}{id} - } - flattened = append(flattened, flat) - } - return flattened -} - -func (c Columns) getColumnDefinitions(withInlineConstraints bool, withComments bool) string { - // TODO(el): verify Snowflake reflects column order back in desc table calls - columnDefinitions := []string{} - for _, column := range c { - columnDefinitions = append(columnDefinitions, column.getColumnDefinition(withInlineConstraints, withComments)) - } - - // NOTE: intentionally blank leading space - return fmt.Sprintf(" (%s)", strings.Join(columnDefinitions, ", ")) -} - -// TableBuilder abstracts the creation of SQL queries for a Snowflake schema. -type TableBuilder struct { - name string - db string - schema string - columns Columns - comment string - clusterBy []string - primaryKey PrimaryKey - dataRetentionTimeInDays *int - changeTracking bool - tags []TagValue -} - -// QualifiedName prepends the db and schema if set and escapes everything nicely. -func (tb *TableBuilder) QualifiedName() string { - var n strings.Builder - - if tb.db != "" && tb.schema != "" { - n.WriteString(fmt.Sprintf(`"%v"."%v".`, tb.db, tb.schema)) - } - - if tb.db != "" && tb.schema == "" { - n.WriteString(fmt.Sprintf(`"%v"..`, tb.db)) - } - - if tb.db == "" && tb.schema != "" { - n.WriteString(fmt.Sprintf(`"%v".`, tb.schema)) - } - - n.WriteString(fmt.Sprintf(`"%v"`, tb.name)) - - return n.String() -} - -// WithComment adds a comment to the TableBuilder. -func (tb *TableBuilder) WithComment(c string) *TableBuilder { - tb.comment = c - return tb -} - -// WithColumns sets the column definitions on the TableBuilder. -func (tb *TableBuilder) WithColumns(c Columns) *TableBuilder { - tb.columns = c - return tb -} - -// WithClustering adds cluster keys/expressions to TableBuilder. -func (tb *TableBuilder) WithClustering(c []string) *TableBuilder { - tb.clusterBy = c - return tb -} - -// WithPrimaryKey sets the primary key on the TableBuilder. -func (tb *TableBuilder) WithPrimaryKey(pk PrimaryKey) *TableBuilder { - tb.primaryKey = pk - return tb -} - -// WithDataRetentionTimeInDays sets the data retention time on the TableBuilder. -func (tb *TableBuilder) WithDataRetentionTimeInDays(days int) *TableBuilder { - tb.dataRetentionTimeInDays = sdk.Int(days) - return tb -} - -// WithChangeTracking sets the change tracking on the TableBuilder. -func (tb *TableBuilder) WithChangeTracking(changeTracking bool) *TableBuilder { - tb.changeTracking = changeTracking - return tb -} - -// WithTags sets the tags on the TableBuilder. -func (tb *TableBuilder) WithTags(tags []TagValue) *TableBuilder { - tb.tags = tags - return tb -} - -// AddTag returns the SQL query that will add a new tag to the table. -func (tb *TableBuilder) AddTag(tag TagValue) string { - return fmt.Sprintf(`ALTER TABLE %s SET TAG "%v"."%v"."%v" = "%v"`, tb.QualifiedName(), tag.Database, tag.Schema, tag.Name, tag.Value) -} - -// ChangeTag returns the SQL query that will alter a tag on the table. -func (tb *TableBuilder) ChangeTag(tag TagValue) string { - return fmt.Sprintf(`ALTER TABLE %s SET TAG "%v"."%v"."%v" = "%v"`, tb.QualifiedName(), tag.Database, tag.Schema, tag.Name, tag.Value) -} - -// UnsetTag returns the SQL query that will unset a tag on the table. -func (tb *TableBuilder) UnsetTag(tag TagValue) string { - return fmt.Sprintf(`ALTER TABLE %s UNSET TAG "%v"."%v"."%v"`, tb.QualifiedName(), tag.Database, tag.Schema, tag.Name) -} - -// Function to get clustering definition. -func (tb *TableBuilder) GetClusterKeyString() string { - return JoinStringList(tb.clusterBy, ", ") -} - -func (tb *TableBuilder) GetTagValueString() string { - var q strings.Builder - for _, v := range tb.tags { - fmt.Println(v) - if v.Schema != "" { - if v.Database != "" { - q.WriteString(fmt.Sprintf(`"%v".`, v.Database)) - } - q.WriteString(fmt.Sprintf(`"%v".`, v.Schema)) - } - q.WriteString(fmt.Sprintf(`"%v" = "%v", `, v.Name, v.Value)) - } - return strings.TrimSuffix(q.String(), ", ") -} - -func JoinStringList(instrings []string, delimiter string) string { - return fmt.Sprint(strings.Join(instrings, delimiter)) -} - -func quoteStringList(instrings []string) []string { +func QuoteStringList(instrings []string) []string { clean := make([]string, 0, len(instrings)) for _, word := range instrings { quoted := fmt.Sprintf(`"%s"`, word) @@ -422,351 +12,3 @@ func quoteStringList(instrings []string) []string { } return clean } - -func (tb *TableBuilder) getCreateStatementBody() string { - var q strings.Builder - - colDef := tb.columns.getColumnDefinitions(true, true) - - if len(tb.primaryKey.keys) > 0 { - colDef = strings.TrimSuffix(colDef, ")") // strip trailing - q.WriteString(colDef) - if tb.primaryKey.name != "" { - q.WriteString(fmt.Sprintf(` ,CONSTRAINT "%v" PRIMARY KEY(%v)`, tb.primaryKey.name, JoinStringList(quoteStringList(tb.primaryKey.keys), ","))) - } else { - q.WriteString(fmt.Sprintf(` ,PRIMARY KEY(%v)`, JoinStringList(quoteStringList(tb.primaryKey.keys), ","))) - } - - q.WriteString(")") // add closing - } else { - q.WriteString(colDef) - } - - return q.String() -} - -// function to take the literal snowflake cluster statement returned from SHOW TABLES and convert it to a list of keys. -func ClusterStatementToList(clusterStatement string) []string { - if clusterStatement == "" { - return nil - } - - cleanStatement := strings.TrimSuffix(strings.Replace(clusterStatement, "LINEAR(", "", 1), ")") - // remove cluster statement and trailing parenthesis - - spCleanStatement := strings.Split(cleanStatement, ",") - clean := make([]string, 0, len(spCleanStatement)) - for _, s := range spCleanStatement { - clean = append(clean, strings.TrimSpace(s)) - } - - return clean -} - -// Table returns a pointer to a Builder that abstracts the DDL operations for a table. -// -// Supported DDL operations are: -// - ALTER TABLE -// - DROP TABLE -// - SHOW TABLES -// -// [Snowflake Reference](https://docs.snowflake.com/en/sql-reference/ddl-table.html) -func NewTableBuilder(name, db, schema string) *TableBuilder { - return &TableBuilder{ - name: name, - db: db, - schema: schema, - } -} - -// Table returns a pointer to a Builder that abstracts the DDL operations for a table. -// -// Supported DDL operations are: -// - CREATE TABLE -// -// [Snowflake Reference](https://docs.snowflake.com/en/sql-reference/ddl-table.html) -func NewTableWithColumnDefinitionsBuilder(name, db, schema string, columns Columns) *TableBuilder { - return &TableBuilder{ - name: name, - db: db, - schema: schema, - columns: columns, - } -} - -// Create returns the SQL statement required to create a table. -func (tb *TableBuilder) Create() string { - q := strings.Builder{} - q.WriteString(fmt.Sprintf(`CREATE TABLE %v`, tb.QualifiedName())) - q.WriteString(tb.getCreateStatementBody()) - - if tb.comment != "" { - q.WriteString(fmt.Sprintf(` COMMENT = '%v'`, EscapeString(tb.comment))) - } - - if tb.clusterBy != nil { - // add optional clustering statement - q.WriteString(fmt.Sprintf(` CLUSTER BY LINEAR(%v)`, tb.GetClusterKeyString())) - } - - if tb.dataRetentionTimeInDays != nil { - q.WriteString(fmt.Sprintf(` DATA_RETENTION_TIME_IN_DAYS = %d`, *tb.dataRetentionTimeInDays)) - } - q.WriteString(fmt.Sprintf(` CHANGE_TRACKING = %t`, tb.changeTracking)) - - if tb.tags != nil { - q.WriteString(fmt.Sprintf(` WITH TAG (%v)`, tb.GetTagValueString())) - } - - return q.String() -} - -// ChangeClusterBy returns the SQL query to change cluastering on table. -func (tb *TableBuilder) ChangeClusterBy(cb string) string { - return fmt.Sprintf(`ALTER TABLE %v CLUSTER BY LINEAR(%v)`, tb.QualifiedName(), cb) -} - -// ChangeComment returns the SQL query that will update the comment on the table. -func (tb *TableBuilder) ChangeComment(c string) string { - return fmt.Sprintf(`ALTER TABLE %v SET COMMENT = '%v'`, tb.QualifiedName(), EscapeString(c)) -} - -// ChangeDataRetention returns the SQL query that will update the DATA_RETENTION_TIME_IN_DAYS on the table. -func (tb *TableBuilder) ChangeDataRetention(days int) string { - return fmt.Sprintf(`ALTER TABLE %v SET DATA_RETENTION_TIME_IN_DAYS = %d`, tb.QualifiedName(), days) -} - -// ChangeChangeTracking returns the SQL query that will update the CHANGE_TRACKING on the table. -func (tb *TableBuilder) ChangeChangeTracking(changeTracking bool) string { - return fmt.Sprintf(`ALTER TABLE %v SET CHANGE_TRACKING = %t`, tb.QualifiedName(), changeTracking) -} - -// AddColumn returns the SQL query that will add a new column to the table. -func (tb *TableBuilder) AddColumn(name string, dataType string, nullable bool, _default *ColumnDefault, identity *ColumnIdentity, comment string, maskingPolicy string) string { - col := Column{ - name: name, - _type: dataType, - nullable: nullable, - _default: _default, - identity: identity, - comment: comment, - maskingPolicy: maskingPolicy, - } - return fmt.Sprintf(`ALTER TABLE %s ADD COLUMN %s`, tb.QualifiedName(), col.getColumnDefinition(true, true)) -} - -// DropColumn returns the SQL query that will add a new column to the table. -func (tb *TableBuilder) DropColumn(name string) string { - return fmt.Sprintf(`ALTER TABLE %s DROP COLUMN "%s"`, tb.QualifiedName(), name) -} - -// ChangeColumnType returns the SQL query that will change the type of the named column to the given type. -func (tb *TableBuilder) ChangeColumnType(name string, dataType string) string { - col := Column{ - name: name, - _type: dataType, - } - - return fmt.Sprintf(`ALTER TABLE %s MODIFY COLUMN %s`, tb.QualifiedName(), col.getColumnDefinition(false, false)) -} - -func (tb *TableBuilder) ChangeColumnComment(name string, comment string) string { - return fmt.Sprintf(`ALTER TABLE %s MODIFY COLUMN "%v" COMMENT '%v'`, tb.QualifiedName(), EscapeString(name), EscapeString(comment)) -} - -func (tb *TableBuilder) ChangeColumnMaskingPolicy(name string, maskingPolicy string) string { - if strings.TrimSpace(maskingPolicy) == "" { - return fmt.Sprintf(`ALTER TABLE %s MODIFY COLUMN "%v" UNSET MASKING POLICY`, tb.QualifiedName(), EscapeString(name)) - } - return fmt.Sprintf(`ALTER TABLE %s MODIFY COLUMN "%v" SET MASKING POLICY %v`, tb.QualifiedName(), EscapeString(name), EscapeString(maskingPolicy)) -} - -func (tb *TableBuilder) DropColumnDefault(name string) string { - return fmt.Sprintf(`ALTER TABLE %s MODIFY COLUMN "%v" DROP DEFAULT`, tb.QualifiedName(), EscapeString(name)) -} - -// RemoveComment returns the SQL query that will remove the comment on the table. -func (tb *TableBuilder) RemoveComment() string { - return fmt.Sprintf(`ALTER TABLE %v UNSET COMMENT`, tb.QualifiedName()) -} - -// Return sql to set/unset null constraint on column. -func (tb *TableBuilder) ChangeNullConstraint(name string, nullable bool) string { - if nullable { - return fmt.Sprintf(`ALTER TABLE %s MODIFY COLUMN "%s" DROP NOT NULL`, tb.QualifiedName(), name) - } - return fmt.Sprintf(`ALTER TABLE %s MODIFY COLUMN "%s" SET NOT NULL`, tb.QualifiedName(), name) -} - -func (tb *TableBuilder) ChangePrimaryKey(newPk PrimaryKey) string { - tb.WithPrimaryKey(newPk) - pks := JoinStringList(quoteStringList(newPk.keys), ", ") - if tb.primaryKey.name != "" { - return fmt.Sprintf(`ALTER TABLE %s ADD CONSTRAINT "%v" PRIMARY KEY(%v)`, tb.QualifiedName(), tb.primaryKey.name, pks) - } - return fmt.Sprintf(`ALTER TABLE %s ADD PRIMARY KEY(%v)`, tb.QualifiedName(), pks) -} - -func (tb *TableBuilder) DropPrimaryKey() string { - return fmt.Sprintf(`ALTER TABLE %s DROP PRIMARY KEY`, tb.QualifiedName()) -} - -// RemoveClustering returns the SQL query that will remove data clustering from the table. -func (tb *TableBuilder) DropClustering() string { - return fmt.Sprintf(`ALTER TABLE %v DROP CLUSTERING KEY`, tb.QualifiedName()) -} - -// Drop returns the SQL query that will drop a table. -func (tb *TableBuilder) Drop() string { - return fmt.Sprintf(`DROP TABLE %v`, tb.QualifiedName()) -} - -// Show returns the SQL query that will show a table. -func (tb *TableBuilder) Show() string { - return fmt.Sprintf(`SHOW TABLES LIKE '%v' IN SCHEMA "%v"."%v"`, tb.name, tb.db, tb.schema) -} - -func (tb *TableBuilder) ShowColumns() string { - return fmt.Sprintf(`DESC TABLE %s`, tb.QualifiedName()) -} - -func (tb *TableBuilder) ShowPrimaryKeys() string { - return fmt.Sprintf(`SHOW PRIMARY KEYS IN TABLE %s`, tb.QualifiedName()) -} - -func (tb *TableBuilder) Rename(newName string) string { - oldName := tb.QualifiedName() - tb.name = newName - return fmt.Sprintf(`ALTER TABLE %s RENAME TO %s`, oldName, tb.QualifiedName()) -} - -type Table struct { - CreatedOn sql.NullString `db:"created_on"` - TableName sql.NullString `db:"name"` - DatabaseName sql.NullString `db:"database_name"` - SchemaName sql.NullString `db:"schema_name"` - Kind sql.NullString `db:"kind"` - Comment sql.NullString `db:"comment"` - ClusterBy sql.NullString `db:"cluster_by"` - Rows sql.NullString `db:"row"` - Bytes sql.NullString `db:"bytes"` - Owner sql.NullString `db:"owner"` - RetentionTime sql.NullInt32 `db:"retention_time"` - AutomaticClustering sql.NullString `db:"automatic_clustering"` - ChangeTracking sql.NullString `db:"change_tracking"` - IsExternal sql.NullString `db:"is_external"` -} - -func ScanTable(row *sqlx.Row) (*Table, error) { - t := &Table{} - e := row.StructScan(t) - return t, e -} - -type TableDescription struct { - Name sql.NullString `db:"name"` - Type sql.NullString `db:"type"` - Kind sql.NullString `db:"kind"` - Nullable sql.NullString `db:"null?"` - Default sql.NullString `db:"default"` - Comment sql.NullString `db:"comment"` - MaskingPolicy sql.NullString `db:"policy name"` -} - -func (td *TableDescription) IsNullable() bool { - return td.Nullable.String == "Y" -} - -func (td *TableDescription) ColumnDefault() *ColumnDefault { - if !td.Default.Valid { - return nil - } - - if strings.HasSuffix(td.Default.String, ".NEXTVAL") { - return NewColumnDefaultWithSequence(strings.TrimSuffix(td.Default.String, ".NEXTVAL")) - } - - if strings.Contains(td.Default.String, "(") && strings.Contains(td.Default.String, ")") { - return NewColumnDefaultWithExpression(td.Default.String) - } - - if strings.Contains(td.Type.String, "CHAR") || td.Type.String == "STRING" || td.Type.String == "TEXT" { - return NewColumnDefaultWithConstant(UnescapeSnowflakeString(td.Default.String)) - } - - if td.ColumnIdentity() != nil { - /* - Identity/autoincrement information is stored in the same column as default information. We want to handle the identity separate so will return nil - here if identity information is present. Default/identity are mutually exclusive - */ - return nil - } - - return NewColumnDefaultWithConstant(td.Default.String) -} - -func (td *TableDescription) ColumnIdentity() *ColumnIdentity { - // if autoincrement is used this is reflected back IDENTITY START 1 INCREMENT 1 - if !td.Default.Valid { - return nil - } - if strings.Contains(td.Default.String, "IDENTITY") { - split := strings.Split(td.Default.String, " ") - start, _ := strconv.Atoi(split[2]) - step, _ := strconv.Atoi(split[4]) - - return &ColumnIdentity{start, step} - } - return nil -} - -type PrimaryKeyDescription struct { - ColumnName sql.NullString `db:"column_name"` - KeySequence sql.NullString `db:"key_sequence"` - ConstraintName sql.NullString `db:"constraint_name"` -} - -func ScanTableDescription(rows *sqlx.Rows) ([]TableDescription, error) { - tds := []TableDescription{} - for rows.Next() { - td := TableDescription{} - err := rows.StructScan(&td) - if err != nil { - return nil, err - } - tds = append(tds, td) - } - return tds, rows.Err() -} - -func ScanPrimaryKeyDescription(rows *sqlx.Rows) ([]PrimaryKeyDescription, error) { - pkds := []PrimaryKeyDescription{} - for rows.Next() { - pk := PrimaryKeyDescription{} - err := rows.StructScan(&pk) - if err != nil { - return nil, err - } - pkds = append(pkds, pk) - } - return pkds, rows.Err() -} - -func ListTables(databaseName string, schemaName string, db *sql.DB) ([]Table, error) { - stmt := fmt.Sprintf(`SHOW TABLES IN SCHEMA "%s"."%v"`, databaseName, schemaName) - rows, err := Query(db, stmt) - if err != nil { - return nil, err - } - defer rows.Close() - - dbs := []Table{} - if err := sqlx.StructScan(rows, &dbs); err != nil { - if errors.Is(err, sql.ErrNoRows) { - log.Println("[DEBUG] no tables found") - return nil, nil - } - return nil, fmt.Errorf("unable to scan row for %s err = %w", stmt, err) - } - return dbs, nil -} From 2de942acb308505256b2a7830c2af48106c18c4f Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Fri, 2 Feb 2024 16:38:11 +0100 Subject: [PATCH 5/5] fix: Fix tag tests in view and in materialized view (#2457) Added generated tag names because sometimes tests were failing because the name was already occupied. --- .../materialized_view_acceptance_test.go | 52 +++++++++++-------- .../testdata/TestAcc_View_Tags/1/test.tf | 4 +- .../testdata/TestAcc_View_Tags/1/variables.tf | 8 +++ .../testdata/TestAcc_View_Tags/2/test.tf | 4 +- .../testdata/TestAcc_View_Tags/2/variables.tf | 8 +++ pkg/resources/view_acceptance_test.go | 8 ++- 6 files changed, 55 insertions(+), 29 deletions(-) diff --git a/pkg/resources/materialized_view_acceptance_test.go b/pkg/resources/materialized_view_acceptance_test.go index 9b5aedfb07..a08a341351 100644 --- a/pkg/resources/materialized_view_acceptance_test.go +++ b/pkg/resources/materialized_view_acceptance_test.go @@ -102,6 +102,8 @@ func TestAcc_MaterializedView(t *testing.T) { func TestAcc_MaterializedView_Tags(t *testing.T) { tableName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) viewName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) + tag1Name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) + tag2Name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) queryEscaped := fmt.Sprintf("SELECT ID FROM \\\"%s\\\"", tableName) @@ -115,20 +117,20 @@ func TestAcc_MaterializedView_Tags(t *testing.T) { Steps: []resource.TestStep{ // create tags { - Config: materializedViewConfigWithTags(acc.TestWarehouseName, tableName, viewName, queryEscaped, acc.TestDatabaseName, acc.TestSchemaName, "test_tag"), + Config: materializedViewConfigWithTags(acc.TestWarehouseName, tableName, viewName, queryEscaped, acc.TestDatabaseName, acc.TestSchemaName, "test_tag", tag1Name, tag2Name), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_materialized_view.test", "name", viewName), resource.TestCheckResourceAttr("snowflake_materialized_view.test", "tag.#", "1"), - resource.TestCheckResourceAttr("snowflake_materialized_view.test", "tag.0.name", "tag1"), + resource.TestCheckResourceAttr("snowflake_materialized_view.test", "tag.0.name", tag1Name), ), }, // update tags { - Config: materializedViewConfigWithTags(acc.TestWarehouseName, tableName, viewName, queryEscaped, acc.TestDatabaseName, acc.TestSchemaName, "test_tag_2"), + Config: materializedViewConfigWithTags(acc.TestWarehouseName, tableName, viewName, queryEscaped, acc.TestDatabaseName, acc.TestSchemaName, "test_tag_2", tag1Name, tag2Name), Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_materialized_view.test", "name", viewName), resource.TestCheckResourceAttr("snowflake_materialized_view.test", "tag.#", "1"), - resource.TestCheckResourceAttr("snowflake_materialized_view.test", "tag.0.name", "tag2"), + resource.TestCheckResourceAttr("snowflake_materialized_view.test", "tag.0.name", tag2Name), ), }, // IMPORT @@ -209,12 +211,12 @@ resource "snowflake_materialized_view" "test" { `, tableName, databaseName, schemaName, viewName, comment, databaseName, schemaName, warehouseName, isSecure, orReplace, q) } -func materializedViewConfigWithTags(warehouseName string, tableName string, viewName string, q string, databaseName string, schemaName string, tag string) string { +func materializedViewConfigWithTags(warehouseName string, tableName string, viewName string, q string, databaseName string, schemaName string, tag string, tag1Name string, tag2Name string) string { return fmt.Sprintf(` resource "snowflake_table" "test" { - name = "%s" - database = "%s" - schema = "%s" + name = "%[1]s" + database = "%[2]s" + schema = "%[3]s" column { name = "ID" @@ -223,32 +225,36 @@ resource "snowflake_table" "test" { } resource "snowflake_tag" "test_tag" { - name = "tag1" - database = "%s" - schema = "%s" + name = "%[8]s" + database = "%[2]s" + schema = "%[3]s" } resource "snowflake_tag" "test_tag_2" { - name = "tag2" - database = "%s" - schema = "%s" + name = "%[9]s" + database = "%[2]s" + schema = "%[3]s" } resource "snowflake_materialized_view" "test" { - name = "%s" - database = "%s" - schema = "%s" - warehouse = "%s" - statement = "%s" + name = "%[4]s" + database = "%[2]s" + schema = "%[3]s" + warehouse = "%[5]s" + statement = "%[6]s" tag { - name = snowflake_tag.%s.name - schema = snowflake_tag.%s.schema - database = snowflake_tag.%s.database + name = snowflake_tag.%[7]s.name + schema = snowflake_tag.%[7]s.schema + database = snowflake_tag.%[7]s.database value = "some_value" } + + depends_on = [ + snowflake_table.test + ] } -`, tableName, databaseName, schemaName, databaseName, schemaName, databaseName, schemaName, viewName, databaseName, schemaName, warehouseName, q, tag, tag, tag) +`, tableName, databaseName, schemaName, viewName, warehouseName, q, tag, tag1Name, tag2Name) } func testAccCheckMaterializedViewDestroy(s *terraform.State) error { diff --git a/pkg/resources/testdata/TestAcc_View_Tags/1/test.tf b/pkg/resources/testdata/TestAcc_View_Tags/1/test.tf index 5de130cf64..b27dd3db42 100644 --- a/pkg/resources/testdata/TestAcc_View_Tags/1/test.tf +++ b/pkg/resources/testdata/TestAcc_View_Tags/1/test.tf @@ -1,11 +1,11 @@ resource "snowflake_tag" "test_tag" { - name = "tag1" + name = var.tag1Name database = var.database schema = var.schema } resource "snowflake_tag" "test_tag_2" { - name = "tag2" + name = var.tag2Name database = var.database schema = var.schema } diff --git a/pkg/resources/testdata/TestAcc_View_Tags/1/variables.tf b/pkg/resources/testdata/TestAcc_View_Tags/1/variables.tf index 5b5810d23d..7d7074a5a9 100644 --- a/pkg/resources/testdata/TestAcc_View_Tags/1/variables.tf +++ b/pkg/resources/testdata/TestAcc_View_Tags/1/variables.tf @@ -13,3 +13,11 @@ variable "schema" { variable "statement" { type = string } + +variable "tag1Name" { + type = string +} + +variable "tag2Name" { + type = string +} diff --git a/pkg/resources/testdata/TestAcc_View_Tags/2/test.tf b/pkg/resources/testdata/TestAcc_View_Tags/2/test.tf index d8c47b7fcb..933fe9613e 100644 --- a/pkg/resources/testdata/TestAcc_View_Tags/2/test.tf +++ b/pkg/resources/testdata/TestAcc_View_Tags/2/test.tf @@ -1,11 +1,11 @@ resource "snowflake_tag" "test_tag" { - name = "tag1" + name = var.tag1Name database = var.database schema = var.schema } resource "snowflake_tag" "test_tag_2" { - name = "tag2" + name = var.tag2Name database = var.database schema = var.schema } diff --git a/pkg/resources/testdata/TestAcc_View_Tags/2/variables.tf b/pkg/resources/testdata/TestAcc_View_Tags/2/variables.tf index 5b5810d23d..7d7074a5a9 100644 --- a/pkg/resources/testdata/TestAcc_View_Tags/2/variables.tf +++ b/pkg/resources/testdata/TestAcc_View_Tags/2/variables.tf @@ -13,3 +13,11 @@ variable "schema" { variable "statement" { type = string } + +variable "tag1Name" { + type = string +} + +variable "tag2Name" { + type = string +} diff --git a/pkg/resources/view_acceptance_test.go b/pkg/resources/view_acceptance_test.go index 69885dbb3a..ded8f0aabb 100644 --- a/pkg/resources/view_acceptance_test.go +++ b/pkg/resources/view_acceptance_test.go @@ -124,6 +124,8 @@ func TestAcc_View(t *testing.T) { func TestAcc_View_Tags(t *testing.T) { viewName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) + tag1Name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) + tag2Name := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha)) query := "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES" @@ -133,6 +135,8 @@ func TestAcc_View_Tags(t *testing.T) { "database": config.StringVariable(acc.TestDatabaseName), "schema": config.StringVariable(acc.TestSchemaName), "statement": config.StringVariable(query), + "tag1Name": config.StringVariable(tag1Name), + "tag2Name": config.StringVariable(tag2Name), } } @@ -151,7 +155,7 @@ func TestAcc_View_Tags(t *testing.T) { Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_view.test", "name", viewName), resource.TestCheckResourceAttr("snowflake_view.test", "tag.#", "1"), - resource.TestCheckResourceAttr("snowflake_view.test", "tag.0.name", "tag1"), + resource.TestCheckResourceAttr("snowflake_view.test", "tag.0.name", tag1Name), resource.TestCheckResourceAttr("snowflake_view.test", "tag.0.value", "some_value"), ), }, @@ -162,7 +166,7 @@ func TestAcc_View_Tags(t *testing.T) { Check: resource.ComposeTestCheckFunc( resource.TestCheckResourceAttr("snowflake_view.test", "name", viewName), resource.TestCheckResourceAttr("snowflake_view.test", "tag.#", "1"), - resource.TestCheckResourceAttr("snowflake_view.test", "tag.0.name", "tag2"), + resource.TestCheckResourceAttr("snowflake_view.test", "tag.0.name", tag2Name), resource.TestCheckResourceAttr("snowflake_view.test", "tag.0.value", "some_value"), ), },