Skip to content

Commit

Permalink
TAS: Add schedling gate for assigned PodTemplates
Browse files Browse the repository at this point in the history
  • Loading branch information
mimowo committed Oct 18, 2024
1 parent e676313 commit f783af8
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 18 deletions.
53 changes: 37 additions & 16 deletions pkg/podset/podset.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ import (
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"

kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1"
kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
"sigs.k8s.io/kueue/pkg/features"
utilmaps "sigs.k8s.io/kueue/pkg/util/maps"
)

Expand All @@ -40,12 +42,13 @@ var (
)

type PodSetInfo struct {
Name string
Count int32
Annotations map[string]string
Labels map[string]string
NodeSelector map[string]string
Tolerations []corev1.Toleration
Name string
Count int32
Annotations map[string]string
Labels map[string]string
NodeSelector map[string]string
Tolerations []corev1.Toleration
SchedulingGates []corev1.PodSchedulingGate
}

// FromAssignment returns a PodSetInfo based on the provided assignment and an error if unable
Expand All @@ -59,6 +62,11 @@ func FromAssignment(ctx context.Context, client client.Client, assignment *kueue
Labels: make(map[string]string),
Annotations: make(map[string]string),
}
if features.Enabled(features.TopologyAwareScheduling) && assignment.TopologyAssignment != nil {
info.SchedulingGates = append(info.SchedulingGates, corev1.PodSchedulingGate{
Name: kueuealpha.TopologySchedulingGate,
})
}
for _, flvRef := range assignment.Flavors {
if processedFlvs.Has(flvRef) {
continue
Expand Down Expand Up @@ -89,12 +97,13 @@ func FromUpdate(update *kueue.PodSetUpdate) PodSetInfo {
// FromPodSet returns a PodSeeInfo based on the provided PodSet
func FromPodSet(ps *kueue.PodSet) PodSetInfo {
return PodSetInfo{
Name: ps.Name,
Count: ps.Count,
Annotations: maps.Clone(ps.Template.Annotations),
Labels: maps.Clone(ps.Template.Labels),
NodeSelector: maps.Clone(ps.Template.Spec.NodeSelector),
Tolerations: slices.Clone(ps.Template.Spec.Tolerations),
Name: ps.Name,
Count: ps.Count,
Annotations: maps.Clone(ps.Template.Annotations),
Labels: maps.Clone(ps.Template.Labels),
NodeSelector: maps.Clone(ps.Template.Spec.NodeSelector),
Tolerations: slices.Clone(ps.Template.Spec.Tolerations),
SchedulingGates: slices.Clone(ps.Template.Spec.SchedulingGates),
}
}

Expand All @@ -118,6 +127,12 @@ func (podSetInfo *PodSetInfo) Merge(o PodSetInfo) error {
podSetInfo.Tolerations = append(podSetInfo.Tolerations, t)
}
}
// make sure we don't duplicate schedulingGates
for _, t := range o.SchedulingGates {
if slices.Index(podSetInfo.SchedulingGates, t) == -1 {
podSetInfo.SchedulingGates = append(podSetInfo.SchedulingGates, t)
}
}
return nil
}

Expand All @@ -135,10 +150,11 @@ func (podSetInfo *PodSetInfo) AddOrUpdateLabel(k, v string) {
// It returns error if there is a conflict.
func Merge(meta *metav1.ObjectMeta, spec *corev1.PodSpec, info PodSetInfo) error {
tmp := PodSetInfo{
Annotations: meta.Annotations,
Labels: meta.Labels,
NodeSelector: spec.NodeSelector,
Tolerations: spec.Tolerations,
Annotations: meta.Annotations,
Labels: meta.Labels,
NodeSelector: spec.NodeSelector,
Tolerations: spec.Tolerations,
SchedulingGates: spec.SchedulingGates,
}
if err := tmp.Merge(info); err != nil {
return err
Expand All @@ -147,6 +163,7 @@ func Merge(meta *metav1.ObjectMeta, spec *corev1.PodSpec, info PodSetInfo) error
meta.Labels = tmp.Labels
spec.NodeSelector = tmp.NodeSelector
spec.Tolerations = tmp.Tolerations
spec.SchedulingGates = tmp.SchedulingGates
return nil
}

Expand All @@ -170,6 +187,10 @@ func RestorePodSpec(meta *metav1.ObjectMeta, spec *corev1.PodSpec, info PodSetIn
spec.Tolerations = slices.Clone(info.Tolerations)
changed = true
}
if !slices.Equal(spec.SchedulingGates, info.SchedulingGates) {
spec.SchedulingGates = slices.Clone(info.SchedulingGates)
changed = true
}
return changed
}

Expand Down
123 changes: 121 additions & 2 deletions pkg/podset/podset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
package podset

import (
"context"
"testing"

"github.com/google/go-cmp/cmp"
Expand All @@ -26,7 +25,9 @@ import (
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/utils/ptr"

kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1"
kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
"sigs.k8s.io/kueue/pkg/features"
utiltesting "sigs.k8s.io/kueue/pkg/util/testing"
)

Expand Down Expand Up @@ -63,6 +64,8 @@ func TestFromAssignment(t *testing.T) {
Obj()

cases := map[string]struct {
enableTopologyAwareScheduling bool

assignment *kueue.PodSetAssignment
defaultCount int32
flavors []kueue.ResourceFlavor
Expand Down Expand Up @@ -163,10 +166,73 @@ func TestFromAssignment(t *testing.T) {
Tolerations: []corev1.Toleration{*toleration1.DeepCopy(), *toleration2.DeepCopy()},
},
},
"with topology assignment; TopologyAwareScheduling enabled - scheduling gate added": {
enableTopologyAwareScheduling: true,
assignment: &kueue.PodSetAssignment{
Name: "name",
Flavors: map[corev1.ResourceName]kueue.ResourceFlavorReference{
corev1.ResourceCPU: kueue.ResourceFlavorReference(flavor1.Name),
},
TopologyAssignment: &kueue.TopologyAssignment{
Levels: []string{"cloud.com/rack"},
Domains: []kueue.TopologyDomainAssignment{
{
Values: []string{"rack1"},
Count: 4,
},
},
},
},
defaultCount: 4,
flavors: []kueue.ResourceFlavor{*flavor1.DeepCopy()},
wantInfo: PodSetInfo{
Name: "name",
Count: 4,
NodeSelector: map[string]string{
"f1l1": "f1v1",
"f1l2": "f1v2",
},
Tolerations: []corev1.Toleration{*toleration1.DeepCopy(), *toleration2.DeepCopy()},
SchedulingGates: []corev1.PodSchedulingGate{
{
Name: kueuealpha.TopologySchedulingGate,
},
},
},
},
"with topology assignment; TopologyAwareScheduling disabled - no scheduling gate added": {
assignment: &kueue.PodSetAssignment{
Name: "name",
Flavors: map[corev1.ResourceName]kueue.ResourceFlavorReference{
corev1.ResourceCPU: kueue.ResourceFlavorReference(flavor1.Name),
},
TopologyAssignment: &kueue.TopologyAssignment{
Levels: []string{"cloud.com/rack"},
Domains: []kueue.TopologyDomainAssignment{
{
Values: []string{"rack1"},
Count: 4,
},
},
},
},
defaultCount: 4,
flavors: []kueue.ResourceFlavor{*flavor1.DeepCopy()},
wantInfo: PodSetInfo{
Name: "name",
Count: 4,
NodeSelector: map[string]string{
"f1l1": "f1v1",
"f1l2": "f1v2",
},
Tolerations: []corev1.Toleration{*toleration1.DeepCopy(), *toleration2.DeepCopy()},
},
},
}
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
ctx := context.TODO()
ctx, _ := utiltesting.ContextWithLog(t)
features.SetFeatureGateDuringTest(t, features.TopologyAwareScheduling, tc.enableTopologyAwareScheduling)
client := utiltesting.NewClientBuilder().WithLists(&kueue.ResourceFlavorList{Items: tc.flavors}).Build()

gotInfo, gotError := FromAssignment(ctx, client, tc.assignment, tc.defaultCount)
Expand Down Expand Up @@ -309,6 +375,59 @@ func TestMergeRestore(t *testing.T) {
},
wantError: true,
},
"podset with scheduling gate; empty info": {
podSet: utiltesting.MakePodSet("", 1).
SchedulingGates(corev1.PodSchedulingGate{
Name: "example.com/gate",
}).
Obj(),
wantPodSet: utiltesting.MakePodSet("", 1).
SchedulingGates(corev1.PodSchedulingGate{
Name: "example.com/gate",
}).
Obj(),
},
"podset with scheduling gate; info re-adds the same": {
podSet: utiltesting.MakePodSet("", 1).
SchedulingGates(corev1.PodSchedulingGate{
Name: "example.com/gate",
}).
Obj(),
info: PodSetInfo{
SchedulingGates: []corev1.PodSchedulingGate{
{
Name: "example.com/gate",
},
},
},
wantPodSet: utiltesting.MakePodSet("", 1).
SchedulingGates(corev1.PodSchedulingGate{
Name: "example.com/gate",
}).
Obj(),
},
"podset with scheduling gate; info adds another": {
podSet: utiltesting.MakePodSet("", 1).
SchedulingGates(corev1.PodSchedulingGate{
Name: "example.com/gate",
}).
Obj(),
info: PodSetInfo{
SchedulingGates: []corev1.PodSchedulingGate{
{
Name: "example.com/gate2",
},
},
},
wantPodSet: utiltesting.MakePodSet("", 1).
SchedulingGates(corev1.PodSchedulingGate{
Name: "example.com/gate",
}, corev1.PodSchedulingGate{
Name: "example.com/gate2",
}).
Obj(),
wantRestoreChanges: true,
},
}

for name, tc := range cases {
Expand Down

0 comments on commit f783af8

Please sign in to comment.