-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathLearnTools.py
320 lines (296 loc) · 14.9 KB
/
LearnTools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import matplotlib.pyplot as plt
import torch
import ImageTools
from torch import nn
from torch.nn import functional
from torch.nn.functional import interpolate
from torch.nn.functional import one_hot
from torch import autograd
import numpy as np
import math # just so I don't use numpy by accident
k_logistic = 30 # the logistic function coefficient
threshold = 0.5
modes = ['bilinear', 'trilinear']
def return_args(parser):
parser.add_argument('-d', '--directory', type=str, default='default',
help='Stores the progress output in the \
directory name given')
parser.add_argument('-sf', '--scale_factor', type=float, default=4,
help='scale factor between high res and low res.')
parser.add_argument("--down_sample", default=False, action="store_true",
help="Down samples the input for G for testing "
"purposes.")
parser.add_argument("--super_sampling", default=False,
action="store_true", help="When comparing super-res "
"and low-res, instead of blurring, it picks one voxel "
"with nearest-neighbour interpolation.")
parser.add_argument("--squash_phases", default=False, action="store_true",
help="All material phases in low res are the same.")
parser.add_argument("--anisotropic", default=False, action="store_true",
help="The material is anisotropic (requires dif Ds).")
parser.add_argument("--with_rotation", default=False, action="store_true",
help="create rotations and mirrors for the BM.")
parser.add_argument("--separator", default=False, action="store_true",
help="Different voxel-wise loss for separator "
"material.")
parser.add_argument('-rotations_bool', nargs='+', type=int,
default=[0, 0, 1], help="If the material is "
"anisotropic, specify which images can be augmented "
"(rotations and mirrors)")
parser.add_argument('-g_image_path', type=str, help="Path to the LR "
"3D volume")
parser.add_argument('-d_image_path', nargs='+', type=str, help="Path to "
"the HR 2D slice, if Isotropic, 3 paths are needed, "
"in correct order")
parser.add_argument('-phases_idx', '--phases_low_res_idx', nargs='+',
type=int, default=[1, 2])
parser.add_argument('-d_dimensions', '--d_dimensions_to_check', nargs='+',
type=int, default=[0, 1, 2])
parser.add_argument('-volume_size_to_evaluate', nargs='+', type=int,
default=[128, 128, 128])
parser.add_argument('-wd', '--widthD', type=int, default=9,
help='Hyper-parameter for \
the width of the Discriminator network')
parser.add_argument('-wg', '--widthG', type=int, default=8,
help='Hyper-parameter for the \
width of the Generator network')
parser.add_argument('-n_res', '--n_res_blocks', type=int, default=1,
help='Number of residual blocks in the network.')
parser.add_argument('-n_dims', '--n_dims', type=int, default=3,
help='The generated image dimension (and input '
'dimension), can be either 2 or 3.')
parser.add_argument('-gu', '--g_update', type=int, default=5,
help='Number of iterations the generator waits before '
'being updated')
parser.add_argument('-e', '--num_epochs', type=int, default=500,
help='Number of epochs.')
parser.add_argument('-pix_d', '--pixel_coefficient_distance', type=int,
default=10,
help='The coefficient of the pixel distance loss '
'added to the cost of G.')
parser.add_argument('-g_epoch_id', type=str, default='', help='Since '
'more than 1 G is saved during a run, specific G can '
'be chosen for evaluation')
args, unknown = parser.parse_known_args()
return args
def forty_five_deg_masks(batch_size, phases, high_l):
"""
:param batch_size: batch size for the images for the making of the mask.
:param phases: number of phases.
:param high_l: the length of the high resolution
:return: list of two masks of the 45 degree angle slices along the
z-axis of the 3d (returns masks for both slices of 45 degrees).
"""
over_sqrt_2 = int(high_l/math.sqrt(2)) # high_l in the diagonal
# create the masks:
masks = []
for m in range(high_l - over_sqrt_2):
mask1 = torch.zeros((batch_size, phases, *[high_l] * 3),
dtype=torch.bool)
mask2 = torch.zeros(mask1.size(), dtype=torch.bool)
mask3 = torch.zeros(mask1.size(), dtype=torch.bool)
mask4 = torch.zeros(mask1.size(), dtype=torch.bool)
if m == 0:
for i in range(over_sqrt_2):
mask1[..., m + i, i, :] = True
mask3[..., i + m, -1 - i, :] = True
masks.extend([mask1, mask3])
else:
for i in range(over_sqrt_2):
mask1[..., m + i, i, :] = True
mask2[..., i, m + i, :] = True
mask3[..., i + m, -1 - i, :] = True
mask4[..., i, -1 - (i + m), :] = True
masks.extend([mask1, mask2, mask3, mask4])
return masks
def to_slice(k, forty_five_deg, D_dimensions_to_check):
"""
:param k: axis idx.
:param forty_five_deg: bool determining if to slice in 45 deg.
:param D_dimensions_to_check: The dimensions to check by the user.
:return: When to slice the volume (in which axis/45 deg angles).
"""
if k not in D_dimensions_to_check:
if k != 2:
return False
if not forty_five_deg:
return False
return True
def forty_five_deg_slices(masks, volume_input):
"""
:param masks: the masks of the 45 degree angle slices
:param volume_input: the volume to slice
:return: the two slices (as a tensor of batch size x 2)
"""
tensors = []
batch_size, phases, high_l = volume_input.size()[:3]
for mask in masks:
# the result of the mask on the input:
slice_mask = volume_input[mask].view(batch_size, phases, -1, high_l)
# add the slice after up_sample to wanted size:
tensors.append(interpolate(slice_mask, size=(high_l, high_l), mode=modes[0]))
return torch.cat(tensors, dim=0) # concat tensors along batch_size
def calc_gradient_penalty(netD, real_data, fake_data, batch_size, l, device,
gp_lambda, nc):
"""
calculate gradient penalty for a batch of real and fake data
:param netD: Discriminator network
:param real_data:
:param fake_data:
:param batch_size:
:param l: image size
:param device:
:param gp_lambda: learning parameter for GP
:param nc: channels
:return: gradient penalty
"""
# sample and reshape random numbers
alpha = torch.rand(batch_size, 1, device=device)
num_images = real_data.size()[0]
alpha = alpha.expand(batch_size, int(real_data.numel() /
batch_size)).contiguous()
alpha = alpha.view(num_images, nc, l, l)
# create interpolate dataset
interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach())
interpolates.requires_grad_(True)
# pass interpolates through netD
disc_interpolates = netD(interpolates)
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size(),
device=device),
create_graph=True, only_inputs=True)[0]
# extract the grads and calculate gp
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * gp_lambda
return gradient_penalty
class DownSample(nn.Module):
"""
Calculates the down-sampled version of the generated volume. Can also be
used to generate a low-res volume from a high-res volume for evaluation
reasons.
"""
def __init__(self, squash, n_dims, low_res_idx, scale_factor,
device, super_sampling=False, separator=False):
"""
:param n_dims: 2d to 2d or 3d to 3d.
:param low_res_idx: the indices of phases to down-sample.
:param scale_factor: scale factor between high-res and low-res.
:param squash: if to squash all material phases together for
:param separator: different voxel-wise loss for the separator material.
:param device: The device the object is on.
down-sampling. (when it is hard to distinguish between material phases
in low resolution e.g. SOFC cathode.)
"""
super(DownSample, self).__init__()
self.squash = squash
self.n_dims = n_dims
# Here we want to compare the pore as well:
self.low_res_idx = torch.cat((torch.zeros(1).to(low_res_idx),
low_res_idx))
self.low_res_len = self.low_res_idx.numel() # how many phases
self.scale_factor = scale_factor
self.device = device
self.separator = separator
self.voxel_wise_loss = nn.MSELoss() # the voxel-wise loss
# Calculate the gaussian kernel and make the 3d convolution:
self.gaussian_k = self.calc_gaussian_kernel_3d(self.scale_factor)
# Reshape to convolutional weight
self.gaussian_k = self.gaussian_k.view(1, 1, *self.gaussian_k.size())
self.gaussian_k = self.gaussian_k.repeat(self.low_res_len, *[1] * (
self.gaussian_k.dim() - 1)).to(self.device)
self.groups = self.low_res_len # ensures that each phase will be
# blurred independently.
self.gaussian_conv = functional.conv3d
self.softmax = functional.softmax
self.super_sampling = super_sampling
def voxel_wise_distance(self, generated_im, low_res):
"""
calculates and returns the pixel wise distance between the low-res
image and the down sampling of the high-res generated image.
:return: the normalized distance (divided by the number of pixels of
the low resolution image.)
"""
down_sampled_im = self(generated_im)
if self.separator: # no punishment for making more material where pore
# is in low_res. All low res phases which are not pore are to be
# matched:
low_res = low_res[:, 1:]
down_sampled_im = down_sampled_im[:, 1:]
down_sampled_im = down_sampled_im * low_res
return torch.nn.MSELoss()(low_res, down_sampled_im)
# There is a double error for a mismatch:
mse_loss = torch.nn.MSELoss()(low_res, down_sampled_im)
return mse_loss * self.low_res_len / 2 # to standardize the loss.
def forward(self, generated_im, low_res_input=False):
"""
Apply gaussian filter to the generated image.
"""
# First choose the material phase in the image:
low_res_phases = torch.index_select(generated_im, 1, self.low_res_idx)
if self.squash: # all phases of material are same in low-res
# sum all the material phases:
low_res_phases = torch.sum(low_res_phases, dim=1).unsqueeze(
dim=1)
# if it is super-sampling, return nearest-neighbour interpolation:
if self.super_sampling:
return interpolate(low_res_phases, scale_factor=1 /
self.scale_factor, mode='nearest')
# Then gaussian blur the low res phases generated image:
blurred_im = self.gaussian_conv(input=low_res_phases,
weight=self.gaussian_k,
padding='same', groups=self.groups)
# Then downsample using trilinear interpolation:
blurred_low_res = interpolate(blurred_im,
scale_factor=1 / self.scale_factor,
mode=modes[self.n_dims - 2])
if low_res_input: # calculate a low-res input.
return self.get_low_res_input(blurred_low_res)
# Multiplying the softmax probabilities by a large number to get a
# differentiable argmax function to avoid blocky super-res volumes:
return self.softmax(blurred_low_res*100, dim=1)
def get_low_res_input(self, blurred_image):
"""
If only the low-res input is to be calculated for evaluation study.
:param blurred_image: after the image has been blurred and
down-sampled.
:return: a batch_size X low_res_phases X *low_res_vol_dimensions of a
for a one-hot volume.
"""
# Adding little noise for the (0.5, 0.5) scenarios.
blurred_image += (torch.rand(blurred_image.size(),
device=blurred_image.device) - 0.5) / 1000
num_phases = blurred_image.size()[1]
blurred_image = torch.argmax(blurred_image, dim=1) # find max phase
one_hot_vol = one_hot(blurred_image, num_classes=num_phases)
return one_hot_vol.permute(0, -1, *torch.arange(1, self.n_dims + 1))
@staticmethod
def calc_gaussian_kernel_3d(scale_factor):
"""
:param scale_factor: The scale factor used between the low- and
high-res volumes.
:return: A gaussian blur 3d kernel for blurring before interpolating
"""
ks = math.ceil(scale_factor) # the kernel size
if ks % 2 == 0:
ks -= 1 # if even, the closest odd number from below.
# The same default sigma as in transforms.functional.gaussian_blur:
sigma = 0.3 * ((ks - 1) * 0.5 - 1) + 0.8
ts = torch.linspace(-(ks // 2), ks // 2, ks)
gauss = torch.exp((-(ts / sigma) ** 2 / 2))
kernel_1d = gauss / gauss.sum() # Normalization
# 3d gaussian kernel can be computed in the following way:
kernel_3d = torch.einsum('i,j,k->ijk', kernel_1d, kernel_1d, kernel_1d)
return kernel_3d
if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
downsample_test = DownSample(squash=False, n_dims=3,
low_res_idx=torch.LongTensor([1]).to(device),
scale_factor=4, device=device)
gen_im = torch.zeros(1, 5, 4, 4, 4).to(device)
gen_im[0, 1, 2:,2:,2:] = 1
gen_im[0, 2, :2,:2,:2] = 1
low_res = torch.zeros(1, 2, 1, 1, 1).to(device)
low_res[:] = 1
res1 = downsample_test(gen_im)
res2 = downsample_test(gen_im, low_res_input=True)
loss = downsample_test.voxel_wise_distance(gen_im, low_res)