Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Pass K8sPluginConfig to spark driver and executor pods #patch #271

Merged
merged 13 commits into from
Nov 1, 2022
18 changes: 17 additions & 1 deletion go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
for _, envVar := range envVars {
sparkEnvVars[envVar.Name] = envVar.Value
}

for k, v := range config.GetK8sPluginConfig().DefaultEnvVarsFromEnv {
sparkEnvVars[k] = v
}
hamersaw marked this conversation as resolved.
Show resolved Hide resolved

sparkEnvVars["FLYTE_MAX_ATTEMPTS"] = strconv.Itoa(int(taskCtx.TaskExecutionMetadata().GetMaxAttempts()))

serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())
Expand All @@ -99,24 +104,34 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
}
driverSpec := sparkOp.DriverSpec{
SparkPodSpec: sparkOp.SparkPodSpec{
Affinity: config.GetK8sPluginConfig().DefaultAffinity,
Annotations: annotations,
Labels: labels,
EnvVars: sparkEnvVars,
Image: &container.Image,
SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(),
DNSConfig: config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy(),
Tolerations: config.GetK8sPluginConfig().DefaultTolerations,
SchedulerName: &config.GetK8sPluginConfig().SchedulerName,
NodeSelector: config.GetK8sPluginConfig().DefaultNodeSelector,
HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod,
},
ServiceAccount: &serviceAccountName,
}

executorSpec := sparkOp.ExecutorSpec{
SparkPodSpec: sparkOp.SparkPodSpec{
Affinity: config.GetK8sPluginConfig().DefaultAffinity,
Annotations: annotations,
Labels: labels,
Image: &container.Image,
EnvVars: sparkEnvVars,
SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(),
DNSConfig: config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy(),
Tolerations: config.GetK8sPluginConfig().DefaultTolerations,
SchedulerName: &config.GetK8sPluginConfig().SchedulerName,
NodeSelector: config.GetK8sPluginConfig().DefaultNodeSelector,
HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod,
},
}

Expand Down Expand Up @@ -227,9 +242,10 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo

// Add Tolerations/NodeSelector to only Executor pods.
if taskCtx.TaskExecutionMetadata().IsInterruptible() {
j.Spec.Executor.Tolerations = config.GetK8sPluginConfig().InterruptibleTolerations
j.Spec.Executor.Tolerations = append(j.Spec.Executor.Tolerations, config.GetK8sPluginConfig().InterruptibleTolerations...)
j.Spec.Executor.NodeSelector = config.GetK8sPluginConfig().InterruptibleNodeSelector
}

return j, nil
}

Expand Down
115 changes: 98 additions & 17 deletions go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,47 @@ func TestBuildResourceSpark(t *testing.T) {
dnsOptVal1 := "1"
dnsOptVal2 := "1"
dnsOptVal3 := "3"

// Set scheduler
schedulerName := "custom-scheduler"

// Node selectors
defaultNodeSelector := map[string]string{
"x/default": "true",
}
interruptibleNodeSelector := map[string]string{
"x/interruptible": "true",
}

defaultPodHostNetwork := true

defaultEnvVars := make(map[string]string)
defaultEnvVars["foo"] = "bar"
hamersaw marked this conversation as resolved.
Show resolved Hide resolved

defaultEnvVarsFromEnv := make(map[string]string)
defaultEnvVarsFromEnv["fooEnv"] = "barEnv"

// Default affinity/anti-affinity
defaultAffinity := &corev1.Affinity{
NodeAffinity: &corev1.NodeAffinity{
RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{
NodeSelectorTerms: []corev1.NodeSelectorTerm{
{
MatchExpressions: []corev1.NodeSelectorRequirement{
{
Key: "x/default",
Operator: corev1.NodeSelectorOpIn,
Values: []string{"true"},
},
},
},
},
},
},
}

assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{
DefaultAffinity: defaultAffinity,
DefaultPodSecurityContext: &corev1.PodSecurityContext{
RunAsUser: &runAsUser,
},
Expand All @@ -378,17 +418,29 @@ func TestBuildResourceSpark(t *testing.T) {
},
Searches: []string{"ns1.svc.cluster-domain.example", "my.dns.search.suffix"},
},
InterruptibleNodeSelector: map[string]string{
"x/interruptible": "true",
DefaultTolerations: []corev1.Toleration{
{
Key: "x/flyte",
Value: "default",
Operator: "Equal",
Effect: "NoSchedule",
},
},
DefaultNodeSelector: defaultNodeSelector,
InterruptibleNodeSelector: interruptibleNodeSelector,
InterruptibleTolerations: []corev1.Toleration{
{
Key: "x/flyte",
Value: "interruptible",
Operator: "Equal",
Effect: "NoSchedule",
},
}}),
},
SchedulerName: schedulerName,
EnableHostNetworkingPod: &defaultPodHostNetwork,
DefaultEnvVars: defaultEnvVars,
DefaultEnvVarsFromEnv: defaultEnvVarsFromEnv,
}),
)
resource, err := sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, true))
assert.Nil(t, err)
Expand Down Expand Up @@ -438,19 +490,36 @@ func TestBuildResourceSpark(t *testing.T) {
assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory)
assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory)
assert.Equal(t, dummySparkConf["spark.batchScheduler"], *sparkApp.Spec.BatchScheduler)
assert.Equal(t, schedulerName, *sparkApp.Spec.Executor.SchedulerName)
assert.Equal(t, schedulerName, *sparkApp.Spec.Driver.SchedulerName)
assert.Equal(t, defaultPodHostNetwork, *sparkApp.Spec.Executor.HostNetwork)
assert.Equal(t, defaultPodHostNetwork, *sparkApp.Spec.Driver.HostNetwork)

// Validate Interruptible Toleration and NodeSelector set for Executor but not Driver.
assert.Equal(t, 0, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, 0, len(sparkApp.Spec.Driver.NodeSelector))

assert.Equal(t, 1, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, 1, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, 1, len(sparkApp.Spec.Driver.NodeSelector))
assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Driver.NodeSelector)
tolDriverDefault := sparkApp.Spec.Driver.Tolerations[0]
assert.Equal(t, tolDriverDefault.Key, "x/flyte")
assert.Equal(t, tolDriverDefault.Value, "default")
assert.Equal(t, tolDriverDefault.Operator, corev1.TolerationOperator("Equal"))
assert.Equal(t, tolDriverDefault.Effect, corev1.TaintEffect("NoSchedule"))

assert.Equal(t, 2, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, 1, len(sparkApp.Spec.Executor.NodeSelector))

tol := sparkApp.Spec.Executor.Tolerations[0]
assert.Equal(t, tol.Key, "x/flyte")
assert.Equal(t, tol.Value, "interruptible")
assert.Equal(t, tol.Operator, corev1.TolerationOperator("Equal"))
assert.Equal(t, tol.Effect, corev1.TaintEffect("NoSchedule"))
assert.Equal(t, interruptibleNodeSelector, sparkApp.Spec.Executor.NodeSelector)

tolExecDefault := sparkApp.Spec.Executor.Tolerations[0]
assert.Equal(t, tolExecDefault.Key, "x/flyte")
assert.Equal(t, tolExecDefault.Value, "default")
assert.Equal(t, tolExecDefault.Operator, corev1.TolerationOperator("Equal"))
assert.Equal(t, tolExecDefault.Effect, corev1.TaintEffect("NoSchedule"))

tolExecInterrupt := sparkApp.Spec.Executor.Tolerations[1]
assert.Equal(t, tolExecInterrupt.Key, "x/flyte")
assert.Equal(t, tolExecInterrupt.Value, "interruptible")
assert.Equal(t, tolExecInterrupt.Operator, corev1.TolerationOperator("Equal"))
assert.Equal(t, tolExecInterrupt.Effect, corev1.TaintEffect("NoSchedule"))
assert.Equal(t, "true", sparkApp.Spec.Executor.NodeSelector["x/interruptible"])

for confKey, confVal := range dummySparkConf {
Expand Down Expand Up @@ -485,6 +554,12 @@ func TestBuildResourceSpark(t *testing.T) {
assert.Equal(t, dummySparkConf["spark.flyteorg.feature3.enabled"], sparkApp.Spec.SparkConf["spark.flyteorg.feature3.enabled"])

assert.Equal(t, len(sparkApp.Spec.Driver.EnvVars["FLYTE_MAX_ATTEMPTS"]), 1)
assert.Equal(t, sparkApp.Spec.Driver.EnvVars["foo"], defaultEnvVars["foo"])
assert.Equal(t, sparkApp.Spec.Executor.EnvVars["foo"], defaultEnvVars["foo"])
assert.Equal(t, sparkApp.Spec.Driver.EnvVars["fooEnv"], defaultEnvVarsFromEnv["fooEnv"])
assert.Equal(t, sparkApp.Spec.Executor.EnvVars["fooEnv"], defaultEnvVarsFromEnv["fooEnv"])
assert.Equal(t, sparkApp.Spec.Driver.Affinity, defaultAffinity)
assert.Equal(t, sparkApp.Spec.Executor.Affinity, defaultAffinity)

// Case 2: Driver/Executor request cores set.
dummyConfWithRequest := make(map[string]string)
Expand Down Expand Up @@ -514,10 +589,16 @@ func TestBuildResourceSpark(t *testing.T) {
assert.True(t, ok)

// Validate Interruptible Toleration and NodeSelector not set for both Driver and Executors.
assert.Equal(t, 0, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, 0, len(sparkApp.Spec.Driver.NodeSelector))
assert.Equal(t, 0, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, 0, len(sparkApp.Spec.Executor.NodeSelector))
assert.Equal(t, 1, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, 1, len(sparkApp.Spec.Driver.NodeSelector))
assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Driver.NodeSelector)
assert.Equal(t, 1, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, 1, len(sparkApp.Spec.Executor.NodeSelector))
assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Executor.NodeSelector)
assert.Equal(t, sparkApp.Spec.Executor.Tolerations[0].Key, "x/flyte")
assert.Equal(t, sparkApp.Spec.Executor.Tolerations[0].Value, "default")
assert.Equal(t, sparkApp.Spec.Driver.Tolerations[0].Key, "x/flyte")
assert.Equal(t, sparkApp.Spec.Driver.Tolerations[0].Value, "default")

// Case 4: Invalid Spark Task-Template
taskTemplate.Custom = nil
Expand Down