diff --git a/contextutils/context.go b/contextutils/context.go index e5be11e2..fd0b5fac 100644 --- a/contextutils/context.go +++ b/contextutils/context.go @@ -22,6 +22,7 @@ const ( JobIDKey Key = "job_id" PhaseKey Key = "phase" RoutineLabelKey Key = "routine" + LaunchPlanIDKey Key = "lp" ) func (k Key) String() string { @@ -38,6 +39,7 @@ var logKeys = []Key{ TaskTypeKey, PhaseKey, RoutineLabelKey, + LaunchPlanIDKey, } // Gets a new context with namespace set. @@ -85,6 +87,11 @@ func WithWorkflowID(ctx context.Context, workflow string) context.Context { return context.WithValue(ctx, WorkflowIDKey, workflow) } +// Gets a new context with a launch plan ID set. +func WithLaunchPlanID(ctx context.Context, launchPlan string) context.Context { + return context.WithValue(ctx, LaunchPlanIDKey, launchPlan) +} + // Get new context with Project and Domain values set func WithProjectDomain(ctx context.Context, project, domain string) context.Context { c := context.WithValue(ctx, ProjectKey, project) diff --git a/contextutils/context_test.go b/contextutils/context_test.go index d65c2a2b..e2effe3a 100644 --- a/contextutils/context_test.go +++ b/contextutils/context_test.go @@ -69,6 +69,13 @@ func TestWithWorkflowID(t *testing.T) { assert.Equal(t, "flyte", ctx.Value(WorkflowIDKey)) } +func TestWithLaunchPlanID(t *testing.T) { + ctx := context.Background() + assert.Nil(t, ctx.Value(LaunchPlanIDKey)) + ctx = WithLaunchPlanID(ctx, "flytelp") + assert.Equal(t, "flytelp", ctx.Value(LaunchPlanIDKey)) +} + func TestWithNodeID(t *testing.T) { ctx := context.Background() assert.Nil(t, ctx.Value(NodeIDKey)) diff --git a/promutils/labeled/counter_test.go b/promutils/labeled/counter_test.go index 130b8217..e427026b 100644 --- a/promutils/labeled/counter_test.go +++ b/promutils/labeled/counter_test.go @@ -10,8 +10,9 @@ import ( ) func TestLabeledCounter(t *testing.T) { + UnsetMetricKeys() assert.NotPanics(t, func() { - SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) + SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey, contextutils.LaunchPlanIDKey) }) scope := promutils.NewTestScope() @@ -28,4 +29,8 @@ func TestLabeledCounter(t *testing.T) { ctx = contextutils.WithTaskID(ctx, "task") c.Inc(ctx) c.Add(ctx, 1.0) + + ctx = contextutils.WithLaunchPlanID(ctx, "lp") + c.Inc(ctx) + c.Add(ctx, 1.0) } diff --git a/promutils/labeled/keys.go b/promutils/labeled/keys.go index d8c86837..7727a0df 100644 --- a/promutils/labeled/keys.go +++ b/promutils/labeled/keys.go @@ -15,11 +15,11 @@ var ( // Metric Keys to label metrics with. These keys get pulled from context if they are present. Use contextutils to fill // them in. - metricKeys = make([]contextutils.Key, 0) + metricKeys []contextutils.Key // :(, we have to create a separate list to satisfy the MustNewCounterVec API as it accepts string only - metricStringKeys = make([]string, 0) - metricKeysAreSet = sync.Once{} + metricStringKeys []string + metricKeysAreSet sync.Once ) // Sets keys to use with labeled metrics. The values of these keys will be pulled from context at runtime. @@ -45,3 +45,14 @@ func SetMetricKeys(keys ...contextutils.Key) { func GetUnlabeledMetricName(metricName string) string { return metricName + "_unlabeled" } + +// Warning: This function is not thread safe and should be used for testing only outside of this package. +func UnsetMetricKeys() { + metricKeys = make([]contextutils.Key, 0) + metricStringKeys = make([]string, 0) + metricKeysAreSet = sync.Once{} +} + +func init() { + UnsetMetricKeys() +} diff --git a/promutils/labeled/keys_test.go b/promutils/labeled/keys_test.go index 4a8600ae..6699ab2a 100644 --- a/promutils/labeled/keys_test.go +++ b/promutils/labeled/keys_test.go @@ -8,8 +8,9 @@ import ( ) func TestMetricKeys(t *testing.T) { + UnsetMetricKeys() input := []contextutils.Key{ - contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey, + contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey, contextutils.LaunchPlanIDKey, } assert.NotPanics(t, func() { SetMetricKeys(input...) }) diff --git a/promutils/labeled/stopwatch_test.go b/promutils/labeled/stopwatch_test.go index d5adf0ea..1d8a69d5 100644 --- a/promutils/labeled/stopwatch_test.go +++ b/promutils/labeled/stopwatch_test.go @@ -11,6 +11,7 @@ import ( ) func TestLabeledStopWatch(t *testing.T) { + UnsetMetricKeys() assert.NotPanics(t, func() { SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) }) diff --git a/storage/cached_rawstore_test.go b/storage/cached_rawstore_test.go index 316f999b..c5225aa7 100644 --- a/storage/cached_rawstore_test.go +++ b/storage/cached_rawstore_test.go @@ -18,11 +18,8 @@ import ( "github.com/stretchr/testify/assert" ) -func init() { - labeled.SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) -} - func TestNewCachedStore(t *testing.T) { + resetMetricKeys() t.Run("CachingDisabled", func(t *testing.T) { testScope := promutils.NewTestScope() @@ -50,6 +47,11 @@ func TestNewCachedStore(t *testing.T) { }) } +func resetMetricKeys() { + labeled.UnsetMetricKeys() + labeled.SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) +} + func dummyCacheStore(t *testing.T, store RawStore, scope promutils.Scope) *cachedRawStore { cfg := &Config{ Cache: CachingConfig{ @@ -86,6 +88,7 @@ func (d *dummyStore) WriteRaw(ctx context.Context, reference DataReference, size } func TestCachedRawStore(t *testing.T) { + resetMetricKeys() ctx := context.TODO() k1 := DataReference("k1") k2 := DataReference("k2") diff --git a/utils/marshal_utils_test.go b/utils/marshal_utils_test.go index 4ac0fc13..4295f5a1 100644 --- a/utils/marshal_utils_test.go +++ b/utils/marshal_utils_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes/struct" + structpb "github.com/golang/protobuf/ptypes/struct" ) type SimpleType struct {