From 180088b2730c4ed8200d25baa97b696b09ba8855 Mon Sep 17 00:00:00 2001 From: Michal Wozniak Date: Tue, 15 Oct 2024 10:40:42 +0200 Subject: [PATCH] TAS: Introduce scheduling gate utils --- pkg/controller/jobs/pod/pod_controller.go | 51 +++--- pkg/controller/jobs/pod/pod_webhook.go | 5 +- pkg/util/pod/pod.go | 58 +++++++ pkg/util/pod/pod_test.go | 199 ++++++++++++++++++++++ 4 files changed, 278 insertions(+), 35 deletions(-) create mode 100644 pkg/util/pod/pod.go create mode 100644 pkg/util/pod/pod_test.go diff --git a/pkg/controller/jobs/pod/pod_controller.go b/pkg/controller/jobs/pod/pod_controller.go index d99a7c64a6..0ee17b65e1 100644 --- a/pkg/controller/jobs/pod/pod_controller.go +++ b/pkg/controller/jobs/pod/pod_controller.go @@ -56,13 +56,13 @@ import ( "sigs.k8s.io/kueue/pkg/util/kubeversion" "sigs.k8s.io/kueue/pkg/util/maps" "sigs.k8s.io/kueue/pkg/util/parallelize" + utilpod "sigs.k8s.io/kueue/pkg/util/pod" utilslices "sigs.k8s.io/kueue/pkg/util/slices" ) const ( SchedulingGateName = "kueue.x-k8s.io/admission" FrameworkName = "pod" - gateNotFound = -1 ConditionTypeTerminationTarget = "TerminationTarget" errMsgIncorrectGroupRoleCount = "pod group can't include more than 8 roles" IsGroupWorkloadAnnotationKey = "kueue.x-k8s.io/is-group-workload" @@ -177,23 +177,12 @@ func (p *Pod) Object() client.Object { return &p.pod } -// gateIndex returns the index of the Kueue scheduling gate for corev1.Pod. -// If the scheduling gate is not found, returns -1. -func gateIndex(p *corev1.Pod) int { - for i := range p.Spec.SchedulingGates { - if p.Spec.SchedulingGates[i].Name == SchedulingGateName { - return i - } - } - return gateNotFound -} - func isPodTerminated(p *corev1.Pod) bool { return p.Status.Phase == corev1.PodFailed || p.Status.Phase == corev1.PodSucceeded } func podSuspended(p *corev1.Pod) bool { - return isPodTerminated(p) || gateIndex(p) != gateNotFound + return isPodTerminated(p) || isGated(p) } func isUnretriablePod(pod corev1.Pod) bool { @@ -238,18 +227,6 @@ func (p *Pod) Suspend() { // Not implemented because this is not called when JobWithCustomStop is implemented. } -// ungatePod removes the kueue scheduling gate from the pod. -// Returns true if the pod has been ungated and false otherwise. -func ungatePod(pod *corev1.Pod) bool { - idx := gateIndex(pod) - if idx != gateNotFound { - pod.Spec.SchedulingGates = append(pod.Spec.SchedulingGates[:idx], pod.Spec.SchedulingGates[idx+1:]...) - return true - } - - return false -} - // Run will inject the node affinity and podSet counts extracting from workload to job and unsuspend it. func (p *Pod) Run(ctx context.Context, c client.Client, podSetsInfo []podset.PodSetInfo, recorder record.EventRecorder, msg string) error { log := ctrl.LoggerFrom(ctx) @@ -259,12 +236,12 @@ func (p *Pod) Run(ctx context.Context, c client.Client, podSetsInfo []podset.Pod return fmt.Errorf("%w: expecting 1 pod set got %d", podset.ErrInvalidPodsetInfo, len(podSetsInfo)) } - if gateIndex(&p.pod) == gateNotFound { + if !isGated(&p.pod) { return nil } if err := clientutil.Patch(ctx, c, &p.pod, true, func() (bool, error) { - ungatePod(&p.pod) + ungate(&p.pod) return true, podset.Merge(&p.pod.ObjectMeta, &p.pod.Spec, podSetsInfo[0]) }); err != nil { return err @@ -280,12 +257,12 @@ func (p *Pod) Run(ctx context.Context, c client.Client, podSetsInfo []podset.Pod return parallelize.Until(ctx, len(p.list.Items), func(i int) error { pod := &p.list.Items[i] - if gateIndex(pod) == gateNotFound { + if !isGated(pod) { return nil } if err := clientutil.Patch(ctx, c, pod, true, func() (bool, error) { - ungatePod(pod) + ungate(pod) roleHash, err := getRoleHash(*pod) if err != nil { @@ -854,8 +831,8 @@ func sortActivePods(activePods []corev1.Pod) { if iFin != jFin { return iFin } - iGated := gateIndex(pi) != gateNotFound - jGated := gateIndex(pj) != gateNotFound + iGated := isGated(pi) + jGated := isGated(pj) // Prefer to keep pods that aren't gated. if iGated != jGated { return !iGated @@ -1354,3 +1331,15 @@ func IsPodOwnerManagedByKueue(p *Pod) bool { func GetWorkloadNameForPod(podName string, podUID types.UID) string { return jobframework.GetWorkloadNameForOwnerWithGVK(podName, podUID, gvk) } + +func isGated(pod *corev1.Pod) bool { + return utilpod.HasGate(pod, SchedulingGateName) +} + +func ungate(pod *corev1.Pod) bool { + return utilpod.Ungate(pod, SchedulingGateName) +} + +func gate(pod *corev1.Pod) bool { + return utilpod.Gate(pod, SchedulingGateName) +} diff --git a/pkg/controller/jobs/pod/pod_webhook.go b/pkg/controller/jobs/pod/pod_webhook.go index 0879d16ff9..8927000cc9 100644 --- a/pkg/controller/jobs/pod/pod_webhook.go +++ b/pkg/controller/jobs/pod/pod_webhook.go @@ -179,10 +179,7 @@ func (w *PodWebhook) Default(ctx context.Context, obj runtime.Object) error { } pod.pod.Labels[ManagedLabelKey] = ManagedLabelValue - if gateIndex(&pod.pod) == gateNotFound { - log.V(5).Info("Adding gate") - pod.pod.Spec.SchedulingGates = append(pod.pod.Spec.SchedulingGates, corev1.PodSchedulingGate{Name: SchedulingGateName}) - } + gate(&pod.pod) if podGroupName(pod.pod) != "" { if err := pod.addRoleHash(); err != nil { diff --git a/pkg/util/pod/pod.go b/pkg/util/pod/pod.go new file mode 100644 index 0000000000..5a996ca45b --- /dev/null +++ b/pkg/util/pod/pod.go @@ -0,0 +1,58 @@ +/* +CCopyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package pod + +import ( + "slices" + + corev1 "k8s.io/api/core/v1" +) + +// HasGate checks if the pod has a scheduling gate with a specified name. +func HasGate(pod *corev1.Pod, gateName string) bool { + return gateIndex(pod, gateName) >= 0 +} + +// Ungate removes scheduling gate from the Pod if present. +// Returns true if the pod has been updated and false otherwise. +func Ungate(pod *corev1.Pod, gateName string) bool { + if idx := gateIndex(pod, gateName); idx >= 0 { + pod.Spec.SchedulingGates = slices.Delete(pod.Spec.SchedulingGates, idx, idx+1) + return true + } + return false +} + +// Gate adds scheduling gate from the Pod if present. +// Returns true if the pod has been updated and false otherwise. +func Gate(pod *corev1.Pod, gateName string) bool { + if !HasGate(pod, gateName) { + pod.Spec.SchedulingGates = append(pod.Spec.SchedulingGates, corev1.PodSchedulingGate{ + Name: gateName, + }) + return true + } + return false +} + +// gateIndex returns the index of the Kueue scheduling gate for corev1.Pod. +// If the scheduling gate is not found, returns -1. +func gateIndex(p *corev1.Pod, gateName string) int { + return slices.IndexFunc(p.Spec.SchedulingGates, func(g corev1.PodSchedulingGate) bool { + return g.Name == gateName + }) +} diff --git a/pkg/util/pod/pod_test.go b/pkg/util/pod/pod_test.go new file mode 100644 index 0000000000..41126ba13f --- /dev/null +++ b/pkg/util/pod/pod_test.go @@ -0,0 +1,199 @@ +/* +CCopyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package pod + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + corev1 "k8s.io/api/core/v1" +) + +func TestHasGate(t *testing.T) { + testCases := map[string]struct { + gateName string + pod corev1.Pod + want bool + }{ + "scheduling gate present": { + gateName: "example.com/gate", + pod: corev1.Pod{ + Spec: corev1.PodSpec{ + SchedulingGates: []corev1.PodSchedulingGate{ + { + Name: "example.com/gate", + }, + }, + }, + }, + want: true, + }, + "another gate present": { + gateName: "example.com/gate", + pod: corev1.Pod{ + Spec: corev1.PodSpec{ + SchedulingGates: []corev1.PodSchedulingGate{ + { + Name: "example.com/gate2", + }, + }, + }, + }, + want: false, + }, + "no scheduling gates": { + pod: corev1.Pod{}, + want: false, + }, + } + + for desc, tc := range testCases { + t.Run(desc, func(t *testing.T) { + got := HasGate(&tc.pod, tc.gateName) + if got != tc.want { + t.Errorf("Unexpected result: want=%v, got=%v", tc.want, got) + } + }) + } +} + +func TestUngate(t *testing.T) { + testCases := map[string]struct { + gateName string + pod corev1.Pod + wantPod corev1.Pod + want bool + }{ + "ungate when scheduling gate present": { + gateName: "example.com/gate", + pod: corev1.Pod{ + Spec: corev1.PodSpec{ + SchedulingGates: []corev1.PodSchedulingGate{ + { + Name: "example.com/gate", + }, + }, + }, + }, + wantPod: corev1.Pod{}, + want: true, + }, + "ungate when scheduling gate missing": { + gateName: "example.com/gate", + pod: corev1.Pod{ + Spec: corev1.PodSpec{ + SchedulingGates: []corev1.PodSchedulingGate{ + { + Name: "example.com/gate2", + }, + }, + }, + }, + wantPod: corev1.Pod{ + Spec: corev1.PodSpec{ + SchedulingGates: []corev1.PodSchedulingGate{ + { + Name: "example.com/gate2", + }, + }, + }, + }, + want: false, + }, + } + for desc, tc := range testCases { + t.Run(desc, func(t *testing.T) { + got := Ungate(&tc.pod, tc.gateName) + if got != tc.want { + t.Errorf("Unexpected result: want=%v, got=%v", tc.want, got) + } + if diff := cmp.Diff(tc.wantPod.Spec.SchedulingGates, tc.pod.Spec.SchedulingGates, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("Unexpected scheduling gates\ndiff=%s", diff) + } + }) + } +} + +func TestGate(t *testing.T) { + testCases := map[string]struct { + gateName string + pod corev1.Pod + wantPod corev1.Pod + want bool + }{ + "gate when scheduling gate present": { + gateName: "example.com/gate", + pod: corev1.Pod{ + Spec: corev1.PodSpec{ + SchedulingGates: []corev1.PodSchedulingGate{ + { + Name: "example.com/gate", + }, + }, + }, + }, + wantPod: corev1.Pod{ + Spec: corev1.PodSpec{ + SchedulingGates: []corev1.PodSchedulingGate{ + { + Name: "example.com/gate", + }, + }, + }, + }, + want: false, + }, + "gate when scheduling gate missing": { + gateName: "example.com/gate", + pod: corev1.Pod{ + Spec: corev1.PodSpec{ + SchedulingGates: []corev1.PodSchedulingGate{ + { + Name: "example.com/gate2", + }, + }, + }, + }, + wantPod: corev1.Pod{ + Spec: corev1.PodSpec{ + SchedulingGates: []corev1.PodSchedulingGate{ + { + Name: "example.com/gate2", + }, + { + Name: "example.com/gate", + }, + }, + }, + }, + want: true, + }, + } + + for desc, tc := range testCases { + t.Run(desc, func(t *testing.T) { + got := Gate(&tc.pod, tc.gateName) + if got != tc.want { + t.Errorf("Unexpected result: want=%v, got=%v", tc.want, got) + } + if diff := cmp.Diff(tc.wantPod.Spec.SchedulingGates, tc.pod.Spec.SchedulingGates); diff != "" { + t.Errorf("Unexpected scheduling gates\ndiff=%s", diff) + } + }) + } +}