Skip to content

Commit

Permalink
feat: Use materialized views and views from SDK (#2448)
Browse files Browse the repository at this point in the history
- views and materialized views are similar
- set/unset tags was tested and fixed for both
- tests were fixed
  • Loading branch information
sfc-gh-asawicki authored Feb 1, 2024
1 parent 973b8f7 commit dc66d02
Show file tree
Hide file tree
Showing 25 changed files with 963 additions and 1,423 deletions.
44 changes: 21 additions & 23 deletions pkg/datasources/materialized_views.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package datasources

import (
"context"
"database/sql"
"errors"
"fmt"
"log"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
)

Expand Down Expand Up @@ -58,34 +58,32 @@ func MaterializedViews() *schema.Resource {

func ReadMaterializedViews(d *schema.ResourceData, meta interface{}) error {
db := meta.(*sql.DB)
ctx := context.Background()
client := sdk.NewClientFromDB(db)
databaseName := d.Get("database").(string)
schemaName := d.Get("schema").(string)

currentViews, err := snowflake.ListMaterializedViews(databaseName, schemaName, db)
if errors.Is(err, sql.ErrNoRows) {
// If not found, mark resource to be removed from state file during apply or refresh
log.Printf("[DEBUG] materialized views in schema (%s) not found", d.Id())
d.SetId("")
return nil
} else if err != nil {
log.Printf("[DEBUG] materialized unable to parse views in schema (%s)", d.Id())
schemaId := sdk.NewDatabaseObjectIdentifier(databaseName, schemaName)
extractedMaterializedViews, err := client.MaterializedViews.Show(ctx, sdk.NewShowMaterializedViewRequest().WithIn(
&sdk.In{Schema: schemaId},
))
if err != nil {
log.Printf("[DEBUG] failed when searching materialized views in schema (%s), err = %s", schemaId.FullyQualifiedName(), err.Error())
d.SetId("")
return nil
}

views := []map[string]interface{}{}

for _, view := range currentViews {
viewMap := map[string]interface{}{}

viewMap["name"] = view.Name.String
viewMap["database"] = view.DatabaseName.String
viewMap["schema"] = view.SchemaName.String
viewMap["comment"] = view.Comment.String
materializedViews := make([]map[string]any, len(extractedMaterializedViews))

views = append(views, viewMap)
for i, materializedView := range extractedMaterializedViews {
materializedViews[i] = map[string]any{
"name": materializedView.Name,
"database": materializedView.DatabaseName,
"schema": materializedView.SchemaName,
"comment": materializedView.Comment,
}
}

d.SetId(fmt.Sprintf(`%v|%v`, databaseName, schemaName))
return d.Set("materialized_views", views)
d.SetId(helpers.EncodeSnowflakeID(databaseName, schemaName))
return d.Set("materialized_views", materializedViews)
}
11 changes: 9 additions & 2 deletions pkg/datasources/materialized_views_acceptance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ import (
"strings"
"testing"

acc "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance"

"github.com/hashicorp/terraform-plugin-testing/helper/acctest"
"github.com/hashicorp/terraform-plugin-testing/helper/resource"
"github.com/hashicorp/terraform-plugin-testing/tfversion"
)

func TestAcc_MaterializedViews(t *testing.T) {
Expand All @@ -15,8 +18,12 @@ func TestAcc_MaterializedViews(t *testing.T) {
schemaName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))
tableName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))
viewName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))
resource.ParallelTest(t, resource.TestCase{
Providers: providers(),
resource.Test(t, resource.TestCase{
ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories,
PreCheck: func() { acc.TestAccPreCheck(t) },
TerraformVersionChecks: []tfversion.TerraformVersionCheck{
tfversion.RequireAbove(tfversion.Version1_5_0),
},
CheckDestroy: nil,
Steps: []resource.TestStep{
{
Expand Down
42 changes: 20 additions & 22 deletions pkg/datasources/views.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package datasources

import (
"context"
"database/sql"
"errors"
"fmt"
"log"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
)

Expand Down Expand Up @@ -58,34 +58,32 @@ func Views() *schema.Resource {

func ReadViews(d *schema.ResourceData, meta interface{}) error {
db := meta.(*sql.DB)
ctx := context.Background()
client := sdk.NewClientFromDB(db)
databaseName := d.Get("database").(string)
schemaName := d.Get("schema").(string)

currentViews, err := snowflake.ListViews(databaseName, schemaName, db)
if errors.Is(err, sql.ErrNoRows) {
// If not found, mark resource to be removed from state file during apply or refresh
log.Printf("[DEBUG] views in schema (%s) not found", d.Id())
d.SetId("")
return nil
} else if err != nil {
log.Printf("[DEBUG] unable to parse views in schema (%s)", d.Id())
schemaId := sdk.NewDatabaseObjectIdentifier(databaseName, schemaName)
extractedViews, err := client.Views.Show(ctx, sdk.NewShowViewRequest().WithIn(
&sdk.In{Schema: schemaId},
))
if err != nil {
log.Printf("[DEBUG] failed when searching views in schema (%s), err = %s", schemaId.FullyQualifiedName(), err.Error())
d.SetId("")
return nil
}

views := []map[string]interface{}{}

for _, view := range currentViews {
viewMap := map[string]interface{}{}

viewMap["name"] = view.Name.String
viewMap["database"] = view.DatabaseName.String
viewMap["schema"] = view.SchemaName.String
viewMap["comment"] = view.Comment.String
views := make([]map[string]any, len(extractedViews))

views = append(views, viewMap)
for i, view := range extractedViews {
views[i] = map[string]any{
"name": view.Name,
"database": view.DatabaseName,
"schema": view.SchemaName,
"comment": view.Comment,
}
}

d.SetId(fmt.Sprintf(`%v|%v`, databaseName, schemaName))
d.SetId(helpers.EncodeSnowflakeID(databaseName, schemaName))
return d.Set("views", views)
}
9 changes: 8 additions & 1 deletion pkg/datasources/views_acceptance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,23 @@ import (
"strings"
"testing"

acc "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance"

"github.com/hashicorp/terraform-plugin-testing/helper/acctest"
"github.com/hashicorp/terraform-plugin-testing/helper/resource"
"github.com/hashicorp/terraform-plugin-testing/tfversion"
)

func TestAcc_Views(t *testing.T) {
databaseName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))
schemaName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))
viewName := strings.ToUpper(acctest.RandStringFromCharSet(10, acctest.CharSetAlpha))
resource.ParallelTest(t, resource.TestCase{
Providers: providers(),
ProtoV6ProviderFactories: acc.TestAccProtoV6ProviderFactories,
PreCheck: func() { acc.TestAccPreCheck(t) },
TerraformVersionChecks: []tfversion.TerraformVersionCheck{
tfversion.RequireAbove(tfversion.Version1_5_0),
},
CheckDestroy: nil,
Steps: []resource.TestStep{
{
Expand Down
25 changes: 14 additions & 11 deletions pkg/resources/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,27 +74,30 @@ func getTagObjectIdentifier(v map[string]any) sdk.ObjectIdentifier {

func getPropertyTags(d *schema.ResourceData, key string) []sdk.TagAssociation {
if from, ok := d.GetOk(key); ok {
tags := from.([]any)
to := make([]sdk.TagAssociation, len(tags))
for i, t := range tags {
v := t.(map[string]any)
to[i] = sdk.TagAssociation{
Name: getTagObjectIdentifier(v),
Value: v["value"].(string),
}
}
return to
return getTagsFromList(from.([]any))
}
return nil
}

func getTagsFromList(tags []any) []sdk.TagAssociation {
to := make([]sdk.TagAssociation, len(tags))
for i, t := range tags {
v := t.(map[string]any)
to[i] = sdk.TagAssociation{
Name: getTagObjectIdentifier(v),
Value: v["value"].(string),
}
}
return to
}

func GetTagsDiff(d *schema.ResourceData, key string) (unsetTags []sdk.ObjectIdentifier, setTags []sdk.TagAssociation) {
o, n := d.GetChange(key)
removed, added, changed := getTags(o).diffs(getTags(n))

unsetTags = make([]sdk.ObjectIdentifier, len(removed))
for i, t := range removed {
unsetTags[i] = sdk.NewDatabaseObjectIdentifier(t.database, t.name)
unsetTags[i] = sdk.NewSchemaObjectIdentifier(t.database, t.schema, t.name)
}

setTags = make([]sdk.TagAssociation, len(added)+len(changed))
Expand Down
Loading

0 comments on commit dc66d02

Please sign in to comment.