So far we've just been using data parallel techniques. You may have heard of other parallelism techniques, and indeed the Llama 405B paper actually uses 4D parallelism when training the 405B model:
- Data parallel (FSDP as we've learned)
- Tensor parallel (this chapter)
- Context parallel (For long context lengths)
- Pipeline/model parallel
In this chapter we are going to diving into what tensor parallelism is, before we think about combining it with other types.
TP splits the model weights AND computation across multiple GPUs.
FSDP splits the model weights, but it gathers them back for the computation. Splitting the computation across GPUs is the difference.
A result of this is the world size is scaled down by your tensor parallel size => the cost of allgathers/allreduces is reduced. This becomes a big factor when your cluster is large, and TP is a very effective way to scale up!
Here are the benefits of this:
- The peak GPU memory is reduced - now instead of each GPU fully loading up the full weights for each layer, they now only load
1/num_gpus
of the weights. - We now have
per GPU memory * num_gpus
as our amount of memory to use for each layer. - Less allgather/allreduce cost
Here are the downsides:
- Global batch size is reduced
- Increased code complexity
Note that this can only really be applied to certain modules, but most of the modules in an LLM work with it.
Since we are splitting computation across GPUs, all GPUs in the same group need to receive the same input. (That is why the global batch size is reduced).
First we are going to create our device mesh. A device mesh is just a way to view your devices in an N-dimensional way. So if you have 8 GPUs, you could organize it into a device mesh like (2, 2, 2)
, or (2, 4)
, or (4, 2)
or even things like (1, 8)
.
The reason this is helpful is because we are going to name these dimensions, much like we do with tensor dimensions. Similar to how we have a batch and sequence dimension, for our device mesh we are going to have a data parallel and tensor parallel dimension.
gpus_on_node = torch.cuda.device_count()
num_nodes = world_size // gpus_on_node
mesh = dist.device_mesh.init_device_mesh(
"cuda",
(num_nodes, gpus_on_node),
mesh_dim_names=("dp", "tp"),
)
So if we have 4 GPUs total, and have a (2, 2)
device mesh, here are the assignments:
DP rank | TP rank | |
---|---|---|
GPU 0 | 0 | 0 |
GPU 1 | 0 | 1 |
GPU 2 | 1 | 0 |
GPU 3 | 1 | 1 |
This doesn't actually mean anything unless we update the rest of our code to use these device meshes, so let's see how we do that!
A lot of the pytorch distributed APIs actually take an optional mesh: Optional[DeviceMesh] = None
argument, we just haven't used it so far.
The first place is actually our data sampler, and this is how we get all of our GPUs in the tensor parallel group the same input:
sampler=DistributedSampler(
...,
num_replicas=mesh["dp"].size(),
# NOTE: every GPU on a node will have the same "dp" rank,
# meaning they will all receive the same input!
rank=mesh["dp"].get_local_rank(),
)
From GPU 0's perspective above, it would have these arguments to DistributedSampler:
num_replicas | rank | |
---|---|---|
GPU 0 | 2 | 0 |
GPU 1 | 2 | 0 |
GPU 2 | 2 | 1 |
GPU 3 | 2 | 1 |
Because our DP dimension is size of 2, and our first table above actually shows the local_rank that we use to pass to DistributedSampler.
Here's the code first and then there are graphics after this that explain how this works. Note that we are passing our mesh["tp"]
to the API, which means this is happening across our tensor parallel group!
for layer in model.model.layers:
tp.parallelize_module(
layer,
mesh["tp"],
{
"self_attn.q_proj": tp.ColwiseParallel(),
"self_attn.k_proj": tp.ColwiseParallel(),
"self_attn.v_proj": tp.ColwiseParallel(),
"self_attn.o_proj": tp.RowwiseParallel(),
"mlp.gate_proj": tp.ColwiseParallel(),
"mlp.up_proj": tp.ColwiseParallel(),
"mlp.down_proj": tp.RowwiseParallel(),
},
)
Our first three linear layers in self attention (q/k/v projection) are all colwise linear. This means we are sharding the weight matrix inside along dimension 0 (since it's stored in a transposed format). The remainder of the attention layer (including self attention), uses this sharded output to run (so attention actually will run on smaller tensors).
Image Source: PyTorchLightning
Our final layer in our self attention layer is another linear layer (o_proj). Note that we are doing rowwise parallel here. This actually let's us "recombine" across our tp dimension, as shown here:
Image Source: PyTorchLightning
So the final output of self attention will be replicated again.
The embeddings weight get's sharded along dimension 1. Meaning each GPU holds a different slice of the data associated with each token:
Embedding Weight Shape | Sharded Shape |
---|---|
(vocab_size, hidden_dim) |
(vocab_size, hidden_dim / mesh["tp"].size()) |
In a normal embedding layer it:
- Takes input tokens of
shape=(batch, seq)
- Outputs embeddings of
shape=(batch, seq, hidden_dim)
Now that we've sharded the embedding weight tensor, the layer will actually output:
- Sharded output embeddings of
shape=(batch, seq, hidden_dim / mesh["tp"].size())
.
We have a problem though: Our colwise pieces of the self_attn
module will receive the output of this module. ColwiseParallel actually expects input to be replicated not sharded.
So we need to do an allgather on the tensor to replicate it across the group (i.e. it will be back to shape=(batch, seq, hidden_dim)
). Luckily we can just specify this additional transformation with the output_layouts
argument:
tp.parallelize_module(
model,
mesh["tp"],
{"model.embed_tokens": tp.ColwiseParallel(output_layouts=Replicate())},
)
tp.parallelize_module(
model,
mesh["tp"],
{
"lm_head": tp.ColwiseParallel(
output_layouts=Replicate()
),
},
)
We have to include Replicate()
here because our loss expects replicated tensors, but colwise by default shards on the last dimension.
For normalization layers, it works a bit differently. We don't actually shard the layer's weights at all, instead, we shard the input for this on the sequence dimension!
So our computation is split, and we need to do some work to join the results back together for the other modules:
for layer in model.model.layers:
tp.parallelize_module(
layer,
mesh["tp"],
{
+ "input_layernorm": tp.SequenceParallel(),
+ "self_attn": tp.PrepareModuleInput(
+ input_kwarg_layouts={"hidden_states": Shard(dim=1)},
+ desired_input_kwarg_layouts={"hidden_states": Replicate()},
+ ),
"self_attn.q_proj": tp.ColwiseParallel(),
"self_attn.k_proj": tp.ColwiseParallel(),
"self_attn.v_proj": tp.ColwiseParallel(),
- "self_attn.o_proj": tp.RowwiseParallel(),
+ "self_attn.o_proj": tp.RowwiseParallel(output_layouts=Shard(1)),
+ "post_attention_layernorm": tp.SequenceParallel(),
+ "mlp": tp.PrepareModuleInput(
+ input_layouts=Shard(dim=1),
+ desired_input_layouts=Replicate(),
+ ),
"mlp.gate_proj": tp.ColwiseParallel(),
"mlp.up_proj": tp.ColwiseParallel(),
- "mlp.down_proj": tp.RowwiseParallel(),
+ "mlp.down_proj": tp.RowwiseParallel(output_layouts=Shard(1)),
},
)
The PrepareModuleInput
objects transform how the tensors are split up. E.g. for self_attn
the hidden_states input is sharded along the 1st dimension because of the SequenceParallel
, but all the ColwiseParallel
expect input to be replicated.
We also need to change our embedding layer, since now the output of that is going into our SequenceParallel layer, we need to shard it along dimension 1:
tp.parallelize_module(
model,
mesh["tp"],
- {"model.embed_tokens": tp.ColwiseParallel(output_layouts=Replicate())},
+ {"model.embed_tokens": tp.ColwiseParallel(output_layouts=Shard(1))},
)
We actually need an additional change because of this, due to transformers
specific code. It computes the sequence length based on the output of the embedding layer, which will be wrong since we are now sharding it along the sequence dimension. Passing position_ids explicitly will fix this, but its very implementation specific:
with timers["data"], torch.no_grad():
batch = next(batches)
batch = {k: v.to(device=device) for k, v in batch.items()}
+ batch["position_ids"] = torch.arange(
+ 0, args.seq_length, device=device, dtype=torch.long
+ ).unsqueeze(0)
And here is the diff for our final output from the network:
tp.parallelize_module(
model,
mesh["tp"],
{
+ "model.norm": tp.SequenceParallel(),
"lm_head": tp.ColwiseParallel(
+ input_layouts=Shard(1),
output_layouts=Replicate(),
),
},
)
There's an additional api for parallelizing the loss computation (only works for Cross Entropy at the moment of writing) across the class dimension. We first need to use this context manager around our loss computation:
with tp.loss_parallel(), timers["forward"]:
outputs = model(**batch)
with tp.loss_parallel(), timers["backward"]:
outputs.loss.backward()
Then we need to update the output of our lm_head
for this also, because loss_parallel requires different sharding format and DTensor:
tp.parallelize_module(
model,
mesh["tp"],
{
"model.norm": tp.SequenceParallel(),
"lm_head": tp.ColwiseParallel(
input_layouts=Shard(1),
- output_layouts=Replicate(),
+ output_layouts=Shard(-1),
+ use_local_output=False,
),
},
)
use_local_output=False
tells pytorch to return a DTensor
from the operation, instead of a normal Tensor
.
Because each of our GPUs is now no longer the unit, we just need to update our throughput calculation to use our device mesh:
if state["global_step"] % args.log_freq == 0:
- tok_per_step = world_size * args.batch_size * args.seq_length
+ tok_per_step = mesh["dp"].size() * args.batch_size * args.seq_length
ms_per_step = sum(t.avg_elapsed_ms() for t in timers.values())
Here are some results from launching training for llama 8B on a single node of 8x H100s:
Command:
HF_HOME=/home/ubuntu/.cache/huggingface OMP_NUM_THREADS=26 torchrun --standalone --nproc-per-node gpu train_llm.py --experiment-name tp-llama-8b --dataset-name tatsu-lab/alpaca --model-name meta-llama/Llama-3.1-8B --log-freq 10 --batch-size 16 --seq-length 1024 --num-epochs 1
For completeness here are the relevant docs/guides from pytorch on how to achieve this:
Here we are going to give a brief explanation of how the api we are going to be using works.
- tp.RowwiseParallel() shards the module's weights in a row wise fashion.
- Inputs by default are sharded on last dimension
- Outputs by default are replicated on all workers
- tp.ColwiseParallel() shards the module's weights in a col wise fashion.
- Inputs by default are replicated on all workers
- Outputs by default are sharded on last dimension
- tp.SequenceParallel() shards the input/output across dimension 1. Module weights are NOT sharded.
- tp.PrepareModuleInput() let's you change the sharding configuration of input tensors
torch.distributed._tensor.Shard(dim=X)
indicates a tensor should be sharded along dimension Xtorch.distributed._tensor.Replicate()
indicates a tensor should be replicated among all workers.
How all of these things interact is actually very subtle and complex, which is why this guide is useful!
You can also change most of the default behavior with arguments to these classes. For example, you can change RowwiseParallel to assume the input is replicated instead of sharded.