Skip to content

Commit

Permalink
fix: sequence import (#775)
Browse files Browse the repository at this point in the history
  • Loading branch information
jtzero authored Mar 14, 2022
1 parent 3a17e34 commit e728d2e
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 26 deletions.
7 changes: 7 additions & 0 deletions docs/resources/sequence.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,11 @@ resource "snowflake_sequence" "test_sequence" {
- **fully_qualified_name** (String) The fully qualified name of the sequence.
- **next_value** (Number) The next value the sequence will provide.

## Import

Import is supported using the following syntax:

```shell
# format is database name | schema name | sequence name
terraform import snowflake_sequence.example 'dbName|schemaName|sequenceName'
```
2 changes: 2 additions & 0 deletions examples/resources/snowflake_sequence/import.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# format is database name | schema name | sequence name
terraform import snowflake_sequence.example 'dbName|schemaName|sequenceName'
141 changes: 119 additions & 22 deletions pkg/resources/sequence.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
package resources

import (
"bytes"
"database/sql"
"encoding/csv"
"fmt"
"log"
"strconv"
"strings"

"github.com/chanzuckerberg/terraform-provider-snowflake/pkg/snowflake"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/pkg/errors"
)

const (
sequenceIDDelimiter = '|'
)

var sequenceSchema = map[string]*schema.Schema{
"name": {
Type: schema.TypeString,
Required: true,
Description: "Specifies the name for the sequence.",
ForceNew: true,
},
"comment": {
Type: schema.TypeString,
Expand All @@ -33,11 +41,13 @@ var sequenceSchema = map[string]*schema.Schema{
Type: schema.TypeString,
Required: true,
Description: "The database in which to create the sequence. Don't use the | character.",
ForceNew: true,
},
"schema": {
Type: schema.TypeString,
Required: true,
Description: "The schema in which to create the sequence. Don't use the | character.",
ForceNew: true,
},
"next_value": {
Type: schema.TypeInt,
Expand All @@ -53,6 +63,27 @@ var sequenceSchema = map[string]*schema.Schema{

var sequenceProperties = []string{"comment", "data_retention_time_in_days"}

type sequenceID struct {
DatabaseName string
SchemaName string
SequenceName string
}

//String() takes in a sequenceID object and returns a pipe-delimited string:
//DatabaseName|SchemaName|SequenceName
func (si *sequenceID) String() (string, error) {
var buf bytes.Buffer
csvWriter := csv.NewWriter(&buf)
csvWriter.Comma = pipeIDDelimiter
dataIdentifiers := [][]string{{si.DatabaseName, si.SchemaName, si.SequenceName}}
err := csvWriter.WriteAll(dataIdentifiers)
if err != nil {
return "", err
}
strSequenceID := strings.TrimSpace(buf.String())
return strSequenceID, nil
}

// Sequence returns a pointer to the resource representing a sequence
func Sequence() *schema.Resource {
return &schema.Resource{
Expand Down Expand Up @@ -90,15 +121,32 @@ func CreateSequence(d *schema.ResourceData, meta interface{}) error {
return errors.Wrapf(err, "error creating sequence")
}

sequenceID := &sequenceID{
DatabaseName: database,
SchemaName: schema,
SequenceName: name,
}

dataIDInput, err := sequenceID.String()
if err != nil {
return err
}
d.SetId(dataIDInput)

return ReadSequence(d, meta)
}

// ReadSequence implements schema.ReadFunc
func ReadSequence(d *schema.ResourceData, meta interface{}) error {
db := meta.(*sql.DB)
database := d.Get("database").(string)
schema := d.Get("schema").(string)
name := d.Get("name").(string)
sequenceID, err := sequenceIDFromString(d.Id())
if err != nil {
return err
}

database := sequenceID.DatabaseName
schema := sequenceID.SchemaName
name := sequenceID.SequenceName

seq := snowflake.Sequence(name, database, schema)
stmt := seq.Show()
Expand All @@ -116,6 +164,11 @@ func ReadSequence(d *schema.ResourceData, meta interface{}) error {
return errors.Wrap(err, "unable to scan row for SHOW SEQUENCES")
}

err = d.Set("name", sequence.Name.String)
if err != nil {
return err
}

err = d.Set("schema", sequence.SchemaName.String)
if err != nil {
return err
Expand All @@ -141,12 +194,12 @@ func ReadSequence(d *schema.ResourceData, meta interface{}) error {
return err
}

i, err = strconv.ParseInt(sequence.NextValue.String, 10, 64)
n, err := strconv.ParseInt(sequence.NextValue.String, 10, 64)
if err != nil {
return err
}

err = d.Set("next_value", i)
err = d.Set("next_value", n)
if err != nil {
return err
}
Expand All @@ -156,24 +209,26 @@ func ReadSequence(d *schema.ResourceData, meta interface{}) error {
return err
}

d.SetId(fmt.Sprintf(`%v|%v|%v`, sequence.DBName.String, sequence.SchemaName.String, sequence.Name.String))
if err != nil {
return err
}

return err
return nil
}

func UpdateSequence(d *schema.ResourceData, meta interface{}) error {
db := meta.(*sql.DB)
database := d.Get("database").(string)
schema := d.Get("schema").(string)
name := d.Get("name").(string)
next := d.Get("next_value").(int)
sequenceID, err := sequenceIDFromString(d.Id())
if err != nil {
return err
}

DeleteSequence(d, meta)
database := sequenceID.DatabaseName
schema := sequenceID.SchemaName
name := sequenceID.SequenceName

sq := snowflake.Sequence(name, database, schema)
stmt := sq.Show()
row := snowflake.QueryRow(db, stmt)

sequence, err := snowflake.ScanSequence(row)
DeleteSequence(d, meta)

if i, ok := d.GetOk("increment"); ok {
sq.WithIncrement(i.(int))
Expand All @@ -183,9 +238,19 @@ func UpdateSequence(d *schema.ResourceData, meta interface{}) error {
sq.WithComment(v.(string))
}

sq.WithStart(next)
nextValue, err := strconv.Atoi(sequence.NextValue.String)
if err != nil {
return err
}

err := snowflake.Exec(db, sq.Create())
err = d.Set("next_value", nextValue)
if err != nil {
return err
}

sq.WithStart(nextValue)

err = snowflake.Exec(db, sq.Create())
if err != nil {
return errors.Wrapf(err, "error creating sequence")
}
Expand All @@ -195,17 +260,49 @@ func UpdateSequence(d *schema.ResourceData, meta interface{}) error {

func DeleteSequence(d *schema.ResourceData, meta interface{}) error {
db := meta.(*sql.DB)
database := d.Get("database").(string)
schema := d.Get("schema").(string)
name := d.Get("name").(string)
sequenceID, err := sequenceIDFromString(d.Id())
if err != nil {
return err
}

database := sequenceID.DatabaseName
schema := sequenceID.SchemaName
name := sequenceID.SequenceName

stmt := snowflake.Sequence(name, database, schema).Drop()

err := snowflake.Exec(db, stmt)
err = snowflake.Exec(db, stmt)
if err != nil {
return errors.Wrapf(err, "error dropping sequence %s", name)
}

d.SetId("")
return nil
}

// sequenceIDFromString() takes in a pipe-delimited string: DatabaseName|SchemaName|PipeName
// and returns a sequenceID object
func sequenceIDFromString(stringID string) (*sequenceID, error) {
reader := csv.NewReader(strings.NewReader(stringID))
reader.Comma = sequenceIDDelimiter
lines, err := reader.ReadAll()
if err != nil {
return nil, fmt.Errorf("Not CSV compatible")
}

if len(lines) != 1 {
return nil, fmt.Errorf("1 line per sequence")
}

if len(lines[0]) != 3 {
return nil, fmt.Errorf("3 fields allowed")
}

sequenceResult := &sequenceID{
DatabaseName: lines[0][0],
SchemaName: lines[0][1],
SequenceName: lines[0][2],
}

return sequenceResult, nil
}
6 changes: 6 additions & 0 deletions pkg/resources/sequence_acceptance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ func TestAcc_Sequence(t *testing.T) {
resource.TestCheckResourceAttr("snowflake_sequence.test_sequence", "fully_qualified_name", fmt.Sprintf(`%v.%v.%v`, accName, accName, accName)),
),
},
// IMPORT
{
ResourceName: "snowflake_sequence.test_sequence",
ImportState: true,
ImportStateVerify: true,
},
},
})
}
Expand Down
5 changes: 2 additions & 3 deletions pkg/resources/sequence_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestSequenceRead(t *testing.T) {
"database": "database",
}

d := sequence(t, "good_name", in)
d := sequence(t, "database|schema|good_name", in)

WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{
Expand Down Expand Up @@ -97,7 +97,6 @@ func TestSequenceRead(t *testing.T) {
r.Equal("database", d.Get("database").(string))
r.Equal("mock comment", d.Get("comment").(string))
r.Equal(25, d.Get("increment").(int))
r.Equal(5, d.Get("next_value").(int))
r.Equal("database|schema|good_name", d.Id())
r.Equal(`"database"."schema"."good_name"`, d.Get("fully_qualified_name").(string))
})
Expand All @@ -111,7 +110,7 @@ func TestSequenceDelete(t *testing.T) {
"database": "database",
}

d := sequence(t, "drop_it", in)
d := sequence(t, "database|schema|drop_it", in)

WithMockDb(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectExec(`DROP SEQUENCE "database"."schema"."drop_it"`).WillReturnResult(sqlmock.NewResult(1, 1))
Expand Down
2 changes: 1 addition & 1 deletion pkg/snowflake/sequence.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (sb *SequenceBuilder) Drop() string {
return fmt.Sprintf(`DROP SEQUENCE %v`, sb.QualifiedName())
}

// Drop returns the SQL query that will drop a sequence.
// Show returns the SQL query that will show a sequence.
func (sb *SequenceBuilder) Show() string {
return fmt.Sprintf(`SHOW SEQUENCES LIKE '%v' IN SCHEMA "%v"."%v"`, sb.name, sb.db, sb.schema)
}
Expand Down

0 comments on commit e728d2e

Please sign in to comment.