-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathaveraging.py
63 lines (59 loc) · 2.15 KB
/
averaging.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import copy
import torch
import numpy as np
from torch import nn
### average of weights ###
def average_weights(w):
w_avg = copy.deepcopy(w[0])
if isinstance(w[0],np.ndarray) == True:
for i in range(1, len(w)):
w_avg += w[i]
w_avg = w_avg/len(w)
else:
for k in w_avg.keys():
for i in range(1, len(w)):
w_avg[k] += w[i][k]
w_avg[k] = torch.div(w_avg[k], len(w))
return w_avg
### average of numerous experiments, f is a two-dimension input ###
def average_experiments(f):
res = [0 for i in range(len(f[0]))]
for j in range(len(f[0])):
for i in range(len(f)):
res[j] += f[i][j]
res[j] = res[j] / len(f)
return res
def average_FSVRG_weights(w, ag_scalar, net, gpu=-1):
"""
This method is for using FSVRG algo to update global parameters
:param w: list of client's state_dict
:param ag_scalar: simpilicity for A Matrix
:param net: global net model
:return: global state_dict
"""
w_t = copy.deepcopy(net.state_dict())
#print("=======================before==============================")
#print(w_t)
sg = {}
total_size = np.array(np.sum([u[0] for u in w]))
for key in w_t.keys():
sg[key] = np.zeros(w_t[key].shape)
for l in range(len(w)):
for k in sg.keys():
# += ag_scalar * w[l][0] * (w[l][1][k] - w_t[k]) / total_size
if(gpu!= -1):
tmp_w = (w[l][1][k] - w_t[k]).cpu()
sg[k] = np.add(sg[k], w[l][0] * tmp_w)
else:
sg[k] = np.add(sg[k], w[l][0] * (w[l][1][k] - w_t[k]))#np.add(sg[k].long(), torch.div(ag_scalar * w[l][0] * (torch.add(w[l][1][k], -w_t[k])).long(), total_size.long()).long())
for key in w_t.keys():
if (gpu != -1):
w_t[key] = np.add(w_t[key].cpu(), np.divide(ag_scalar * sg[key], total_size))
else:
w_t[key] = np.add(w_t[key], np.divide(ag_scalar * sg[key], total_size))
#print('===========================after===================================')
#print(w_t)
return w_t