-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d5f4b27
commit 91f8812
Showing
16 changed files
with
1,347 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Tensor Parallelism | ||
|
||
Code accompanying the deep-dive [blog post on Tensor Parallelism](https://determined.ai/blog/tp). | ||
|
||
- The MLP and TP MLP layers are in `layer.py` | ||
- Matmul profiling code in `matmul_profiling.py` | ||
- MLP TP profiling code in `tp_profiling.py` | ||
- Tests of the rearranging tensor sums are in `test_dot_product_{local,distributed}.py` | ||
|
||
|
||
## Contributors | ||
|
||
- [Garrett Goon](https://github.com/garrett361) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
from typing import Any, Optional, Union | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.nn as nn | ||
|
||
|
||
class MLP(nn.Module): | ||
""" | ||
Basic MLP (multi-layer perceptron) layer. Dropout is neglected. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
d_model: int, | ||
device: Optional[Union[str, torch.device]] = None, | ||
dtype: Optional[torch.dtype] = None, | ||
) -> None: | ||
super().__init__() | ||
|
||
self.lin_0 = nn.Linear(d_model, 4 * d_model, device=device, dtype=dtype) | ||
self.act_fn = nn.GELU() | ||
self.lin_1 = nn.Linear(4 * d_model, d_model, device=device, dtype=dtype) | ||
|
||
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | ||
x = self.lin_0(inputs) | ||
x = self.act_fn(x) | ||
x = self.lin_1(x) | ||
return x | ||
|
||
|
||
class AllReduceFwdIdentityBwd(torch.autograd.Function): | ||
@staticmethod | ||
def forward( | ||
ctx: Any, inputs: torch.Tensor, group: Optional[dist.ProcessGroup] = None | ||
) -> torch.Tensor: | ||
inputs = inputs.clone() | ||
dist.all_reduce(inputs, group=group) | ||
return inputs | ||
|
||
@staticmethod | ||
def backward(ctx: Any, grad_outputs: torch.Tensor) -> tuple[torch.Tensor, None]: | ||
return grad_outputs, None | ||
|
||
|
||
class IdentityFwdAllReduceBwd(torch.autograd.Function): | ||
@staticmethod | ||
def forward( | ||
ctx: Any, inputs: torch.Tensor, group: Optional[dist.ProcessGroup] = None | ||
) -> torch.Tensor: | ||
ctx.group = group | ||
return inputs | ||
|
||
@staticmethod | ||
def backward(ctx: Any, grad_outputs: torch.Tensor) -> tuple[torch.Tensor, None]: | ||
grad_outputs = grad_outputs.clone() | ||
dist.all_reduce(grad_outputs, group=ctx.group) | ||
return grad_outputs, None | ||
|
||
|
||
class LinearShardedOutputs(nn.Linear): | ||
def __init__( | ||
self, | ||
in_features: int, | ||
out_features: int, | ||
group: dist.ProcessGroup, | ||
device: Optional[Union[str, torch.device]] = None, | ||
dtype: Optional[torch.dtype] = None, | ||
) -> None: | ||
sharded_out_features, remainder = divmod(out_features, group.size()) | ||
assert not remainder, "out_features must be divisible by the ProcessGroup size" | ||
super().__init__( | ||
in_features=in_features, | ||
out_features=sharded_out_features, | ||
device=device, | ||
dtype=dtype, | ||
) | ||
|
||
self.group = group | ||
|
||
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | ||
# Wrap the unsharded inputs for backwards-pass correctness. | ||
x = IdentityFwdAllReduceBwd.apply(inputs, self.group) | ||
x = super().forward(x) | ||
return x | ||
|
||
|
||
class LinearShardedInputs(nn.Linear): | ||
def __init__( | ||
self, | ||
in_features: int, | ||
out_features: int, | ||
group: dist.ProcessGroup, | ||
device: Optional[Union[str, torch.device]] = None, | ||
dtype: Optional[torch.dtype] = None, | ||
) -> None: | ||
sharded_in_features, remainder = divmod(in_features, group.size()) | ||
assert not remainder, "in_features must be divisible by the ProcessGroup size" | ||
super().__init__( | ||
in_features=sharded_in_features, | ||
out_features=out_features, | ||
device=device, | ||
dtype=dtype, | ||
) | ||
self.group = group | ||
|
||
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | ||
x = inputs @ self.weight.T | ||
# Wrap the mat-mul in an all-reduce forwards-pass correctness. | ||
x = AllReduceFwdIdentityBwd.apply(x, self.group) | ||
# Crucial: add the bias _after_ the all-reduce. | ||
x = x + self.bias | ||
return x | ||
|
||
|
||
class MLPTP(MLP): | ||
""" | ||
Basic Tensor Parallel MLP (multi-layer perceptron) layer. Dropout is neglected. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
d_model: int, | ||
group: Optional[dist.ProcessGroup] = None, | ||
device: Optional[Union[str, torch.device]] = None, | ||
dtype: Optional[torch.dtype] = None, | ||
) -> None: | ||
nn.Module.__init__(self) | ||
# Fallback to the WORLD process group, if None provided | ||
group = group or dist.group.WORLD | ||
|
||
self.lin_0 = LinearShardedOutputs( | ||
d_model, 4 * d_model, group=group, device=device, dtype=dtype | ||
) | ||
self.act_fn = nn.GELU() | ||
self.lin_1 = LinearShardedInputs( | ||
4 * d_model, d_model, group=group, device=device, dtype=dtype | ||
) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import gc | ||
import logging | ||
|
||
import determined as det | ||
import torch | ||
|
||
import utils | ||
|
||
""" | ||
Script for profiling square matmuls on a single GPU. | ||
""" | ||
|
||
|
||
def profile_and_report( | ||
core_context: det.core.Context, | ||
d_model: int, | ||
num_repeats: int, | ||
num_warmups: int, | ||
dtype: torch.dtype = torch.bfloat16, | ||
) -> None: | ||
A = torch.randn(d_model, d_model, device="cuda", dtype=dtype) | ||
B = torch.randn(d_model, d_model, device="cuda", dtype=dtype) | ||
|
||
# Use CUDA events for accurate timing. | ||
timer = utils.CUDAEventTimer() | ||
torch.cuda.synchronize() | ||
|
||
# Warmups | ||
for _ in range(num_warmups): | ||
A @ B | ||
|
||
# Timed region. | ||
for _ in range(num_repeats): | ||
with timer: | ||
A @ B | ||
|
||
# Mean and std TFLOP computations | ||
flops = 2 * d_model**3 | ||
time_s_t = torch.tensor(timer.time_s_list) | ||
tflop_s_gpu_t = flops / time_s_t / 1e12 | ||
metrics = { | ||
"d_model": d_model, | ||
"time_s": timer.time_s_mean, | ||
"time_s_std": timer.time_s_std, | ||
"tflop_s_gpu": tflop_s_gpu_t.mean().item(), | ||
"tflop_s_gpu_std": tflop_s_gpu_t.std().item(), | ||
} | ||
|
||
# Use d_model as the x-axis for plotting purposes. | ||
core_context.train.report_metrics(group="matmul", steps_completed=d_model, metrics=metrics) | ||
|
||
# Memory management | ||
del A | ||
del B | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
|
||
|
||
def main( | ||
core_context: det.core.Context, | ||
d_model_min: int, | ||
d_model_max: int, | ||
d_model_step: int, | ||
num_repeats: int, | ||
num_warmups: int, | ||
) -> None: | ||
for d_model in range(d_model_min, d_model_max + 1, d_model_step): | ||
profile_and_report( | ||
core_context=core_context, | ||
d_model=d_model, | ||
num_repeats=num_repeats, | ||
num_warmups=num_warmups, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
info = det.get_cluster_info() | ||
assert info, "This script must run on a determined cluster." | ||
hparams = info.trial.hparams | ||
|
||
with det.core.init() as core_context: | ||
logging.basicConfig(level=logging.INFO, format=det.LOG_FORMAT) | ||
|
||
main( | ||
core_context=core_context, | ||
d_model_min=hparams["d_model_min"], | ||
d_model_max=hparams["d_model_max"], | ||
d_model_step=hparams["d_model_step"], | ||
num_repeats=hparams["num_repeats"], | ||
num_warmups=hparams["num_warmups"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
name: Matmul Profiling | ||
# Adjust the workspace and project names, as appropriate. | ||
workspace: TP Blog Post | ||
project: Matmul Profiling | ||
resources: | ||
slots_per_trial: 1 | ||
searcher: | ||
name: single | ||
metric: not_used | ||
max_length: 1 | ||
hyperparameters: | ||
d_model_min: 256 | ||
d_model_max: 16384 | ||
d_model_step: 256 | ||
num_warmups: 5 | ||
num_repeats: 100 | ||
entrypoint: >- | ||
python3 -m determined.launch.torch_distributed | ||
python3 matmul_profiling.py | ||
max_restarts: 0 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.