Skip to content

Commit

Permalink
blog: tp (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
garrett361 authored Jul 8, 2024
1 parent d5f4b27 commit 91f8812
Show file tree
Hide file tree
Showing 16 changed files with 1,347 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ This repository contains a variety of Determined examples that are not actively
| [LLM Finetuning 2](blog/llm-finetuning-2) | Finetuning Mistral-7B on Text-to-SQL using LoRA and DeepSpeed. |
| [LLM Finetuning 3](blog/llm-finetuning-3) | Finetuning Gemma-2B using DPO. |
| [Python SDK demo](blog/python_sdk_demo) | Example usage of the Determined Python SDK to run and administer experiments. |
| [Tensor Parallelism](blog/tp) | Profiling tensor parallelism in PyTorch. |

## Computer Vision

Expand Down
5 changes: 5 additions & 0 deletions blog/act-mem-2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@ memory.
- `attn_script.py` shows the cost of activation memory in the attention layer.
- Tests of the code are in `test.py`.
- See `requirements.txt` for versions the code was built against.


## Contributors

- [Garrett Goon](https://github.com/garrett361)
13 changes: 13 additions & 0 deletions blog/tp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Tensor Parallelism

Code accompanying the deep-dive [blog post on Tensor Parallelism](https://determined.ai/blog/tp).

- The MLP and TP MLP layers are in `layer.py`
- Matmul profiling code in `matmul_profiling.py`
- MLP TP profiling code in `tp_profiling.py`
- Tests of the rearranging tensor sums are in `test_dot_product_{local,distributed}.py`


## Contributors

- [Garrett Goon](https://github.com/garrett361)
138 changes: 138 additions & 0 deletions blog/tp/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from typing import Any, Optional, Union

import torch
import torch.distributed as dist
import torch.nn as nn


class MLP(nn.Module):
"""
Basic MLP (multi-layer perceptron) layer. Dropout is neglected.
"""

def __init__(
self,
d_model: int,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__()

self.lin_0 = nn.Linear(d_model, 4 * d_model, device=device, dtype=dtype)
self.act_fn = nn.GELU()
self.lin_1 = nn.Linear(4 * d_model, d_model, device=device, dtype=dtype)

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
x = self.lin_0(inputs)
x = self.act_fn(x)
x = self.lin_1(x)
return x


class AllReduceFwdIdentityBwd(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any, inputs: torch.Tensor, group: Optional[dist.ProcessGroup] = None
) -> torch.Tensor:
inputs = inputs.clone()
dist.all_reduce(inputs, group=group)
return inputs

@staticmethod
def backward(ctx: Any, grad_outputs: torch.Tensor) -> tuple[torch.Tensor, None]:
return grad_outputs, None


class IdentityFwdAllReduceBwd(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any, inputs: torch.Tensor, group: Optional[dist.ProcessGroup] = None
) -> torch.Tensor:
ctx.group = group
return inputs

@staticmethod
def backward(ctx: Any, grad_outputs: torch.Tensor) -> tuple[torch.Tensor, None]:
grad_outputs = grad_outputs.clone()
dist.all_reduce(grad_outputs, group=ctx.group)
return grad_outputs, None


class LinearShardedOutputs(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
group: dist.ProcessGroup,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
sharded_out_features, remainder = divmod(out_features, group.size())
assert not remainder, "out_features must be divisible by the ProcessGroup size"
super().__init__(
in_features=in_features,
out_features=sharded_out_features,
device=device,
dtype=dtype,
)

self.group = group

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
# Wrap the unsharded inputs for backwards-pass correctness.
x = IdentityFwdAllReduceBwd.apply(inputs, self.group)
x = super().forward(x)
return x


class LinearShardedInputs(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
group: dist.ProcessGroup,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
sharded_in_features, remainder = divmod(in_features, group.size())
assert not remainder, "in_features must be divisible by the ProcessGroup size"
super().__init__(
in_features=sharded_in_features,
out_features=out_features,
device=device,
dtype=dtype,
)
self.group = group

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
x = inputs @ self.weight.T
# Wrap the mat-mul in an all-reduce forwards-pass correctness.
x = AllReduceFwdIdentityBwd.apply(x, self.group)
# Crucial: add the bias _after_ the all-reduce.
x = x + self.bias
return x


class MLPTP(MLP):
"""
Basic Tensor Parallel MLP (multi-layer perceptron) layer. Dropout is neglected.
"""

def __init__(
self,
d_model: int,
group: Optional[dist.ProcessGroup] = None,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
nn.Module.__init__(self)
# Fallback to the WORLD process group, if None provided
group = group or dist.group.WORLD

self.lin_0 = LinearShardedOutputs(
d_model, 4 * d_model, group=group, device=device, dtype=dtype
)
self.act_fn = nn.GELU()
self.lin_1 = LinearShardedInputs(
4 * d_model, d_model, group=group, device=device, dtype=dtype
)
Binary file added blog/tp/matmul.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
91 changes: 91 additions & 0 deletions blog/tp/matmul_profiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import gc
import logging

import determined as det
import torch

import utils

"""
Script for profiling square matmuls on a single GPU.
"""


def profile_and_report(
core_context: det.core.Context,
d_model: int,
num_repeats: int,
num_warmups: int,
dtype: torch.dtype = torch.bfloat16,
) -> None:
A = torch.randn(d_model, d_model, device="cuda", dtype=dtype)
B = torch.randn(d_model, d_model, device="cuda", dtype=dtype)

# Use CUDA events for accurate timing.
timer = utils.CUDAEventTimer()
torch.cuda.synchronize()

# Warmups
for _ in range(num_warmups):
A @ B

# Timed region.
for _ in range(num_repeats):
with timer:
A @ B

# Mean and std TFLOP computations
flops = 2 * d_model**3
time_s_t = torch.tensor(timer.time_s_list)
tflop_s_gpu_t = flops / time_s_t / 1e12
metrics = {
"d_model": d_model,
"time_s": timer.time_s_mean,
"time_s_std": timer.time_s_std,
"tflop_s_gpu": tflop_s_gpu_t.mean().item(),
"tflop_s_gpu_std": tflop_s_gpu_t.std().item(),
}

# Use d_model as the x-axis for plotting purposes.
core_context.train.report_metrics(group="matmul", steps_completed=d_model, metrics=metrics)

# Memory management
del A
del B
gc.collect()
torch.cuda.empty_cache()


def main(
core_context: det.core.Context,
d_model_min: int,
d_model_max: int,
d_model_step: int,
num_repeats: int,
num_warmups: int,
) -> None:
for d_model in range(d_model_min, d_model_max + 1, d_model_step):
profile_and_report(
core_context=core_context,
d_model=d_model,
num_repeats=num_repeats,
num_warmups=num_warmups,
)


if __name__ == "__main__":
info = det.get_cluster_info()
assert info, "This script must run on a determined cluster."
hparams = info.trial.hparams

with det.core.init() as core_context:
logging.basicConfig(level=logging.INFO, format=det.LOG_FORMAT)

main(
core_context=core_context,
d_model_min=hparams["d_model_min"],
d_model_max=hparams["d_model_max"],
d_model_step=hparams["d_model_step"],
num_repeats=hparams["num_repeats"],
num_warmups=hparams["num_warmups"],
)
20 changes: 20 additions & 0 deletions blog/tp/matmul_profiling.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: Matmul Profiling
# Adjust the workspace and project names, as appropriate.
workspace: TP Blog Post
project: Matmul Profiling
resources:
slots_per_trial: 1
searcher:
name: single
metric: not_used
max_length: 1
hyperparameters:
d_model_min: 256
d_model_max: 16384
d_model_step: 256
num_warmups: 5
num_repeats: 100
entrypoint: >-
python3 -m determined.launch.torch_distributed
python3 matmul_profiling.py
max_restarts: 0
Binary file added blog/tp/mlp_tp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 91f8812

Please sign in to comment.