Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: task after dag support #1342

Merged
merged 21 commits into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
276 changes: 134 additions & 142 deletions pkg/resources/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
"github.com/pkg/errors"
"golang.org/x/exp/slices"
)

const (
Expand Down Expand Up @@ -75,9 +76,10 @@ var taskSchema = map[string]*schema.Schema{
Description: "Specifies a comment for the task.",
},
"after": {
Type: schema.TypeString,
Type: schema.TypeList,
Elem: &schema.Schema{Type: schema.TypeString},
Optional: true,
Description: "Specifies the predecessor task in the same database and schema of the current task. When a run of the predecessor task finishes successfully, it triggers this task (after a brief lag). (Conflict with schedule)",
Description: "Specifies one or more predecessor tasks for the current task. Use this option to create a DAG of tasks or add this task to an existing DAG. A DAG is a series of tasks that starts with a scheduled root task and is linked together by dependencies.",
ConflictsWith: []string{"schedule"},
},
"when": {
Expand Down Expand Up @@ -146,88 +148,6 @@ func difference(a, b map[string]interface{}) map[string]interface{} {
return diff
}

// getActiveRootTask tries to retrieve the root of current task or returns the current (standalone) task.
func getActiveRootTask(data *schema.ResourceData, meta interface{}) (*snowflake.TaskBuilder, error) {
log.Println("[DEBUG] retrieving root task")

db := meta.(*sql.DB)
database := data.Get("database").(string)
dbSchema := data.Get("schema").(string)
name := data.Get("name").(string)
after := data.Get("after").(string)

if name == "" {
return nil, nil
}

// always start from first predecessor
// or the current task when standalone
if after != "" {
name = after
}

for {
builder := snowflake.Task(name, database, dbSchema)
q := builder.Show()
row := snowflake.QueryRow(db, q)
task, err := snowflake.ScanTask(row)

if err != nil && name != data.Get("name").(string) {
return nil, errors.Wrapf(err, "failed to locate the root node of: %v", name)
}

currentName := task.GetPredecessorName()
if currentName == "" {
log.Printf("[DEBUG] found root task: %v", name)
// we only want to deal with suspending the root task when its enabled (started)
if task.IsEnabled() {
return snowflake.Task(name, database, dbSchema), nil
}
return nil, nil
}

name = currentName
}
}

// getActiveRootTaskAndSuspend retrieves the root task and suspends it.
func getActiveRootTaskAndSuspend(data *schema.ResourceData, meta interface{}) (*snowflake.TaskBuilder, error) {
db := meta.(*sql.DB)
name := data.Get("name").(string)

root, err := getActiveRootTask(data, meta)
if err != nil {
return nil, errors.Wrapf(err, "error retrieving root task %v", name)
}

if root != nil {
qr := root.Suspend()
err = snowflake.Exec(db, qr)
if err != nil {
return nil, errors.Wrapf(err, "error suspending root task %v", name)
}
}

return root, nil
}

func resumeTask(root *snowflake.TaskBuilder, meta interface{}) {
if root == nil {
return
}

if root.IsDisabled() {
return
}

db := meta.(*sql.DB)
qr := root.Resume()
err := snowflake.Exec(db, qr)
if err != nil {
log.Fatal(errors.Wrapf(err, "error resuming root task %v", root.QualifiedName()))
}
}

// taskIDFromString() takes in a pipe-delimited string: DatabaseName|SchemaName|TaskName
// and returns a taskID object.
func taskIDFromString(stringID string) (*taskID, error) {
Expand Down Expand Up @@ -360,12 +280,13 @@ func ReadTask(d *schema.ResourceData, meta interface{}) error {
return err
}

predecessorName := t.GetPredecessorName()
if predecessorName != "" {
err = d.Set("after", predecessorName)
if err != nil {
return err
}
predecessors, err := t.GetPredecessors()
if err != nil {
return err
}
err = d.Set("predecessors", predecessors)
if err != nil {
return err
}

err = d.Set("when", t.Condition)
Expand Down Expand Up @@ -479,13 +400,38 @@ func CreateTask(d *schema.ResourceData, meta interface{}) error {
}

if v, ok := d.GetOk("after"); ok {
root, err := getActiveRootTaskAndSuspend(d, meta)
a := v.([]interface{})
after := make([]string, len(a))
for i, v := range a {
after[i] = v.(string)
}
rootTasks, err := snowflake.GetRootTasks(name, database, dbSchema, db)
if err != nil {
return err
}
defer resumeTask(root, meta)
for _, rootTask := range rootTasks {
// if a root task is enabled, then it needs to be suspended before the child tasks can be created
if rootTask.IsEnabled() {
q := rootTask.Suspend()
err = snowflake.Exec(db, q)
if err != nil {
return err
}

// resume the task after modifications are complete as long as it is not a standalone task
if !(rootTask.Name == name){
sfc-gh-swinkler marked this conversation as resolved.
Show resolved Hide resolved
defer func() {
q = rootTask.Resume()
err = snowflake.Exec(db, q)
if err != nil {
log.Printf("[WARN] failed to resume task %s", rootTask.Name)
}
}()
}
}

builder.WithDependency(v.(string))
builder.WithAfter(after)
}
}

if v, ok := d.GetOk("when"); ok {
Expand All @@ -498,14 +444,6 @@ func CreateTask(d *schema.ResourceData, meta interface{}) error {
return errors.Wrapf(err, "error creating task %v", name)
}

if enabled {
q = builder.Resume()
err = snowflake.Exec(db, q)
if err != nil {
return errors.Wrapf(err, "error starting task %v", name)
}
}

taskID := &taskID{
DatabaseName: database,
SchemaName: dbSchema,
Expand All @@ -517,6 +455,14 @@ func CreateTask(d *schema.ResourceData, meta interface{}) error {
}
d.SetId(dataIDInput)

if enabled {
q = builder.Resume()
err = snowflake.Exec(db, q)
if err != nil {
return errors.Wrapf(err, "error starting task %v", name)
}
}

return ReadTask(d, meta)
}

Expand All @@ -533,10 +479,32 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error {
dbSchema := taskID.SchemaName
name := taskID.TaskName
builder := snowflake.Task(name, database, dbSchema)
root, err := getActiveRootTaskAndSuspend(d, meta)

rootTasks, err := snowflake.GetRootTasks(name, database, dbSchema, db)
if err != nil {
return err
}
for _, rootTask := range rootTasks {
// if a root task is enabled, then it needs to be suspended before the child tasks can be created
if rootTask.IsEnabled() {
q := rootTask.Suspend()
err = snowflake.Exec(db, q)
if err != nil {
return err
}

if !(rootTask.Name == name) {
// resume the task after modifications are complete, as long as it is not a standalone task
defer func() {
q = rootTask.Resume()
err = snowflake.Exec(db, q)
if err != nil {
log.Printf("[WARN] failed to resume task %s", rootTask.Name)
}
}()
}
}
}

if d.HasChange("warehouse") {
var q string
Expand Down Expand Up @@ -580,27 +548,53 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error {
}
}

// Need to remove dependency before adding schedule if needed
if d.HasChange("after") {
var (
q string
err error
)

old, _ := d.GetChange("after")

q = builder.Suspend()
// making changes to after require suspending the current task
q := builder.Suspend()
err = snowflake.Exec(db, q)
if err != nil {
return errors.Wrapf(err, "error suspending task %v", d.Id())
}
needResumeCurrentTask = d.Get("enabled").(bool)

if old != "" {
q = builder.RemoveDependency(old.(string))
err = snowflake.Exec(db, q)
old, new := d.GetChange("after")
var oldAfter []string
if len(old.([]interface{})) > 0 {
oldAfter = expandStringList(old.([]interface{}))
}

var newAfter []string
if len(new.([]interface{})) > 0 {
newAfter = expandStringList(new.([]interface{}))
}

// Remove old dependencies that are not in new dependencies
var toRemove []string
for _, dep := range oldAfter {
if !slices.Contains(newAfter, dep) {
toRemove = append(toRemove, dep)
}
}
if len(toRemove) > 0 {
q := builder.RemoveAfter(toRemove)
err := snowflake.Exec(db, q)
if err != nil {
return errors.Wrapf(err, "error removing after dependencies from task %v", d.Id())
}
}

// Add new dependencies that are not in old dependencies
var toAdd []string
for _, dep := range newAfter {
if !slices.Contains(oldAfter, dep) {
toAdd = append(toAdd, dep)
}
}
if len(toAdd) > 0 {
q := builder.AddAfter(toAdd)
err := snowflake.Exec(db, q)
if err != nil {
return errors.Wrapf(err, "error removing old after dependency from task %v", d.Id())
return errors.Wrapf(err, "error adding after dependencies to task %v", d.Id())
}
}
}
Expand Down Expand Up @@ -662,21 +656,6 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error {
}
}

if d.HasChange("after") {
var (
q string
)
new := d.Get("after")

if new != "" {
q = builder.AddDependency(new.(string))
err := snowflake.Exec(db, q)
if err != nil {
return errors.Wrapf(err, "error adding after dependency on task %v", d.Id())
}
}
}

if d.HasChange("session_parameters") {
var q string
o, n := d.GetChange("session_parameters")
Expand Down Expand Up @@ -737,12 +716,6 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error {
q = builder.Resume()
} else {
q = builder.Suspend()
// make sure defer doesn't enable task again
// when standalone or root task and status is suspended
needResumeCurrentTask = false
if root != nil && builder.QualifiedName() == root.QualifiedName() {
root = root.SetDisabled() //nolint
}
}

err := snowflake.Exec(db, q)
Expand All @@ -752,10 +725,13 @@ func UpdateTask(d *schema.ResourceData, meta interface{}) error {
}

if needResumeCurrentTask {
resumeTask(builder, meta)
q := builder.Resume()
err := snowflake.Exec(db, q)
if err != nil {
return errors.Wrapf(err, "error resuming task %v", d.Id())
}
}

resumeTask(root, meta)
return ReadTask(d, meta)
}

Expand All @@ -771,14 +747,30 @@ func DeleteTask(d *schema.ResourceData, meta interface{}) error {
schema := taskID.SchemaName
name := taskID.TaskName

root, err := getActiveRootTaskAndSuspend(d, meta)
rootTasks, err := snowflake.GetRootTasks(name, database, schema, db)
if err != nil {
return err
}
for _, rootTask := range rootTasks {
// if a root task is enabled, then it needs to be suspended before the child tasks can be created
if rootTask.IsEnabled() {
q := rootTask.Suspend()
err = snowflake.Exec(db, q)
if err != nil {
return err
}

// only resume the root when not a standalone task
if root != nil && name != root.Name() {
defer resumeTask(root, meta)
if !(rootTask.Name == name) {
// resume the task after modifications are complete, as long as it is not a standalone task
defer func() {
q = rootTask.Resume()
err = snowflake.Exec(db, q)
if err != nil {
log.Printf("[WARN] failed to resume task %s", rootTask.Name)
}
}()
}
}
}

q := snowflake.Task(name, database, schema).Drop()
Expand Down
Loading