-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathlauncher_utils.py
153 lines (123 loc) · 5.22 KB
/
launcher_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from copy import deepcopy
from typing import Tuple, List
import torch.multiprocessing as mp
from fling.component.client import ClientTemplate
def _client_trainer(client: ClientTemplate, kwargs: dict) -> Tuple:
# This is the function that each client will execute train function.
# It will receive a task and its arguments, execute it, and return its result and the updated client.
res = client.train(**kwargs)
return res, client
def _client_tester(client: ClientTemplate, kwargs: dict) -> Tuple:
# This is the function that each client will execute test function.
# It will receive a task and its arguments, execute it, and return its result and the updated client.
res = client.test(**kwargs)
return res, client
def _client_finetuner(client: ClientTemplate, kwargs: dict) -> Tuple:
# This is the function that each client will execute finetune function.
# It will receive a task and its arguments, execute it, and return its result and the updated client.
res = client.finetune(**kwargs)
return res, client
op2func = {'train': _client_trainer, 'test': _client_tester, 'finetune': _client_finetuner}
def copy_attributes(src: object, dst: object) -> None:
r"""
Overview:
Copy all the attributes of src to dst.
This function requires that src and dst is the same class.
Arguments:
src: The attributes of this object will be copied to dst.
dst: The attributes of this object will be over-written by src's.
"""
for attr in src.__dict__:
setattr(dst, attr, getattr(src, attr))
class SerialLauncher:
r"""
Overview:
Use one process to serially execute operations all clients.
"""
def launch(self, clients: ClientTemplate, task_name: str, **kwargs) -> List:
r"""
Overview:
Launch the tasks in each client.
Arguments:
clients: Clients to be launched.
task_name: Task name of the operation in each client.
kwargs: Arguments required by corresponding operations (e.g. train, test, finetune)
Returns:
loggers: A list, each element corresponds to the logger generated by one client.
"""
tasks = [(client, kwargs) for client in clients]
results = []
# Get the operation function according to the task name.
try:
op_func = op2func[task_name]
except KeyError:
raise ValueError(f'Unrecognized task name: {task_name}')
for task in tasks:
results.append(op_func(task[0], task[1]))
# Retrieve the loggers.
loggers = [results[i][0] for i in range(len(results))]
return loggers
class MultiProcessLauncher:
r"""
Overview:
Accelerate the process of operations on each client.
Use one process to monitor operations on individual clients.
"""
def __init__(self, num_proc: int):
r"""
Overview:
Initialization for launcher.
Arguments:
num_proc: Number of processes used.
"""
self.num_proc = num_proc
def launch(self, clients: List, task_name: str, **kwargs) -> List:
r"""
Overview:
Launch the tasks in each client.
Arguments:
clients: Clients to be launched.
task_name: Task name of the operation in each client.
kwargs: Arguments required by corresponding operations (e.g. train, test, finetune)
Returns:
loggers: A list, each element corresponds to the logger generated by one client.
"""
tasks = [(client, kwargs) for client in clients]
# Get the operation function according to the task name.
try:
op_func = op2func[task_name]
except KeyError:
raise ValueError(f'Unrecognized task name: {task_name}')
with mp.Pool(self.num_proc) as pool:
# Use starmap to apply the worker function to every task
# Each task is a tuple that contains the task object and the arguments
results = pool.starmap(op_func, tasks)
# Retrieve the loggers and updated clients respectively.
loggers = [results[i][0] for i in range(len(results))]
new_clients = [results[i][1] for i in range(len(results))]
# Copy the attributes of new clients to original clients.
for i in range(len(clients)):
new_client = new_clients[i]
client = clients[i]
assert new_client.client_id == client.client_id
copy_attributes(src=new_client, dst=client)
return loggers
def get_launcher(args: dict) -> object:
r"""
Overview:
Build the launcher according to the configurations.
Arguments:
args: The input configurations.
Returns:
Corresponding launcher.
"""
# Copy the args or the args will be modified by the following ``pop()``
launcher_args = deepcopy(args.launcher)
launcher_name = launcher_args.pop('name')
# Build different types of launchers.
if launcher_name == 'serial':
return SerialLauncher()
elif launcher_name == 'multiprocessing':
return MultiProcessLauncher(**launcher_args)
else:
raise ValueError(f'Unrecognized launcher type: {launcher_name}')