From f50d1a0092d735fb27270e90f4566a4f869c0853 Mon Sep 17 00:00:00 2001 From: lucaswzhang Date: Mon, 20 Dec 2021 17:32:34 +0800 Subject: [PATCH] add unit test for tf add amend function add test exit code add scale up and down cases --- pkg/common/util/v1/testutil/pod.go | 116 ++-- pkg/common/util/v1/testutil/service.go | 67 ++- pkg/common/util/v1/testutil/util.go | 23 +- pkg/controller.v1/tensorflow/job_test.go | 525 +++++++++++++++++ pkg/controller.v1/tensorflow/pod_test.go | 539 +++++++++++++++++ pkg/controller.v1/tensorflow/status_test.go | 611 ++++++++++++++++++++ pkg/controller.v1/tensorflow/suite_test.go | 52 +- pkg/controller.v1/tensorflow/util_test.go | 74 +++ 8 files changed, 1916 insertions(+), 91 deletions(-) create mode 100644 pkg/controller.v1/tensorflow/job_test.go create mode 100644 pkg/controller.v1/tensorflow/pod_test.go create mode 100644 pkg/controller.v1/tensorflow/status_test.go create mode 100644 pkg/controller.v1/tensorflow/util_test.go diff --git a/pkg/common/util/v1/testutil/pod.go b/pkg/common/util/v1/testutil/pod.go index adce63fa32..6b7b4620d7 100644 --- a/pkg/common/util/v1/testutil/pod.go +++ b/pkg/common/util/v1/testutil/pod.go @@ -15,81 +15,99 @@ package testutil import ( + "context" "fmt" - "testing" + "k8s.io/apimachinery/pkg/types" + "time" - v1 "k8s.io/api/core/v1" + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/client-go/tools/cache" - - tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1" + "sigs.k8s.io/controller-runtime/pkg/client" ) const ( - // labels for pods and servers. - tfReplicaTypeLabel = "replica-type" - tfReplicaIndexLabel = "replica-index" + DummyContainerName = "dummy" + DummyContainerImage = "dummy/dummy:latest" ) -var ( - controllerKind = tfv1.GroupVersion.WithKind(TFJobKind) -) +func NewBasePod(name string, job metav1.Object, refs []metav1.OwnerReference) *corev1.Pod { -func NewBasePod(name string, tfJob *tfv1.TFJob) *v1.Pod { - return &v1.Pod{ + return &corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ Name: name, - Labels: GenLabels(tfJob.Name), - Namespace: tfJob.Namespace, - OwnerReferences: []metav1.OwnerReference{*metav1.NewControllerRef(tfJob, controllerKind)}, + Labels: map[string]string{}, + Namespace: job.GetNamespace(), + OwnerReferences: refs, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: DummyContainerName, + Image: DummyContainerImage, + }, + }, }, } } -func NewPod(tfJob *tfv1.TFJob, typ string, index int) *v1.Pod { - pod := NewBasePod(fmt.Sprintf("%s-%d", typ, index), tfJob) - pod.Labels[tfReplicaTypeLabel] = typ - pod.Labels[tfReplicaIndexLabel] = fmt.Sprintf("%d", index) +func NewPod(job metav1.Object, typ string, index int, refs []metav1.OwnerReference) *corev1.Pod { + pod := NewBasePod(fmt.Sprintf("%s-%s-%d", job.GetName(), typ, index), job, refs) + pod.Labels[commonv1.ReplicaTypeLabelDeprecated] = typ + pod.Labels[commonv1.ReplicaTypeLabel] = typ + pod.Labels[commonv1.ReplicaIndexLabelDeprecated] = fmt.Sprintf("%d", index) + pod.Labels[commonv1.ReplicaIndexLabel] = fmt.Sprintf("%d", index) return pod } -// create count pods with the given phase for the given tfJob -func NewPodList(count int32, status v1.PodPhase, tfJob *tfv1.TFJob, typ string, start int32) []*v1.Pod { - pods := []*v1.Pod{} +// NewPodList create count pods with the given phase for the given tfJob +func NewPodList(count int32, status corev1.PodPhase, job metav1.Object, typ string, start int32, refs []metav1.OwnerReference) []*corev1.Pod { + pods := []*corev1.Pod{} for i := int32(0); i < count; i++ { - newPod := NewPod(tfJob, typ, int(start+i)) - newPod.Status = v1.PodStatus{Phase: status} + newPod := NewPod(job, typ, int(start+i), refs) + newPod.Status = corev1.PodStatus{Phase: status} pods = append(pods, newPod) } return pods } -func SetPodsStatuses(podIndexer cache.Indexer, tfJob *tfv1.TFJob, typ string, pendingPods, activePods, succeededPods, failedPods int32, restartCounts []int32, t *testing.T) { +func SetPodsStatusesV2(client client.Client, job metav1.Object, typ string, + pendingPods, activePods, succeededPods, failedPods int32, restartCounts []int32, + refs []metav1.OwnerReference, basicLabels map[string]string) { + timeout := 10 * time.Second + interval := 1000 * time.Millisecond var index int32 - for _, pod := range NewPodList(pendingPods, v1.PodPending, tfJob, typ, index) { - if err := podIndexer.Add(pod); err != nil { - t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err) - } - } - index += pendingPods - for i, pod := range NewPodList(activePods, v1.PodRunning, tfJob, typ, index) { - if restartCounts != nil { - pod.Status.ContainerStatuses = []v1.ContainerStatus{{RestartCount: restartCounts[i]}} - } - if err := podIndexer.Add(pod); err != nil { - t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err) - } + taskMap := map[corev1.PodPhase]int32{ + corev1.PodFailed: failedPods, + corev1.PodPending: pendingPods, + corev1.PodRunning: activePods, + corev1.PodSucceeded: succeededPods, } - index += activePods - for _, pod := range NewPodList(succeededPods, v1.PodSucceeded, tfJob, typ, index) { - if err := podIndexer.Add(pod); err != nil { - t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err) - } - } - index += succeededPods - for _, pod := range NewPodList(failedPods, v1.PodFailed, tfJob, typ, index) { - if err := podIndexer.Add(pod); err != nil { - t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err) + ctx := context.Background() + + for podPhase, desiredCount := range taskMap { + for i, pod := range NewPodList(desiredCount, podPhase, job, typ, index, refs) { + for k, v := range basicLabels { + pod.Labels[k] = v + } + _ = client.Create(ctx, pod) + launcherKey := types.NamespacedName{ + Namespace: metav1.NamespaceDefault, + Name: pod.GetName(), + } + Eventually(func() error { + po := &corev1.Pod{} + if err := client.Get(ctx, launcherKey, po); err != nil { + return err + } + po.Status.Phase = podPhase + if podPhase == corev1.PodRunning && restartCounts != nil { + po.Status.ContainerStatuses = []corev1.ContainerStatus{{RestartCount: restartCounts[i]}} + } + return client.Status().Update(ctx, po) + }, timeout, interval).Should(BeNil()) } + index += desiredCount } } diff --git a/pkg/common/util/v1/testutil/service.go b/pkg/common/util/v1/testutil/service.go index 2bf6448f5f..d8786a108e 100644 --- a/pkg/common/util/v1/testutil/service.go +++ b/pkg/common/util/v1/testutil/service.go @@ -15,48 +15,71 @@ package testutil import ( + "context" "fmt" - "testing" - - v1 "k8s.io/api/core/v1" + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/client-go/tools/cache" + "sigs.k8s.io/controller-runtime/pkg/client" +) - tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1" +const ( + DummyPortName = "dummy" + DummyPort int32 = 1221 ) -func NewBaseService(name string, tfJob *tfv1.TFJob, t *testing.T) *v1.Service { - return &v1.Service{ +func NewBaseService(name string, job metav1.Object, refs []metav1.OwnerReference) *corev1.Service { + return &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ Name: name, - Labels: GenLabels(tfJob.Name), - Namespace: tfJob.Namespace, - OwnerReferences: []metav1.OwnerReference{*metav1.NewControllerRef(tfJob, controllerKind)}, + Labels: map[string]string{}, + Namespace: job.GetNamespace(), + OwnerReferences: refs, + }, + Spec: corev1.ServiceSpec{ + Ports: []corev1.ServicePort{ + { + Name: DummyPortName, + Port: DummyPort, + }, + }, }, } } -func NewService(tfJob *tfv1.TFJob, typ string, index int, t *testing.T) *v1.Service { - service := NewBaseService(fmt.Sprintf("%s-%d", typ, index), tfJob, t) - service.Labels[tfReplicaTypeLabel] = typ - service.Labels[tfReplicaIndexLabel] = fmt.Sprintf("%d", index) - return service +func NewService(job metav1.Object, typ string, index int, refs []metav1.OwnerReference) *corev1.Service { + svc := NewBaseService(fmt.Sprintf("%s-%s-%d", job.GetName(), typ, index), job, refs) + svc.Labels[commonv1.ReplicaTypeLabelDeprecated] = typ + svc.Labels[commonv1.ReplicaTypeLabel] = typ + svc.Labels[commonv1.ReplicaIndexLabelDeprecated] = fmt.Sprintf("%d", index) + svc.Labels[commonv1.ReplicaIndexLabel] = fmt.Sprintf("%d", index) + return svc } // NewServiceList creates count pods with the given phase for the given tfJob -func NewServiceList(count int32, tfJob *tfv1.TFJob, typ string, t *testing.T) []*v1.Service { - services := []*v1.Service{} +func NewServiceList(count int32, job metav1.Object, typ string, refs []metav1.OwnerReference) []*corev1.Service { + services := []*corev1.Service{} for i := int32(0); i < count; i++ { - newService := NewService(tfJob, typ, int(i), t) + newService := NewService(job, typ, int(i), refs) services = append(services, newService) } return services } -func SetServices(serviceIndexer cache.Indexer, tfJob *tfv1.TFJob, typ string, activeWorkerServices int32, t *testing.T) { - for _, service := range NewServiceList(activeWorkerServices, tfJob, typ, t) { - if err := serviceIndexer.Add(service); err != nil { - t.Errorf("unexpected error when adding service %v", err) +func SetServicesV2(client client.Client, job metav1.Object, typ string, activeWorkerServices int32, + refs []metav1.OwnerReference, basicLabels map[string]string) { + ctx := context.Background() + for _, svc := range NewServiceList(activeWorkerServices, job, typ, refs) { + for k, v := range basicLabels { + svc.Labels[k] = v + } + err := client.Create(ctx, svc) + if errors.IsAlreadyExists(err) { + return + } else { + Expect(err).To(BeNil()) } } } diff --git a/pkg/common/util/v1/testutil/util.go b/pkg/common/util/v1/testutil/util.go index 5337ad04f2..14c02d06b2 100644 --- a/pkg/common/util/v1/testutil/util.go +++ b/pkg/common/util/v1/testutil/util.go @@ -15,10 +15,9 @@ package testutil import ( - "strings" "testing" - common "github.com/kubeflow/common/pkg/apis/common/v1" + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -43,21 +42,13 @@ var ( ControllerName = "training-operator" ) -func GenLabels(jobName string) map[string]string { - return map[string]string{ - LabelGroupName: GroupName, - JobNameLabel: strings.Replace(jobName, "/", "-", -1), - DeprecatedLabelTFJobName: strings.Replace(jobName, "/", "-", -1), - } -} - -func GenOwnerReference(tfjob *tfv1.TFJob) *metav1.OwnerReference { +func GenOwnerReference(job metav1.Object, apiVersion string, kind string) *metav1.OwnerReference { boolPtr := func(b bool) *bool { return &b } controllerRef := &metav1.OwnerReference{ - APIVersion: tfv1.GroupVersion.Version, - Kind: TFJobKind, - Name: tfjob.Name, - UID: tfjob.UID, + APIVersion: apiVersion, + Kind: kind, + Name: job.GetName(), + UID: job.GetUID(), BlockOwnerDeletion: boolPtr(true), Controller: boolPtr(true), } @@ -85,7 +76,7 @@ func GetKey(tfJob *tfv1.TFJob, t *testing.T) string { return key } -func CheckCondition(tfJob *tfv1.TFJob, condition common.JobConditionType, reason string) bool { +func CheckCondition(tfJob *tfv1.TFJob, condition commonv1.JobConditionType, reason string) bool { for _, v := range tfJob.Status.Conditions { if v.Type == condition && v.Status == v1.ConditionTrue && v.Reason == reason { return true diff --git a/pkg/controller.v1/tensorflow/job_test.go b/pkg/controller.v1/tensorflow/job_test.go new file mode 100644 index 0000000000..b22d77d5b2 --- /dev/null +++ b/pkg/controller.v1/tensorflow/job_test.go @@ -0,0 +1,525 @@ +// Copyright 2021 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tensorflow + +import ( + "context" + "fmt" + "time" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + "github.com/kubeflow/common/pkg/controller.v1/common" + commonutil "github.com/kubeflow/common/pkg/util" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/uuid" + "sigs.k8s.io/controller-runtime/pkg/client" + + tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1" + "github.com/kubeflow/training-operator/pkg/common/util/v1/testutil" +) + +var _ = Describe("TFJob controller", func() { + // Define utility constants for object names and testing timeouts/durations and intervals. + const ( + timeout = 10 * time.Second + interval = 1000 * time.Millisecond + ) + + Context("Test Add TFJob", func() { + It("should get the exact TFJob", func() { + By("submitting an TFJob") + + testJobName := "test-case-12" + testNamespace := metav1.NamespaceDefault + + decoyJobName := "decoy-case-34" + + ctx := context.Background() + + tfJob := testutil.NewTFJob(1, 0) + tfJob.SetName(testJobName) + tfJob.SetNamespace(testNamespace) + + decoyJob := testutil.NewTFJob(2, 3) + decoyJob.SetName(decoyJobName) + decoyJob.SetNamespace(testNamespace) + + Expect(testK8sClient.Create(ctx, tfJob)).Should(Succeed()) + Expect(testK8sClient.Create(ctx, decoyJob)).Should(Succeed()) + + key := types.NamespacedName{ + Namespace: testNamespace, + Name: testJobName, + } + Eventually(func() error { + job := &tfv1.TFJob{} + return reconciler.Get(ctx, key, job) + }, timeout, interval).Should(BeNil()) + + Expect(testK8sClient.Delete(ctx, tfJob)).Should(Succeed()) + Expect(testK8sClient.Delete(ctx, decoyJob)).Should(Succeed()) + }) + }) + + Context("Test Copy Labels and Annotation", func() { + It("should copy labels and annotation from the spec to generated Pods", func() { + ctx := context.Background() + testAnnotationKey := "annotation1" + testAnnotationVal := "1" + testLabelKey := "label1" + testLabelVal := "1" + + testJobName := "test-copy-labels-anno" + tfjob := testutil.NewTFJob(1, 0) + tfjob.SetName(testJobName) + annotations := map[string]string{ + testAnnotationKey: testAnnotationVal, + } + labels := map[string]string{ + testLabelKey: testLabelVal, + } + tfjob.Spec.TFReplicaSpecs[tfv1.TFReplicaTypeWorker].Template.Labels = labels + tfjob.Spec.TFReplicaSpecs[tfv1.TFReplicaTypeWorker].Template.Annotations = annotations + + By("submitting an TFJob with specific labels and annotations") + Expect(testK8sClient.Create(ctx, tfjob)).Should(Succeed()) + + Eventually(func() error { + pod := &corev1.Pod{} + key := types.NamespacedName{ + Namespace: metav1.NamespaceDefault, + Name: common.GenGeneralName(tfjob.Name, "worker", "0"), + } + err := testK8sClient.Get(ctx, key, pod) + if err != nil { + return err + } + + if pod.Annotations == nil { + return fmt.Errorf("annotation of %s/%s is nil", pod.GetNamespace(), pod.GetName()) + } + if val, exist := pod.Annotations[testAnnotationKey]; exist { + if val != testAnnotationVal { + return fmt.Errorf("annotation of %s not match with %s", testAnnotationKey, testAnnotationVal) + } + } else { + return fmt.Errorf("annotation %s not found", testAnnotationKey) + } + + if pod.Labels == nil { + return fmt.Errorf("label of %s/%s is nil", pod.GetNamespace(), pod.GetName()) + } + if val, exist := pod.Labels[testLabelKey]; exist { + if val != testLabelVal { + return fmt.Errorf("annotation of %s not match with %s", testLabelKey, testLabelVal) + } + } else { + return fmt.Errorf("label %s not found", testLabelKey) + } + + return nil + }, timeout, interval).Should(BeNil()) + }) + }) + + Context("Test Delete Pods and Services", func() { + It("it should clean associated Pods and Services according to clean policy", func() { + type testCase struct { + description string + tfJob *tfv1.TFJob + + pendingWorkerPods int32 + activeWorkerPods int32 + succeededWorkerPods int32 + failedWorkerPods int32 + + pendingPSPods int32 + activePSPods int32 + succeededPSPods int32 + failedPSPods int32 + + activeWorkerServices int32 + activePSServices int32 + + expectedPodRemaining int + } + + testCases := []testCase{ + { + description: "4 workers and 2 ps is running, policy is all", + tfJob: testutil.NewTFJobWithCleanPolicy(0, 4, 2, commonv1.CleanPodPolicyAll), + + pendingWorkerPods: 0, + activeWorkerPods: 4, + succeededWorkerPods: 0, + failedWorkerPods: 0, + + pendingPSPods: 0, + activePSPods: 2, + succeededPSPods: 0, + failedPSPods: 0, + + activeWorkerServices: 4, + activePSServices: 2, + + expectedPodRemaining: 0, + }, + { + description: "4 workers and 2 ps is running, policy is running", + tfJob: testutil.NewTFJobWithCleanPolicy(0, 4, 2, commonv1.CleanPodPolicyRunning), + + pendingWorkerPods: 0, + activeWorkerPods: 4, + succeededWorkerPods: 0, + failedWorkerPods: 0, + + pendingPSPods: 0, + activePSPods: 2, + succeededPSPods: 0, + failedPSPods: 0, + + activeWorkerServices: 4, + activePSServices: 2, + + expectedPodRemaining: 0, + }, + { + description: "4 workers and 2 ps is succeeded, policy is running", + tfJob: testutil.NewTFJobWithCleanPolicy(0, 4, 2, commonv1.CleanPodPolicyRunning), + + pendingWorkerPods: 0, + activeWorkerPods: 0, + succeededWorkerPods: 4, + failedWorkerPods: 0, + + pendingPSPods: 0, + activePSPods: 0, + succeededPSPods: 2, + failedPSPods: 0, + + activeWorkerServices: 4, + activePSServices: 2, + + expectedPodRemaining: 6, + }, + { + description: "4 workers and 2 ps is succeeded, policy is None", + tfJob: testutil.NewTFJobWithCleanPolicy(0, 4, 2, commonv1.CleanPodPolicyNone), + + pendingWorkerPods: 0, + activeWorkerPods: 0, + succeededWorkerPods: 4, + failedWorkerPods: 0, + + pendingPSPods: 0, + activePSPods: 0, + succeededPSPods: 2, + failedPSPods: 0, + + activeWorkerServices: 4, + activePSServices: 2, + + expectedPodRemaining: 6, + }, + } + + jobNameTemplate := "test-del-pod-svc-%d" + for idx, tc := range testCases { + By(fmt.Sprintf("preparing cases %s", tc.description)) + ctx := context.Background() + tc.tfJob.SetName(fmt.Sprintf(jobNameTemplate, idx)) + tc.tfJob.SetUID(uuid.NewUUID()) + Expect(commonutil.UpdateJobConditions(&tc.tfJob.Status, commonv1.JobSucceeded, tfJobSucceededReason, "")).Should(Succeed()) + + refs := []metav1.OwnerReference{ + *reconciler.GenOwnerReference(tc.tfJob), + } + + basicLabels := reconciler.GenLabels(tc.tfJob.GetName()) + selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ + MatchLabels: basicLabels, + }) + Expect(err).Should(BeNil()) + listOpt := client.MatchingLabelsSelector{ + Selector: selector, + } + + By("creating Services and Pods with designed phases") + testutil.SetPodsStatusesV2(testK8sClient, tc.tfJob, testutil.LabelWorker, + tc.pendingWorkerPods, tc.activeWorkerPods, tc.succeededWorkerPods, tc.failedWorkerPods, + nil, refs, basicLabels) + testutil.SetPodsStatusesV2(testK8sClient, tc.tfJob, testutil.LabelPS, + tc.pendingPSPods, tc.activePSPods, tc.succeededPSPods, tc.failedPSPods, + nil, refs, basicLabels) + + testutil.SetServicesV2(testK8sClient, tc.tfJob, testutil.LabelWorker, tc.activeWorkerServices, refs, basicLabels) + testutil.SetServicesV2(testK8sClient, tc.tfJob, testutil.LabelPS, tc.activePSServices, refs, basicLabels) + + podList := &corev1.PodList{} + Expect(testK8sClient.List(ctx, podList, listOpt)).Should(Succeed()) + Expect(len(podList.Items)).To(Equal( + int(tc.pendingPSPods + tc.activePSPods + tc.failedPSPods + tc.succeededPSPods + + tc.pendingWorkerPods + tc.activeWorkerPods + tc.failedWorkerPods + tc.succeededWorkerPods))) + + By("calling ReconcileJob") + _ = reconciler.ReconcileJobs(tc.tfJob, tc.tfJob.Spec.TFReplicaSpecs, tc.tfJob.Status, &tc.tfJob.Spec.RunPolicy) + + podList = &corev1.PodList{} + Expect(testK8sClient.List(ctx, podList, listOpt, client.InNamespace(tc.tfJob.GetNamespace()))).Should(Succeed()) + podRemainingCount := len(podList.Items) + Expect(podRemainingCount).To(Equal(tc.expectedPodRemaining)) + + svcList := &corev1.ServiceList{} + Expect(testK8sClient.List(ctx, svcList, listOpt)).Should(Succeed()) + svcRemainingCount := len(svcList.Items) + Expect(svcRemainingCount).To(Equal(tc.expectedPodRemaining)) + } + }) + }) + + Context("Test Active Deadline Seconds", func() { + It("clean desired Pods and Services according to TFJob config", func() { + type testCase struct { + description string + tfJob *tfv1.TFJob + + pendingWorkerPods int32 + activeWorkerPods int32 + succeededWorkerPods int32 + failedWorkerPods int32 + + pendingPSPods int32 + activePSPods int32 + succeededPSPods int32 + failedPSPods int32 + + activeWorkerServices int32 + activePSServices int32 + + expectedPodRemaining int + } + + ads2 := int64(2) + adsTest2 := &ads2 + testCases := []testCase{ + { + description: "4 workers and 2 ps is running, ActiveDeadlineSeconds unset", + tfJob: testutil.NewTFJobWithActiveDeadlineSeconds(0, 4, 2, nil), + + pendingWorkerPods: 0, + activeWorkerPods: 4, + succeededWorkerPods: 0, + failedWorkerPods: 0, + + pendingPSPods: 0, + activePSPods: 2, + succeededPSPods: 0, + failedPSPods: 0, + + activeWorkerServices: 4, + activePSServices: 2, + + expectedPodRemaining: 6, + }, + { + description: "4 workers and 2 ps is running, ActiveDeadlineSeconds is 2", + tfJob: testutil.NewTFJobWithActiveDeadlineSeconds(0, 4, 2, adsTest2), + + pendingWorkerPods: 0, + activeWorkerPods: 4, + succeededWorkerPods: 0, + failedWorkerPods: 0, + + pendingPSPods: 0, + activePSPods: 2, + succeededPSPods: 0, + failedPSPods: 0, + + activeWorkerServices: 4, + activePSServices: 2, + + expectedPodRemaining: 0, + }, + } + jobNameTemplate := "test-ads-%d" + for idx, tc := range testCases { + By(fmt.Sprintf("preparing cases %s", tc.description)) + ctx := context.Background() + tc.tfJob.SetName(fmt.Sprintf(jobNameTemplate, idx)) + tc.tfJob.SetUID(uuid.NewUUID()) + + refs := []metav1.OwnerReference{ + *reconciler.GenOwnerReference(tc.tfJob), + } + + basicLabels := reconciler.GenLabels(tc.tfJob.GetName()) + selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ + MatchLabels: basicLabels, + }) + Expect(err).Should(BeNil()) + listOpt := client.MatchingLabelsSelector{ + Selector: selector, + } + + By("creating Services and Pods with designed phases") + testutil.SetPodsStatusesV2(testK8sClient, tc.tfJob, testutil.LabelWorker, + tc.pendingWorkerPods, tc.activeWorkerPods, tc.succeededWorkerPods, tc.failedWorkerPods, + nil, refs, basicLabels) + testutil.SetPodsStatusesV2(testK8sClient, tc.tfJob, testutil.LabelPS, + tc.pendingPSPods, tc.activePSPods, tc.succeededPSPods, tc.failedPSPods, + nil, refs, basicLabels) + + testutil.SetServicesV2(testK8sClient, tc.tfJob, testutil.LabelWorker, tc.activeWorkerServices, refs, basicLabels) + testutil.SetServicesV2(testK8sClient, tc.tfJob, testutil.LabelPS, tc.activePSServices, refs, basicLabels) + + podList := &corev1.PodList{} + Expect(testK8sClient.List(ctx, podList, listOpt)).Should(Succeed()) + Expect(len(podList.Items)).To(Equal( + int(tc.pendingPSPods + tc.activePSPods + tc.failedPSPods + tc.succeededPSPods + + tc.pendingWorkerPods + tc.activeWorkerPods + tc.failedWorkerPods + tc.succeededWorkerPods))) + + By("waiting enough time") + now := metav1.Now() + tc.tfJob.Status.StartTime = &now + ads := tc.tfJob.Spec.RunPolicy.ActiveDeadlineSeconds + if ads != nil { + dur := time.Second * time.Duration(*ads) + time.Sleep(dur) + } + + By("calling ReconcileJob") + _ = reconciler.ReconcileJobs(tc.tfJob, tc.tfJob.Spec.TFReplicaSpecs, tc.tfJob.Status, &tc.tfJob.Spec.RunPolicy) + + podList = &corev1.PodList{} + Expect(testK8sClient.List(ctx, podList, listOpt, client.InNamespace(tc.tfJob.GetNamespace()))).Should(Succeed()) + podRemainingCount := len(podList.Items) + Expect(podRemainingCount).To(Equal(tc.expectedPodRemaining)) + + svcList := &corev1.ServiceList{} + Expect(testK8sClient.List(ctx, svcList, listOpt)).Should(Succeed()) + svcRemainingCount := len(svcList.Items) + Expect(svcRemainingCount).To(Equal(tc.expectedPodRemaining)) + } + }) + }) + + Context("Test Backoff For On Failure(", func() { + It("clean desired Pods and Services according to TFJob config", func() { + type testCase struct { + description string + tfJob *tfv1.TFJob + + pendingWorkerPods int32 + activeWorkerPods int32 + succeededWorkerPods int32 + failedWorkerPods int32 + + restartCounts []int32 + + pendingPSPods int32 + activePSPods int32 + succeededPSPods int32 + failedPSPods int32 + + activeWorkerServices int32 + activePSServices int32 + + expectedPodRemaining int + } + + backoffLimit4 := int32(4) + backoffLimitTest4 := &backoffLimit4 + testCases := []testCase{ + { + description: "4 workers each having 1 restartCount and 2 ps is running, backoffLimit 4 ", + tfJob: testutil.NewTFJobWithBackoffLimit(0, 4, 2, backoffLimitTest4), + + pendingWorkerPods: 0, + activeWorkerPods: 4, + succeededWorkerPods: 0, + failedWorkerPods: 0, + + restartCounts: []int32{1, 1, 1, 1}, + + pendingPSPods: 0, + activePSPods: 2, + succeededPSPods: 0, + failedPSPods: 0, + + activeWorkerServices: 4, + activePSServices: 2, + + expectedPodRemaining: 0, + }, + } + + jobNameTemplate := "test-bof-%d" + for idx, tc := range testCases { + By(fmt.Sprintf("preparing cases %s", tc.description)) + ctx := context.Background() + tc.tfJob.SetName(fmt.Sprintf(jobNameTemplate, idx)) + tc.tfJob.SetUID(uuid.NewUUID()) + + refs := []metav1.OwnerReference{ + *reconciler.GenOwnerReference(tc.tfJob), + } + + basicLabels := reconciler.GenLabels(tc.tfJob.GetName()) + selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ + MatchLabels: basicLabels, + }) + Expect(err).Should(BeNil()) + listOpt := client.MatchingLabelsSelector{ + Selector: selector, + } + + By("creating Services and Pods with designed phases") + testutil.SetPodsStatusesV2(testK8sClient, tc.tfJob, testutil.LabelWorker, + tc.pendingWorkerPods, tc.activeWorkerPods, tc.succeededWorkerPods, tc.failedWorkerPods, + tc.restartCounts, refs, basicLabels) + testutil.SetPodsStatusesV2(testK8sClient, tc.tfJob, testutil.LabelPS, + tc.pendingPSPods, tc.activePSPods, tc.succeededPSPods, tc.failedPSPods, + tc.restartCounts, refs, basicLabels) + + testutil.SetServicesV2(testK8sClient, tc.tfJob, testutil.LabelWorker, tc.activeWorkerServices, refs, basicLabels) + testutil.SetServicesV2(testK8sClient, tc.tfJob, testutil.LabelPS, tc.activePSServices, refs, basicLabels) + + podList := &corev1.PodList{} + Expect(testK8sClient.List(ctx, podList, listOpt)).Should(Succeed()) + Expect(len(podList.Items)).To(Equal( + int(tc.pendingPSPods + tc.activePSPods + tc.failedPSPods + tc.succeededPSPods + + tc.pendingWorkerPods + tc.activeWorkerPods + tc.failedWorkerPods + tc.succeededWorkerPods))) + + By("calling ReconcileJob") + _ = reconciler.ReconcileJobs(tc.tfJob, tc.tfJob.Spec.TFReplicaSpecs, tc.tfJob.Status, &tc.tfJob.Spec.RunPolicy) + + podList = &corev1.PodList{} + Expect(testK8sClient.List(ctx, podList, listOpt, client.InNamespace(tc.tfJob.GetNamespace()))).Should(Succeed()) + podRemainingCount := len(podList.Items) + Expect(podRemainingCount).To(Equal(tc.expectedPodRemaining)) + + svcList := &corev1.ServiceList{} + Expect(testK8sClient.List(ctx, svcList, listOpt)).Should(Succeed()) + svcRemainingCount := len(svcList.Items) + Expect(svcRemainingCount).To(Equal(tc.expectedPodRemaining)) + } + }) + }) + +}) diff --git a/pkg/controller.v1/tensorflow/pod_test.go b/pkg/controller.v1/tensorflow/pod_test.go new file mode 100644 index 0000000000..a20c12215f --- /dev/null +++ b/pkg/controller.v1/tensorflow/pod_test.go @@ -0,0 +1,539 @@ +// Copyright 2021 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tensorflow + +import ( + "context" + "fmt" + "os" + "time" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + "github.com/kubeflow/common/pkg/core" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/uuid" + "sigs.k8s.io/controller-runtime/pkg/client" + + tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1" + "github.com/kubeflow/training-operator/pkg/common/util/v1/testutil" +) + +var _ = Describe("TFJob controller", func() { + const ( + timeout = 10 * time.Second + interval = 1000 * time.Millisecond + ) + + Context("Test ClusterSpec", func() { + It("should generate desired cluster spec", func() { + type tc struct { + tfJob *tfv1.TFJob + rt string + index string + customClusterDomain string + expectedClusterSpec string + } + testCase := []tc{ + { + tfJob: testutil.NewTFJobWithNamespace(1, 0, "ns0"), + rt: "worker", + index: "0", + customClusterDomain: "", + expectedClusterSpec: "", + }, + { + tfJob: testutil.NewTFJobWithNamespace(1, 0, "ns1"), + rt: "worker", + index: "0", + customClusterDomain: "tf.training.com", + expectedClusterSpec: "", + }, + { + tfJob: testutil.NewTFJobWithNamespace(1, 1, "ns2"), + rt: "worker", + index: "0", + customClusterDomain: "tf.training.org", + expectedClusterSpec: `{"cluster":{"ps":["` + testutil.TestTFJobName + + `-ps-0.ns2.svc.tf.training.org:2222"],"worker":["` + testutil.TestTFJobName + + `-worker-0.ns2.svc.tf.training.org:2222"]},"task":{"type":"worker","index":0},"environment":"cloud"}`, + }, + { + tfJob: testutil.NewTFJobWithEvaluatorAndNamespace(1, 1, 1, "ns3"), + rt: "worker", + index: "0", + customClusterDomain: "tf.training.io", + expectedClusterSpec: `{"cluster":{"evaluator":["` + testutil.TestTFJobName + + `-evaluator-0.ns3.svc.tf.training.io:2222"],"ps":["` + testutil.TestTFJobName + + `-ps-0.ns3.svc.tf.training.io:2222"],"worker":["` + testutil.TestTFJobName + + `-worker-0.ns3.svc.tf.training.io:2222"]},"task":{"type":"worker","index":0},"environment":"cloud"}`, + }, + { + tfJob: testutil.NewTFJobWithEvaluatorAndNamespace(1, 1, 1, "ns3"), + rt: "worker", + index: "0", + customClusterDomain: "", + expectedClusterSpec: `{"cluster":{"evaluator":["` + testutil.TestTFJobName + + `-evaluator-0.ns3.svc:2222"],"ps":["` + testutil.TestTFJobName + + `-ps-0.ns3.svc:2222"],"worker":["` + testutil.TestTFJobName + + `-worker-0.ns3.svc:2222"]},"task":{"type":"worker","index":0},"environment":"cloud"}`, + }, + } + + for _, c := range testCase { + c.tfJob.SetName("test-tfjob") + c.tfJob.SetUID(uuid.NewUUID()) + _ = os.Setenv(EnvCustomClusterDomain, c.customClusterDomain) + + podTemplate := c.tfJob.Spec.TFReplicaSpecs[tfv1.TFReplicaTypeWorker].Template.DeepCopy() + + podTemplate.Name = core.GenGeneralName(c.tfJob.GetName(), c.rt, c.index) + + if podTemplate.Labels == nil { + podTemplate.Labels = map[string]string{} + } + + jobName := c.tfJob.GetName() + labels := reconciler.GenLabels(jobName) + labels[commonv1.ReplicaTypeLabelDeprecated] = c.rt + labels[commonv1.ReplicaTypeLabel] = c.rt + labels[commonv1.ReplicaIndexLabelDeprecated] = c.index + labels[commonv1.ReplicaIndexLabel] = c.index + + Expect(reconciler.SetClusterSpec(c.tfJob, podTemplate, c.rt, c.index)).Should(Succeed()) + + if c.expectedClusterSpec == "" { + Expect(len(podTemplate.Spec.Containers[0].Env)).Should(Equal(0)) + } else { + actual := podTemplate.Spec.Containers[0].Env[0].Value + reconciler.Log.Info("printing cluster spec", "expected", c.expectedClusterSpec, "actual pod", podTemplate) + Expect(actual).Should(Equal(c.expectedClusterSpec)) + } + } + }) + }) + + Context("Test IsDistributed", func() { + It("should returns correctly", func() { + type tc struct { + tfJob *tfv1.TFJob + expected bool + } + testCase := []tc{ + { + tfJob: testutil.NewTFJob(1, 0), + expected: false, + }, + { + tfJob: testutil.NewTFJob(1, 1), + expected: true, + }, + { + tfJob: testutil.NewTFJob(0, 1), + expected: false, + }, + { + tfJob: testutil.NewTFJobWithChief(1, 0), + expected: true, + }, + } + for _, c := range testCase { + Expect(isDistributed(c.tfJob)).To(Equal(c.expected)) + } + }) + }) + + Context("Test Restart Policy", func() { + It("should assign proper restart policy to pod", func() { + type tc struct { + tfJob *tfv1.TFJob + expectedRestartPolicy corev1.RestartPolicy + expectedType commonv1.ReplicaType + } + testCase := []tc{ + func() tc { + tfJob := testutil.NewTFJob(1, 0) + specRestartPolicy := commonv1.RestartPolicyExitCode + tfJob.Spec.TFReplicaSpecs[tfv1.TFReplicaTypeWorker].RestartPolicy = specRestartPolicy + return tc{ + tfJob: tfJob, + expectedRestartPolicy: corev1.RestartPolicyNever, + expectedType: tfv1.TFReplicaTypeWorker, + } + }(), + func() tc { + tfJob := testutil.NewTFJob(1, 0) + specRestartPolicy := commonv1.RestartPolicyNever + tfJob.Spec.TFReplicaSpecs[tfv1.TFReplicaTypeWorker].RestartPolicy = specRestartPolicy + return tc{ + tfJob: tfJob, + expectedRestartPolicy: corev1.RestartPolicyNever, + expectedType: tfv1.TFReplicaTypeWorker, + } + }(), + func() tc { + tfJob := testutil.NewTFJob(1, 0) + specRestartPolicy := commonv1.RestartPolicyAlways + tfJob.Spec.TFReplicaSpecs[tfv1.TFReplicaTypeWorker].RestartPolicy = specRestartPolicy + return tc{ + tfJob: tfJob, + expectedRestartPolicy: corev1.RestartPolicyAlways, + expectedType: tfv1.TFReplicaTypeWorker, + } + }(), + func() tc { + tfJob := testutil.NewTFJob(1, 0) + specRestartPolicy := commonv1.RestartPolicyOnFailure + tfJob.Spec.TFReplicaSpecs[tfv1.TFReplicaTypeWorker].RestartPolicy = specRestartPolicy + return tc{ + tfJob: tfJob, + expectedRestartPolicy: corev1.RestartPolicyOnFailure, + expectedType: tfv1.TFReplicaTypeWorker, + } + }(), + } + for _, c := range testCase { + spec := c.tfJob.Spec.TFReplicaSpecs[c.expectedType] + podTemplate := spec.Template + setRestartPolicy(&podTemplate, spec) + Expect(podTemplate.Spec.RestartPolicy).To(Equal(c.expectedRestartPolicy)) + } + }) + }) + + Context("Test Exit Code", func() { + It("should delete designated Pod", func() { + By("Creating TFJob \"test-exit-code\" with 1 worker only") + ctx := context.Background() + + tfJob := testutil.NewTFJob(1, 0) + tfJob.SetName("test-exit-code") + tfJob.SetUID(uuid.NewUUID()) + tfJob.Spec.TFReplicaSpecs[tfv1.TFReplicaTypeWorker].RestartPolicy = commonv1.RestartPolicyExitCode + + refs := []metav1.OwnerReference{ + *reconciler.GenOwnerReference(tfJob), + } + By("creating worker Pod") + pod := testutil.NewPod(tfJob, testutil.LabelWorker, 0, refs) + basicLabels := reconciler.GenLabels(tfJob.GetName()) + for k, v := range basicLabels { + pod.Labels[k] = v + } + Expect(testK8sClient.Create(ctx, pod)).Should(Succeed()) + + po := &corev1.Pod{} + key := types.NamespacedName{Namespace: metav1.NamespaceDefault, Name: pod.GetName()} + Expect(testK8sClient.Get(ctx, key, po)).Should(Succeed()) + po.Status.Phase = corev1.PodFailed + po.Spec.Containers = append(pod.Spec.Containers, corev1.Container{}) + po.Status.ContainerStatuses = append(po.Status.ContainerStatuses, corev1.ContainerStatus{ + Name: tfv1.DefaultContainerName, + State: corev1.ContainerState{ + Terminated: &corev1.ContainerStateTerminated{ + ExitCode: 130, + }, + }, + }) + Expect(testK8sClient.Status().Update(ctx, po)) + + _ = reconciler.ReconcileJobs(tfJob, tfJob.Spec.TFReplicaSpecs, tfJob.Status, &tfJob.Spec.RunPolicy) + + Eventually(func() bool { + noPod := &corev1.Pod{} + err := testK8sClient.Get(ctx, key, noPod) + if err == nil { + return false + } + return errors.IsNotFound(err) + }, timeout, interval).Should(BeTrue()) + }) + }) + + Describe("Test Scale Down", func() { + It("should delete redundant Pods", func() { + ctx := context.Background() + + tfJob := testutil.NewTFJob(2, 0) + //tfJob.SelfLink = "/api/v1/namespaces/default/tfjob/test-tfjob" + tfJob.SetName("test-scale-down") + tfJob.SetUID(uuid.NewUUID()) + tfJob.Spec.EnableDynamicWorker = true + + refs := []metav1.OwnerReference{*reconciler.GenOwnerReference(tfJob)} + + pods := []*corev1.Pod{ + testutil.NewPod(tfJob, testutil.LabelWorker, 0, refs), + testutil.NewPod(tfJob, testutil.LabelWorker, 1, refs), + testutil.NewPod(tfJob, testutil.LabelWorker, 2, refs), + } + + for i := range pods { + pod := pods[i] + for k, v := range reconciler.GenLabels(tfJob.GetName()) { + pod.Labels[k] = v + } + Expect(testK8sClient.Create(ctx, pod)).Should(Succeed()) + } + + // Ensure the created Pods are all in cache + Eventually(func() error { + podList := &corev1.PodList{} + selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ + MatchLabels: reconciler.GenLabels(tfJob.GetName()), + }) + if err != nil { + return err + } + listOpt := client.MatchingLabelsSelector{ + Selector: selector, + } + err = testK8sClient.List(ctx, podList, listOpt) + if err != nil { + return err + } + if len(podList.Items) != 3 { + return fmt.Errorf("expecting %d Pods while got %d", 3, len(podList.Items)) + } + return nil + }, timeout, interval).Should(BeNil()) + + _ = reconciler.ReconcileJobs(tfJob, tfJob.Spec.TFReplicaSpecs, tfJob.Status, &tfJob.Spec.RunPolicy) + + noKey := types.NamespacedName{ + Namespace: metav1.NamespaceDefault, + Name: pods[2].GetName(), + } + Eventually(func() bool { + noPod := &corev1.Pod{} + err := testK8sClient.Get(ctx, noKey, noPod) + if err == nil { + return false + } + return errors.IsNotFound(err) + }, timeout, interval).Should(BeTrue()) + }) + }) + + Describe("Test Scale Up", func() { + It("should create missing Pods", func() { + ctx := context.Background() + + tfJob := testutil.NewTFJob(3, 0) + tfJob.SetName("test-scale-up") + tfJob.SetUID(uuid.NewUUID()) + tfJob.Spec.EnableDynamicWorker = true + + refs := []metav1.OwnerReference{*reconciler.GenOwnerReference(tfJob)} + + pods := []*corev1.Pod{ + testutil.NewPod(tfJob, testutil.LabelWorker, 0, refs), + } + + for i := range pods { + pod := pods[i] + for k, v := range reconciler.GenLabels(tfJob.GetName()) { + pod.Labels[k] = v + } + Expect(testK8sClient.Create(ctx, pod)).Should(Succeed()) + } + + // Ensure the created Pods are all in cache + Eventually(func() error { + podList := &corev1.PodList{} + selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ + MatchLabels: reconciler.GenLabels(tfJob.GetName()), + }) + if err != nil { + return err + } + listOpt := client.MatchingLabelsSelector{ + Selector: selector, + } + err = testK8sClient.List(ctx, podList, listOpt) + if err != nil { + return err + } + if len(podList.Items) != 1 { + return fmt.Errorf("before reconciling, expecting %d Pods while got %d", 1, len(podList.Items)) + } + return nil + }, timeout, interval).Should(BeNil()) + + _ = reconciler.ReconcileJobs(tfJob, tfJob.Spec.TFReplicaSpecs, tfJob.Status, &tfJob.Spec.RunPolicy) + + // Check if there are two more Pods created + Eventually(func() error { + podList := &corev1.PodList{} + selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ + MatchLabels: reconciler.GenLabels(tfJob.GetName()), + }) + if err != nil { + return err + } + listOpt := client.MatchingLabelsSelector{ + Selector: selector, + } + err = testK8sClient.List(ctx, podList, listOpt) + if err != nil { + return err + } + if len(podList.Items) != 3 { + return fmt.Errorf("after reconciling, expecting %d Pods while got %d", 3, len(podList.Items)) + } + return nil + }, timeout, interval).Should(BeNil()) + }) + }) + + Describe("TestIsWorker0Completed", func() { + It("should match expected result", func() { + newInt32 := func(in int32) *int32 { + return &in + } + tests := []struct { + // worker failed, succeeded, running num + workers [3]int32 + tfJob *tfv1.TFJob + replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec + expected bool + expectedErr bool + }{ + { + workers: [3]int32{0, 0, 1}, + tfJob: testutil.NewTFJobV2(1, 1, 0, 0, 0), + expected: false, + expectedErr: false, + replicas: map[commonv1.ReplicaType]*commonv1.ReplicaSpec{ + tfv1.TFReplicaTypeWorker: { + Replicas: newInt32(1), + Template: testutil.NewTFReplicaSpecTemplate(), + }, + tfv1.TFReplicaTypePS: { + Replicas: newInt32(1), + Template: testutil.NewTFReplicaSpecTemplate(), + }, + }, + }, + { + workers: [3]int32{0, 1, 0}, + tfJob: testutil.NewTFJobV2(1, 0, 0, 0, 0), + expected: true, + expectedErr: false, + replicas: map[commonv1.ReplicaType]*commonv1.ReplicaSpec{ + tfv1.TFReplicaTypeWorker: { + Replicas: newInt32(1), + Template: testutil.NewTFReplicaSpecTemplate(), + }, + }, + }, + { + workers: [3]int32{0, 0, 0}, + tfJob: testutil.NewTFJobV2(0, 0, 1, 0, 0), + expected: true, + expectedErr: false, + replicas: map[commonv1.ReplicaType]*commonv1.ReplicaSpec{ + tfv1.TFReplicaTypeMaster: { + Replicas: newInt32(1), + Template: testutil.NewTFReplicaSpecTemplate(), + }, + }, + }, + { + workers: [3]int32{0, 0, 0}, + tfJob: testutil.NewTFJobV2(0, 0, 0, 1, 0), + expected: true, + expectedErr: false, + replicas: map[commonv1.ReplicaType]*commonv1.ReplicaSpec{ + tfv1.TFReplicaTypeChief: { + Replicas: newInt32(1), + Template: testutil.NewTFReplicaSpecTemplate(), + }, + }, + }, + { + workers: [3]int32{1, 1, 0}, + tfJob: testutil.NewTFJobV2(2, 0, 0, 0, 0), + expected: true, + expectedErr: false, + replicas: map[commonv1.ReplicaType]*commonv1.ReplicaSpec{ + tfv1.TFReplicaTypeWorker: { + Replicas: newInt32(2), + Template: testutil.NewTFReplicaSpecTemplate(), + }, + }, + }, + { + workers: [3]int32{1, 0, 1}, + tfJob: testutil.NewTFJobV2(2, 0, 0, 0, 0), + expected: false, + expectedErr: false, + replicas: map[commonv1.ReplicaType]*commonv1.ReplicaSpec{ + tfv1.TFReplicaTypeWorker: { + Replicas: newInt32(2), + Template: testutil.NewTFReplicaSpecTemplate(), + }, + }, + }, + } + + jobNameTemplate := "test-worker0-complete-%d" + for i, tt := range tests { + tt.tfJob.SetName(fmt.Sprintf(jobNameTemplate, i)) + tt.tfJob.SetUID(uuid.NewUUID()) + // only related to worker status + initializeReplicaStatuses(&tt.tfJob.Status, tfv1.TFReplicaTypeWorker) + // set status and add pod to indexer + setStatusForTest(tt.tfJob, tfv1.TFReplicaTypeWorker, tt.workers[0], tt.workers[1], tt.workers[2], false, true, testK8sClient) + + // Adding this section to make sure all pods are created and cached + Eventually(func() error { + podList := &corev1.PodList{} + selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ + MatchLabels: reconciler.GenLabels(tt.tfJob.GetName()), + }) + if err != nil { + return err + } + listOpt := client.MatchingLabelsSelector{ + Selector: selector, + } + err = testK8sClient.List(context.Background(), podList, listOpt) + if err != nil { + return nil + } + totalExpectedPodCount := tt.workers[0] + tt.workers[1] + tt.workers[2] + if len(podList.Items) != int(totalExpectedPodCount) { + return fmt.Errorf("pod number (%d) for %s not match for expected pod number %d", + len(podList.Items), tt.tfJob.GetName(), totalExpectedPodCount) + } + return nil + }, timeout, interval).Should(BeNil()) + + got, err := reconciler.IsWorker0Completed(tt.tfJob, tt.replicas) + + if err != nil { + Expect(err).To(Equal(tt.expectedErr)) + } else { + Expect(got).To(Equal(tt.expected)) + } + } + }) + }) +}) diff --git a/pkg/controller.v1/tensorflow/status_test.go b/pkg/controller.v1/tensorflow/status_test.go new file mode 100644 index 0000000000..10d9c02869 --- /dev/null +++ b/pkg/controller.v1/tensorflow/status_test.go @@ -0,0 +1,611 @@ +// Copyright 2021 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tensorflow + +import ( + "context" + "fmt" + "time" + + "k8s.io/apimachinery/pkg/util/uuid" + + "k8s.io/apimachinery/pkg/types" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + "github.com/kubeflow/common/pkg/util" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client" + + tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1" + "github.com/kubeflow/training-operator/pkg/common/util/v1/testutil" +) + +var _ = Describe("TFJob controller", func() { + // Define utility constants for object names and testing timeouts/durations and intervals. + const ( + timeout = 10 * time.Second + interval = 1000 * time.Millisecond + ) + + Context("Test Failed", func() { + It("should update TFJob with failed status", func() { + By("creating a TFJob with replicaStatues initialized") + tfJob := testutil.NewTFJob(3, 0) + initializeReplicaStatuses(&tfJob.Status, tfv1.TFReplicaTypeWorker) + + By("prepare pod") + refs := []metav1.OwnerReference{ + *reconciler.GenOwnerReference(tfJob), + } + pod := testutil.NewBasePod("pod", tfJob, refs) + pod.Status.Phase = v1.PodFailed + + By("update job replica statuses") + updateJobReplicaStatuses(&tfJob.Status, tfv1.TFReplicaTypeWorker, pod) + Expect(tfJob.Status.ReplicaStatuses[tfv1.TFReplicaTypeWorker].Failed).Should(Equal(int32(1))) + + By("update job status") + Expect(reconciler.UpdateJobStatus(tfJob, tfJob.Spec.TFReplicaSpecs, &tfJob.Status)).To(Succeed()) + + By("finding failed job status") + found := false + for _, condition := range tfJob.Status.Conditions { + if condition.Type == commonv1.JobFailed { + found = true + } + } + Expect(found).To(BeTrue()) + }) + }) + + Context("Test Status", func() { + It("should update TFJob with desired status", func() { + type testCase struct { + description string + tfJob *tfv1.TFJob + + expectedFailedPS int32 + expectedSucceededPS int32 + expectedActivePS int32 + + expectedFailedWorker int32 + expectedSucceededWorker int32 + expectedActiveWorker int32 + + expectedFailedChief int32 + expectedSucceededChief int32 + expectedActiveChief int32 + + restart bool + worker0Completed bool + + expectedType commonv1.JobConditionType + } + + testCases := []testCase{ + { + description: "Chief worker is succeeded", + tfJob: testutil.NewTFJobWithChief(1, 0), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 0, + expectedFailedWorker: 0, + expectedSucceededWorker: 1, + expectedActiveWorker: 0, + expectedFailedChief: 0, + expectedSucceededChief: 1, + expectedActiveChief: 0, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobSucceeded, + }, + { + description: "Chief worker is running", + tfJob: testutil.NewTFJobWithChief(1, 0), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 0, + expectedFailedWorker: 0, + expectedSucceededWorker: 0, + expectedActiveWorker: 0, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 1, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobRunning, + }, + { + description: "Chief worker is failed", + tfJob: testutil.NewTFJobWithChief(1, 0), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 0, + expectedFailedWorker: 0, + expectedSucceededWorker: 0, + expectedActiveWorker: 0, + expectedFailedChief: 1, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobFailed, + }, + { + description: "(No chief worker) Worker is failed", + tfJob: testutil.NewTFJob(1, 0), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 0, + expectedFailedWorker: 1, + expectedSucceededWorker: 0, + expectedActiveWorker: 0, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobFailed, + }, + { + description: "(No chief worker) Worker is succeeded", + tfJob: testutil.NewTFJob(1, 0), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 0, + expectedFailedWorker: 0, + expectedSucceededWorker: 1, + expectedActiveWorker: 0, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobSucceeded, + }, + { + description: "(No chief worker) Worker is running", + tfJob: testutil.NewTFJob(1, 0), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 0, + expectedFailedWorker: 0, + expectedSucceededWorker: 0, + expectedActiveWorker: 1, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobRunning, + }, + { + description: "(No chief worker) 2 workers are succeeded, 2 workers are active", + tfJob: testutil.NewTFJob(4, 2), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 2, + expectedFailedWorker: 0, + expectedSucceededWorker: 2, + expectedActiveWorker: 2, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobRunning, + }, + { + description: "(No chief worker) 2 workers are running, 2 workers are failed", + tfJob: testutil.NewTFJob(4, 2), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 2, + expectedFailedWorker: 2, + expectedSucceededWorker: 0, + expectedActiveWorker: 2, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobFailed, + }, + { + description: "(No chief worker) 2 workers are succeeded, 2 workers are failed", + tfJob: testutil.NewTFJob(4, 2), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 2, + expectedFailedWorker: 2, + expectedSucceededWorker: 2, + expectedActiveWorker: 0, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobFailed, + }, + { + description: "(No chief worker) worker-0 are succeeded, 3 workers are active", + tfJob: testutil.NewTFJob(4, 2), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 2, + expectedFailedWorker: 0, + expectedSucceededWorker: 1, + expectedActiveWorker: 3, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: true, + expectedType: commonv1.JobSucceeded, + }, + { + description: "(No chief worker, successPolicy: AllWorkers) worker-0 are succeeded, 3 workers are active", + tfJob: testutil.NewTFJobWithSuccessPolicy(4, 0, tfv1.SuccessPolicyAllWorkers), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 0, + expectedFailedWorker: 0, + expectedSucceededWorker: 1, + expectedActiveWorker: 3, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: true, + expectedType: commonv1.JobRunning, + }, + { + description: "(No chief worker, successPolicy: AllWorkers) 4 workers are succeeded", + tfJob: testutil.NewTFJobWithSuccessPolicy(4, 0, tfv1.SuccessPolicyAllWorkers), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 0, + expectedFailedWorker: 0, + expectedSucceededWorker: 4, + expectedActiveWorker: 0, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: true, + expectedType: commonv1.JobSucceeded, + }, + { + description: "(No chief worker, successPolicy: AllWorkers) worker-0 is succeeded, 2 workers are running, 1 worker is failed", + tfJob: testutil.NewTFJobWithSuccessPolicy(4, 0, tfv1.SuccessPolicyAllWorkers), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 0, + expectedFailedWorker: 1, + expectedSucceededWorker: 1, + expectedActiveWorker: 2, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: true, + expectedType: commonv1.JobFailed, + }, + { + description: "Chief is running, workers are failed", + tfJob: testutil.NewTFJobWithChief(4, 2), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 2, + expectedFailedWorker: 4, + expectedSucceededWorker: 0, + expectedActiveWorker: 0, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 1, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobRunning, + }, + { + description: "Chief is running, workers are succeeded", + tfJob: testutil.NewTFJobWithChief(4, 2), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 2, + expectedFailedWorker: 0, + expectedSucceededWorker: 4, + expectedActiveWorker: 0, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 1, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobRunning, + }, + { + description: "Chief is running, a PS is failed", + tfJob: testutil.NewTFJobWithChief(4, 2), + expectedFailedPS: 1, + expectedSucceededPS: 0, + expectedActivePS: 1, + expectedFailedWorker: 0, + expectedSucceededWorker: 4, + expectedActiveWorker: 0, + expectedFailedChief: 0, + expectedSucceededChief: 0, + expectedActiveChief: 1, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobFailed, + }, + { + description: "Chief is failed, workers are succeeded", + tfJob: testutil.NewTFJobWithChief(4, 2), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 2, + expectedFailedWorker: 0, + expectedSucceededWorker: 4, + expectedActiveWorker: 0, + expectedFailedChief: 1, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobFailed, + }, + { + description: "Chief is succeeded, workers are failed", + tfJob: testutil.NewTFJobWithChief(4, 2), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 2, + expectedFailedWorker: 4, + expectedSucceededWorker: 0, + expectedActiveWorker: 0, + expectedFailedChief: 0, + expectedSucceededChief: 1, + expectedActiveChief: 0, + restart: false, + worker0Completed: false, + expectedType: commonv1.JobSucceeded, + }, + { + description: "Chief is failed and restarting", + tfJob: testutil.NewTFJobWithChief(4, 2), + expectedFailedPS: 0, + expectedSucceededPS: 0, + expectedActivePS: 2, + expectedFailedWorker: 4, + expectedSucceededWorker: 0, + expectedActiveWorker: 0, + expectedFailedChief: 1, + expectedSucceededChief: 0, + expectedActiveChief: 0, + restart: true, + worker0Completed: false, + expectedType: commonv1.JobRestarting, + }, + } + + jobNameTemplate := "test-status-%d" + for i, c := range testCases { + reconciler.Log.Info("testing case", "description", c.description) + c.tfJob.SetName(fmt.Sprintf(jobNameTemplate, i)) + c.tfJob.SetUID(uuid.NewUUID()) + + initializeReplicaStatuses(&c.tfJob.Status, tfv1.TFReplicaTypeWorker) + initializeReplicaStatuses(&c.tfJob.Status, tfv1.TFReplicaTypeChief) + initializeReplicaStatuses(&c.tfJob.Status, tfv1.TFReplicaTypePS) + + setStatusForTest(c.tfJob, tfv1.TFReplicaTypePS, c.expectedFailedPS, c.expectedSucceededPS, c.expectedActivePS, c.restart, c.worker0Completed, testK8sClient) + setStatusForTest(c.tfJob, tfv1.TFReplicaTypeWorker, c.expectedFailedWorker, c.expectedSucceededWorker, c.expectedActiveWorker, c.restart, c.worker0Completed, testK8sClient) + setStatusForTest(c.tfJob, tfv1.TFReplicaTypeChief, c.expectedFailedChief, c.expectedSucceededChief, c.expectedActiveChief, c.restart, c.worker0Completed, testK8sClient) + + // Adding this section to make sure all pods are created and cached + Eventually(func() error { + podList := &corev1.PodList{} + basicLabels := reconciler.GenLabels(c.tfJob.GetName()) + selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ + MatchLabels: basicLabels, + }) + if err != nil { + return err + } + listOpt := client.MatchingLabelsSelector{ + Selector: selector, + } + err = testK8sClient.List(context.Background(), podList, listOpt) + if err != nil { + return nil + } + totalExpectedPodCount := c.expectedFailedPS + c.expectedSucceededPS + c.expectedActivePS + + c.expectedFailedWorker + c.expectedSucceededWorker + c.expectedActiveWorker + + c.expectedFailedChief + c.expectedSucceededChief + c.expectedActiveChief + if len(podList.Items) != int(totalExpectedPodCount) { + return fmt.Errorf("pod number (%d) for %s not match for expected pod number %d", + len(podList.Items), c.tfJob.GetName(), totalExpectedPodCount) + } + return nil + }, timeout, interval).Should(BeNil()) + + _ = reconciler.ReconcileJobs(c.tfJob, c.tfJob.Spec.TFReplicaSpecs, c.tfJob.Status, &c.tfJob.Spec.RunPolicy) + + Expect(filterOutConditionTest(c.tfJob.Status)).Should(Succeed()) + + reconciler.Log.Info("checking status", "tfJob.Status", c.tfJob.Status) + found := false + for _, condition := range c.tfJob.Status.Conditions { + if condition.Type == c.expectedType { + found = true + } + } + Expect(found).To(BeTrue()) + reconciler.Log.Info("passed!", + "job name", c.tfJob.GetName(), "job description", c.description) + } + }) + }) +}) + +func setStatusForTest(tfJob *tfv1.TFJob, rtype commonv1.ReplicaType, failed, succeeded, active int32, restart bool, worker0Completed bool, client client.Client) { + if restart == true { + tfJob.Spec.TFReplicaSpecs[rtype].RestartPolicy = commonv1.RestartPolicyExitCode + } + + basicLabels := reconciler.GenLabels(tfJob.GetName()) + + const ( + timeout = 10 * time.Second + interval = 1000 * time.Millisecond + ) + + ctx := context.Background() + + var typ string + switch rtype { + case tfv1.TFReplicaTypeWorker: + typ = testutil.LabelWorker + case tfv1.TFReplicaTypePS: + typ = testutil.LabelPS + case tfv1.TFReplicaTypeChief: + typ = testutil.LabelChief + default: + fmt.Println("wrong type") + } + refs := []metav1.OwnerReference{ + *reconciler.GenOwnerReference(tfJob), + } + + var i int32 + index := 0 + for i = 0; i < succeeded; i++ { + pod := testutil.NewPod(tfJob, typ, index, refs) + for k, v := range basicLabels { + pod.Labels[k] = v + } + po := &corev1.Pod{} + _ = client.Create(ctx, pod) + key := genKeyFromJob(pod) + Eventually(func() error { + if err := client.Get(ctx, key, po); err != nil { + return err + } + + po.Status.Phase = corev1.PodSucceeded + if worker0Completed == true && rtype == tfv1.TFReplicaTypeWorker && index == 0 { + po.Status.ContainerStatuses = []corev1.ContainerStatus{ + { + Name: tfv1.DefaultContainerName, + State: corev1.ContainerState{ + Terminated: &corev1.ContainerStateTerminated{ + ExitCode: int32(0), // exit with 0 + }, + }, + }, + } + } + + return client.Status().Update(ctx, po) + }, timeout, interval).Should(BeNil()) + + updateJobReplicaStatuses(&tfJob.Status, rtype, po) + + index++ + } + for i = 0; i < failed; i++ { + pod := testutil.NewPod(tfJob, typ, index, refs) + for k, v := range basicLabels { + pod.Labels[k] = v + } + po := &corev1.Pod{} + _ = client.Create(ctx, pod) + key := genKeyFromJob(pod) + Eventually(func() error { + + if err := client.Get(ctx, key, po); err != nil { + return err + } + + po.Status.Phase = corev1.PodFailed + if restart == true { + if po.Status.ContainerStatuses == nil { + po.Status.ContainerStatuses = []corev1.ContainerStatus{ + { + Name: tfv1.DefaultContainerName, + State: corev1.ContainerState{ + Terminated: &corev1.ContainerStateTerminated{ + ExitCode: int32(130), // 130 is a retryable code + }, + }, + }, + } + } + } + + return client.Status().Update(ctx, po) + }, timeout, interval).Should(BeNil()) + + updateJobReplicaStatuses(&tfJob.Status, rtype, po) + index++ + } + for i = 0; i < active; i++ { + pod := testutil.NewPod(tfJob, typ, index, refs) + for k, v := range basicLabels { + pod.Labels[k] = v + } + po := &corev1.Pod{} + Expect(client.Create(ctx, pod)).Should(Succeed()) + key := genKeyFromJob(pod) + Eventually(func() error { + if err := client.Get(ctx, key, po); err != nil { + return err + } + + po.Status.Phase = corev1.PodRunning + + return client.Status().Update(ctx, po) + }, timeout, interval).Should(BeNil()) + + updateJobReplicaStatuses(&tfJob.Status, rtype, po) + index++ + } +} + +func genKeyFromJob(job client.Object) types.NamespacedName { + ns := metav1.NamespaceDefault + if job.GetNamespace() != "" { + ns = job.GetNamespace() + } + return types.NamespacedName{ + Namespace: ns, + Name: job.GetName(), + } +} + +func filterOutConditionTest(status commonv1.JobStatus) error { + flag := util.IsFailed(status) || util.IsSucceeded(status) + for _, condition := range status.Conditions { + if flag && condition.Type == commonv1.JobRunning && condition.Status == corev1.ConditionTrue { + return fmt.Errorf("error condition status when succeeded or failed") + } + } + return nil +} diff --git a/pkg/controller.v1/tensorflow/suite_test.go b/pkg/controller.v1/tensorflow/suite_test.go index 640c6284c5..0be8a261fb 100644 --- a/pkg/controller.v1/tensorflow/suite_test.go +++ b/pkg/controller.v1/tensorflow/suite_test.go @@ -15,8 +15,14 @@ package tensorflow import ( + "context" + "fmt" + corev1 "k8s.io/api/core/v1" "path/filepath" "testing" + "time" + + ctrl "sigs.k8s.io/controller-runtime" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -34,8 +40,13 @@ import ( // These tests use Ginkgo (BDD-style Go testing framework). Refer to // http://onsi.github.io/ginkgo/ to learn more about Ginkgo. -var k8sClient client.Client -var testEnv *envtest.Environment +var ( + testK8sClient client.Client + testEnv *envtest.Environment + testCtx context.Context + testCancel context.CancelFunc + reconciler *TFJobReconciler +) func TestAPIs(t *testing.T) { RegisterFailHandler(Fail) @@ -46,8 +57,14 @@ func TestAPIs(t *testing.T) { } var _ = BeforeSuite(func() { + const ( + timeout = 10 * time.Second + interval = 1000 * time.Millisecond + ) logf.SetLogger(zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true))) + testCtx, testCancel = context.WithCancel(context.TODO()) + By("bootstrapping test environment") testEnv = &envtest.Environment{ CRDDirectoryPaths: []string{filepath.Join("..", "..", "..", "manifests", "base", "crds")}, @@ -63,14 +80,41 @@ var _ = BeforeSuite(func() { //+kubebuilder:scaffold:scheme - k8sClient, err = client.New(cfg, client.Options{Scheme: scheme.Scheme}) + testK8sClient, err = client.New(cfg, client.Options{Scheme: scheme.Scheme}) + Expect(err).NotTo(HaveOccurred()) + Expect(testK8sClient).NotTo(BeNil()) + + mgr, err := ctrl.NewManager(cfg, ctrl.Options{ + MetricsBindAddress: "0", + }) Expect(err).NotTo(HaveOccurred()) - Expect(k8sClient).NotTo(BeNil()) + reconciler = NewReconciler(mgr, false) + Expect(reconciler.SetupWithManager(mgr)).NotTo(HaveOccurred()) + + go func() { + defer GinkgoRecover() + err = mgr.Start(testCtx) + Expect(err).ToNot(HaveOccurred(), "failed to run manager") + }() + + // This step is introduced to make sure cache starts before running any tests + Eventually(func() error { + nsList := &corev1.NamespaceList{} + if err := testK8sClient.List(context.Background(), nsList); err != nil { + return err + } else if len(nsList.Items) < 1 { + return fmt.Errorf("cannot get at lease one namespace, got %d", len(nsList.Items)) + } + return nil + }, timeout, interval).Should(BeNil()) }, 60) var _ = AfterSuite(func() { By("tearing down the test environment") + testCancel() + // Give 5 seconds to stop all tests + time.Sleep(5 * time.Second) err := testEnv.Stop() Expect(err).NotTo(HaveOccurred()) }) diff --git a/pkg/controller.v1/tensorflow/util_test.go b/pkg/controller.v1/tensorflow/util_test.go new file mode 100644 index 0000000000..e3d573cccb --- /dev/null +++ b/pkg/controller.v1/tensorflow/util_test.go @@ -0,0 +1,74 @@ +// Copyright 2021 The Kubeflow Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tensorflow + +import ( + "testing" + + commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" + + tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/uuid" +) + +func TestGenOwnerReference(t *testing.T) { + testName := "test-tfjob" + testUID := uuid.NewUUID() + tfJob := &tfv1.TFJob{ + ObjectMeta: metav1.ObjectMeta{ + Name: testName, + UID: testUID, + }, + } + + ref := reconciler.GenOwnerReference(tfJob) + if ref.UID != testUID { + t.Errorf("Expected UID %s, got %s", testUID, ref.UID) + } + if ref.Name != testName { + t.Errorf("Expected Name %s, got %s", testName, ref.Name) + } + if ref.APIVersion != tfv1.SchemeGroupVersion.String() { + t.Errorf("Expected APIVersion %s, got %s", tfv1.SchemeGroupVersion.String(), ref.APIVersion) + } +} + +func TestGenLabels(t *testing.T) { + testJobName := "test/key" + expctedVal := "test-key" + + labels := reconciler.GenLabels(testJobName) + jobNameLabel := commonv1.JobNameLabel + JobNameLabelDeprecated := commonv1.JobNameLabelDeprecated + + if labels[jobNameLabel] != expctedVal { + t.Errorf("Expected %s %s, got %s", jobNameLabel, expctedVal, jobNameLabel) + } + + if labels[JobNameLabelDeprecated] != expctedVal { + t.Errorf("Expected %s %s, got %s", JobNameLabelDeprecated, expctedVal, JobNameLabelDeprecated) + } + + if labels[commonv1.GroupNameLabelDeprecated] != tfv1.GroupVersion.Group { + t.Errorf("Expected %s %s, got %s", commonv1.GroupNameLabelDeprecated, tfv1.GroupVersion.Group, + labels[commonv1.GroupNameLabelDeprecated]) + } + + if labels[commonv1.OperatorNameLabel] != controllerName { + t.Errorf("Expected %s %s, got %s", commonv1.OperatorNameLabel, controllerName, + labels[commonv1.OperatorNameLabel]) + } +}