Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Pass K8sPluginConfig to spark driver and executor pods #patch #271

Merged
merged 13 commits into from
Nov 1, 2022
21 changes: 19 additions & 2 deletions go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
for _, envVar := range envVars {
sparkEnvVars[envVar.Name] = envVar.Value
}

for k, v := range config.GetK8sPluginConfig().DefaultEnvVarsFromEnv {
sparkEnvVars[k] = v
}
hamersaw marked this conversation as resolved.
Show resolved Hide resolved

sparkEnvVars["FLYTE_MAX_ATTEMPTS"] = strconv.Itoa(int(taskCtx.TaskExecutionMetadata().GetMaxAttempts()))

serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())
Expand All @@ -99,24 +104,34 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
}
driverSpec := sparkOp.DriverSpec{
SparkPodSpec: sparkOp.SparkPodSpec{
Affinity: config.GetK8sPluginConfig().DefaultAffinity,
Annotations: annotations,
Labels: labels,
EnvVars: sparkEnvVars,
Image: &container.Image,
SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(),
DNSConfig: config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy(),
Tolerations: config.GetK8sPluginConfig().DefaultTolerations,
SchedulerName: &config.GetK8sPluginConfig().SchedulerName,
NodeSelector: config.GetK8sPluginConfig().DefaultNodeSelector,
HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod,
},
ServiceAccount: &serviceAccountName,
}

executorSpec := sparkOp.ExecutorSpec{
SparkPodSpec: sparkOp.SparkPodSpec{
Affinity: config.GetK8sPluginConfig().DefaultAffinity,
Annotations: annotations,
Labels: labels,
Image: &container.Image,
EnvVars: sparkEnvVars,
SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(),
DNSConfig: config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy(),
Tolerations: config.GetK8sPluginConfig().DefaultTolerations,
SchedulerName: &config.GetK8sPluginConfig().SchedulerName,
NodeSelector: config.GetK8sPluginConfig().DefaultNodeSelector,
HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod,
},
}

Expand Down Expand Up @@ -225,11 +240,13 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
j.Spec.MainClass = &sparkJob.MainClass
}

// Add Tolerations/NodeSelector to only Executor pods.
// Add Interruptible Tolerations/NodeSelector to only Executor pods.
// The Interruptible NodeSelector takes precedence over the DefaultNodeSelector
if taskCtx.TaskExecutionMetadata().IsInterruptible() {
j.Spec.Executor.Tolerations = config.GetK8sPluginConfig().InterruptibleTolerations
j.Spec.Executor.Tolerations = append(j.Spec.Executor.Tolerations, config.GetK8sPluginConfig().InterruptibleTolerations...)
j.Spec.Executor.NodeSelector = config.GetK8sPluginConfig().InterruptibleNodeSelector
}

return j, nil
}

Expand Down
123 changes: 104 additions & 19 deletions go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,47 @@ func TestBuildResourceSpark(t *testing.T) {
dnsOptVal1 := "1"
dnsOptVal2 := "1"
dnsOptVal3 := "3"

// Set scheduler
schedulerName := "custom-scheduler"

// Node selectors
defaultNodeSelector := map[string]string{
"x/default": "true",
}
interruptibleNodeSelector := map[string]string{
"x/interruptible": "true",
}

defaultPodHostNetwork := true

defaultEnvVars := make(map[string]string)
defaultEnvVars["foo"] = "bar"
hamersaw marked this conversation as resolved.
Show resolved Hide resolved

defaultEnvVarsFromEnv := make(map[string]string)
defaultEnvVarsFromEnv["fooEnv"] = "barEnv"

// Default affinity/anti-affinity
defaultAffinity := &corev1.Affinity{
NodeAffinity: &corev1.NodeAffinity{
RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{
NodeSelectorTerms: []corev1.NodeSelectorTerm{
{
MatchExpressions: []corev1.NodeSelectorRequirement{
{
Key: "x/default",
Operator: corev1.NodeSelectorOpIn,
Values: []string{"true"},
},
},
},
},
},
},
}

assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{
DefaultAffinity: defaultAffinity,
DefaultPodSecurityContext: &corev1.PodSecurityContext{
RunAsUser: &runAsUser,
},
Expand All @@ -378,17 +418,29 @@ func TestBuildResourceSpark(t *testing.T) {
},
Searches: []string{"ns1.svc.cluster-domain.example", "my.dns.search.suffix"},
},
InterruptibleNodeSelector: map[string]string{
"x/interruptible": "true",
DefaultTolerations: []corev1.Toleration{
{
Key: "x/flyte",
Value: "default",
Operator: "Equal",
Effect: "NoSchedule",
},
},
DefaultNodeSelector: defaultNodeSelector,
InterruptibleNodeSelector: interruptibleNodeSelector,
InterruptibleTolerations: []corev1.Toleration{
{
Key: "x/flyte",
Value: "interruptible",
Operator: "Equal",
Effect: "NoSchedule",
},
}}),
},
SchedulerName: schedulerName,
EnableHostNetworkingPod: &defaultPodHostNetwork,
DefaultEnvVars: defaultEnvVars,
DefaultEnvVarsFromEnv: defaultEnvVarsFromEnv,
}),
)
resource, err := sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, true))
assert.Nil(t, err)
Expand Down Expand Up @@ -438,19 +490,39 @@ func TestBuildResourceSpark(t *testing.T) {
assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory)
assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory)
assert.Equal(t, dummySparkConf["spark.batchScheduler"], *sparkApp.Spec.BatchScheduler)

// Validate Interruptible Toleration and NodeSelector set for Executor but not Driver.
assert.Equal(t, 0, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, 0, len(sparkApp.Spec.Driver.NodeSelector))

assert.Equal(t, 1, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, schedulerName, *sparkApp.Spec.Executor.SchedulerName)
assert.Equal(t, schedulerName, *sparkApp.Spec.Driver.SchedulerName)
assert.Equal(t, defaultPodHostNetwork, *sparkApp.Spec.Executor.HostNetwork)
assert.Equal(t, defaultPodHostNetwork, *sparkApp.Spec.Driver.HostNetwork)

// Validate
// * Interruptible Toleration and NodeSelector set for Executor but not Driver.
// * Validate Default NodeSelector set for Driver but overwritten with Interruptible NodeSelector for Executor.
// * Default Tolerations set for both Driver and Executor.
assert.Equal(t, 1, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, 1, len(sparkApp.Spec.Driver.NodeSelector))
assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Driver.NodeSelector)
tolDriverDefault := sparkApp.Spec.Driver.Tolerations[0]
assert.Equal(t, tolDriverDefault.Key, "x/flyte")
assert.Equal(t, tolDriverDefault.Value, "default")
assert.Equal(t, tolDriverDefault.Operator, corev1.TolerationOperator("Equal"))
assert.Equal(t, tolDriverDefault.Effect, corev1.TaintEffect("NoSchedule"))

assert.Equal(t, 2, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, 1, len(sparkApp.Spec.Executor.NodeSelector))

tol := sparkApp.Spec.Executor.Tolerations[0]
assert.Equal(t, tol.Key, "x/flyte")
assert.Equal(t, tol.Value, "interruptible")
assert.Equal(t, tol.Operator, corev1.TolerationOperator("Equal"))
assert.Equal(t, tol.Effect, corev1.TaintEffect("NoSchedule"))
assert.Equal(t, interruptibleNodeSelector, sparkApp.Spec.Executor.NodeSelector)

tolExecDefault := sparkApp.Spec.Executor.Tolerations[0]
assert.Equal(t, tolExecDefault.Key, "x/flyte")
assert.Equal(t, tolExecDefault.Value, "default")
assert.Equal(t, tolExecDefault.Operator, corev1.TolerationOperator("Equal"))
assert.Equal(t, tolExecDefault.Effect, corev1.TaintEffect("NoSchedule"))

tolExecInterrupt := sparkApp.Spec.Executor.Tolerations[1]
assert.Equal(t, tolExecInterrupt.Key, "x/flyte")
assert.Equal(t, tolExecInterrupt.Value, "interruptible")
assert.Equal(t, tolExecInterrupt.Operator, corev1.TolerationOperator("Equal"))
assert.Equal(t, tolExecInterrupt.Effect, corev1.TaintEffect("NoSchedule"))
assert.Equal(t, "true", sparkApp.Spec.Executor.NodeSelector["x/interruptible"])

for confKey, confVal := range dummySparkConf {
Expand Down Expand Up @@ -485,6 +557,12 @@ func TestBuildResourceSpark(t *testing.T) {
assert.Equal(t, dummySparkConf["spark.flyteorg.feature3.enabled"], sparkApp.Spec.SparkConf["spark.flyteorg.feature3.enabled"])

assert.Equal(t, len(sparkApp.Spec.Driver.EnvVars["FLYTE_MAX_ATTEMPTS"]), 1)
assert.Equal(t, sparkApp.Spec.Driver.EnvVars["foo"], defaultEnvVars["foo"])
assert.Equal(t, sparkApp.Spec.Executor.EnvVars["foo"], defaultEnvVars["foo"])
assert.Equal(t, sparkApp.Spec.Driver.EnvVars["fooEnv"], defaultEnvVarsFromEnv["fooEnv"])
assert.Equal(t, sparkApp.Spec.Executor.EnvVars["fooEnv"], defaultEnvVarsFromEnv["fooEnv"])
assert.Equal(t, sparkApp.Spec.Driver.Affinity, defaultAffinity)
assert.Equal(t, sparkApp.Spec.Executor.Affinity, defaultAffinity)

// Case 2: Driver/Executor request cores set.
dummyConfWithRequest := make(map[string]string)
Expand Down Expand Up @@ -514,10 +592,17 @@ func TestBuildResourceSpark(t *testing.T) {
assert.True(t, ok)

// Validate Interruptible Toleration and NodeSelector not set for both Driver and Executors.
assert.Equal(t, 0, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, 0, len(sparkApp.Spec.Driver.NodeSelector))
assert.Equal(t, 0, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, 0, len(sparkApp.Spec.Executor.NodeSelector))
// Validate that the default Toleration and NodeSelector are set for both Driver and Executors.
assert.Equal(t, 1, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, 1, len(sparkApp.Spec.Driver.NodeSelector))
assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Driver.NodeSelector)
assert.Equal(t, 1, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, 1, len(sparkApp.Spec.Executor.NodeSelector))
assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Executor.NodeSelector)
assert.Equal(t, sparkApp.Spec.Executor.Tolerations[0].Key, "x/flyte")
assert.Equal(t, sparkApp.Spec.Executor.Tolerations[0].Value, "default")
assert.Equal(t, sparkApp.Spec.Driver.Tolerations[0].Key, "x/flyte")
assert.Equal(t, sparkApp.Spec.Driver.Tolerations[0].Value, "default")

// Case 4: Invalid Spark Task-Template
taskTemplate.Custom = nil
Expand Down