Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TPU Multi-Host Support #1913

Merged
merged 16 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions ray-operator/controllers/ray/raycluster_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -706,14 +706,13 @@ func (r *RayClusterReconciler) reconcilePods(ctx context.Context, instance *rayv
return err
}

// Delete unhealthy worker Pods
// Delete unhealthy worker Pods.
deletedWorkers := make(map[string]struct{})
deleted := struct{}{}
numDeletedUnhealthyWorkerPods := 0
for _, workerPod := range workerPods.Items {
shouldDelete, reason := shouldDeletePod(workerPod, rayv1.WorkerNode)
r.Log.Info("reconcilePods", "worker Pod", workerPod.Name, "shouldDelete", shouldDelete, "reason", reason)
// TODO (kevin85421): We may need to allow users to configure how many `Failed` or `Succeeded` Pods should be kept for debugging purposes.
if shouldDelete {
numDeletedUnhealthyWorkerPods++
deletedWorkers[workerPod.Name] = deleted
Expand Down Expand Up @@ -758,7 +757,10 @@ func (r *RayClusterReconciler) reconcilePods(ctx context.Context, instance *rayv
runningPods.Items = append(runningPods.Items, pod)
}
}
diff := workerReplicas - int32(len(runningPods.Items))
// A replica can contain multiple hosts, so we need to calculate this based on the number of hosts per replica.
numExpectedPods := workerReplicas * worker.NumOfHosts
diff := numExpectedPods - int32(len(runningPods.Items))

r.Log.Info("reconcilePods", "workerReplicas", workerReplicas, "runningPods", len(runningPods.Items), "diff", diff)

if diff > 0 {
Expand Down
168 changes: 168 additions & 0 deletions ray-operator/controllers/ray/raycluster_controller_fake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ var (
headGroupNameStr string
groupNameStr string
expectReplicaNum int32
expectNumOfHostNum int32
testPods []runtime.Object
testPodsNoHeadIP []runtime.Object
testRayCluster *rayv1.RayCluster
Expand All @@ -77,6 +78,7 @@ func setupTest(t *testing.T) {
headGroupNameStr = "head-group"
groupNameStr = "small-group"
expectReplicaNum = 3
expectNumOfHostNum = 1
workersToDelete = []string{"pod1", "pod2"}
headNodeIP = "1.2.3.4"
testPods = []runtime.Object{
Expand Down Expand Up @@ -324,6 +326,7 @@ func setupTest(t *testing.T) {
Replicas: pointer.Int32(expectReplicaNum),
MinReplicas: pointer.Int32(0),
MaxReplicas: pointer.Int32(10000),
NumOfHosts: expectNumOfHostNum,
GroupName: groupNameStr,
RayStartParams: map[string]string{
"port": "6379",
Expand Down Expand Up @@ -2340,6 +2343,171 @@ func TestReconcile_Replicas_Optional(t *testing.T) {
}
}

func TestReconcile_Multihost_Replicas(t *testing.T) {
setupTest(t)

// This test makes some assumptions about the testRayCluster object.
// (1) 1 workerGroup (2) disable autoscaling
assert.Equal(t, 1, len(testRayCluster.Spec.WorkerGroupSpecs), "This test assumes only one worker group.")

// Disable autoscaling so that the random Pod deletion is enabled.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Random pod deletion is a pretty bad behavior for a multi-host setup. I am considering disabling it.

// Set `NumOfHosts` to 4 to specify multi-host group
testRayCluster.Spec.EnableInTreeAutoscaling = pointer.Bool(false)
testRayCluster.Spec.WorkerGroupSpecs[0].ScaleStrategy.WorkersToDelete = []string{}
testRayCluster.Spec.WorkerGroupSpecs[0].NumOfHosts = 4

tests := map[string]struct {
replicas *int32
minReplicas *int32
maxReplicas *int32
desiredReplicas int
numOfHosts int
}{
"Replicas is nil": {
// If `Replicas` is nil, the controller will set the desired state of the workerGroup to `MinReplicas`*`NumOfHosts` Pods.
replicas: nil,
minReplicas: pointer.Int32(1),
maxReplicas: pointer.Int32(10000),
desiredReplicas: 1,
numOfHosts: 4,
},
"Replicas is smaller than MinReplicas": {
// If `Replicas` is smaller than `MinReplicas`, the controller will set the desired state of the workerGroup to `MinReplicas`*`NumOfHosts` Pods.
replicas: pointer.Int32(0),
minReplicas: pointer.Int32(1),
maxReplicas: pointer.Int32(10000),
desiredReplicas: 1,
numOfHosts: 4,
},
"Replicas is larger than MaxReplicas": {
// If `Replicas` is larger than `MaxReplicas`, the controller will set the desired state of the workerGroup to `MaxReplicas`*`NumOfHosts` Pods.
replicas: pointer.Int32(4),
minReplicas: pointer.Int32(1),
maxReplicas: pointer.Int32(3),
desiredReplicas: 3,
numOfHosts: 4,
},
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
cluster := testRayCluster.DeepCopy()
cluster.Spec.WorkerGroupSpecs[0].Replicas = tc.replicas
cluster.Spec.WorkerGroupSpecs[0].MinReplicas = tc.minReplicas
cluster.Spec.WorkerGroupSpecs[0].MaxReplicas = tc.maxReplicas

// This test makes some assumptions about the testPods object.
// `testPods` contains 6 pods, including 1 head pod and 5 worker pods.
assert.Equal(t, 6, len(testPods), "This test assumes the testPods object contains 6 pods.")
numHeadPods := 1
oldNumWorkerPods := len(testPods) - numHeadPods

// Initialize a fake client with newScheme and runtimeObjects.
fakeClient := clientFake.NewClientBuilder().WithRuntimeObjects(testPods...).Build()
ctx := context.Background()

// Get the pod list from the fake client.
podList := corev1.PodList{}
err := fakeClient.List(ctx, &podList, client.InNamespace(namespaceStr))
assert.Nil(t, err, "Fail to get pod list")
assert.Equal(t, oldNumWorkerPods+numHeadPods, len(podList.Items), "Init pod list len is wrong")

// Initialize a new RayClusterReconciler.
testRayClusterReconciler := &RayClusterReconciler{
Client: fakeClient,
Recorder: &record.FakeRecorder{},
Scheme: scheme.Scheme,
Log: ctrl.Log.WithName("controllers").WithName("RayCluster"),
}

// Since the desired state of the workerGroup is 1 replica,
// the controller will delete 4 worker Pods.
err = testRayClusterReconciler.reconcilePods(ctx, cluster)
assert.Nil(t, err, "Fail to reconcile Pods")

err = fakeClient.List(ctx, &podList, &client.ListOptions{
LabelSelector: workerSelector,
Namespace: namespaceStr,
})
assert.Nil(t, err, "Fail to get pod list after reconcile")
assert.Equal(t, tc.desiredReplicas*tc.numOfHosts, len(podList.Items),
"Pod list is wrong after reconcile expect %d actual %d", tc.desiredReplicas*tc.numOfHosts, len(podList.Items))
})
}
}

func TestReconcile_NumOfHosts(t *testing.T) {
setupTest(t)

// This test makes some assumptions about the testRayCluster object.
// (1) 1 workerGroup (2) disable autoscaling
assert.Equal(t, 1, len(testRayCluster.Spec.WorkerGroupSpecs), "This test assumes only one worker group.")

// Disable autoscaling so that the random Pod deletion is enabled.
// Set `Replicas` to 1 and clear `WorkersToDelete`
testRayCluster.Spec.EnableInTreeAutoscaling = pointer.Bool(false)
testRayCluster.Spec.WorkerGroupSpecs[0].ScaleStrategy.WorkersToDelete = []string{}
testRayCluster.Spec.WorkerGroupSpecs[0].Replicas = pointer.Int32(1)

tests := map[string]struct {
replicas *int32
numOfHosts int32
}{
"NumOfHosts is 1": {
// If `NumOfHosts` is 1, the controller will set the desired state of the workerGroup to `Replicas` Pods.
replicas: pointer.Int32(1),
numOfHosts: 1,
},
"NumOfHosts is larger than 1": {
// If `NumOfHosts` is larger than 1, the controller will set the desired state of the workerGroup to `NumOfHosts` Pods.
replicas: pointer.Int32(1),
numOfHosts: 4,
},
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
cluster := testRayCluster.DeepCopy()
cluster.Spec.WorkerGroupSpecs[0].NumOfHosts = tc.numOfHosts

// Initialize a fake client with newScheme and runtimeObjects.
// The fake client will start with 1 head pod and 0 worker pods.
fakeClient := clientFake.NewClientBuilder().WithRuntimeObjects(testPods[0]).Build()
ctx := context.Background()

// Get the pod list from the fake client.
podList := corev1.PodList{}
err := fakeClient.List(ctx, &podList, client.InNamespace(namespaceStr))
assert.Nil(t, err, "Fail to get pod list")
assert.Equal(t, 1, len(podList.Items), "Init pod list len is wrong")

// Initialize a new RayClusterReconciler.
testRayClusterReconciler := &RayClusterReconciler{
Client: fakeClient,
Recorder: &record.FakeRecorder{},
Scheme: scheme.Scheme,
Log: ctrl.Log.WithName("controllers").WithName("RayCluster"),
}

err = testRayClusterReconciler.reconcilePods(ctx, cluster)
assert.Nil(t, err, "Fail to reconcile Pods")

err = fakeClient.List(ctx, &podList, &client.ListOptions{
LabelSelector: workerSelector,
Namespace: namespaceStr,
})
assert.Nil(t, err, "Fail to get pod list after reconcile")
if tc.numOfHosts > 1 {
assert.Equal(t, int(tc.numOfHosts), len(podList.Items),
"Number of worker pods is wrong after reconcile expect %d actual %d", int(tc.numOfHosts), len(podList.Items)-1)
} else {
assert.Equal(t, int(*tc.replicas), len(podList.Items),
"Replica number is wrong after reconcile expect %d actual %d", int(*tc.replicas), len(podList.Items))
}
})
}
}

func TestSumGPUs(t *testing.T) {
nvidiaGPUResourceName := corev1.ResourceName("nvidia.com/gpu")
googleTPUResourceName := corev1.ResourceName("google.com/tpu")
Expand Down
Loading