diff --git a/cluster-autoscaler/utils/tpu/tpu.go b/cluster-autoscaler/utils/tpu/tpu.go index 1929f51ba4b0..27bde6253136 100644 --- a/cluster-autoscaler/utils/tpu/tpu.go +++ b/cluster-autoscaler/utils/tpu/tpu.go @@ -17,18 +17,22 @@ limitations under the License. package tpu import ( + "strings" + apiv1 "k8s.io/api/core/v1" ) const ( - // ResourceTPU is the name of the TPU resource. - ResourceTPU = "cloud-tpus.google.com/v2" + // ResourceTPUPrefix is the prefix of the TPU resource names. + ResourceTPUPrefix = "cloud-tpus.google.com/" ) func hasTPURequest(pod *apiv1.Pod) bool { for _, container := range pod.Spec.Containers { - if _, found := container.Resources.Requests[ResourceTPU]; found { - return true + for name := range container.Resources.Requests { + if strings.HasPrefix(string(name), ResourceTPUPrefix) { + return true + } } } @@ -38,7 +42,11 @@ func hasTPURequest(pod *apiv1.Pod) bool { func clearTPURequest(pod *apiv1.Pod) *apiv1.Pod { sanitized := pod.DeepCopy() for _, container := range sanitized.Spec.Containers { - delete(container.Resources.Requests, ResourceTPU) + for name := range container.Resources.Requests { + if strings.HasPrefix(string(name), ResourceTPUPrefix) { + delete(container.Resources.Requests, name) + } + } } return sanitized diff --git a/cluster-autoscaler/utils/tpu/tpu_test.go b/cluster-autoscaler/utils/tpu/tpu_test.go index b8660ae09ed3..6446a3bd3f31 100644 --- a/cluster-autoscaler/utils/tpu/tpu_test.go +++ b/cluster-autoscaler/utils/tpu/tpu_test.go @@ -26,6 +26,11 @@ import ( apiv1 "k8s.io/api/core/v1" ) +var ( + ResourceTPUV2 = ResourceTPUPrefix + "v2" + ResourceTPUPreemptibleV2 = ResourceTPUPrefix + "preemptible-v2" +) + type requests map[apiv1.ResourceName]int64 type containerSpecs []requests @@ -53,16 +58,18 @@ func TestClearTPURequests(t *testing.T) { cpuPod := testPod("cpuPod", requests{apiv1.ResourceCPU: 10}) memoryPod := testPod("memoryPod", requests{apiv1.ResourceMemory: 100}) cpuMemoryPod := testPod("cpuMemoryPod", requests{apiv1.ResourceCPU: 10, apiv1.ResourceMemory: 30}, requests{apiv1.ResourceMemory: 20}) - tpuPod := testPod("tpuPod", requests{ResourceTPU: 1}) + tpuPod := testPod("tpuPod", requests{apiv1.ResourceName(ResourceTPUV2): 1}) sanitizedTPUPod := testPod("tpuPod", requests{}) - tpuMemoryPod := testPod("tpuMemoryPod", requests{ResourceTPU: 1, apiv1.ResourceMemory: 30}, requests{ResourceTPU: 2, apiv1.ResourceMemory: 13}) + preemptibleTPUPod := testPod("preemptibleTPUPod", requests{apiv1.ResourceName(ResourceTPUPreemptibleV2): 1}) + sanitizedPreemptibleTPUPod := testPod("preemptibleTPUPod", requests{}) + tpuMemoryPod := testPod("tpuMemoryPod", requests{apiv1.ResourceName(ResourceTPUV2): 1, apiv1.ResourceMemory: 30}, requests{apiv1.ResourceName(ResourceTPUV2): 2, apiv1.ResourceMemory: 13}) sanitizedTPUMemoryPod := testPod("tpuMemoryPod", requests{apiv1.ResourceMemory: 30}, requests{apiv1.ResourceMemory: 13}) podsWithoutTPUs := []*apiv1.Pod{cpuPod, memoryPod, cpuMemoryPod} - mixedPods := []*apiv1.Pod{cpuPod, tpuPod, memoryPod} - sanitizedMixedPods := []*apiv1.Pod{cpuPod, sanitizedTPUPod, memoryPod} - podsWithTPUs := []*apiv1.Pod{tpuPod, tpuMemoryPod} - sanitizedPodsWithTPUs := []*apiv1.Pod{sanitizedTPUPod, sanitizedTPUMemoryPod} + mixedPods := []*apiv1.Pod{cpuPod, tpuPod, preemptibleTPUPod, memoryPod} + sanitizedMixedPods := []*apiv1.Pod{cpuPod, sanitizedTPUPod, sanitizedPreemptibleTPUPod, memoryPod} + podsWithTPUs := []*apiv1.Pod{tpuPod, preemptibleTPUPod, tpuMemoryPod} + sanitizedPodsWithTPUs := []*apiv1.Pod{sanitizedTPUPod, sanitizedPreemptibleTPUPod, sanitizedTPUMemoryPod} testCases := []struct { desc string