import matplotlib.pyplot as plt import torch import ImageTools from torch import nn from torch.nn import functional from torch.nn.functional import interpolate from torch.nn.functional import one_hot from torch import autograd import numpy as np import math # just so I don't use numpy by accident k_logistic = 30 # the logistic function coefficient threshold = 0.5 modes = ['bilinear', 'trilinear'] def return_args(parser): parser.add_argument('-d', '--directory', type=str, default='default', help='Stores the progress output in the \ directory name given') parser.add_argument('-sf', '--scale_factor', type=float, default=4, help='scale factor between high res and low res.') parser.add_argument("--down_sample", default=False, action="store_true", help="Down samples the input for G for testing " "purposes.") parser.add_argument("--super_sampling", default=False, action="store_true", help="When comparing super-res " "and low-res, instead of blurring, it picks one voxel " "with nearest-neighbour interpolation.") parser.add_argument("--squash_phases", default=False, action="store_true", help="All material phases in low res are the same.") parser.add_argument("--anisotropic", default=False, action="store_true", help="The material is anisotropic (requires dif Ds).") parser.add_argument("--with_rotation", default=False, action="store_true", help="create rotations and mirrors for the BM.") parser.add_argument("--separator", default=False, action="store_true", help="Different voxel-wise loss for separator " "material.") parser.add_argument('-rotations_bool', nargs='+', type=int, default=[0, 0, 1], help="If the material is " "anisotropic, specify which images can be augmented " "(rotations and mirrors)") parser.add_argument('-g_image_path', type=str, help="Path to the LR " "3D volume") parser.add_argument('-d_image_path', nargs='+', type=str, help="Path to " "the HR 2D slice, if Isotropic, 3 paths are needed, " "in correct order") parser.add_argument('-phases_idx', '--phases_low_res_idx', nargs='+', type=int, default=[1, 2]) parser.add_argument('-d_dimensions', '--d_dimensions_to_check', nargs='+', type=int, default=[0, 1, 2]) parser.add_argument('-volume_size_to_evaluate', nargs='+', type=int, default=[128, 128, 128]) parser.add_argument('-wd', '--widthD', type=int, default=9, help='Hyper-parameter for \ the width of the Discriminator network') parser.add_argument('-wg', '--widthG', type=int, default=8, help='Hyper-parameter for the \ width of the Generator network') parser.add_argument('-n_res', '--n_res_blocks', type=int, default=1, help='Number of residual blocks in the network.') parser.add_argument('-n_dims', '--n_dims', type=int, default=3, help='The generated image dimension (and input ' 'dimension), can be either 2 or 3.') parser.add_argument('-gu', '--g_update', type=int, default=5, help='Number of iterations the generator waits before ' 'being updated') parser.add_argument('-e', '--num_epochs', type=int, default=500, help='Number of epochs.') parser.add_argument('-pix_d', '--pixel_coefficient_distance', type=int, default=10, help='The coefficient of the pixel distance loss ' 'added to the cost of G.') parser.add_argument('-g_epoch_id', type=str, default='', help='Since ' 'more than 1 G is saved during a run, specific G can ' 'be chosen for evaluation') args, unknown = parser.parse_known_args() return args def forty_five_deg_masks(batch_size, phases, high_l): """ :param batch_size: batch size for the images for the making of the mask. :param phases: number of phases. :param high_l: the length of the high resolution :return: list of two masks of the 45 degree angle slices along the z-axis of the 3d (returns masks for both slices of 45 degrees). """ over_sqrt_2 = int(high_l/math.sqrt(2)) # high_l in the diagonal # create the masks: masks = [] for m in range(high_l - over_sqrt_2): mask1 = torch.zeros((batch_size, phases, *[high_l] * 3), dtype=torch.bool) mask2 = torch.zeros(mask1.size(), dtype=torch.bool) mask3 = torch.zeros(mask1.size(), dtype=torch.bool) mask4 = torch.zeros(mask1.size(), dtype=torch.bool) if m == 0: for i in range(over_sqrt_2): mask1[..., m + i, i, :] = True mask3[..., i + m, -1 - i, :] = True masks.extend([mask1, mask3]) else: for i in range(over_sqrt_2): mask1[..., m + i, i, :] = True mask2[..., i, m + i, :] = True mask3[..., i + m, -1 - i, :] = True mask4[..., i, -1 - (i + m), :] = True masks.extend([mask1, mask2, mask3, mask4]) return masks def to_slice(k, forty_five_deg, D_dimensions_to_check): """ :param k: axis idx. :param forty_five_deg: bool determining if to slice in 45 deg. :param D_dimensions_to_check: The dimensions to check by the user. :return: When to slice the volume (in which axis/45 deg angles). """ if k not in D_dimensions_to_check: if k != 2: return False if not forty_five_deg: return False return True def forty_five_deg_slices(masks, volume_input): """ :param masks: the masks of the 45 degree angle slices :param volume_input: the volume to slice :return: the two slices (as a tensor of batch size x 2) """ tensors = [] batch_size, phases, high_l = volume_input.size()[:3] for mask in masks: # the result of the mask on the input: slice_mask = volume_input[mask].view(batch_size, phases, -1, high_l) # add the slice after up_sample to wanted size: tensors.append(interpolate(slice_mask, size=(high_l, high_l), mode=modes[0])) return torch.cat(tensors, dim=0) # concat tensors along batch_size def calc_gradient_penalty(netD, real_data, fake_data, batch_size, l, device, gp_lambda, nc): """ calculate gradient penalty for a batch of real and fake data :param netD: Discriminator network :param real_data: :param fake_data: :param batch_size: :param l: image size :param device: :param gp_lambda: learning parameter for GP :param nc: channels :return: gradient penalty """ # sample and reshape random numbers alpha = torch.rand(batch_size, 1, device=device) num_images = real_data.size()[0] alpha = alpha.expand(batch_size, int(real_data.numel() / batch_size)).contiguous() alpha = alpha.view(num_images, nc, l, l) # create interpolate dataset interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach()) interpolates.requires_grad_(True) # pass interpolates through netD disc_interpolates = netD(interpolates) gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones(disc_interpolates.size(), device=device), create_graph=True, only_inputs=True)[0] # extract the grads and calculate gp gradients = gradients.view(gradients.size(0), -1) gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * gp_lambda return gradient_penalty class DownSample(nn.Module): """ Calculates the down-sampled version of the generated volume. Can also be used to generate a low-res volume from a high-res volume for evaluation reasons. """ def __init__(self, squash, n_dims, low_res_idx, scale_factor, device, super_sampling=False, separator=False): """ :param n_dims: 2d to 2d or 3d to 3d. :param low_res_idx: the indices of phases to down-sample. :param scale_factor: scale factor between high-res and low-res. :param squash: if to squash all material phases together for :param separator: different voxel-wise loss for the separator material. :param device: The device the object is on. down-sampling. (when it is hard to distinguish between material phases in low resolution e.g. SOFC cathode.) """ super(DownSample, self).__init__() self.squash = squash self.n_dims = n_dims # Here we want to compare the pore as well: self.low_res_idx = torch.cat((torch.zeros(1).to(low_res_idx), low_res_idx)) self.low_res_len = self.low_res_idx.numel() # how many phases self.scale_factor = scale_factor self.device = device self.separator = separator self.voxel_wise_loss = nn.MSELoss() # the voxel-wise loss # Calculate the gaussian kernel and make the 3d convolution: self.gaussian_k = self.calc_gaussian_kernel_3d(self.scale_factor) # Reshape to convolutional weight self.gaussian_k = self.gaussian_k.view(1, 1, *self.gaussian_k.size()) self.gaussian_k = self.gaussian_k.repeat(self.low_res_len, *[1] * ( self.gaussian_k.dim() - 1)).to(self.device) self.groups = self.low_res_len # ensures that each phase will be # blurred independently. self.gaussian_conv = functional.conv3d self.softmax = functional.softmax self.super_sampling = super_sampling def voxel_wise_distance(self, generated_im, low_res): """ calculates and returns the pixel wise distance between the low-res image and the down sampling of the high-res generated image. :return: the normalized distance (divided by the number of pixels of the low resolution image.) """ down_sampled_im = self(generated_im) if self.separator: # no punishment for making more material where pore # is in low_res. All low res phases which are not pore are to be # matched: low_res = low_res[:, 1:] down_sampled_im = down_sampled_im[:, 1:] down_sampled_im = down_sampled_im * low_res return torch.nn.MSELoss()(low_res, down_sampled_im) # There is a double error for a mismatch: mse_loss = torch.nn.MSELoss()(low_res, down_sampled_im) return mse_loss * self.low_res_len / 2 # to standardize the loss. def forward(self, generated_im, low_res_input=False): """ Apply gaussian filter to the generated image. """ # First choose the material phase in the image: low_res_phases = torch.index_select(generated_im, 1, self.low_res_idx) if self.squash: # all phases of material are same in low-res # sum all the material phases: low_res_phases = torch.sum(low_res_phases, dim=1).unsqueeze( dim=1) # if it is super-sampling, return nearest-neighbour interpolation: if self.super_sampling: return interpolate(low_res_phases, scale_factor=1 / self.scale_factor, mode='nearest') # Then gaussian blur the low res phases generated image: blurred_im = self.gaussian_conv(input=low_res_phases, weight=self.gaussian_k, padding='same', groups=self.groups) # Then downsample using trilinear interpolation: blurred_low_res = interpolate(blurred_im, scale_factor=1 / self.scale_factor, mode=modes[self.n_dims - 2]) if low_res_input: # calculate a low-res input. return self.get_low_res_input(blurred_low_res) # Multiplying the softmax probabilities by a large number to get a # differentiable argmax function to avoid blocky super-res volumes: return self.softmax(blurred_low_res*100, dim=1) def get_low_res_input(self, blurred_image): """ If only the low-res input is to be calculated for evaluation study. :param blurred_image: after the image has been blurred and down-sampled. :return: a batch_size X low_res_phases X *low_res_vol_dimensions of a for a one-hot volume. """ # Adding little noise for the (0.5, 0.5) scenarios. blurred_image += (torch.rand(blurred_image.size(), device=blurred_image.device) - 0.5) / 1000 num_phases = blurred_image.size()[1] blurred_image = torch.argmax(blurred_image, dim=1) # find max phase one_hot_vol = one_hot(blurred_image, num_classes=num_phases) return one_hot_vol.permute(0, -1, *torch.arange(1, self.n_dims + 1)) @staticmethod def calc_gaussian_kernel_3d(scale_factor): """ :param scale_factor: The scale factor used between the low- and high-res volumes. :return: A gaussian blur 3d kernel for blurring before interpolating """ ks = math.ceil(scale_factor) # the kernel size if ks % 2 == 0: ks -= 1 # if even, the closest odd number from below. # The same default sigma as in transforms.functional.gaussian_blur: sigma = 0.3 * ((ks - 1) * 0.5 - 1) + 0.8 ts = torch.linspace(-(ks // 2), ks // 2, ks) gauss = torch.exp((-(ts / sigma) ** 2 / 2)) kernel_1d = gauss / gauss.sum() # Normalization # 3d gaussian kernel can be computed in the following way: kernel_3d = torch.einsum('i,j,k->ijk', kernel_1d, kernel_1d, kernel_1d) return kernel_3d if __name__ == '__main__': device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") downsample_test = DownSample(squash=False, n_dims=3, low_res_idx=torch.LongTensor([1]).to(device), scale_factor=4, device=device) gen_im = torch.zeros(1, 5, 4, 4, 4).to(device) gen_im[0, 1, 2:,2:,2:] = 1 gen_im[0, 2, :2,:2,:2] = 1 low_res = torch.zeros(1, 2, 1, 1, 1).to(device) low_res[:] = 1 res1 = downsample_test(gen_im) res2 = downsample_test(gen_im, low_res_input=True) loss = downsample_test.voxel_wise_distance(gen_im, low_res)