-
Notifications
You must be signed in to change notification settings - Fork 103
/
train.py
213 lines (184 loc) · 5.87 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import print_function
import argparse
import torch.nn.functional as F
import torch.utils.data
import egg.core as core
from egg.zoo.simple_autoenc.archs import Receiver, Sender
from egg.zoo.simple_autoenc.features import OneHotLoader
def get_params(params):
parser = argparse.ArgumentParser()
parser.add_argument(
"--n_features",
type=int,
default=10,
help='Dimensionality of the "concept" space (default: 10)',
)
parser.add_argument(
"--batches_per_epoch",
type=int,
default=1000,
help="Number of batches per epoch (default: 1000)",
)
parser.add_argument(
"--sender_hidden",
type=int,
default=10,
help="Size of the hidden layer of Sender (default: 10)",
)
parser.add_argument(
"--receiver_hidden",
type=int,
default=10,
help="Size of the hidden layer of Receiver (default: 10)",
)
parser.add_argument(
"--sender_embedding",
type=int,
default=10,
help="Dimensionality of the embedding hidden layer for Sender (default: 10)",
)
parser.add_argument(
"--receiver_embedding",
type=int,
default=10,
help="Dimensionality of the embedding hidden layer for Receiver (default: 10)",
)
parser.add_argument(
"--sender_cell",
type=str,
default="rnn",
help="Type of the cell used for Sender {rnn, gru, lstm} (default: rnn)",
)
parser.add_argument(
"--receiver_cell",
type=str,
default="rnn",
help="Type of the cell used for Receiver {rnn, gru, lstm} (default: rnn)",
)
parser.add_argument(
"--sender_entropy_coeff",
type=float,
default=1e-1,
help="The entropy regularisation coefficient for Sender (default: 1e-1)",
)
parser.add_argument(
"--receiver_entropy_coeff",
type=float,
default=1e-1,
help="The entropy regularisation coefficient for Receiver (default: 1e-1)",
)
parser.add_argument(
"--sender_lr",
type=float,
default=1e-3,
help="Learning rate for Sender's parameters (default: 1e-3)",
)
parser.add_argument(
"--receiver_lr",
type=float,
default=1e-3,
help="Learning rate for Receiver's parameters (default: 1e-3)",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="GS temperature for the sender (default: 1.0)",
)
parser.add_argument(
"--mode",
type=str,
default="rf",
help="Selects whether Reinforce or GumbelSoftmax relaxation is used for training {rf, gs}"
"(default: rf)",
)
args = core.init(parser, params)
return args
def loss(sender_input, _message, _receiver_input, receiver_output, _labels, _aux_input):
acc = (receiver_output.argmax(dim=1) == sender_input.argmax(dim=1)).detach().float()
loss = F.cross_entropy(
receiver_output, sender_input.argmax(dim=1), reduction="none"
)
return loss, {"acc": acc}
def main(params):
opts = get_params(params)
device = torch.device("cuda" if opts.cuda else "cpu")
train_loader = OneHotLoader(
n_features=opts.n_features,
batch_size=opts.batch_size,
batches_per_epoch=opts.batches_per_epoch,
)
test_loader = OneHotLoader(
n_features=opts.n_features,
batch_size=opts.batch_size,
batches_per_epoch=opts.batches_per_epoch,
seed=7,
)
sender = Sender(n_hidden=opts.sender_hidden, n_features=opts.n_features)
receiver = Receiver(n_features=opts.n_features, n_hidden=opts.receiver_hidden)
if opts.mode.lower() == "rf":
sender = core.RnnSenderReinforce(
sender,
opts.vocab_size,
opts.sender_embedding,
opts.sender_hidden,
cell=opts.sender_cell,
max_len=opts.max_len,
)
receiver = core.RnnReceiverDeterministic(
receiver,
opts.vocab_size,
opts.receiver_embedding,
opts.receiver_hidden,
cell=opts.receiver_cell,
)
game = core.SenderReceiverRnnReinforce(
sender,
receiver,
loss,
sender_entropy_coeff=opts.sender_entropy_coeff,
receiver_entropy_coeff=opts.receiver_entropy_coeff,
)
callbacks = []
elif opts.mode.lower() == "gs":
sender = core.RnnSenderGS(
sender,
opts.vocab_size,
opts.sender_embedding,
opts.sender_hidden,
cell=opts.sender_cell,
max_len=opts.max_len,
temperature=opts.temperature,
)
receiver = core.RnnReceiverGS(
receiver,
opts.vocab_size,
opts.receiver_embedding,
opts.receiver_hidden,
cell=opts.receiver_cell,
)
game = core.SenderReceiverRnnGS(sender, receiver, loss)
callbacks = [core.TemperatureUpdater(agent=sender, decay=0.9, minimum=0.1)]
else:
raise NotImplementedError(f"Unknown training mode, {opts.mode}")
optimizer = torch.optim.Adam(
[
{"params": game.sender.parameters(), "lr": opts.sender_lr},
{"params": game.receiver.parameters(), "lr": opts.receiver_lr},
]
)
trainer = core.Trainer(
game=game,
optimizer=optimizer,
train_data=train_loader,
validation_data=test_loader,
callbacks=callbacks + [core.ConsoleLogger(as_json=True)],
)
trainer.train(n_epochs=opts.n_epochs)
core.close()
if __name__ == "__main__":
import sys
main(sys.argv[1:])