diff --git a/.idea/deployment.xml b/.idea/deployment.xml new file mode 100644 index 0000000..1e934ef --- /dev/null +++ b/.idea/deployment.xml @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/webServers.xml b/.idea/webServers.xml new file mode 100644 index 0000000..02486ba --- /dev/null +++ b/.idea/webServers.xml @@ -0,0 +1,15 @@ + + + + + + \ No newline at end of file diff --git a/README.md b/README.md index 8fc9c61..bdda839 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,62 @@ -# DIM +# Deep InfoMax (DIM) + +[UPDATE]: this work has been accepted as an oral presentation at ICLR 2019. +We are gradually updating the repository over the next few weeks to reflect experiments in the camera-ready version. + Learning Deep Representations by Mutual Information Estimation and Maximization +Sample code to do the local-only objective in +https://openreview.net/forum?id=Bklr3j0cKX +https://arxiv.org/abs/1808.06670 + +### Completed +[Updated 1/15/2019] +* Latest code for dot-product style scoring function for local DIM (single or multiple globals). +* NCE / DV losses. +* Convnet and folded convnet (strided crops) architectures. + +### TODO +* Resnet and folded resnet architectures and training classifiers keeping the encoder fixed (evaluation). +* NDM, MINE, SVM, and MS-SSIM evaluation. +* Global DIM and prior matching. +* Coordinate and occlusion tasks. +* Add nearest neighbor analysis. + +### Installation / requirements + +This is a package, so to install just run: + + $ pip install . + +This package installs the dev branch cortex: https://github.com/rdevon/cortex + +Which requires Python 3.5+ (Not tested on higher than 3.7). Note that cortex is in early beta stages, but it is usable for this demo. + +cortex optionally requires visdom: https://github.com/pytorch/vision + +You will need to do: + + $ cortex setup + +See the cortex README for more info or email us (or submit an issue for legitimate bugs). + +### Usage + +To get the full set of commands, try: + + $ python scripts/deep_infomax.py --help + +For CIFAR10 on a DCGAN architecture, try: + + $ python scripts/deep_infomax.py --d.source CIFAR10 -n DIM_CIFAR10 --d.copy_to_local --t.epochs 1000 + +You should get over 71-72% in the pretraining step alone (this was included for monitoring purposes only). +Note, this wont get you all the way towards reproducing results in the paper: for this the classifier needs to be retrained with the encoder held fixed. +Support for training a classifier with the representations fixed is coming soon. + +For STL-10 on folded 64x64 Alexnet (strided crops) with multiple globals and the noise-contrastive estimation type loss, try: + + $ python scripts/deep_infomax.py --d.sources STL10 --d.data_args "dict(stl_resize_only=True)" --d.n_workers 32 -n DIM_STL --t.epochs 200 --d.copy_to_local --encoder_config foldmultialex64x64 --mode nce --global_units 0 + +### Deep Infomax + +TODO: visual guide to Deep Infomax. \ No newline at end of file diff --git a/cortex_DIM/__init__.py b/cortex_DIM/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cortex_DIM/configs/__init__.py b/cortex_DIM/configs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cortex_DIM/configs/convnets.py b/cortex_DIM/configs/convnets.py new file mode 100644 index 0000000..77c5e84 --- /dev/null +++ b/cortex_DIM/configs/convnets.py @@ -0,0 +1,98 @@ +'''Basic convnet hyperparameters. + +conv_args are in format (dim_h, f_size, stride, pad batch_norm, dropout, nonlinearity, pool) +fc_args are in format (dim_h, batch_norm, dropout, nonlinearity) + +''' + +from cortex_DIM.nn_modules.encoder import ConvnetEncoder, FoldedConvnetEncoder + + +# Basic DCGAN-like encoders + +_basic28x28 = dict( + Encoder=ConvnetEncoder, + conv_args=[(64, 5, 2, 2, True, False, 'ReLU', None), + (128, 5, 2, 2, True, False, 'ReLU', None)], + fc_args=[(1024, True, False, 'ReLU', None)], + local_idx=1, + fc_idx=0 +) + +_basic32x32 = dict( + Encoder=ConvnetEncoder, + conv_args=[(64, 4, 2, 1, True, False, 'ReLU', None), + (128, 4, 2, 1, True, False, 'ReLU', None), + (256, 4, 2, 1, True, False, 'ReLU', None)], + fc_args=[(1024, True, False, 'ReLU')], + local_idx=1, + conv_idx=2, + fc_idx=0 +) + +_basic64x64 = dict( + Encoder=ConvnetEncoder, + conv_args=[(64, 4, 2, 1, True, False, 'ReLU', None), + (128, 4, 2, 1, True, False, 'ReLU', None), + (256, 4, 2, 1, True, False, 'ReLU', None), + (512, 4, 2, 1, True, False, 'ReLU', None)], + fc_args=[(1024, True, False, 'ReLU')], + local_idx=2, + conv_idx=3, + fc_idx=0 +) + +# Alexnet-like encoders + +_alex64x64 = dict( + Encoder=ConvnetEncoder, + conv_args=[(96, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)), + (192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)), + (384, 3, 1, 1, True, False, 'ReLU', None), + (384, 3, 1, 1, True, False, 'ReLU', None), + (192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2))], + fc_args=[(4096, True, False, 'ReLU'), + (4096, True, False, 'ReLU')], + local_idx=2, + conv_idx=4, + fc_idx=1 +) + +_foldalex64x64 = dict( + Encoder=FoldedConvnetEncoder, + crop_size=16, + conv_args=[(96, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)), + (192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)), + (384, 3, 1, 1, True, False, 'ReLU', None), + (384, 3, 1, 1, True, False, 'ReLU', None), + (192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2))], + fc_args=[(4096, True, False, 'ReLU'), + (4096, True, False, 'ReLU')], + local_idx=4, + fc_idx=1 +) + +_foldmultialex64x64 = dict( + Encoder=FoldedConvnetEncoder, + crop_size=16, + conv_args=[(96, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)), + (192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)), + (384, 3, 1, 1, True, False, 'ReLU', None), + (384, 3, 1, 1, True, False, 'ReLU', None), + (192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)), + (192, 3, 1, 0, True, False, 'ReLU', None), + (192, 1, 1, 0, True, False, 'ReLU', None)], + fc_args=[(4096, True, False, 'ReLU')], + local_idx=4, + multi_idx=6, + fc_idx=1 +) + +configs = dict( + basic28x28=_basic28x28, + basic32x32=_basic32x32, + basic64x64=_basic64x64, + alex64x64=_alex64x64, + foldalex64x64=_foldalex64x64, + foldmultialex64x64=_foldmultialex64x64 +) \ No newline at end of file diff --git a/cortex_DIM/configs/resnets.py b/cortex_DIM/configs/resnets.py new file mode 100644 index 0000000..46ef6fd --- /dev/null +++ b/cortex_DIM/configs/resnets.py @@ -0,0 +1,151 @@ +"""Configurations for ResNets + +""" + +from cortex_DIM.networks.dim_encoders import DIMResnet, DIMFoldedResnet + + +_resnet19_32x32 = dict( + Encoder=DIMResnet, + conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)], + res_args=[ + ([(64, 1, 1, 0, True, False, 'ReLU', None), + (64, 3, 1, 1, True, False, 'ReLU', None), + (64 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 1), + ([(64, 1, 1, 0, True, False, 'ReLU', None), + (64, 3, 1, 1, True, False, 'ReLU', None), + (64 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 1), + ([(128, 1, 1, 0, True, False, 'ReLU', None), + (128, 3, 2, 1, True, False, 'ReLU', None), + (128 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 1), + ([(128, 1, 1, 0, True, False, 'ReLU', None), + (128, 3, 1, 1, True, False, 'ReLU', None), + (128 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 1), + ([(256, 1, 1, 0, True, False, 'ReLU', None), + (256, 3, 2, 1, True, False, 'ReLU', None), + (256 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 1), + ([(256, 1, 1, 0, True, False, 'ReLU', None), + (256, 3, 1, 1, True, False, 'ReLU', None), + (256 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 1) + ], + fc_args=[(1024, True, False, 'ReLU')], + local_idx=4, + fc_idx=0 +) + +_foldresnet19_32x32 = dict( + Encoder=DIMFoldedResnet, + crop_size=8, + conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)], + res_args=[ + ([(64, 1, 1, 0, True, False, 'ReLU', None), + (64, 3, 1, 1, True, False, 'ReLU', None), + (64 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 1), + ([(64, 1, 1, 0, True, False, 'ReLU', None), + (64, 3, 1, 1, True, False, 'ReLU', None), + (64 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 1), + ([(128, 1, 1, 0, True, False, 'ReLU', None), + (128, 3, 2, 1, True, False, 'ReLU', None), + (128 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 1), + ([(128, 1, 1, 0, True, False, 'ReLU', None), + (128, 3, 1, 1, True, False, 'ReLU', None), + (128 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 1), + ([(256, 1, 1, 0, True, False, 'ReLU', None), + (256, 3, 2, 1, True, False, 'ReLU', None), + (256 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 1), + ([(256, 1, 1, 0, True, False, 'ReLU', None), + (256, 3, 1, 1, True, False, 'ReLU', None), + (256 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 1) + ], + fc_args=[(1024, True, False, 'ReLU')], + local_idx=6, + fc_idx=0 +) + +_resnet34_32x32 = dict( + Encoder=DIMResnet, + conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)], + res_args=[ + ([(64, 1, 1, 0, True, False, 'ReLU', None), + (64, 3, 1, 1, True, False, 'ReLU', None), + (64 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 1), + ([(64, 1, 1, 0, True, False, 'ReLU', None), + (64, 3, 1, 1, True, False, 'ReLU', None), + (64 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 2), + ([(128, 1, 1, 0, True, False, 'ReLU', None), + (128, 3, 2, 1, True, False, 'ReLU', None), + (128 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 1), + ([(128, 1, 1, 0, True, False, 'ReLU', None), + (128, 3, 1, 1, True, False, 'ReLU', None), + (128 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 5), + ([(256, 1, 1, 0, True, False, 'ReLU', None), + (256, 3, 2, 1, True, False, 'ReLU', None), + (256 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 1), + ([(256, 1, 1, 0, True, False, 'ReLU', None), + (256, 3, 1, 1, True, False, 'ReLU', None), + (256 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 2) + ], + fc_args=[(1024, True, False, 'ReLU')], + local_idx=2, + fc_idx=0 +) + +_foldresnet34_32x32 = dict( + Encoder=DIMFoldedResnet, + crop_size=8, + conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)], + res_args=[ + ([(64, 1, 1, 0, True, False, 'ReLU', None), + (64, 3, 1, 1, True, False, 'ReLU', None), + (64 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 1), + ([(64, 1, 1, 0, True, False, 'ReLU', None), + (64, 3, 1, 1, True, False, 'ReLU', None), + (64 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 2), + ([(128, 1, 1, 0, True, False, 'ReLU', None), + (128, 3, 2, 1, True, False, 'ReLU', None), + (128 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 1), + ([(128, 1, 1, 0, True, False, 'ReLU', None), + (128, 3, 1, 1, True, False, 'ReLU', None), + (128 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 5), + ([(256, 1, 1, 0, True, False, 'ReLU', None), + (256, 3, 2, 1, True, False, 'ReLU', None), + (256 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 1), + ([(256, 1, 1, 0, True, False, 'ReLU', None), + (256, 3, 1, 1, True, False, 'ReLU', None), + (256 * 4, 1, 1, 0, True, False, 'ReLU', None)], + 2) + ], + fc_args=[(1024, True, False, 'ReLU')], + local_idx=12, + fc_idx=0 +) + +configs = dict( + resnet19_32x32=_resnet19_32x32, + resnet34_32x32=_resnet34_32x32, + foldresnet19_32x32=_foldresnet19_32x32, + foldresnet34_32x32=_foldresnet34_32x32 +) \ No newline at end of file diff --git a/cortex_DIM/functions/__init__.py b/cortex_DIM/functions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cortex_DIM/functions/dim_losses.py b/cortex_DIM/functions/dim_losses.py new file mode 100644 index 0000000..7de64ea --- /dev/null +++ b/cortex_DIM/functions/dim_losses.py @@ -0,0 +1,224 @@ +'''cortex_DIM losses. + +''' + +import math + +import torch +import torch.nn.functional as F + +from cortex_DIM.functions.gan_losses import get_positive_expectation, get_negative_expectation + + +def fenchel_dual_loss(l, g, measure=None): + '''Computes the f-divergence distance between positive and negative joint distributions. + + Divergences supported are Jensen-Shannon `JSD`, `GAN` (equivalent to JSD), + Squared Hellinger `H2`, Chi-squeared `X2`, `KL`, and reverse KL `RKL`. + + Args: + l: Local feature map. + g: Global features. + measure: f-divergence measure. + + Returns: + torch.Tensor: Loss. + + ''' + N, local_units, n_locs = l.size() + l = l.permute(0, 2, 1) + l = l.reshape(-1, local_units) + + u = torch.mm(g, l.t()) + u = u.reshape(N, N, -1) + mask = torch.eye(N).cuda() + n_mask = 1 - mask + + E_pos = get_positive_expectation(u, measure, average=False).mean(2) + E_neg = get_negative_expectation(u, measure, average=False).mean(2) + E_pos = (E_pos * mask).sum() / mask.sum() + E_neg = (E_neg * n_mask).sum() / n_mask.sum() + loss = E_neg - E_pos + return loss + + +def multi_fenchel_dual_loss(l, m, measure=None): + '''Computes the f-divergence distance between positive and negative joint distributions. + + Used for multiple globals. + + Divergences supported are Jensen-Shannon `JSD`, `GAN` (equivalent to JSD), + Squared Hellinger `H2`, Chi-squeared `X2`, `KL`, and reverse KL `RKL`. + + Args: + l: Local feature map. + m: Multiple globals feature map. + measure: f-divergence measure. + + Returns: + torch.Tensor: Loss. + + ''' + N, units, n_locals = l.size() + n_multis = m.size(2) + + l = l.view(N, units, n_locals) + l = l.permute(0, 2, 1) + l = l.reshape(-1, units) + + m = m.view(N, units, n_multis) + m = m.permute(0, 2, 1) + m = m.reshape(-1, units) + + u = torch.mm(m, l.t()) + u = u.reshape(N, n_multis, N, n_locals).permute(0, 2, 3, 1) + + mask = torch.eye(N).cuda() + n_mask = 1 - mask + + E_pos = get_positive_expectation(u, measure, average=False).mean(2).mean(2) + E_neg = get_negative_expectation(u, measure, average=False).mean(2).mean(2) + E_pos = (E_pos * mask).sum() / mask.sum() + E_neg = (E_neg * n_mask).sum() / n_mask.sum() + loss = E_neg - E_pos + return loss + + +def nce_loss(l, g): + '''Computes the noise contrastive estimation-based loss. + + Args: + l: Local feature map. + g: Global features. + + Returns: + torch.Tensor: Loss. + + ''' + N, local_units, n_locs = l.size() + l_p = l.permute(0, 2, 1) + u_p = torch.matmul(l_p, g.unsqueeze(dim=2)) + + l_n = l_p.reshape(-1, local_units) + u_n = torch.mm(g, l_n.t()) + u_n = u_n.reshape(N, N, n_locs) + + mask = torch.eye(N).unsqueeze(dim=2).cuda() + n_mask = 1 - mask + + u_n = (n_mask * u_n) - (10. * (1 - n_mask)) # mask out "self" examples + u_n = u_n.reshape(N, -1).unsqueeze(dim=1).expand(-1, n_locs, -1) + + pred_lgt = torch.cat([u_p, u_n], dim=2) + pred_log = F.log_softmax(pred_lgt, dim=2) + loss = -pred_log[:, :, 0].mean() + return loss + + +def multi_nce_loss(l, m): + ''' + + Used for multiple globals. + + Args: + l: Local feature map. + m: Multiple globals feature map. + + Returns: + torch.Tensor: Loss. + + ''' + N, units, n_locals = l.size() + _, _ , n_multis = m.size() + + l = l.view(N, units, n_locals) + m = m.view(N, units, n_multis) + l_p = l.permute(0, 2, 1) + m_p = m.permute(0, 2, 1) + u_p = torch.matmul(l_p, m).unsqueeze(2) + + l_n = l_p.reshape(-1, units) + m_n = m_p.reshape(-1, units) + u_n = torch.mm(m_n, l_n.t()) + u_n = u_n.reshape(N, n_multis, N, n_locals).permute(0, 2, 3, 1) + + mask = torch.eye(N)[:, :, None, None].cuda() + n_mask = 1 - mask + + u_n = (n_mask * u_n) - (10. * (1 - n_mask)) # mask out "self" examples + u_n = u_n.reshape(N, N * n_locals, n_multis).unsqueeze(dim=1).expand(-1, n_locals, -1, -1) + + pred_lgt = torch.cat([u_p, u_n], dim=2) + pred_log = F.log_softmax(pred_lgt, dim=2) + loss = -pred_log[:, :, 0].mean() + + return loss + + +def donsker_varadhan_loss(l, g): + ''' + + Args: + l: Local feature map. + g: Global features. + + Returns: + torch.Tensor: Loss. + + ''' + N, local_units, n_locs = l.size() + l = l.permute(0, 2, 1) + l = l.reshape(-1, local_units) + + u = torch.mm(g, l.t()) + u = u.reshape(N, N, n_locs) + + mask = torch.eye(N).cuda() + n_mask = (1 - mask)[:, :, None] + + E_pos = (u.mean(2) * mask).sum() / mask.sum() + + u -= 100 * (1 - n_mask) + u_max = torch.max(u) + E_neg = torch.log((n_mask * torch.exp(u - u_max)).sum() + 1e-6) + u_max - math.log(n_mask.sum()) + loss = E_neg - E_pos + return loss + + +def multi_donsker_varadhan_loss(l, m): + ''' + + Used for multiple globals. + + Args: + l: Local feature map. + m: Multiple globals feature map. + + Returns: + torch.Tensor: Loss. + + ''' + N, units, n_locals = l.size() + n_multis = m.size(2) + + l = l.view(N, units, n_locals) + l = l.permute(0, 2, 1) + l = l.reshape(-1, units) + + m = m.view(N, units, n_multis) + m = m.permute(0, 2, 1) + m = m.reshape(-1, units) + + u = torch.mm(m, l.t()) + u = u.reshape(N, n_multis, N, n_locals).permute(0, 2, 3, 1) + + mask = torch.eye(N).cuda() + n_mask = 1 - mask + + E_pos = (u.mean(2) * mask).sum() / mask.sum() + + u -= 100 * (1 - n_mask) + u_max = torch.max(u) + E_neg = torch.log((n_mask * torch.exp(u - u_max)).sum() + 1e-6) + u_max - math.log(n_mask.sum()) + loss = E_neg - E_pos + return loss \ No newline at end of file diff --git a/cortex_DIM/functions/gan_losses.py b/cortex_DIM/functions/gan_losses.py new file mode 100644 index 0000000..55d1848 --- /dev/null +++ b/cortex_DIM/functions/gan_losses.py @@ -0,0 +1,95 @@ +""" + +""" + +import math + +import torch +import torch.nn.functional as F + +from cortex_DIM.functions.misc import log_sum_exp + + +def raise_measure_error(measure): + supported_measures = ['GAN', 'JSD', 'X2', 'KL', 'RKL', 'DV', 'H2', 'W1'] + raise NotImplementedError( + 'Measure `{}` not supported. Supported: {}'.format(measure, + supported_measures)) + + +def get_positive_expectation(p_samples, measure, average=True): + """Computes the positive part of a divergence / difference. + + Args: + p_samples: Positive samples. + measure: Measure to compute for. + average: Average the result over samples. + + Returns: + torch.Tensor + + """ + log_2 = math.log(2.) + + if measure == 'GAN': + Ep = - F.softplus(-p_samples) + elif measure == 'JSD': + Ep = log_2 - F.softplus(- p_samples) + elif measure == 'X2': + Ep = p_samples ** 2 + elif measure == 'KL': + Ep = p_samples + 1. + elif measure == 'RKL': + Ep = -torch.exp(-p_samples) + elif measure == 'DV': + Ep = p_samples + elif measure == 'H2': + Ep = 1. - torch.exp(-p_samples) + elif measure == 'W1': + Ep = p_samples + else: + raise_measure_error(measure) + + if average: + return Ep.mean() + else: + return Ep + + +def get_negative_expectation(q_samples, measure, average=True): + """Computes the negative part of a divergence / difference. + + Args: + q_samples: Negative samples. + measure: Measure to compute for. + average: Average the result over samples. + + Returns: + torch.Tensor + + """ + log_2 = math.log(2.) + + if measure == 'GAN': + Eq = F.softplus(-q_samples) + q_samples + elif measure == 'JSD': + Eq = F.softplus(-q_samples) + q_samples - log_2 + elif measure == 'X2': + Eq = -0.5 * ((torch.sqrt(q_samples ** 2) + 1.) ** 2) + elif measure == 'KL': + Eq = torch.exp(q_samples) + elif measure == 'RKL': + Eq = q_samples - 1. + elif measure == 'DV': + Eq = log_sum_exp(q_samples, 0) - math.log(q_samples.size(0)) + elif measure == 'H2': + Eq = torch.exp(q_samples) - 1. + elif measure == 'W1': + Eq = q_samples + else: + raise_measure_error(measure) + + if average: + return Eq.mean() + else: + return Eq \ No newline at end of file diff --git a/cortex_DIM/functions/misc.py b/cortex_DIM/functions/misc.py new file mode 100644 index 0000000..5ff72de --- /dev/null +++ b/cortex_DIM/functions/misc.py @@ -0,0 +1,39 @@ +"""Miscilaneous functions. + +""" + +import torch + + +def log_sum_exp(x, axis=None): + """Log sum exp function + + Args: + x: Input. + axis: Axis over which to perform sum. + + Returns: + torch.Tensor: log sum exp + + """ + x_max = torch.max(x, axis)[0] + y = torch.log((torch.exp(x - x_max)).sum(axis)) + x_max + return y + + +def random_permute(X): + """Randomly permutes a tensor. + + Args: + X: Input tensor. + + Returns: + torch.Tensor + + """ + X = X.transpose(1, 2) + b = torch.rand((X.size(0), X.size(1))).cuda() + idx = b.sort(0)[1] + adx = torch.range(0, X.size(1) - 1).long() + X = X[idx, adx[None, :]].transpose(1, 2) + return X diff --git a/cortex_DIM/nn_modules/__init__.py b/cortex_DIM/nn_modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cortex_DIM/nn_modules/convnet.py b/cortex_DIM/nn_modules/convnet.py new file mode 100644 index 0000000..6a5ac50 --- /dev/null +++ b/cortex_DIM/nn_modules/convnet.py @@ -0,0 +1,352 @@ +'''Convnet encoder module. + +''' + +import torch +import torch.nn as nn + +from cortex.built_ins.networks.utils import get_nonlinearity + +from cortex_DIM.nn_modules.misc import Fold, Unfold, View + + +def infer_conv_size(w, k, s, p): + '''Infers the next size after convolution. + + Args: + w: Input size. + k: Kernel size. + s: Stride. + p: Padding. + + Returns: + int: Output size. + + ''' + x = (w - k + 2 * p) // s + 1 + return x + + +class Convnet(nn.Module): + '''Basic convnet convenience class. + + Attributes: + conv_layers: nn.Sequential of nn.Conv2d layers with batch norm, + dropout, nonlinearity. + fc_layers: nn.Sequential of nn.Linear layers with batch norm, + dropout, nonlinearity. + reshape: Simple reshape layer. + conv_shape: Shape of the convolutional output. + + ''' + + def __init__(self, *args, **kwargs): + super().__init__() + self.create_layers(*args, **kwargs) + + def create_layers(self, shape, conv_args=None, fc_args=None): + '''Creates layers + + conv_args are in format (dim_h, f_size, stride, pad, batch_norm, dropout, nonlinearity, pool) + fc_args are in format (dim_h, batch_norm, dropout, nonlinearity) + + Args: + shape: Shape of input. + conv_args: List of tuple of convolutional arguments. + fc_args: List of tuple of fully-connected arguments. + ''' + + self.conv_layers, self.conv_shape = self.create_conv_layers(shape, conv_args) + + dim_x, dim_y, dim_out = self.conv_shape + dim_r = dim_x * dim_y * dim_out + self.reshape = View(-1, dim_r) + self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args) + + def create_conv_layers(self, shape, conv_args): + '''Creates a set of convolutional layers. + + Args: + shape: Input shape. + conv_args: List of tuple of convolutional arguments. + + Returns: + nn.Sequential: a sequence of convolutional layers. + + ''' + + conv_layers = nn.Sequential() + conv_args = conv_args or [] + + dim_x, dim_y, dim_in = shape + + for i, (dim_out, f, s, p, batch_norm, dropout, nonlinearity, pool) in enumerate(conv_args): + name = '({}/{})_{}'.format(dim_in, dim_out, i + 1) + conv_block = nn.Sequential() + + if dim_out is not None: + conv = nn.Conv2d(dim_in, dim_out, kernel_size=f, stride=s, padding=p, bias=not(batch_norm)) + nn.init.kaiming_normal_(conv.weight, mode='fan_out', nonlinearity='relu') + conv_block.add_module(name + 'conv', conv) + dim_x, dim_y = self.next_size(dim_x, dim_y, f, s, p) + else: + dim_out = dim_in + + if dropout: + conv_block.add_module(name + 'do', nn.Dropout2d(p=dropout)) + if batch_norm: + bn = nn.BatchNorm2d(dim_out) + conv_block.add_module(name + 'bn', bn) + + if nonlinearity: + nonlinearity = get_nonlinearity(nonlinearity) + conv_block.add_module(nonlinearity.__class__.__name__, nonlinearity) + + if pool: + (pool_type, kernel, stride) = pool + Pool = getattr(nn, pool_type) + conv_block.add_module(name + 'pool', Pool(kernel_size=kernel, stride=stride)) + dim_x, dim_y = self.next_size(dim_x, dim_y, kernel, stride, 0) + + conv_layers.add_module(name, conv_block) + + dim_in = dim_out + + dim_out = dim_in + + return conv_layers, (dim_x, dim_y, dim_out) + + def create_linear_layers(self, dim_in, fc_args): + ''' + + Args: + dim_in: Number of input units. + fc_args: List of tuple of fully-connected arguments. + + Returns: + nn.Sequential. + + ''' + + fc_layers = nn.Sequential() + fc_args = fc_args or [] + + for i, (dim_out, batch_norm, dropout, nonlinearity) in enumerate(fc_args): + name = '({}/{})_{}'.format(dim_in, dim_out, i + 1) + fc_block = nn.Sequential() + + if dim_out is not None: + fc_block.add_module(name + 'fc', nn.Linear(dim_in, dim_out)) + else: + dim_out = dim_in + + if dropout: + fc_block.add_module(name + 'do', nn.Dropout(p=dropout)) + if batch_norm: + bn = nn.BatchNorm1d(dim_out) + fc_block.add_module(name + 'bn', bn) + + if nonlinearity: + nonlinearity = get_nonlinearity(nonlinearity) + fc_block.add_module(nonlinearity.__class__.__name__, nonlinearity) + + fc_layers.add_module(name, fc_block) + + dim_in = dim_out + + return fc_layers, dim_in + + def next_size(self, dim_x, dim_y, k, s, p): + '''Infers the next size of a convolutional layer. + + Args: + dim_x: First dimension. + dim_y: Second dimension. + k: Kernel size. + s: Stride. + p: Padding. + + Returns: + (int, int): (First output dimension, Second output dimension) + + ''' + if isinstance(k, int): + kx, ky = (k, k) + else: + kx, ky = k + + if isinstance(s, int): + sx, sy = (s, s) + else: + sx, sy = s + + if isinstance(p, int): + px, py = (p, p) + else: + px, py = p + return (infer_conv_size(dim_x, kx, sx, px), + infer_conv_size(dim_y, ky, sy, py)) + + def forward(self, x: torch.Tensor, return_full_list=False): + '''Forward pass + + Args: + x: Input. + return_full_list: Optional, returns all layer outputs. + + Returns: + torch.Tensor or list of torch.Tensor. + + ''' + if return_full_list: + conv_out = [] + for conv_layer in self.conv_layers: + x = conv_layer(x) + conv_out.append(x) + else: + conv_out = self.conv_layers(x) + x = conv_out + + x = self.reshape(x) + + if return_full_list: + fc_out = [] + for fc_layer in self.fc_layers: + x = fc_layer(x) + fc_out.append(x) + else: + fc_out = self.fc_layers(x) + + return conv_out, fc_out + + +class FoldedConvnet(Convnet): + '''Convnet with strided crop input. + + ''' + + def create_layers(self, shape, crop_size=8, conv_args=None, fc_args=None): + '''Creates layers + + conv_args are in format (dim_h, f_size, stride, pad, batch_norm, dropout, nonlinearity, pool) + fc_args are in format (dim_h, batch_norm, dropout, nonlinearity) + + Args: + shape: Shape of input. + crop_size: Size of crops + conv_args: List of tuple of convolutional arguments. + fc_args: List of tuple of fully-connected arguments. + ''' + + self.crop_size = crop_size + + dim_x, dim_y, dim_in = shape + if dim_x != dim_y: + raise ValueError('x and y dimensions must be the same to use Folded encoders.') + + self.final_size = 2 * (dim_x // self.crop_size) - 1 + + self.unfold = Unfold(dim_x, self.crop_size) + self.refold = Fold(dim_x, self.crop_size) + + shape = (self.crop_size, self.crop_size, dim_in) + + self.conv_layers, self.conv_shape = self.create_conv_layers(shape, conv_args) + + dim_x, dim_y, dim_out = self.conv_shape + dim_r = dim_x * dim_y * dim_out + self.reshape = View(-1, dim_r) + self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args) + + def create_conv_layers(self, shape, conv_args): + '''Creates a set of convolutional layers. + + Args: + shape: Input shape. + conv_args: List of tuple of convolutional arguments. + + Returns: + nn.Sequential: A sequence of convolutional layers. + + ''' + + conv_layers = nn.Sequential() + conv_args = conv_args or [] + dim_x, dim_y, dim_in = shape + + for i, (dim_out, f, s, p, batch_norm, dropout, nonlinearity, pool) in enumerate(conv_args): + name = '({}/{})_{}'.format(dim_in, dim_out, i + 1) + conv_block = nn.Sequential() + + if dim_out is not None: + conv_block.add_module(name + 'conv', + nn.Conv2d(dim_in, dim_out, kernel_size=f, stride=s, padding=p, bias=not(batch_norm))) + dim_x, dim_y = self.next_size(dim_x, dim_y, f, s, p) + else: + dim_out = dim_in + + if dropout: + conv_block.add_module(name + 'do', nn.Dropout2d(p=dropout)) + if batch_norm: + conv_block.add_module(name + 'bn', nn.BatchNorm2d(dim_out)) + + if nonlinearity: + nonlinearity = get_nonlinearity(nonlinearity) + conv_block.add_module(nonlinearity.__class__.__name__, nonlinearity) + + if pool: + (pool_type, kernel, stride) = pool + Pool = getattr(nn, pool_type) + conv_block.add_module('pool', Pool(kernel_size=kernel, stride=stride)) + dim_x, dim_y = self.next_size(dim_x, dim_y, kernel, stride, 0) + + conv_layers.add_module(name, conv_block) + + dim_in = dim_out + + if dim_x != dim_y: + raise ValueError('dim_x and dim_y do not match.') + + if dim_x == 1: + dim_x = self.final_size + dim_y = self.final_size + + dim_out = dim_in + + return conv_layers, (dim_x, dim_y, dim_out) + + def forward(self, x: torch.Tensor, return_full_list=False): + '''Forward pass + + Args: + x: Input. + return_full_list: Optional, returns all layer outputs. + + Returns: + torch.Tensor or list of torch.Tensor. + + ''' + + x = self.unfold(x) + + conv_out = [] + for conv_layer in self.conv_layers: + x = conv_layer(x) + if x.size(2) == 1: + x = self.refold(x) + conv_out.append(x) + + x = self.reshape(x) + + if return_full_list: + fc_out = [] + for fc_layer in self.fc_layers: + x = fc_layer(x) + fc_out.append(x) + else: + fc_out = self.fc_layers(x) + + if not return_full_list: + conv_out = conv_out[-1] + + return conv_out, fc_out \ No newline at end of file diff --git a/cortex_DIM/nn_modules/encoder.py b/cortex_DIM/nn_modules/encoder.py new file mode 100644 index 0000000..52db297 --- /dev/null +++ b/cortex_DIM/nn_modules/encoder.py @@ -0,0 +1,96 @@ +'''Basic cortex_DIM encoder. + +''' + +import torch + +from cortex_DIM.nn_modules.convnet import Convnet, FoldedConvnet +#from cortex_DIM.nn_modules import ResNet, FoldedResNet + + +def create_encoder(Module): + class Encoder(Module): + '''Encoder used for cortex_DIM. + + ''' + + def __init__(self, *args, local_idx=None, multi_idx=None, conv_idx=None, fc_idx=None, **kwargs): + ''' + + Args: + args: Arguments for parent class. + local_idx: Index in list of convolutional layers for local features. + multi_idx: Index in list of convolutional layers for multiple globals. + conv_idx: Index in list of convolutional layers for intermediate features. + fc_idx: Index in list of fully-connected layers for intermediate features. + kwargs: Keyword arguments for the parent class. + ''' + + super().__init__(*args, **kwargs) + + if local_idx is None: + raise ValueError('`local_idx` must be set') + + conv_idx = conv_idx or local_idx + + self.local_idx = local_idx + self.multi_idx = multi_idx + self.conv_idx = conv_idx + self.fc_idx = fc_idx + + def forward(self, x: torch.Tensor): + ''' + + Args: + x: Input tensor. + + Returns: + local_out, multi_out, hidden_out, global_out + + ''' + + outs = super().forward(x, return_full_list=True) + if len(outs) == 2: + conv_out, fc_out = outs + else: + conv_before_out, res_out, conv_after_out, fc_out = outs + conv_out = conv_before_out + res_out + conv_after_out + + local_out = conv_out[self.local_idx] + + if self.multi_idx is not None: + multi_out = conv_out[self.multi_idx] + else: + multi_out = None + + if len(fc_out) > 0: + if self.fc_idx is not None: + hidden_out = fc_out[self.fc_idx] + else: + hidden_out = None + global_out = fc_out[-1] + else: + hidden_out = None + global_out = None + + conv_out = conv_out[self.conv_idx] + + return local_out, conv_out, multi_out, hidden_out, global_out + + return Encoder + + +class ConvnetEncoder(create_encoder(Convnet)): + pass + + +class FoldedConvnetEncoder(create_encoder(FoldedConvnet)): + pass + + +#class DIMResnet(create_dim_encoder(ResNet)): +# pass + + +#class DIMFoldedResnet(create_dim_encoder(FoldedResNet)): +# pass diff --git a/cortex_DIM/nn_modules/mi_networks.py b/cortex_DIM/nn_modules/mi_networks.py new file mode 100644 index 0000000..5f9de51 --- /dev/null +++ b/cortex_DIM/nn_modules/mi_networks.py @@ -0,0 +1,106 @@ +"""Module for networks used for computing MI. + +""" + +import numpy as np +import torch +import torch.nn as nn + +from cortex_DIM.nn_modules.misc import Permute + + +class MIFCNet(nn.Module): + """Simple custom network for computing MI. + + """ + def __init__(self, n_input, n_units): + """ + + Args: + n_input: Number of input units. + n_units: Number of output units. + """ + super().__init__() + + assert(n_units >= n_input) + + self.linear_shortcut = nn.Linear(n_input, n_units) + self.block_nonlinear = nn.Sequential( + nn.Linear(n_input, n_units), + nn.BatchNorm1d(n_units), + nn.ReLU(), + nn.Linear(n_units, n_units) + ) + + # initialize the initial projection to a sort of noisy copy + eye_mask = np.zeros((n_units, n_input), dtype=np.uint8) + for i in range(n_input): + eye_mask[i, i] = 1 + + self.linear_shortcut.weight.data.uniform_(-0.01, 0.01) + self.linear_shortcut.weight.data.masked_fill_(torch.tensor(eye_mask), 1.) + + def forward(self, x): + """ + + Args: + x: Input tensor. + + Returns: + torch.Tensor: network output. + + """ + h = self.block_nonlinear(x) + self.linear_shortcut(x) + return h + + +class MI1x1ConvNet(nn.Module): + """Simple custorm 1x1 convnet. + + """ + def __init__(self, n_input, n_units): + """ + + Args: + n_input: Number of input units. + n_units: Number of output units. + """ + + super().__init__() + + self.block_nonlinear = nn.Sequential( + nn.Conv2d(n_input, n_units, kernel_size=1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(n_units), + nn.ReLU(), + nn.Conv2d(n_units, n_units, kernel_size=1, stride=1, padding=0, bias=True), + ) + + self.block_ln = nn.Sequential( + Permute(0, 2, 3, 1), + nn.LayerNorm(n_units), + Permute(0, 3, 1, 2) + ) + + self.linear_shortcut = nn.Conv2d(n_input, n_units, kernel_size=1, + stride=1, padding=0, bias=False) + + # initialize shortcut to be like identity (if possible) + if n_units >= n_input: + eye_mask = np.zeros((n_units, n_input, 1, 1), dtype=np.uint8) + for i in range(n_input): + eye_mask[i, i, 0, 0] = 1 + self.linear_shortcut.weight.data.uniform_(-0.01, 0.01) + self.linear_shortcut.weight.data.masked_fill_(torch.tensor(eye_mask), 1.) + + def forward(self, x): + """ + + Args: + x: Input tensor. + + Returns: + torch.Tensor: network output. + + """ + h = self.block_ln(self.block_nonlinear(x) + self.linear_shortcut(x)) + return h \ No newline at end of file diff --git a/cortex_DIM/nn_modules/misc.py b/cortex_DIM/nn_modules/misc.py new file mode 100644 index 0000000..9909808 --- /dev/null +++ b/cortex_DIM/nn_modules/misc.py @@ -0,0 +1,130 @@ +'''Various miscellaneous modules + +''' + +import torch + + +class View(torch.nn.Module): + """Basic reshape module. + + """ + def __init__(self, *shape): + """ + + Args: + *shape: Input shape. + """ + super().__init__() + self.shape = shape + + def forward(self, input): + """Reshapes tensor. + + Args: + input: Input tensor. + + Returns: + torch.Tensor: Flattened tensor. + + """ + return input.view(*self.shape) + + +class Unfold(torch.nn.Module): + """Module for unfolding tensor. + + Performs strided crops on 2d (image) tensors. Stride is assumed to be half the crop size. + + """ + def __init__(self, img_size, fold_size): + """ + + Args: + img_size: Input size. + fold_size: Crop size. + """ + super().__init__() + + fold_stride = fold_size // 2 + self.fold_size = fold_size + self.fold_stride = fold_stride + self.n_locs = 2 * (img_size // fold_size) - 1 + self.unfold = torch.nn.Unfold((self.fold_size, self.fold_size), + stride=(self.fold_stride, self.fold_stride)) + + def forward(self, x): + """Unfolds tensor. + + Args: + x: Input tensor. + + Returns: + torch.Tensor: Unfolded tensor. + + """ + N = x.size(0) + x = self.unfold(x).reshape(N, -1, self.fold_size, self.fold_size, self.n_locs * self.n_locs)\ + .permute(0, 4, 1, 2, 3)\ + .reshape(N * self.n_locs * self.n_locs, -1, self.fold_size, self.fold_size) + return x + + +class Fold(torch.nn.Module): + """Module (re)folding tensor. + + Undoes the strided crops above. Works only on 1x1. + + """ + def __init__(self, img_size, fold_size): + """ + + Args: + img_size: Images size. + fold_size: Crop size. + """ + super().__init__() + self.n_locs = 2 * (img_size // fold_size) - 1 + + def forward(self, x): + """(Re)folds tensor. + + Args: + x: Input tensor. + + Returns: + torch.Tensor: Refolded tensor. + + """ + dim_c, dim_x, dim_y = x.size()[1:] + x = x.reshape(-1, self.n_locs * self.n_locs, dim_c, dim_x * dim_y) + x = x.reshape(-1, self.n_locs * self.n_locs, dim_c, dim_x * dim_y)\ + .permute(0, 2, 3, 1)\ + .reshape(-1, dim_c * dim_x * dim_y, self.n_locs, self.n_locs).contiguous() + return x + + +class Permute(torch.nn.Module): + """Module for permuting axes. + + """ + def __init__(self, *perm): + """ + + Args: + *perm: Permute axes. + """ + super().__init__() + self.perm = perm + + def forward(self, input): + """Permutes axes of tensor. + + Args: + input: Input tensor. + + Returns: + torch.Tensor: permuted tensor. + + """ + return input.permute(*self.perm) diff --git a/scripts/deep_infomax.py b/scripts/deep_infomax.py new file mode 100644 index 0000000..0d3732f --- /dev/null +++ b/scripts/deep_infomax.py @@ -0,0 +1,235 @@ +'''Deep Implicit Infomax + +''' + +import logging + +from cortex.main import run +from cortex.plugins import ModelPlugin +from cortex.built_ins.models.classifier import SimpleClassifier + +from cortex_DIM.nn_modules.mi_networks import MIFCNet, MI1x1ConvNet +from cortex_DIM.functions.dim_losses import fenchel_dual_loss, multi_donsker_varadhan_loss, nce_loss, \ + multi_nce_loss, donsker_varadhan_loss, multi_fenchel_dual_loss +from cortex_DIM.configs.convnets import configs as convnet_configs + + +logger = logging.getLogger('cortex_DIM') + + +class DIM(ModelPlugin): + '''Deep InfoMax + + ''' + + defaults = dict( + data=dict(batch_size=dict(train=64, test=64), + inputs=dict(inputs='images'), skip_last_batch=True), + train=dict(save_on_lowest='losses.encoder', epochs=1000), + model=dict( + classifier_c_args=dict(dropout=0.1, dim_h=[200], batch_norm=True), + classifier_m_args=dict(dropout=0.1, dim_h=[200], batch_norm=True), + classifier_f_args=dict(dropout=0.1, dim_h=[200], batch_norm=True), + classifier_g_args=dict(dropout=0.1, dim_h=[200], batch_norm=True)), + optimizer=dict(learning_rate=1e-4) + ) + + def __init__(self, Classifier=SimpleClassifier): + super().__init__() + + self.classifier_c = Classifier( + nets=dict(classifier='classifier_c'), + kwargs=dict(classifier_args='classifier_c_args')) + self.classifier_m = Classifier( + nets=dict(classifier='classifier_m'), + kwargs=dict(classifier_args='classifier_m_args')) + self.classifier_f = Classifier( + nets=dict(classifier='classifier_f'), + kwargs=dict(classifier_args='classifier_f_args')) + self.classifier_g = Classifier( + nets=dict(classifier='classifier_g'), + kwargs=dict(classifier_args='classifier_g_args')) + + def build(self, global_units=64, mi_units=1024, encoder_config='basic32x32', + encoder_args={}): + ''' + + Args: + global_units: Number of global units. + mi_units: Number of units for MI estimation. + encoder_config: Config of encoder. See `cortex_DIM/configs` for details. + encoder_args: Additional dictionary to update encoder. + + ''' + + # Draw data to help with shape inference. + self.data.reset(mode='test', make_pbar=False) + self.data.next() + + dim_c, dim_x, dim_y = self.get_dims('images') + input_shape = (dim_x, dim_y, dim_c) + + # Create encoder. + encoder_args_ = convnet_configs.get(encoder_config, None) + if encoder_args_ is None: + raise logger.warning('encoder_type `{}` not supported'.format(encoder_type)) + encoder_args_ = {} + encoder_args_.update(**encoder_args) + encoder_args = encoder_args_ + + if global_units > 0: + if 'fc_args' in list(encoder_args.keys()): + encoder_args['fc_args'].append((global_units, False, False, None)) + else: + encoder_args['fc_args'] = [(global_units, False, False, None)] + else: + if 'fc_args' in list(encoder_args.keys()): + encoder_args.pop('fc_args') + + Encoder = encoder_args.pop('Encoder') + self.nets.encoder = Encoder(input_shape, **encoder_args) + + # Create MI nn_modules and classifiers for monitoring. + S = self.inputs('images').cpu() + L, C, M, F, G = self.nets.encoder(S) + + local_units, locals_x, locals_y = L.size()[1:] + self.nets.local_net = MI1x1ConvNet(local_units, mi_units) + + conv_units, conv_x, conv_y = C.size()[1:] + self.classifier_c.build(dim_in=conv_units * conv_x * conv_y) + + if M is not None: + multi_units, multis_x, multis_y = M.size()[1:] + self.nets.multi_net = MI1x1ConvNet(multi_units, mi_units) + self.classifier_m.build(dim_in=multi_units * multis_x * multis_y) + + if F is not None: + fc_units = F.size(1) + self.classifier_f.build(dim_in=fc_units) + + if G is not None: + self.nets.global_net = MIFCNet(global_units, mi_units) + self.classifier_g.build(dim_in=global_units) + + def routine(self, measure='JSD', mode='fd'): + ''' + + Args: + measure: Type of f-divergence. For use with mode `fd` + mode: Loss mode. Fenchel-dual `fd`, NCE `nce`, or Donsker-Vadadhan `dv`. + + ''' + X, Y = self.inputs('images', 'targets') + L, C, M, F, G = self.nets.encoder(X) + + if G is not None: + # Add a global-local loss. + local_global_loss = self.local_global_loss(L, G, measure, mode) + self.losses.global_net = local_global_loss + else: + local_global_loss = 0. + + if M is not None: + # Add a multi-global local loss. + local_multi_loss = self.local_multi_loss(L, M, measure, mode) + self.losses.multi_net = local_multi_loss + else: + local_multi_loss = 0. + + self.losses.encoder = local_global_loss + local_multi_loss + self.losses.local_net = local_global_loss + local_multi_loss + + # Classifiers + units, dim_x, dim_y = C.size()[1:] + C = C.view(-1, units * dim_x * dim_y) + self.classifier_c.routine(C.detach(), Y) + + if M is not None: + units, dim_x, dim_y = M.size()[1:] + M = M.view(-1, units * dim_x * dim_y) + self.classifier_m.routine(M.detach(), Y) + + if F is not None: + self.classifier_f.routine(F.detach(), Y) + + if G is not None: + self.classifier_g.routine(G.detach(), Y) + + def local_global_loss(self, l, g, measure, mode): + ''' + + Args: + l: Local feature map. + g: Global features. + measure: Type of f-divergence. For use with mode `fd` + mode: Loss mode. Fenchel-dual `fd`, NCE `nce`, or Donsker-Vadadhan `dv`. + + Returns: + torch.Tensor: Loss. + + ''' + l_enc = self.nets.local_net(l) + g_enc = self.nets.global_net(g) + N, local_units, dim_x, dim_y = l_enc.size() + l_enc = l_enc.view(N, local_units, -1) + + if mode == 'fd': + loss = fenchel_dual_loss(l_enc, g_enc, measure=measure) + elif mode == 'nce': + loss = nce_loss(l_enc, g_enc) + elif mode == 'dv': + loss = donsker_varadhan_loss(l_enc, g_enc) + else: + raise NotImplementedError(mode) + + return loss + + def local_multi_loss(self, l, m, measure, mode): + ''' + + Args: + l: Local feature map. + m: Multiple globals feature map. + measure: Type of f-divergence. For use with mode `fd` + mode: Loss mode. Fenchel-dual `fd`, NCE `nce`, or Donsker-Vadadhan `dv`. + + Returns: + torch.Tensor: Loss. + + ''' + l_enc = self.nets.local_net(l) + m_enc = self.nets.multi_net(m) + N, local_units, dim_x, dim_y = l_enc.size() + l_enc = l_enc.view(N, local_units, -1) + m_enc = m_enc.view(N, local_units, -1) + + if mode == 'fd': + loss = multi_fenchel_dual_loss(l_enc, m_enc, measure=measure) + elif mode == 'nce': + loss = multi_nce_loss(l_enc, m_enc) + elif mode == 'dv': + loss = multi_donsker_varadhan_loss(l_enc, m_enc) + else: + raise NotImplementedError(mode) + + return loss + + def train_step(self): + """One step in training. + + """ + self.data.next() + self.routine() + self.optimizer_step() + + def eval_step(self): + """One step in evaluation. + + """ + self.data.next() + self.routine() + + +if __name__ == '__main__': + run(DIM()) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..23d6694 --- /dev/null +++ b/setup.py @@ -0,0 +1,11 @@ +from setuptools import setup + +setup(name='cortex_DIM', + version='0.1', + description='The Deep InfoMax package', + author='R Devon Hjelm', + author_email='erroneus@gmail.com', + packages=['cortex_DIM', 'cortex_DIM.configs', 'cortex_DIM.functions', 'cortex_DIM.nn_modules'], + install_requires=['cortex==0.12'], + dependency_links=['git+https://github.com/rdevon/cortex.git@dev#egg=cortex-0.12'], + zip_safe=False)