-
Notifications
You must be signed in to change notification settings - Fork 94
/
Copy pathhelloworld_ddp.py
executable file
·133 lines (103 loc) · 5.52 KB
/
helloworld_ddp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.distributed as dist
from torch import nn
import argparse
from tutel import system
from tutel import moe as tutel_moe
assert torch.__version__ >= '1.8.0', "DDP-based MoE requires Pytorch >= 1.8.0"
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--num_tokens', type=int, default=512)
parser.add_argument('--model_dim', type=int, default=2048)
parser.add_argument('--hidden_size', type=int, default=2048)
parser.add_argument('--num_local_experts', type=int, default=2)
parser.add_argument('--dtype', type=str, default='float32')
parser.add_argument('--fp32_gate', default=False, action='store_true')
parser.add_argument('--top', type=int, default=2)
parser.add_argument('--a2a_ffn_overlap_degree', type=int, default=1)
parser.add_argument('--num_steps', type=int, default=100)
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_args()
parallel_env = system.init_data_model_parallel(backend='nccl' if args.device == 'cuda' else 'gloo')
dist_rank, dist_world_size, dist_print = parallel_env.global_rank, parallel_env.global_size, parallel_env.dist_print
args.local_rank = parallel_env.local_device.index
batch_size = args.batch_size
num_tokens = args.num_tokens
model_dim = args.model_dim
hidden_size = args.hidden_size
num_local_experts = args.num_local_experts
top_value = args.top
a2a_ffn_overlap_degree = args.a2a_ffn_overlap_degree
device = parallel_env.local_device
if args.dtype == 'float32':
torch.set_default_dtype(torch.float32)
elif args.dtype == 'float64':
torch.set_default_dtype(torch.float64)
elif args.dtype == 'float16':
torch.set_default_dtype(torch.float16)
elif args.dtype == 'bfloat16':
torch.set_default_dtype(torch.bfloat16)
else:
raise Exception('Unrecognized data type specified: %s' % args.dtype)
class ExampleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self._moe_layer = tutel_moe.moe_layer(
gate_type = {'type': 'top', 'k': top_value, 'fp32_gate': args.fp32_gate},
experts = {'type': 'ffn', 'num_experts_per_device': num_local_experts, 'hidden_size_per_expert': hidden_size, 'activation_fn': lambda x: F.relu(x)},
model_dim = model_dim,
scan_expert_func = lambda name, param: setattr(param, 'skip_allreduce', True),
seeds = (1, dist_rank + 1, 1),
a2a_ffn_overlap_degree = a2a_ffn_overlap_degree,
)
# Summary of different parameter types: gate, local_experts
local_count = sum([torch.numel(param) for name, param in self._moe_layer.get_parameter_iterator(param_type='local_experts')])
shared_count = sum([torch.numel(param) for name, param in self._moe_layer.get_parameter_iterator(param_type='gate')])
dist_print('[Statistics] param count for MoE local_experts = %s, param count for MoE gate = %s.\n' % (local_count, shared_count))
def forward(self, input):
result = self._moe_layer(input)
result = F.log_softmax(torch.sum(result, dim=2), dim=1)
return result
# Important setting 1: skip handling expert parameters by Pytorch DDP
def add_param_to_skip_allreduce(self, param_name):
if not hasattr(self, '_ddp_params_and_buffers_to_ignore'):
self._ddp_params_and_buffers_to_ignore = list()
self._ddp_params_and_buffers_to_ignore.append(param_name)
model = ExampleModel().to(device)
# Important setting 2: iterate all expert paramter object and move them into the array of setting 1
for name, param in model.named_parameters():
if hasattr(param, 'skip_allreduce'):
model.add_param_to_skip_allreduce(name)
if torch.distributed.is_initialized():
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
dist_print(model)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)
torch.manual_seed(0)
x = torch.tensor(torch.randn([batch_size, num_tokens, model_dim], dtype=torch.float32, device='cpu').detach().numpy(), dtype=torch.get_default_dtype(), requires_grad=False, device=device)
y = torch.LongTensor(batch_size).random_(1).to(device)
tuples = (dist_world_size, args.dtype, model_dim, hidden_size, batch_size * num_tokens, num_local_experts, top_value, a2a_ffn_overlap_degree, device)
dist_print('[Benchmark] world_size = %s, dtype = %s, model_dim = %s, hidden_size = %s, samples = %s, num_local_experts = %s, topK = %s, a2a_ffn_overlap_degree = %s, device = `%s`' % tuples)
average_time, num_steps = 0, args.num_steps
for i in range(num_steps):
t_start = system.record_time()
optimizer.zero_grad()
output = model(x)
loss = F.nll_loss(output, y)
loss.backward()
optimizer.step()
t_stop = system.record_time()
num_global_experts = tutel_moe.moe_layer.global_expert_count(num_local_experts, group=system.get_local_session().model_group)
args.top = min(args.top, num_global_experts)
tflops = (batch_size * num_tokens * model_dim * hidden_size) * 4 * args.top * 3 * 1e-12 / (t_stop - t_start)
dist_print('STEP-%s: loss = %.5f, step_time = %.6f sec, perf = %.2f tflops.' % (i, float(loss.data), t_stop - t_start, tflops))
if i + 10 >= num_steps:
average_time += t_stop - t_start
average_time /= 10
dist_print('\n[Summary] Average synchronized step_time = %s sec.' % average_time)