Skip to content

Commit

Permalink
fix: race condition with parallel + enable_sharing
Browse files Browse the repository at this point in the history
  • Loading branch information
snakster committed Dec 30, 2024
1 parent 78174eb commit 8bb52e6
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 16 deletions.
36 changes: 20 additions & 16 deletions cmd/terramate/cli/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/terramate-io/terramate/hcl/ast"
"github.com/terramate-io/terramate/printer"
prj "github.com/terramate-io/terramate/project"
"github.com/terramate-io/terramate/run"
runutil "github.com/terramate-io/terramate/run"
"github.com/terramate-io/terramate/run/dag"
"github.com/terramate-io/terramate/scheduler"
Expand Down Expand Up @@ -380,9 +381,7 @@ func (c *cli) runAll(
}()

// map of stackName -> map of backendName -> outputs
type stackOutputs map[prj.Path]map[string]cty.Value

allOutputs := stackOutputs{}
allOutputs := run.NewOnceMap[string, *run.OnceMap[string, cty.Value]]()

err = sched.Run(func(run stackRun) error {
errs := errors.L()
Expand Down Expand Up @@ -476,12 +475,12 @@ func (c *cli) runAll(
}
break tasksLoop
}
_, ok = allOutputs[otherStack.Dir]
if !ok {
allOutputs[otherStack.Dir] = make(map[string]cty.Value)
}
_, ok = allOutputs[otherStack.Dir][backend.Name]
if !ok {

stackOutputs, _ := allOutputs.GetOrInit(otherStack.Dir.String(), func() (*run.OnceMap[string, cty.Value], error) {

Check failure on line 479 in cmd/terramate/cli/run.go

View workflow job for this annotation

GitHub Actions / Release Dry Run

run.OnceMap is not a type
return run.NewOnceMap[string, cty.Value](), nil

Check failure on line 480 in cmd/terramate/cli/run.go

View workflow job for this annotation

GitHub Actions / Release Dry Run

run.NewOnceMap undefined (type stackRun has no field or method NewOnceMap)
})

outputsVal, err := stackOutputs.GetOrInit(backend.Name, func() (cty.Value, error) {
var stdout bytes.Buffer
var stderr bytes.Buffer
cmd := exec.Command(backend.Command[0], backend.Command[1:]...)
Expand All @@ -492,14 +491,15 @@ func (c *cli) runAll(
err := cmd.Run()
if err != nil {
if !task.MockOnFail {
errs.Append(errors.E(err, "failed to execute: (cmd: %s) (stdout: %s) (stderr: %s)", cmd.String(), stdout.String(), stderr.String()))
err := errors.E(err, "failed to execute: (cmd: %s) (stdout: %s) (stderr: %s)", cmd.String(), stdout.String(), stderr.String())
errs.Append(err)
c.cloudSyncAfter(cloudRun, runResult{ExitCode: -1}, errors.E(ErrRunCommandNotExecuted, err))
releaseResource()
failedTaskIndex = taskIndex
if !continueOnError {
cancel()
}
break tasksLoop
return cty.Value{}, err
}

printer.Stderr.WarnWithDetails(
Expand All @@ -518,7 +518,8 @@ func (c *cli) runAll(
if !continueOnError {
cancel()
}
break tasksLoop
return cty.Value{}, err

}
inputVal, err = json.Unmarshal(stdoutBytes, typ)
if err != nil {
Expand All @@ -530,13 +531,16 @@ func (c *cli) runAll(
if !continueOnError {
cancel()
}
break tasksLoop
return cty.Value{}, err
}
}
allOutputs[otherStack.Dir][backend.Name] = inputVal
return inputVal, nil
})
if err != nil {
break tasksLoop
}
stackOutputs := allOutputs[otherStack.Dir][backend.Name]
evalctx.SetNamespaceRaw("outputs", stackOutputs)

evalctx.SetNamespaceRaw("outputs", outputsVal)
inputVal, inputErr := input.Value(evalctx)
mockVal, mockFound, mockErr := input.Mock(evalctx)

Expand Down
38 changes: 38 additions & 0 deletions run/oncemap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright 2023 Terramate GmbH
// SPDX-License-Identifier: MPL-2.0

package run

import "sync"

type OnceMap[K ~string, V any] struct {
mtx sync.RWMutex
data map[K]V
}

func NewOnceMap[K ~string, V any]() *OnceMap[K, V] {
return &OnceMap[K, V]{data: make(map[K]V)}
}

func (m *OnceMap[K, V]) GetOrInit(k K, init func() (V, error)) (V, error) {
m.mtx.RLock()
v, found := m.data[k]
m.mtx.RUnlock()

if !found {
m.mtx.Lock()
v, found = m.data[k]
if !found {
var err error
v, err = init()
if err != nil {
m.mtx.Unlock()
return v, err
}
m.data[k] = v
}
m.mtx.Unlock()
}

return v, nil
}
120 changes: 120 additions & 0 deletions run/oncemap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright 2023 Terramate GmbH
// SPDX-License-Identifier: MPL-2.0

package run_test

import (
"errors"
"fmt"
"strconv"
"sync"
"sync/atomic"
"testing"

"github.com/madlambda/spells/assert"
"github.com/terramate-io/terramate/run"
)

func TestOnceMap(t *testing.T) {
t.Parallel()

t.Run("ok", func(t *testing.T) {
t.Parallel()

outerCount := 0
innerCount := 0

data := run.NewOnceMap[string, *run.OnceMap[string, string]]()

for outer := 0; outer < 10; outer++ {
k1 := strconv.Itoa(outer)
v1, err1 := data.GetOrInit(k1, func() (*run.OnceMap[string, string], error) {
outerCount++
return run.NewOnceMap[string, string](), nil
})
assert.NoError(t, err1)

for inner := 0; inner < 10; inner++ {
k2 := strconv.Itoa(inner)
v2, err2 := v1.GetOrInit(k2, func() (string, error) {
innerCount++
return fmt.Sprintf("%d_%d", outer, inner), nil
})
assert.NoError(t, err2)
assert.EqualStrings(t, fmt.Sprintf("%d_%d", outer, inner), v2)
}
}

assert.EqualInts(t, 10, outerCount, "outer count")
assert.EqualInts(t, 100, innerCount, "inner count")
})

t.Run("error", func(t *testing.T) {
t.Parallel()

count := 0

m := run.NewOnceMap[string, string]()

_, err := m.GetOrInit("k", func() (string, error) {
count++
return "", errors.New("failed")
})

assert.EqualErrs(t, err, errors.New("failed"))

v, err := m.GetOrInit("k", func() (string, error) {
count++
return "success", nil
})

assert.EqualStrings(t, "success", v)
assert.NoError(t, err)
})

t.Run("concurrent", func(t *testing.T) {
t.Parallel()

var outerCount atomic.Int32
var innerCount atomic.Int32

data := run.NewOnceMap[string, *run.OnceMap[string, string]]()

var wg sync.WaitGroup

for outer := 0; outer < 10; outer++ {
outer := outer

wg.Add(1 + 10)

go func() {
defer wg.Done()

k1 := strconv.Itoa(outer)
v1, _ := data.GetOrInit(k1, func() (*run.OnceMap[string, string], error) {
outerCount.Add(1)
return run.NewOnceMap[string, string](), nil
})

for inner := 0; inner < 10; inner++ {
inner := inner

go func() {
defer wg.Done()

k2 := strconv.Itoa(inner)
_, _ = v1.GetOrInit(k2, func() (string, error) {
innerCount.Add(1)
return "blah", nil
})
}()
}
}()
}

wg.Wait()

assert.EqualInts(t, 10, int(outerCount.Load()), "outer count")
assert.EqualInts(t, 100, int(innerCount.Load()), "inner inner")
})
}

0 comments on commit 8bb52e6

Please sign in to comment.