-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathpublic-asgd.py
156 lines (123 loc) · 6.1 KB
/
public-asgd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import argparse
import os
import threading
import time
import torch
from torch import optim
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.distributed.rpc as rpc
from torchvision import transforms, datasets, models
model_dict = {'resnet18': models.resnet18, 'resnet50': models.resnet50, 'vgg16': models.vgg16, 'alexnet': models.alexnet,
'googlenet': models.googlenet, 'inception': models.inception_v3,
'densenet121': models.densenet121, 'mobilenet': models.mobilenet_v2}
class ParameterServer(object):
""""
The parameter server (PS) updates model parameters with gradients from the workers
and sends the updated parameters back to the workers.
"""
def __init__(self, model, num_workers, lr):
self.lock = threading.Lock()
self.future_model = torch.futures.Future()
self.num_workers = num_workers
# initialize model parameters
assert model in model_dict.keys(), \
f'model {model} is not in the model list: {list(model_dict.keys())}'
self.model = model_dict[model](num_classes=10)
# zero gradients
for p in self.model.parameters():
p.grad = torch.zeros_like(p)
self.optimizer = optim.SGD(self.model.parameters(), lr=lr, momentum=0.9)
def get_model(self):
return self.model
@staticmethod
@rpc.functions.async_execution
def update_and_fetch_model(ps_rref, grads, worker_rank):
self = ps_rref.local_value()
with self.lock:
print(f'PS updates parameters based on gradients from worker{worker_rank}')
# update model parameters
for p, g in zip(self.model.parameters(), grads):
p.grad = g
self.optimizer.step()
self.optimizer.zero_grad()
fut = self.future_model
fut.set_result(self.model)
self.future_model = torch.futures.Future()
return fut
def run_worker(ps_rref, rank, data_dir, batch_size, num_epochs):
"""
A worker pulls model parameters from the PS, computes gradients on a mini-batch
from its data partition, and pushes the gradients to the PS.
"""
# prepare dataset
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose(
[transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize])
train_dataset = datasets.ImageFolder(root=data_dir, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# set device
device_id = rank - 1
# device_id = 0
device = torch.device(f"cuda:{device_id}" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()
# get initial model from the PS
m = ps_rref.rpc_sync().get_model().to(device)
print(f'worker{rank} starts training')
tt0 = time.time()
for i in range(num_epochs):
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
output = m(data)
loss = criterion(output, target)
loss.backward()
print("worker{:d} | Epoch:{:3d} | Batch: {:3d} | Loss: {:6.2f}"
.format(rank, (i + 1), (batch_idx + 1), loss.item()))
# send gradients to the PS and fetch updated model parameters
m = rpc.rpc_sync(to=ps_rref.owner(),
func=ParameterServer.update_and_fetch_model,
args=(ps_rref, [p.grad for p in m.cpu().parameters()], rank)
).to(device)
tt1 = time.time()
print("Time: {:.2f} seconds".format((tt1 - tt0)))
def main():
parser = argparse.ArgumentParser(description="Train models on Imagenette under ASGD")
parser.add_argument("--model", type=str, default="resnet18", help="The job's name.")
parser.add_argument("--rank", type=int, default=1, help="Global rank of this process.")
parser.add_argument("--world_size", type=int, default=3, help="Total number of workers.")
parser.add_argument("--data_dir", type=str, default="./imagenette2/val", help="The location of dataset.")
parser.add_argument("--master_addr", type=str, default="localhost", help="Address of master.")
parser.add_argument("--master_port", type=str, default="29600", help="Port that master is listening on.")
parser.add_argument("--batch_size", type=int, default=64, help="Batch size of each worker during training.")
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate.")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
args = parser.parse_args()
os.environ['MASTER_ADDR'] = args.master_addr
os.environ['MASTER_PORT'] = args.master_port
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=16, rpc_timeout=0)
if args.rank == 0:
"""
initialize PS and run workers
"""
print(f"PS{args.rank} initializing")
rpc.init_rpc(f"PS{args.rank}", rank=args.rank, world_size=args.world_size, rpc_backend_options=options)
print(f"PS{args.rank} initialized")
ps_rref = rpc.RRef(ParameterServer(args.model, args.world_size, args.lr))
futs = []
for r in range(1, args.world_size):
worker = f'worker{r}'
futs.append(rpc.rpc_async(to=worker,
func=run_worker,
args=(ps_rref, r, args.data_dir, args.batch_size, args.num_epochs)))
torch.futures.wait_all(futs)
print(f"Finish training")
else:
"""
initialize workers
"""
print(f"worker{args.rank} initializing")
rpc.init_rpc(f"worker{args.rank}", rank=args.rank, world_size=args.world_size, rpc_backend_options=options)
print(f"worker{args.rank} initialized")
rpc.shutdown()
if __name__ == "__main__":
main()