From 8efeda726a2f3ae8605ceedbe53a05e48acf128d Mon Sep 17 00:00:00 2001 From: Dan Rammer Date: Tue, 27 Dec 2022 10:40:41 -0600 Subject: [PATCH] PodSpec NodeSelectors overwrite the default k8s plugin settings (#303) * PodSpec NodeSelectors overwrite the default k8s plugin settings Signed-off-by: Dan Rammer * added unit test Signed-off-by: Dan Rammer * better node selector value Signed-off-by: Dan Rammer * fixed issue with unionmaps where k8s plugin overrides labels Signed-off-by: Dan Rammer Signed-off-by: Dan Rammer --- go/tasks/pluginmachinery/flytek8s/pod_helper.go | 2 +- go/tasks/pluginmachinery/flytek8s/pod_helper_test.go | 2 +- go/tasks/pluginmachinery/flytek8s/testdata/config.yaml | 2 ++ go/tasks/plugins/array/k8s/subtask.go | 2 +- 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 54f9a1f01..e9736ff6a 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -85,7 +85,7 @@ func UpdatePod(taskExecutionMetadata pluginsCore.TaskExecutionMetadata, if len(podSpec.SchedulerName) == 0 { podSpec.SchedulerName = config.GetK8sPluginConfig().SchedulerName } - podSpec.NodeSelector = utils.UnionMaps(podSpec.NodeSelector, config.GetK8sPluginConfig().DefaultNodeSelector) + podSpec.NodeSelector = utils.UnionMaps(config.GetK8sPluginConfig().DefaultNodeSelector, podSpec.NodeSelector) if taskExecutionMetadata.IsInterruptible() { podSpec.NodeSelector = utils.UnionMaps(podSpec.NodeSelector, config.GetK8sPluginConfig().InterruptibleNodeSelector) } diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index 01c16051e..716807452 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -330,7 +330,7 @@ func toK8sPodInterruptible(t *testing.T) { assert.Len(t, p.Tolerations, 2) assert.Equal(t, "x/flyte", p.Tolerations[1].Key) assert.Equal(t, "interruptible", p.Tolerations[1].Value) - assert.Equal(t, 1, len(p.NodeSelector)) + assert.Equal(t, 2, len(p.NodeSelector)) assert.Equal(t, "true", p.NodeSelector["x/interruptible"]) assert.EqualValues( t, diff --git a/go/tasks/pluginmachinery/flytek8s/testdata/config.yaml b/go/tasks/pluginmachinery/flytek8s/testdata/config.yaml index 1441a6459..a34968682 100644 --- a/go/tasks/pluginmachinery/flytek8s/testdata/config.yaml +++ b/go/tasks/pluginmachinery/flytek8s/testdata/config.yaml @@ -43,6 +43,8 @@ plugins: - FLYTE_AWS_ENDPOINT: "http://minio.flyte:9000" - FLYTE_AWS_ACCESS_KEY_ID: minio - FLYTE_AWS_SECRET_ACCESS_KEY: miniostorage + default-node-selector: + user: 'default' default-pod-security-context: runAsUser: 1000 runAsGroup: 3000 diff --git a/go/tasks/plugins/array/k8s/subtask.go b/go/tasks/plugins/array/k8s/subtask.go index 67e9d2fe5..5cdc80596 100644 --- a/go/tasks/plugins/array/k8s/subtask.go +++ b/go/tasks/plugins/array/k8s/subtask.go @@ -62,7 +62,7 @@ func addMetadata(stCtx SubTaskExecutionContext, cfg *Config, k8sPluginCfg *confi pod.SetNamespace(namespace) pod.SetAnnotations(utils.UnionMaps(k8sPluginCfg.DefaultAnnotations, pod.GetAnnotations(), utils.CopyMap(taskExecutionMetadata.GetAnnotations()))) - pod.SetLabels(utils.UnionMaps(pod.GetLabels(), utils.CopyMap(taskExecutionMetadata.GetLabels()), k8sPluginCfg.DefaultLabels)) + pod.SetLabels(utils.UnionMaps(k8sPluginCfg.DefaultLabels, pod.GetLabels(), utils.CopyMap(taskExecutionMetadata.GetLabels()))) pod.SetName(taskExecutionMetadata.GetTaskExecutionID().GetGeneratedName()) if !cfg.RemoteClusterConfig.Enabled {