This repository has been archived by the owner on Oct 31, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmodeling.py
273 lines (222 loc) · 9.52 KB
/
modeling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
import torch.nn as nn
from torch.optim import LBFGS, SGD
from torch.optim.lr_scheduler import StepLR
import dataloading
import modules
import resnet
import util
MODULE_TYPES = ["Linear", "Conv2d", "GroupNorm"]
def add_l2_regularization(criterion, model, regularization_param):
"""
Adds an L2-regularizer on the parameters in the `model` to the loss function
`criterion`. The regularization parameter is given by `regularization_param`.
"""
def regularized_loss(predictions, targets):
loss = criterion(predictions, targets)
for param in model.parameters():
loss += (regularization_param / 2.) * param.flatten().dot(param.flatten())
return loss
return regularized_loss
def initialize_model(num_inputs, num_outputs, model="linear", device="cpu"):
"""
Initializes linear model with specified number of inputs and outputs.
"""
# load model:
model_name = model
if model_name == "linear":
model = nn.Linear(num_inputs, num_outputs)
elif model_name.startswith("resnet"):
# get a vanilla ResNet model:
assert hasattr(resnet, model_name), f"Unknown model: {model_name}"
model = getattr(resnet, model_name)()
# TODO: Add checks that number of inputs and outputs match.
# replace all batchnorm layers by groupnorm layers:
for name, module in model.named_modules():
if isinstance(module, nn.BatchNorm2d):
# create groupnorm layer:
new_module = nn.GroupNorm(
min(32, module.num_features),
module.num_features,
affine=(module.weight is not None and module.bias is not None),
)
# replace the layer:
parent = model
name_list = name.split(".")
for name in name_list[:-1]:
parent = parent._modules[name]
parent._modules[name_list[-1]] = new_module
else:
raise ValueError(f"Unknown model: {model_name}")
# copy model to GPU(s) and return:
if device == "gpu":
assert torch.cuda.is_available(), "CUDA is not available on this machine."
logging.info("Copying model to GPU...")
model.cuda()
return model
def privatize_model(model, clip, std):
"""
Converts a "normal" model into a model that computes private gradients.
"""
types = tuple(getattr(nn, mod_type) for mod_type in MODULE_TYPES)
private_types = tuple(getattr(modules, mod_type) for mod_type in MODULE_TYPES)
for module in model.modules():
if isinstance(module, types) and not isinstance(module, private_types):
typename = str(type(module))
typename = typename[typename.rfind(".") + 1:-2]
module.__class__ = getattr(modules, typename)
module.clip = torch.tensor(clip)
module.std = torch.tensor(std)
else:
if hasattr(module, "weight") or hasattr(module, "bias"):
raise NotImplementedError(
f"Privacy conversion of {type(module)} not implemented."
)
return model
def unprivatize_model(model):
"""
Converts a model that computes private gradients into a "normal" model.
"""
types = tuple(getattr(modules, mod_type) for mod_type in MODULE_TYPES)
for module in model.modules():
if isinstance(module, types):
typename = str(type(module))
typename = typename[typename.rfind(".") + 1:-2]
module.__class__ = getattr(nn, typename)
del module.clip
del module.std
return model
def train_model(model, dataset, optimizer="lbfgs", batch_size=128, num_epochs=100,
learning_rate=1., criterion=None, augmentation=False, momentum=0.9,
use_lr_scheduler=True, visualizer=None, title=None):
"""
Trains `model` on samples from the specified `dataset` using the specified
`optimizer` ("lbfgs" or "sgd") with batch size `batch_size` for `num_epochs`
epochs to minimize the specified `criterion` (default = `nn.CrossEntropyLoss`).
For L-BFGS, the batch size is ignored and full gradients are used. The
`learning_rate` is only used as initial value; step sizes are determined by
checking the Wolfe conditions.
For SGD, the initial learning rate is set to `learning_rate` and is reduced
by a factor of 10 four times during training. Training uses Nesterov momentum
of 0.9. Optionally, data `augmentation` can be enabled as well.
Training progress is shown in the visdom `visualizer` in a window with the
specified `title`.
"""
# set up optimizer, criterion, and learning curve:
model.train()
device = next(model.parameters()).device
if criterion is None:
criterion = nn.CrossEntropyLoss()
if visualizer is not None:
window = [None]
# set up optimizer and learning rate scheduler:
if optimizer == "sgd":
optimizer = SGD(model.parameters(), lr=learning_rate, momentum=momentum)
scheduler = StepLR(optimizer, step_size=max(1, num_epochs // 4), gamma=0.1)
elif optimizer == "lbfgs":
assert not augmentation, "Cannot use data augmentation with L-BFGS."
use_lr_scheduler = False
optimizer = LBFGS(
model.parameters(),
lr=learning_rate,
tolerance_grad=1e-4,
line_search_fn="strong_wolfe",
)
batch_size = len(dataset["targets"])
else:
raise ValueError(f"Unknown optimizer: {optimizer}")
# create data sampler:
transform = dataloading.data_augmentation() if augmentation else None
datasampler = dataloading.load_datasampler(
dataset, batch_size=batch_size, transform=transform
)
# perform training epochs:
for epoch in range(num_epochs):
num_samples, total_loss = 0, 0.
for sample in datasampler():
# copy sample to correct device if needed:
for key in sample.keys():
if sample[key].device != device:
sample[key] = sample[key].to(device=device)
# closure that performs forward-backward pass:
def loss_closure():
optimizer.zero_grad()
predictions = model(sample["features"])
loss = criterion(predictions, sample["targets"])
loss.backward()
return loss
# perform parameter update:
loss = optimizer.step(closure=loss_closure)
# aggregate loss values for monitoring:
total_loss += (loss.item() * sample["features"].size(0))
num_samples += sample["features"].size(0)
# decay learning rate (SGD only):
if use_lr_scheduler and epoch != num_epochs - 1:
scheduler.step()
# print statistics:
if epoch % 10 == 0:
average_loss = total_loss / float(num_samples)
logging.info(f" => epoch {epoch + 1}: loss = {average_loss}")
if visualizer is not None:
window[0] = util.learning_curve(
visualizer,
torch.LongTensor([epoch + 1]),
torch.DoubleTensor([average_loss]),
window=window[0],
title=title,
)
# we are done training:
model.eval()
def test_model(model, dataset, batch_size=128, augmentation=False):
"""
Evaluates `model` on samples from the specified `dataset` using the specified
`batch_size`. Returns predictions for all samples in the dataset. Optionally,
test-time data `augmentation` can be enabled as well.
"""
# create data sampler:
model.eval()
device = next(model.parameters()).device
transform = dataloading.data_augmentation(train=False) if augmentation else None
datasampler = dataloading.load_datasampler(
dataset, batch_size=batch_size, transform=transform, shuffle=False
)
# perform test pass:
predictions = []
for sample in datasampler():
# copy sample to correct device if needed:
for key in sample.keys():
if sample[key].device != device:
sample[key] = sample[key].to(device=device)
# make predictions:
with torch.no_grad():
predictions.append(model(sample["features"]))
# return all predictions:
return torch.cat(predictions, dim=0)
def get_parameter_vector(model):
"""
Returns all parameters in the specified `model` in a single vector.
Alternatively, `model` can also be an iterable of parameters.
"""
if isinstance(model, nn.Module):
return torch.nn.utils.parameters_to_vector(model.parameters())
elif hasattr(model, "__iter__"):
return torch.nn.utils.parameters_to_vector(model)
else:
raise ValueError("Model is not nn.Module or iterable.")
def set_parameter_vector(model, parameters):
"""
Sets parameters in the specified `model` to values in `parameters` vector.
Alternatively, `model` can also be an iterable of parameters.
"""
if isinstance(model, nn.Module):
torch.nn.utils.vector_to_parameters(parameters, model.parameters())
elif hasattr(model, "__iter__"):
torch.nn.utils.vector_to_parameters(parameters, model)
else:
raise ValueError("Model is not nn.Module or iterable.")