Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TAS: Introduce pod scheduling gate utils #3234

Merged
merged 1 commit into from
Oct 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 20 additions & 31 deletions pkg/controller/jobs/pod/pod_controller.go
Original file line number Diff line number Diff line change
@@ -56,13 +56,13 @@ import (
"sigs.k8s.io/kueue/pkg/util/kubeversion"
"sigs.k8s.io/kueue/pkg/util/maps"
"sigs.k8s.io/kueue/pkg/util/parallelize"
utilpod "sigs.k8s.io/kueue/pkg/util/pod"
utilslices "sigs.k8s.io/kueue/pkg/util/slices"
)

const (
SchedulingGateName = "kueue.x-k8s.io/admission"
FrameworkName = "pod"
gateNotFound = -1
ConditionTypeTerminationTarget = "TerminationTarget"
errMsgIncorrectGroupRoleCount = "pod group can't include more than 8 roles"
IsGroupWorkloadAnnotationKey = "kueue.x-k8s.io/is-group-workload"
@@ -177,23 +177,12 @@ func (p *Pod) Object() client.Object {
return &p.pod
}

// gateIndex returns the index of the Kueue scheduling gate for corev1.Pod.
// If the scheduling gate is not found, returns -1.
func gateIndex(p *corev1.Pod) int {
for i := range p.Spec.SchedulingGates {
if p.Spec.SchedulingGates[i].Name == SchedulingGateName {
return i
}
}
return gateNotFound
}

func isPodTerminated(p *corev1.Pod) bool {
return p.Status.Phase == corev1.PodFailed || p.Status.Phase == corev1.PodSucceeded
}

func podSuspended(p *corev1.Pod) bool {
return isPodTerminated(p) || gateIndex(p) != gateNotFound
return isPodTerminated(p) || isGated(p)
}

func isUnretriablePod(pod corev1.Pod) bool {
@@ -238,18 +227,6 @@ func (p *Pod) Suspend() {
// Not implemented because this is not called when JobWithCustomStop is implemented.
}

// ungatePod removes the kueue scheduling gate from the pod.
// Returns true if the pod has been ungated and false otherwise.
func ungatePod(pod *corev1.Pod) bool {
idx := gateIndex(pod)
if idx != gateNotFound {
pod.Spec.SchedulingGates = append(pod.Spec.SchedulingGates[:idx], pod.Spec.SchedulingGates[idx+1:]...)
return true
}

return false
}

// Run will inject the node affinity and podSet counts extracting from workload to job and unsuspend it.
func (p *Pod) Run(ctx context.Context, c client.Client, podSetsInfo []podset.PodSetInfo, recorder record.EventRecorder, msg string) error {
log := ctrl.LoggerFrom(ctx)
@@ -259,12 +236,12 @@ func (p *Pod) Run(ctx context.Context, c client.Client, podSetsInfo []podset.Pod
return fmt.Errorf("%w: expecting 1 pod set got %d", podset.ErrInvalidPodsetInfo, len(podSetsInfo))
}

if gateIndex(&p.pod) == gateNotFound {
if !isGated(&p.pod) {
return nil
}

if err := clientutil.Patch(ctx, c, &p.pod, true, func() (bool, error) {
ungatePod(&p.pod)
ungate(&p.pod)
return true, podset.Merge(&p.pod.ObjectMeta, &p.pod.Spec, podSetsInfo[0])
}); err != nil {
return err
@@ -280,12 +257,12 @@ func (p *Pod) Run(ctx context.Context, c client.Client, podSetsInfo []podset.Pod
return parallelize.Until(ctx, len(p.list.Items), func(i int) error {
pod := &p.list.Items[i]

if gateIndex(pod) == gateNotFound {
if !isGated(pod) {
return nil
}

if err := clientutil.Patch(ctx, c, pod, true, func() (bool, error) {
ungatePod(pod)
ungate(pod)

roleHash, err := getRoleHash(*pod)
if err != nil {
@@ -854,8 +831,8 @@ func sortActivePods(activePods []corev1.Pod) {
if iFin != jFin {
return iFin
}
iGated := gateIndex(pi) != gateNotFound
jGated := gateIndex(pj) != gateNotFound
iGated := isGated(pi)
jGated := isGated(pj)
// Prefer to keep pods that aren't gated.
if iGated != jGated {
return !iGated
@@ -1354,3 +1331,15 @@ func IsPodOwnerManagedByKueue(p *Pod) bool {
func GetWorkloadNameForPod(podName string, podUID types.UID) string {
return jobframework.GetWorkloadNameForOwnerWithGVK(podName, podUID, gvk)
}

func isGated(pod *corev1.Pod) bool {
return utilpod.HasGate(pod, SchedulingGateName)
}

func ungate(pod *corev1.Pod) bool {
return utilpod.Ungate(pod, SchedulingGateName)
}

func gate(pod *corev1.Pod) bool {
return utilpod.Gate(pod, SchedulingGateName)
}
5 changes: 1 addition & 4 deletions pkg/controller/jobs/pod/pod_webhook.go
Original file line number Diff line number Diff line change
@@ -179,10 +179,7 @@ func (w *PodWebhook) Default(ctx context.Context, obj runtime.Object) error {
}
pod.pod.Labels[ManagedLabelKey] = ManagedLabelValue

if gateIndex(&pod.pod) == gateNotFound {
log.V(5).Info("Adding gate")
pod.pod.Spec.SchedulingGates = append(pod.pod.Spec.SchedulingGates, corev1.PodSchedulingGate{Name: SchedulingGateName})
}
gate(&pod.pod)

if podGroupName(pod.pod) != "" {
if err := pod.addRoleHash(); err != nil {
58 changes: 58 additions & 0 deletions pkg/util/pod/pod.go
mimowo marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
CCopyright The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package pod

import (
"slices"

corev1 "k8s.io/api/core/v1"
)

// HasGate checks if the pod has a scheduling gate with a specified name.
func HasGate(pod *corev1.Pod, gateName string) bool {
return gateIndex(pod, gateName) >= 0
}

// Ungate removes scheduling gate from the Pod if present.
// Returns true if the pod has been updated and false otherwise.
func Ungate(pod *corev1.Pod, gateName string) bool {
if idx := gateIndex(pod, gateName); idx >= 0 {
pod.Spec.SchedulingGates = slices.Delete(pod.Spec.SchedulingGates, idx, idx+1)
return true
}
return false
}

// Gate adds scheduling gate from the Pod if present.
// Returns true if the pod has been updated and false otherwise.
func Gate(pod *corev1.Pod, gateName string) bool {
if !HasGate(pod, gateName) {
pod.Spec.SchedulingGates = append(pod.Spec.SchedulingGates, corev1.PodSchedulingGate{
Name: gateName,
})
return true
}
return false
}

// gateIndex returns the index of the Kueue scheduling gate for corev1.Pod.
// If the scheduling gate is not found, returns -1.
func gateIndex(p *corev1.Pod, gateName string) int {
return slices.IndexFunc(p.Spec.SchedulingGates, func(g corev1.PodSchedulingGate) bool {
return g.Name == gateName
})
}
199 changes: 199 additions & 0 deletions pkg/util/pod/pod_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
/*
CCopyright The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package pod

import (
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
corev1 "k8s.io/api/core/v1"
)

func TestHasGate(t *testing.T) {
testCases := map[string]struct {
gateName string
pod corev1.Pod
want bool
}{
"scheduling gate present": {
gateName: "example.com/gate",
pod: corev1.Pod{
Spec: corev1.PodSpec{
SchedulingGates: []corev1.PodSchedulingGate{
{
Name: "example.com/gate",
},
},
},
},
want: true,
},
"another gate present": {
gateName: "example.com/gate",
pod: corev1.Pod{
Spec: corev1.PodSpec{
SchedulingGates: []corev1.PodSchedulingGate{
{
Name: "example.com/gate2",
},
},
},
},
want: false,
},
"no scheduling gates": {
pod: corev1.Pod{},
want: false,
},
}

for desc, tc := range testCases {
t.Run(desc, func(t *testing.T) {
got := HasGate(&tc.pod, tc.gateName)
if got != tc.want {
t.Errorf("Unexpected result: want=%v, got=%v", tc.want, got)
}
})
}
}

func TestUngate(t *testing.T) {
testCases := map[string]struct {
gateName string
pod corev1.Pod
wantPod corev1.Pod
want bool
}{
"ungate when scheduling gate present": {
gateName: "example.com/gate",
pod: corev1.Pod{
Spec: corev1.PodSpec{
SchedulingGates: []corev1.PodSchedulingGate{
{
Name: "example.com/gate",
},
},
},
},
wantPod: corev1.Pod{},
want: true,
},
"ungate when scheduling gate missing": {
gateName: "example.com/gate",
pod: corev1.Pod{
Spec: corev1.PodSpec{
SchedulingGates: []corev1.PodSchedulingGate{
{
Name: "example.com/gate2",
},
},
},
},
wantPod: corev1.Pod{
Spec: corev1.PodSpec{
SchedulingGates: []corev1.PodSchedulingGate{
{
Name: "example.com/gate2",
},
},
},
},
want: false,
},
}
for desc, tc := range testCases {
t.Run(desc, func(t *testing.T) {
got := Ungate(&tc.pod, tc.gateName)
if got != tc.want {
t.Errorf("Unexpected result: want=%v, got=%v", tc.want, got)
}
if diff := cmp.Diff(tc.wantPod.Spec.SchedulingGates, tc.pod.Spec.SchedulingGates, cmpopts.EquateEmpty()); diff != "" {
t.Errorf("Unexpected scheduling gates\ndiff=%s", diff)
}
})
}
}

func TestGate(t *testing.T) {
testCases := map[string]struct {
gateName string
pod corev1.Pod
wantPod corev1.Pod
want bool
}{
"gate when scheduling gate present": {
gateName: "example.com/gate",
pod: corev1.Pod{
Spec: corev1.PodSpec{
SchedulingGates: []corev1.PodSchedulingGate{
{
Name: "example.com/gate",
},
},
},
},
wantPod: corev1.Pod{
Spec: corev1.PodSpec{
SchedulingGates: []corev1.PodSchedulingGate{
{
Name: "example.com/gate",
},
},
},
},
want: false,
},
"gate when scheduling gate missing": {
gateName: "example.com/gate",
pod: corev1.Pod{
Spec: corev1.PodSpec{
SchedulingGates: []corev1.PodSchedulingGate{
{
Name: "example.com/gate2",
},
},
},
},
wantPod: corev1.Pod{
Spec: corev1.PodSpec{
SchedulingGates: []corev1.PodSchedulingGate{
{
Name: "example.com/gate2",
},
{
Name: "example.com/gate",
},
},
},
},
want: true,
},
}

for desc, tc := range testCases {
t.Run(desc, func(t *testing.T) {
got := Gate(&tc.pod, tc.gateName)
if got != tc.want {
t.Errorf("Unexpected result: want=%v, got=%v", tc.want, got)
}
if diff := cmp.Diff(tc.wantPod.Spec.SchedulingGates, tc.pod.Spec.SchedulingGates); diff != "" {
t.Errorf("Unexpected scheduling gates\ndiff=%s", diff)
}
})
}
}