Skip to content

Commit

Permalink
TAS: support rank-based ordering for JobSet
Browse files Browse the repository at this point in the history
  • Loading branch information
mimowo committed Nov 19, 2024
1 parent ec42d93 commit c999305
Show file tree
Hide file tree
Showing 4 changed files with 630 additions and 20 deletions.
96 changes: 78 additions & 18 deletions pkg/controller/tas/topology_ungater.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package tas
import (
"context"
"errors"
"fmt"
"slices"
"strconv"
"strings"
Expand All @@ -38,6 +39,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/handler"
"sigs.k8s.io/controller-runtime/pkg/predicate"
"sigs.k8s.io/controller-runtime/pkg/reconcile"
jobset "sigs.k8s.io/jobset/api/jobset/v1alpha2"

configapi "sigs.k8s.io/kueue/apis/config/v1beta1"
kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1"
Expand All @@ -57,6 +59,11 @@ const (
ungateBatchPeriod = time.Second
)

type replicatedJobsInfo struct {
replicasCount int
jobIndexLabel string
}

var (
errPendingUngateOps = errors.New("pending ungate operations")
)
Expand Down Expand Up @@ -396,39 +403,92 @@ func assignGatedPodsToDomainsGreedy(
func readRanksIfAvailable(log logr.Logger,
psa *kueue.PodSetAssignment,
pods []*corev1.Pod) (map[int]*corev1.Pod, bool) {
if len(pods) == 0 {
// If there are no pods then we are done. We do this special check to
// ensure we have at least one pod as the code below determines if
// rank-ordering is enabled based on the first Pod.
return nil, false
}
if podIndexLabel, rjInfo := determineRanksLookup(pods[0]); podIndexLabel != nil {
result, err := readRanksForLabels(psa, pods, *podIndexLabel, rjInfo)
if err != nil {
log.Error(err, "failed to read rank information from Pods")
return nil, false
}
return result, true
}
// couldn't determine the labels to lookup the Pod ranks
return nil, false
}

func determineRanksLookup(pod *corev1.Pod) (*string, *replicatedJobsInfo) {
// Check if this is JobSet
if jobCount, _ := readIntFromLabel(pod, jobset.ReplicatedJobReplicas); jobCount != nil {
return ptr.To(batchv1.JobCompletionIndexAnnotation), &replicatedJobsInfo{
jobIndexLabel: jobset.JobIndexKey,
replicasCount: *jobCount,
}
}
// Check if this is batch/Job
if _, found := pod.Labels[batchv1.JobCompletionIndexAnnotation]; found {
return ptr.To(batchv1.JobCompletionIndexAnnotation), nil
}
return nil, nil
}

func readRanksForLabels(
psa *kueue.PodSetAssignment,
pods []*corev1.Pod,
podIndexLabel string,
rjInfo *replicatedJobsInfo,
) (map[int]*corev1.Pod, error) {
result := make(map[int]*corev1.Pod, 0)
count := int(*psa.Count)
podSetSize := int(*psa.Count)
for _, pod := range pods {
rank := readIntFromLabel(log, pod, batchv1.JobCompletionIndexAnnotation)
if rank == nil {
podIndex, err := readIntFromLabel(pod, podIndexLabel)
if err != nil {
// the Pod has no rank information - ranks cannot be used
return nil, false
return nil, err
}
if _, found := result[*rank]; found {
// there is a conflict in ranks, they cannot be used
return nil, false
rank := *podIndex
if rjInfo != nil {
jobIndex, err := readIntFromLabel(pod, rjInfo.jobIndexLabel)
if err != nil {
// the Pod has no Job index information - ranks cannot be used
return nil, err
}
singleJobSize := podSetSize / rjInfo.replicasCount
if *podIndex >= singleJobSize {
// the pod index exceeds size, this scenario is not
// supported by the rank-based ordering of pods.
return nil, fmt.Errorf("pod index %v of Pod %q exceeds the single Job size: %v", *podIndex, klog.KObj(pod), singleJobSize)
}
rank = *podIndex + *jobIndex*singleJobSize
}
if *rank >= count {
// the rank exceeds parallelism, this scenario is not supported by
// the rank-based ordering of pods.
return nil, false
if rank >= podSetSize {
// the rank exceeds the PodSet size, this scenario is not supported
// by the rank-based ordering of pods.
return nil, fmt.Errorf("rank %v of Pod %q exceeds PodSet size %v", rank, klog.KObj(pod), podSetSize)
}
result[*rank] = pod
if _, found := result[rank]; found {
// there is a conflict in ranks, they cannot be used
return nil, fmt.Errorf("conflicting rank %v found for pod %q", rank, klog.KObj(pod))
}
result[rank] = pod
}
return result, true
return result, nil
}

func readIntFromLabel(log logr.Logger, pod *corev1.Pod, labelKey string) *int {
func readIntFromLabel(pod *corev1.Pod, labelKey string) (*int, error) {
v, found := pod.Labels[labelKey]
if !found {
return nil
return nil, fmt.Errorf("no label %q for Pod %q", labelKey, klog.KObj(pod))
}
i, err := strconv.Atoi(v)
if err != nil {
log.Error(err, "failed to parse index annotation", "value", v)
return nil
return nil, fmt.Errorf("failed to parse label value %q for Pod %q", v, klog.KObj(pod))
}
return ptr.To(i)
return ptr.To(i), nil
}

func isAdmittedByTAS(w *kueue.Workload) bool {
Expand Down
Loading

0 comments on commit c999305

Please sign in to comment.