-
Notifications
You must be signed in to change notification settings - Fork 0
/
agents.py
115 lines (92 loc) · 3.16 KB
/
agents.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
from torch.autograd import Variable
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal as MVN
import torch.nn.functional as F
import numpy as np
from phe import paillier
import multiprocessing as mp
def encrypt(args):
key, x = args
return key.encrypt(float(x))
def decrypt(args):
key, x = args
try:
return key.decrypt(x)
except OverflowError:
print(x)
return 0.
class Central():
def __init__(self, model, optim, encryption=False):
self.model = model
self.optim = optim
# Encryption-based Setup
self.pool = None
self.keyring = None
if encryption:
self.pool = mp.Pool(1)
self.keyring = paillier.PaillierPrivateKeyring()
self.public_key, self.private_key = (
paillier.generate_paillier_keypair(self.keyring, n_length=128))
def __del__(self):
if self.pool is not None:
self.pool.close()
def update_model(self, ups):
"""
Update the central model with the new gradients.
ups is consisting of weight grads
"""
self.optim.zero_grad()
i = 0
for layer, paramval in self.model.named_parameters():
if self.keyring is not None:
print('Decrypting {} ...'.format(ups[i].shape))
nargs = [(
self.private_key, x) for _, x in np.ndenumerate(ups[i])]
update = np.reshape(
self.pool.map(decrypt, nargs), ups[i].shape)
update = torch.FloatTensor(update)
paramval.grad = update.cuda()
else:
paramval.grad = ups[i]
i += 1
self.optim.step()
self.optim.zero_grad()
def init_adv(self, model):
self.adv = model
def get_key(self):
"""
Returns an encryption keyring to store public/private keys
"""
if self.keyring:
return self.public_key
return None
class Worker():
def __init__(self, loss, key=None):
self.model = None
self.loss = loss
self.key = key
if self.key is not None:
self.pool = mp.Pool(1)
def __del__(self):
if self.key is not None:
self.pool.close()
def fwd_bkwd(self, inp, outp):
pred = self.model(inp)
lossval = self.loss(pred, outp)
lossval.backward()
weightgrads = []
for layer, paramval in self.model.named_parameters():
if self.key is not None:
grads = np.array(paramval.grad.cpu())
print('Encrypting {} ...'.format(grads.shape))
nargs = [(self.key, x) for _, x in np.ndenumerate(grads)]
grads_e = np.reshape(
self.pool.map(encrypt, nargs), grads.shape)
weightgrads.append(grads_e)
else:
weightgrads.append(paramval.grad)
return weightgrads
class Agg():
def __init__(self, rule):
self.rule = rule # rule should be a function that takes a list of gradient updates and aggregates them