Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More robust and dynamic multi-slicing TPU training #690

Open
Ivan-Zhou opened this issue Aug 13, 2024 · 4 comments
Open

More robust and dynamic multi-slicing TPU training #690

Ivan-Zhou opened this issue Aug 13, 2024 · 4 comments

Comments

@Ivan-Zhou
Copy link
Contributor

The current way multi-slicing training is not robust or reliable on spot instances.

There have been some discussion inside CRFM and with GCP team on this topic. I create this issue to capture the main ideas and threads.

Main Objectives

  • Implement a more robust system that will provision multiple TPU slices and allocate workload on available slices, by constantly monitoring the status of slices.
  • Take a step further, to use Ray to coordinate multiple training jobs on single slice as well as a single training runs on multi slice

Challenges

  • For spot instances in TRC, we could not directly use GKE service. It is due to lack of ability to disable TPU billing via GKE. Therefore, many of the tricks and techniques implemented on GKE is not available to us, including Pathways,

Possible Ideas

Use Ray for scheduling

Allen Wang from Google proposed to use Ray to schedule and run workloads through slices. He put together a quick gist on how to run both single and multi-slice workloads via Ray (>= 2.10.0). This covers the job scheduling aspect and will work regardless if the cluster is provisioned directly on VMs or on GKE.

To mitigate potential race conditions, Allen also added placement groups to pre-reserve existing TPU pod slices (ray_tpu.py) and an example of how it can be used to run tasks (ray_tpu_task.py)

David's summarization:

  • We would spin up a ray head node as a job scheduler
  • tpu slices register with the head node (each worker runs ray start)
  • we launch a multislice tpu job by getting a set of named worker 0s, setting the env variables, then launching the real job

Use a host to coordinate work and communicate gradients

@dlwh 's idea:

  • spin up individual slices as more or less atomic units.
  • They run their own SPMD job as levanter instances.
  • However, they also phone home to some coordinator machine.
  • coordinator assigns work units (batches) to slices, tells them how to communicate the gradients. JAX has an experimental “jax.lax.infeed” and “jax.lax.outfeed” for sending and receiving values from the host. The host receives the appropriate gradients from the device (using outfeed), communicates them to the other hosts (using a tree or something fancy), then sends the accumulated gradients via infeed.
    The trick will be scheduling this to maximize throughput, since I don’t know how to tell XLA how long something will take.
    To make things reproducible, you’ll have to be very careful, ensuring that you reduce batches in the same order even in the presence of slices dropping out. To do this, you will likely have to either use a lot of host memory and/or accept recomputing batches.
@dlwh
Copy link
Member

dlwh commented Aug 13, 2024

These two ideas can be fused pretty well I think, fwiw.

@Ivan-Zhou
Copy link
Contributor Author

I could reproduce Allen's script on single slice v4 TPU, but not on multi-slices. It should not be a blocker for now, if we are not prioritizing multi-slice training.

Now I think more of it, I realized that this is not the shortest path. I should instead take reference of Marin's existing Ray + TPU framework for launching data preparation jobs. It seems to be a more applicable guide.

I will try to follow https://github.com/stanford-crfm/marin/tree/main/infra#maintaining-a-ray-cluster and build a PoC for training.

@dlwh
Copy link
Member

dlwh commented Aug 21, 2024

sounds like a good plan! The main difference is that Allen's script handles multi-node TPU slices, which the marin cluster doesn't bother with.

@Ivan-Zhou
Copy link
Contributor Author

I will run experiment with larger TPU nodes then

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants