Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Pass K8sPluginConfig to spark driver and executor pods #patch #271

Merged
merged 13 commits into from
Nov 1, 2022
31 changes: 19 additions & 12 deletions go/tasks/pluginmachinery/flytek8s/pod_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ const OOMKilled = "OOMKilled"
const Interrupted = "Interrupted"
const SIGKILL = 137

// ApplyInterruptibleNodeAffinity configures the node-affinity for the pod using the configuration specified.
func ApplyInterruptibleNodeAffinity(interruptible bool, podSpec *v1.PodSpec) {
// ApplyInterruptibleNodeSelectorRequirement configures the node selector requirement of the node-affinity using the configuration specified.
func ApplyInterruptibleNodeSelectorRequirement(interruptible bool, affinity *v1.Affinity) {
// Determine node selector terms to add to node affinity
var nodeSelectorRequirement v1.NodeSelectorRequirement
if interruptible {
Expand All @@ -40,24 +40,31 @@ func ApplyInterruptibleNodeAffinity(interruptible bool, podSpec *v1.PodSpec) {
nodeSelectorRequirement = *config.GetK8sPluginConfig().NonInterruptibleNodeSelectorRequirement
}

if podSpec.Affinity == nil {
podSpec.Affinity = &v1.Affinity{}
if affinity.NodeAffinity == nil {
affinity.NodeAffinity = &v1.NodeAffinity{}
}
if podSpec.Affinity.NodeAffinity == nil {
podSpec.Affinity.NodeAffinity = &v1.NodeAffinity{}
if affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution == nil {
affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution = &v1.NodeSelector{}
}
if podSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution == nil {
podSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution = &v1.NodeSelector{}
}
if len(podSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms) > 0 {
nodeSelectorTerms := podSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms
if len(affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms) > 0 {
nodeSelectorTerms := affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms
for i := range nodeSelectorTerms {
nst := &nodeSelectorTerms[i]
nst.MatchExpressions = append(nst.MatchExpressions, nodeSelectorRequirement)
}
} else {
podSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms = []v1.NodeSelectorTerm{v1.NodeSelectorTerm{MatchExpressions: []v1.NodeSelectorRequirement{nodeSelectorRequirement}}}
affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms = []v1.NodeSelectorTerm{v1.NodeSelectorTerm{MatchExpressions: []v1.NodeSelectorRequirement{nodeSelectorRequirement}}}
}

}

// ApplyInterruptibleNodeAffinity configures the node-affinity for the pod using the configuration specified.
func ApplyInterruptibleNodeAffinity(interruptible bool, podSpec *v1.PodSpec) {
if podSpec.Affinity == nil {
podSpec.Affinity = &v1.Affinity{}
}

ApplyInterruptibleNodeSelectorRequirement(interruptible, podSpec.Affinity)
}

// UpdatePod updates the base pod spec used to execute tasks. This is configured with plugins and task metadata-specific options
Expand Down
20 changes: 18 additions & 2 deletions go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
for _, envVar := range envVars {
sparkEnvVars[envVar.Name] = envVar.Value
}

sparkEnvVars["FLYTE_MAX_ATTEMPTS"] = strconv.Itoa(int(taskCtx.TaskExecutionMetadata().GetMaxAttempts()))

serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())
Expand All @@ -99,24 +100,34 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
}
driverSpec := sparkOp.DriverSpec{
SparkPodSpec: sparkOp.SparkPodSpec{
Affinity: config.GetK8sPluginConfig().DefaultAffinity,
Annotations: annotations,
Labels: labels,
EnvVars: sparkEnvVars,
Image: &container.Image,
SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(),
DNSConfig: config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy(),
Tolerations: config.GetK8sPluginConfig().DefaultTolerations,
SchedulerName: &config.GetK8sPluginConfig().SchedulerName,
NodeSelector: config.GetK8sPluginConfig().DefaultNodeSelector,
HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod,
},
ServiceAccount: &serviceAccountName,
}

executorSpec := sparkOp.ExecutorSpec{
SparkPodSpec: sparkOp.SparkPodSpec{
Affinity: config.GetK8sPluginConfig().DefaultAffinity.DeepCopy(),
Annotations: annotations,
Labels: labels,
Image: &container.Image,
EnvVars: sparkEnvVars,
SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(),
DNSConfig: config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy(),
Tolerations: config.GetK8sPluginConfig().DefaultTolerations,
SchedulerName: &config.GetK8sPluginConfig().SchedulerName,
NodeSelector: config.GetK8sPluginConfig().DefaultNodeSelector,
HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod,
},
}

Expand Down Expand Up @@ -225,11 +236,16 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
j.Spec.MainClass = &sparkJob.MainClass
}

// Add Tolerations/NodeSelector to only Executor pods.
// Add Interruptible Tolerations/NodeSelector to only Executor pods.
// The Interruptible NodeSelector takes precedence over the DefaultNodeSelector
if taskCtx.TaskExecutionMetadata().IsInterruptible() {
j.Spec.Executor.Tolerations = config.GetK8sPluginConfig().InterruptibleTolerations
j.Spec.Executor.Tolerations = append(j.Spec.Executor.Tolerations, config.GetK8sPluginConfig().InterruptibleTolerations...)
j.Spec.Executor.NodeSelector = config.GetK8sPluginConfig().InterruptibleNodeSelector
}

// Add interruptible/non-interruptible node selector requirements to executor pod
flytek8s.ApplyInterruptibleNodeSelectorRequirement(taskCtx.TaskExecutionMetadata().IsInterruptible(), j.Spec.Executor.Affinity)

return j, nil
}

Expand Down
170 changes: 151 additions & 19 deletions go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package spark
import (
"context"
"fmt"
"os"
"strconv"
"testing"

Expand Down Expand Up @@ -353,7 +354,67 @@ func TestBuildResourceSpark(t *testing.T) {
dnsOptVal1 := "1"
dnsOptVal2 := "1"
dnsOptVal3 := "3"

// Set scheduler
schedulerName := "custom-scheduler"

// Node selectors
defaultNodeSelector := map[string]string{
"x/default": "true",
}
interruptibleNodeSelector := map[string]string{
"x/interruptible": "true",
}

defaultPodHostNetwork := true

// Default env vars passed explicitly and default env vars derived from environment
defaultEnvVars := make(map[string]string)
defaultEnvVars["foo"] = "bar"
hamersaw marked this conversation as resolved.
Show resolved Hide resolved

defaultEnvVarsFromEnv := make(map[string]string)
targetKeyFromEnv := "TEST_VAR_FROM_ENV_KEY"
targetValueFromEnv := "TEST_VAR_FROM_ENV_VALUE"
os.Setenv(targetKeyFromEnv, targetValueFromEnv)
defer os.Unsetenv(targetKeyFromEnv)
defaultEnvVarsFromEnv["fooEnv"] = targetKeyFromEnv

// Default affinity/anti-affinity
defaultAffinity := &corev1.Affinity{
NodeAffinity: &corev1.NodeAffinity{
RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{
NodeSelectorTerms: []corev1.NodeSelectorTerm{
{
MatchExpressions: []corev1.NodeSelectorRequirement{
{
Key: "x/default",
Operator: corev1.NodeSelectorOpIn,
Values: []string{"true"},
},
},
},
},
},
},
}

// interruptible/non-interruptible nodeselector requirement
interruptibleNodeSelectorRequirement := &corev1.NodeSelectorRequirement{
Key: "x/interruptible",
Operator: corev1.NodeSelectorOpIn,
Values: []string{"true"},
}

nonInterruptibleNodeSelectorRequirement := &corev1.NodeSelectorRequirement{
Key: "x/non-interruptible",
Operator: corev1.NodeSelectorOpIn,
Values: []string{"true"},
}

// NonInterruptibleNodeSelectorRequirement

assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{
DefaultAffinity: defaultAffinity,
DefaultPodSecurityContext: &corev1.PodSecurityContext{
RunAsUser: &runAsUser,
},
Expand All @@ -378,17 +439,31 @@ func TestBuildResourceSpark(t *testing.T) {
},
Searches: []string{"ns1.svc.cluster-domain.example", "my.dns.search.suffix"},
},
InterruptibleNodeSelector: map[string]string{
"x/interruptible": "true",
DefaultTolerations: []corev1.Toleration{
{
Key: "x/flyte",
Value: "default",
Operator: "Equal",
Effect: "NoSchedule",
},
},
DefaultNodeSelector: defaultNodeSelector,
InterruptibleNodeSelector: interruptibleNodeSelector,
InterruptibleTolerations: []corev1.Toleration{
{
Key: "x/flyte",
Value: "interruptible",
Operator: "Equal",
Effect: "NoSchedule",
},
}}),
},
InterruptibleNodeSelectorRequirement: interruptibleNodeSelectorRequirement,
NonInterruptibleNodeSelectorRequirement: nonInterruptibleNodeSelectorRequirement,
SchedulerName: schedulerName,
EnableHostNetworkingPod: &defaultPodHostNetwork,
DefaultEnvVars: defaultEnvVars,
DefaultEnvVarsFromEnv: defaultEnvVarsFromEnv,
}),
)
resource, err := sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, true))
assert.Nil(t, err)
Expand Down Expand Up @@ -438,19 +513,40 @@ func TestBuildResourceSpark(t *testing.T) {
assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory)
assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory)
assert.Equal(t, dummySparkConf["spark.batchScheduler"], *sparkApp.Spec.BatchScheduler)

// Validate Interruptible Toleration and NodeSelector set for Executor but not Driver.
assert.Equal(t, 0, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, 0, len(sparkApp.Spec.Driver.NodeSelector))

assert.Equal(t, 1, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, schedulerName, *sparkApp.Spec.Executor.SchedulerName)
assert.Equal(t, schedulerName, *sparkApp.Spec.Driver.SchedulerName)
assert.Equal(t, defaultPodHostNetwork, *sparkApp.Spec.Executor.HostNetwork)
assert.Equal(t, defaultPodHostNetwork, *sparkApp.Spec.Driver.HostNetwork)

// Validate
// * Interruptible Toleration and NodeSelector set for Executor but not Driver.
// * Validate Default NodeSelector set for Driver but overwritten with Interruptible NodeSelector for Executor.
// * Default Tolerations set for both Driver and Executor.
// * Interruptible/Non-Interruptible NodeSelectorRequirements set for Executor Affinity but not Driver Affinity.
assert.Equal(t, 1, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, 1, len(sparkApp.Spec.Driver.NodeSelector))
assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Driver.NodeSelector)
tolDriverDefault := sparkApp.Spec.Driver.Tolerations[0]
assert.Equal(t, tolDriverDefault.Key, "x/flyte")
assert.Equal(t, tolDriverDefault.Value, "default")
assert.Equal(t, tolDriverDefault.Operator, corev1.TolerationOperator("Equal"))
assert.Equal(t, tolDriverDefault.Effect, corev1.TaintEffect("NoSchedule"))

assert.Equal(t, 2, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, 1, len(sparkApp.Spec.Executor.NodeSelector))

tol := sparkApp.Spec.Executor.Tolerations[0]
assert.Equal(t, tol.Key, "x/flyte")
assert.Equal(t, tol.Value, "interruptible")
assert.Equal(t, tol.Operator, corev1.TolerationOperator("Equal"))
assert.Equal(t, tol.Effect, corev1.TaintEffect("NoSchedule"))
assert.Equal(t, interruptibleNodeSelector, sparkApp.Spec.Executor.NodeSelector)

tolExecDefault := sparkApp.Spec.Executor.Tolerations[0]
assert.Equal(t, tolExecDefault.Key, "x/flyte")
assert.Equal(t, tolExecDefault.Value, "default")
assert.Equal(t, tolExecDefault.Operator, corev1.TolerationOperator("Equal"))
assert.Equal(t, tolExecDefault.Effect, corev1.TaintEffect("NoSchedule"))

tolExecInterrupt := sparkApp.Spec.Executor.Tolerations[1]
assert.Equal(t, tolExecInterrupt.Key, "x/flyte")
assert.Equal(t, tolExecInterrupt.Value, "interruptible")
assert.Equal(t, tolExecInterrupt.Operator, corev1.TolerationOperator("Equal"))
assert.Equal(t, tolExecInterrupt.Effect, corev1.TaintEffect("NoSchedule"))
assert.Equal(t, "true", sparkApp.Spec.Executor.NodeSelector["x/interruptible"])

for confKey, confVal := range dummySparkConf {
Expand Down Expand Up @@ -485,6 +581,22 @@ func TestBuildResourceSpark(t *testing.T) {
assert.Equal(t, dummySparkConf["spark.flyteorg.feature3.enabled"], sparkApp.Spec.SparkConf["spark.flyteorg.feature3.enabled"])

assert.Equal(t, len(sparkApp.Spec.Driver.EnvVars["FLYTE_MAX_ATTEMPTS"]), 1)
assert.Equal(t, sparkApp.Spec.Driver.EnvVars["foo"], defaultEnvVars["foo"])
assert.Equal(t, sparkApp.Spec.Executor.EnvVars["foo"], defaultEnvVars["foo"])
assert.Equal(t, sparkApp.Spec.Driver.EnvVars["fooEnv"], targetValueFromEnv)
assert.Equal(t, sparkApp.Spec.Executor.EnvVars["fooEnv"], targetValueFromEnv)
assert.Equal(t, sparkApp.Spec.Driver.Affinity, defaultAffinity)

assert.Equal(
t,
sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
defaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
)
assert.Equal(
t,
sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[1],
*interruptibleNodeSelectorRequirement,
)

// Case 2: Driver/Executor request cores set.
dummyConfWithRequest := make(map[string]string)
Expand Down Expand Up @@ -514,10 +626,30 @@ func TestBuildResourceSpark(t *testing.T) {
assert.True(t, ok)

// Validate Interruptible Toleration and NodeSelector not set for both Driver and Executors.
assert.Equal(t, 0, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, 0, len(sparkApp.Spec.Driver.NodeSelector))
assert.Equal(t, 0, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, 0, len(sparkApp.Spec.Executor.NodeSelector))
// Validate that the default Toleration and NodeSelector are set for both Driver and Executors.
assert.Equal(t, 1, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, 1, len(sparkApp.Spec.Driver.NodeSelector))
assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Driver.NodeSelector)
assert.Equal(t, 1, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, 1, len(sparkApp.Spec.Executor.NodeSelector))
assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Executor.NodeSelector)
assert.Equal(t, sparkApp.Spec.Executor.Tolerations[0].Key, "x/flyte")
assert.Equal(t, sparkApp.Spec.Executor.Tolerations[0].Value, "default")
assert.Equal(t, sparkApp.Spec.Driver.Tolerations[0].Key, "x/flyte")
assert.Equal(t, sparkApp.Spec.Driver.Tolerations[0].Value, "default")

// Validate correct affinity and nodeselector requirements are set for both Driver and Executors.
assert.Equal(t, sparkApp.Spec.Driver.Affinity, defaultAffinity)
assert.Equal(
t,
sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
defaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
)
assert.Equal(
t,
sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[1],
*nonInterruptibleNodeSelectorRequirement,
)

// Case 4: Invalid Spark Task-Template
taskTemplate.Custom = nil
Expand Down