-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathserver_improved.py
44 lines (34 loc) · 1.32 KB
/
server_improved.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# -*- coding: utf-8 -*-
import numpy as np
class imp_server(object):
def __init__(self,
narms,
nclients):
self.M = nclients
self.K = narms
self.local_means = np.zeros([self.M,self.K])
self.global_means = np.zeros(self.K)
self.global_set = set()
self.global_delta = np.ones([self.M, self.K])
self.p = 1
self.c_local_stat = np.zeros(self.M)
def local_mean_update(self,i,local_stat):
self.local_means[i] = local_stat
self.c_local_stat[i] = 1
def global_mean_update(self):
self.global_set = set()
if sum(self.c_local_stat) >= self.M:
self.global_means = np.sum(self.local_means, axis=0)/self.M
self.c_local_stat = np.zeros(self.M)
return True, self.global_means
else:
return False, 0
def local_set_update(self,i,local_set):
self.global_set = self.global_set|local_set
def local_delta_update(self,i,local_delta):
self.global_delta[i] = local_delta
def global_set_update(self):
return self.global_set
def global_delta_update(self):
#print("global delta:", self.global_delta)
return self.global_delta