From 4266933c4f65aed83871785927efe1eb1a0edd31 Mon Sep 17 00:00:00 2001 From: Alex Collins Date: Tue, 6 Sep 2022 10:29:56 -0700 Subject: [PATCH] feat: Add new lambda expr. Fixes #9529 Signed-off-by: Alex Collins --- util/mapper/map.go | 56 ++++++++++++++++++++++++++++++ util/mapper/map_test.go | 53 ++++++++++++++++++++++++++++ util/template/eval.go | 35 +++++++++++++++++++ util/template/eval_test.go | 13 +++++++ workflow/controller/workflowpod.go | 11 ++++++ 5 files changed, 168 insertions(+) create mode 100644 util/mapper/map.go create mode 100644 util/mapper/map_test.go create mode 100644 util/template/eval.go create mode 100644 util/template/eval_test.go diff --git a/util/mapper/map.go b/util/mapper/map.go new file mode 100644 index 000000000000..72dff354f5d3 --- /dev/null +++ b/util/mapper/map.go @@ -0,0 +1,56 @@ +package mapper + +import ( + "reflect" +) + +// Map translate (maps) any object by recursively applying the mapper func to each field, array element, and map value. +// Among intended use cases is translating a data structure (e.g. from English to Spanish). +func Map(x any, m func(any) (any, error)) (any, error) { + value, err := _map(reflect.ValueOf(x), m) + return value.Interface(), err +} + +func _map(x reflect.Value, m func(any) (any, error)) (reflect.Value, error) { + if x.IsZero() { + return x, nil + } + switch x.Kind() { + case reflect.Ptr: + y, err := _map(x.Elem(), m) + return y.Addr(), err + case reflect.Struct: + y := reflect.Indirect(reflect.New(x.Type())) + for i := 0; i < x.NumField(); i++ { + g, err := _map(x.Field(i), m) + if err != nil { + return y, err + } + y.Field(i).Set(g) + } + return y, nil + case reflect.Array, reflect.Slice: + y := reflect.Indirect(reflect.MakeSlice(x.Type(), x.Len(), x.Len())) + for i := 0; i < x.Len(); i++ { + g, err := _map(x.Index(i), m) + if err != nil { + return y, err + } + y.Index(i).Set(g) + } + return y, nil + case reflect.Map: + y := reflect.Indirect(reflect.MakeMap(x.Type())) + for _, key := range x.MapKeys() { + g, err := _map(x.MapIndex(key), m) + if err != nil { + return y, err + } + y.SetMapIndex(key, g) + } + return y, nil + default: + y, err := m(x.Interface()) + return reflect.ValueOf(y), err + } +} diff --git a/util/mapper/map_test.go b/util/mapper/map_test.go new file mode 100644 index 000000000000..e8c05ad050f8 --- /dev/null +++ b/util/mapper/map_test.go @@ -0,0 +1,53 @@ +package mapper + +import ( + wfv1 "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" + "github.com/stretchr/testify/assert" + "testing" +) + +type s struct { + String string +} + +func TestVisit(t *testing.T) { + v := func(x any) (any, error) { + s, ok := x.(string) + if ok && s == "foo" { + return "bar", nil + } + return x, nil + } + t.Run("string", func(t *testing.T) { + x, err := Map("foo", v) + assert.NoError(t, err) + assert.Equal(t, "bar", x) + }) + t.Run("Struct", func(t *testing.T) { + x, err := Map(s{String: "foo"}, v) + assert.NoError(t, err) + assert.Equal(t, "bar", x.(s).String) + }) + t.Run("array", func(t *testing.T) { + x, err := Map([]string{"foo"}, v) + assert.NoError(t, err) + assert.Equal(t, []string{"bar"}, x) + }) + t.Run("map", func(t *testing.T) { + x, err := Map(map[string]string{"x": "foo"}, v) + assert.NoError(t, err) + assert.Equal(t, map[string]string{"x": "bar"}, x) + }) + t.Run("WorkflowSpec", func(t *testing.T) { + y, err := Map(wfv1.WorkflowSpec{}, v) + assert.NoError(t, err) + assert.Equal(t, wfv1.WorkflowSpec{}, y) + }) + t.Run("*WorkflowSpec", func(t *testing.T) { + y, err := Map(&wfv1.WorkflowSpec{ + Entrypoint: "foo", + }, v) + assert.NoError(t, err) + assert.Equal(t, &wfv1.WorkflowSpec{Entrypoint: "bar"}, y) + }) +} diff --git a/util/template/eval.go b/util/template/eval.go new file mode 100644 index 000000000000..683e8ecfab3c --- /dev/null +++ b/util/template/eval.go @@ -0,0 +1,35 @@ +package template + +import ( + "fmt" + "github.com/antonmedv/expr" + "github.com/argoproj/argo-workflows/v3/util/mapper" + "strings" +) + +func Eval(x any, env any) (any, error) { + return mapper.Map(x, func(g any) (any, error) { + s, ok := g.(string) + if ok { + return eval(s, env) + } + return g, nil + }) +} + +func eval(s string, env any) (string, error) { + const prefix = "ƛ" + if !strings.HasPrefix(s, prefix) { + return s, nil + } + input := strings.TrimPrefix(s, prefix) + output, err := expr.Eval(input, env) + if err != nil { + return "", fmt.Errorf("failed to evaluate %s: %w", s, err) + } + result, ok := output.(string) + if !ok { + return "", fmt.Errorf("failed to evaluate %s: %w", s, fmt.Errorf("expected result to be a string, but got %T", output)) + } + return result, nil +} diff --git a/util/template/eval_test.go b/util/template/eval_test.go new file mode 100644 index 000000000000..3f002c584ff2 --- /dev/null +++ b/util/template/eval_test.go @@ -0,0 +1,13 @@ +package template + +import ( + wfv1 "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestEval(t *testing.T) { + y, err := Eval(&wfv1.WorkflowSpec{Entrypoint: `ƛx == "foo" ? "bar": "x"`}, map[string]string{"x": "foo"}) + assert.NoError(t, err) + assert.Equal(t, &wfv1.WorkflowSpec{Entrypoint: "bar"}, y) +} diff --git a/workflow/controller/workflowpod.go b/workflow/controller/workflowpod.go index fed3b697cffb..30067f174eb4 100644 --- a/workflow/controller/workflowpod.go +++ b/workflow/controller/workflowpod.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + exprenv "github.com/argoproj/argo-workflows/v3/util/expr/env" "path/filepath" "strconv" "time" @@ -428,6 +429,8 @@ func (woc *wfOperationCtx) createWorkflowPod(ctx context.Context, nodeName strin return nil, ErrResourceRateLimitReached } + pod = woc.mustEval(pod).(*apiv1.Pod) + woc.log.Debugf("Creating Pod: %s (%s)", nodeName, pod.Name) created, err := woc.controller.kubeclientset.CoreV1().Pods(woc.wf.ObjectMeta.Namespace).Create(ctx, pod, metav1.CreateOptions{}) @@ -449,6 +452,14 @@ func (woc *wfOperationCtx) createWorkflowPod(ctx context.Context, nodeName strin return created, nil } +func (woc *wfOperationCtx) mustEval(x any) any { + y, err := template.Eval(x, exprenv.GetFuncMap(template.EnvMap(woc.globalParams))) + if err != nil { + panic(err) + } + return y +} + func (woc *wfOperationCtx) podExists(nodeID string) (existing *apiv1.Pod, exists bool, err error) { objs, err := woc.controller.podInformer.GetIndexer().ByIndex(indexes.NodeIDIndex, woc.wf.Namespace+"/"+nodeID) if err != nil {