Skip to content

Commit

Permalink
add unit test for tf
Browse files Browse the repository at this point in the history
add amend function

add test exit code

add scale up and down cases
  • Loading branch information
zw0610 committed Dec 21, 2021
1 parent b51bfda commit f50d1a0
Show file tree
Hide file tree
Showing 8 changed files with 1,916 additions and 91 deletions.
116 changes: 67 additions & 49 deletions pkg/common/util/v1/testutil/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,81 +15,99 @@
package testutil

import (
"context"
"fmt"
"testing"
"k8s.io/apimachinery/pkg/types"
"time"

v1 "k8s.io/api/core/v1"
commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
. "github.com/onsi/gomega"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/tools/cache"

tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1"
"sigs.k8s.io/controller-runtime/pkg/client"
)

const (
// labels for pods and servers.
tfReplicaTypeLabel = "replica-type"
tfReplicaIndexLabel = "replica-index"
DummyContainerName = "dummy"
DummyContainerImage = "dummy/dummy:latest"
)

var (
controllerKind = tfv1.GroupVersion.WithKind(TFJobKind)
)
func NewBasePod(name string, job metav1.Object, refs []metav1.OwnerReference) *corev1.Pod {

func NewBasePod(name string, tfJob *tfv1.TFJob) *v1.Pod {
return &v1.Pod{
return &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: name,
Labels: GenLabels(tfJob.Name),
Namespace: tfJob.Namespace,
OwnerReferences: []metav1.OwnerReference{*metav1.NewControllerRef(tfJob, controllerKind)},
Labels: map[string]string{},
Namespace: job.GetNamespace(),
OwnerReferences: refs,
},
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: DummyContainerName,
Image: DummyContainerImage,
},
},
},
}
}

func NewPod(tfJob *tfv1.TFJob, typ string, index int) *v1.Pod {
pod := NewBasePod(fmt.Sprintf("%s-%d", typ, index), tfJob)
pod.Labels[tfReplicaTypeLabel] = typ
pod.Labels[tfReplicaIndexLabel] = fmt.Sprintf("%d", index)
func NewPod(job metav1.Object, typ string, index int, refs []metav1.OwnerReference) *corev1.Pod {
pod := NewBasePod(fmt.Sprintf("%s-%s-%d", job.GetName(), typ, index), job, refs)
pod.Labels[commonv1.ReplicaTypeLabelDeprecated] = typ
pod.Labels[commonv1.ReplicaTypeLabel] = typ
pod.Labels[commonv1.ReplicaIndexLabelDeprecated] = fmt.Sprintf("%d", index)
pod.Labels[commonv1.ReplicaIndexLabel] = fmt.Sprintf("%d", index)
return pod
}

// create count pods with the given phase for the given tfJob
func NewPodList(count int32, status v1.PodPhase, tfJob *tfv1.TFJob, typ string, start int32) []*v1.Pod {
pods := []*v1.Pod{}
// NewPodList create count pods with the given phase for the given tfJob
func NewPodList(count int32, status corev1.PodPhase, job metav1.Object, typ string, start int32, refs []metav1.OwnerReference) []*corev1.Pod {
pods := []*corev1.Pod{}
for i := int32(0); i < count; i++ {
newPod := NewPod(tfJob, typ, int(start+i))
newPod.Status = v1.PodStatus{Phase: status}
newPod := NewPod(job, typ, int(start+i), refs)
newPod.Status = corev1.PodStatus{Phase: status}
pods = append(pods, newPod)
}
return pods
}

func SetPodsStatuses(podIndexer cache.Indexer, tfJob *tfv1.TFJob, typ string, pendingPods, activePods, succeededPods, failedPods int32, restartCounts []int32, t *testing.T) {
func SetPodsStatusesV2(client client.Client, job metav1.Object, typ string,
pendingPods, activePods, succeededPods, failedPods int32, restartCounts []int32,
refs []metav1.OwnerReference, basicLabels map[string]string) {
timeout := 10 * time.Second
interval := 1000 * time.Millisecond
var index int32
for _, pod := range NewPodList(pendingPods, v1.PodPending, tfJob, typ, index) {
if err := podIndexer.Add(pod); err != nil {
t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err)
}
}
index += pendingPods
for i, pod := range NewPodList(activePods, v1.PodRunning, tfJob, typ, index) {
if restartCounts != nil {
pod.Status.ContainerStatuses = []v1.ContainerStatus{{RestartCount: restartCounts[i]}}
}
if err := podIndexer.Add(pod); err != nil {
t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err)
}
taskMap := map[corev1.PodPhase]int32{
corev1.PodFailed: failedPods,
corev1.PodPending: pendingPods,
corev1.PodRunning: activePods,
corev1.PodSucceeded: succeededPods,
}
index += activePods
for _, pod := range NewPodList(succeededPods, v1.PodSucceeded, tfJob, typ, index) {
if err := podIndexer.Add(pod); err != nil {
t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err)
}
}
index += succeededPods
for _, pod := range NewPodList(failedPods, v1.PodFailed, tfJob, typ, index) {
if err := podIndexer.Add(pod); err != nil {
t.Errorf("%s: unexpected error when adding pod %v", tfJob.Name, err)
ctx := context.Background()

for podPhase, desiredCount := range taskMap {
for i, pod := range NewPodList(desiredCount, podPhase, job, typ, index, refs) {
for k, v := range basicLabels {
pod.Labels[k] = v
}
_ = client.Create(ctx, pod)
launcherKey := types.NamespacedName{
Namespace: metav1.NamespaceDefault,
Name: pod.GetName(),
}
Eventually(func() error {
po := &corev1.Pod{}
if err := client.Get(ctx, launcherKey, po); err != nil {
return err
}
po.Status.Phase = podPhase
if podPhase == corev1.PodRunning && restartCounts != nil {
po.Status.ContainerStatuses = []corev1.ContainerStatus{{RestartCount: restartCounts[i]}}
}
return client.Status().Update(ctx, po)
}, timeout, interval).Should(BeNil())
}
index += desiredCount
}
}
67 changes: 45 additions & 22 deletions pkg/common/util/v1/testutil/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,48 +15,71 @@
package testutil

import (
"context"
"fmt"
"testing"

v1 "k8s.io/api/core/v1"
commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
. "github.com/onsi/gomega"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/tools/cache"
"sigs.k8s.io/controller-runtime/pkg/client"
)

tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1"
const (
DummyPortName = "dummy"
DummyPort int32 = 1221
)

func NewBaseService(name string, tfJob *tfv1.TFJob, t *testing.T) *v1.Service {
return &v1.Service{
func NewBaseService(name string, job metav1.Object, refs []metav1.OwnerReference) *corev1.Service {
return &corev1.Service{
ObjectMeta: metav1.ObjectMeta{
Name: name,
Labels: GenLabels(tfJob.Name),
Namespace: tfJob.Namespace,
OwnerReferences: []metav1.OwnerReference{*metav1.NewControllerRef(tfJob, controllerKind)},
Labels: map[string]string{},
Namespace: job.GetNamespace(),
OwnerReferences: refs,
},
Spec: corev1.ServiceSpec{
Ports: []corev1.ServicePort{
{
Name: DummyPortName,
Port: DummyPort,
},
},
},
}
}

func NewService(tfJob *tfv1.TFJob, typ string, index int, t *testing.T) *v1.Service {
service := NewBaseService(fmt.Sprintf("%s-%d", typ, index), tfJob, t)
service.Labels[tfReplicaTypeLabel] = typ
service.Labels[tfReplicaIndexLabel] = fmt.Sprintf("%d", index)
return service
func NewService(job metav1.Object, typ string, index int, refs []metav1.OwnerReference) *corev1.Service {
svc := NewBaseService(fmt.Sprintf("%s-%s-%d", job.GetName(), typ, index), job, refs)
svc.Labels[commonv1.ReplicaTypeLabelDeprecated] = typ
svc.Labels[commonv1.ReplicaTypeLabel] = typ
svc.Labels[commonv1.ReplicaIndexLabelDeprecated] = fmt.Sprintf("%d", index)
svc.Labels[commonv1.ReplicaIndexLabel] = fmt.Sprintf("%d", index)
return svc
}

// NewServiceList creates count pods with the given phase for the given tfJob
func NewServiceList(count int32, tfJob *tfv1.TFJob, typ string, t *testing.T) []*v1.Service {
services := []*v1.Service{}
func NewServiceList(count int32, job metav1.Object, typ string, refs []metav1.OwnerReference) []*corev1.Service {
services := []*corev1.Service{}
for i := int32(0); i < count; i++ {
newService := NewService(tfJob, typ, int(i), t)
newService := NewService(job, typ, int(i), refs)
services = append(services, newService)
}
return services
}

func SetServices(serviceIndexer cache.Indexer, tfJob *tfv1.TFJob, typ string, activeWorkerServices int32, t *testing.T) {
for _, service := range NewServiceList(activeWorkerServices, tfJob, typ, t) {
if err := serviceIndexer.Add(service); err != nil {
t.Errorf("unexpected error when adding service %v", err)
func SetServicesV2(client client.Client, job metav1.Object, typ string, activeWorkerServices int32,
refs []metav1.OwnerReference, basicLabels map[string]string) {
ctx := context.Background()
for _, svc := range NewServiceList(activeWorkerServices, job, typ, refs) {
for k, v := range basicLabels {
svc.Labels[k] = v
}
err := client.Create(ctx, svc)
if errors.IsAlreadyExists(err) {
return
} else {
Expect(err).To(BeNil())
}
}
}
23 changes: 7 additions & 16 deletions pkg/common/util/v1/testutil/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
package testutil

import (
"strings"
"testing"

common "github.com/kubeflow/common/pkg/apis/common/v1"
commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
tfv1 "github.com/kubeflow/training-operator/pkg/apis/tensorflow/v1"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand All @@ -43,21 +42,13 @@ var (
ControllerName = "training-operator"
)

func GenLabels(jobName string) map[string]string {
return map[string]string{
LabelGroupName: GroupName,
JobNameLabel: strings.Replace(jobName, "/", "-", -1),
DeprecatedLabelTFJobName: strings.Replace(jobName, "/", "-", -1),
}
}

func GenOwnerReference(tfjob *tfv1.TFJob) *metav1.OwnerReference {
func GenOwnerReference(job metav1.Object, apiVersion string, kind string) *metav1.OwnerReference {
boolPtr := func(b bool) *bool { return &b }
controllerRef := &metav1.OwnerReference{
APIVersion: tfv1.GroupVersion.Version,
Kind: TFJobKind,
Name: tfjob.Name,
UID: tfjob.UID,
APIVersion: apiVersion,
Kind: kind,
Name: job.GetName(),
UID: job.GetUID(),
BlockOwnerDeletion: boolPtr(true),
Controller: boolPtr(true),
}
Expand Down Expand Up @@ -85,7 +76,7 @@ func GetKey(tfJob *tfv1.TFJob, t *testing.T) string {
return key
}

func CheckCondition(tfJob *tfv1.TFJob, condition common.JobConditionType, reason string) bool {
func CheckCondition(tfJob *tfv1.TFJob, condition commonv1.JobConditionType, reason string) bool {
for _, v := range tfJob.Status.Conditions {
if v.Type == condition && v.Status == v1.ConditionTrue && v.Reason == reason {
return true
Expand Down
Loading

0 comments on commit f50d1a0

Please sign in to comment.