diff --git a/.idea/deployment.xml b/.idea/deployment.xml
new file mode 100644
index 0000000..1e934ef
--- /dev/null
+++ b/.idea/deployment.xml
@@ -0,0 +1,14 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/webServers.xml b/.idea/webServers.xml
new file mode 100644
index 0000000..02486ba
--- /dev/null
+++ b/.idea/webServers.xml
@@ -0,0 +1,15 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/README.md b/README.md
index 8fc9c61..bdda839 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,62 @@
-# DIM
+# Deep InfoMax (DIM)
+
+[UPDATE]: this work has been accepted as an oral presentation at ICLR 2019.
+We are gradually updating the repository over the next few weeks to reflect experiments in the camera-ready version.
+
Learning Deep Representations by Mutual Information Estimation and Maximization
+Sample code to do the local-only objective in
+https://openreview.net/forum?id=Bklr3j0cKX
+https://arxiv.org/abs/1808.06670
+
+### Completed
+[Updated 1/15/2019]
+* Latest code for dot-product style scoring function for local DIM (single or multiple globals).
+* NCE / DV losses.
+* Convnet and folded convnet (strided crops) architectures.
+
+### TODO
+* Resnet and folded resnet architectures and training classifiers keeping the encoder fixed (evaluation).
+* NDM, MINE, SVM, and MS-SSIM evaluation.
+* Global DIM and prior matching.
+* Coordinate and occlusion tasks.
+* Add nearest neighbor analysis.
+
+### Installation / requirements
+
+This is a package, so to install just run:
+
+ $ pip install .
+
+This package installs the dev branch cortex: https://github.com/rdevon/cortex
+
+Which requires Python 3.5+ (Not tested on higher than 3.7). Note that cortex is in early beta stages, but it is usable for this demo.
+
+cortex optionally requires visdom: https://github.com/pytorch/vision
+
+You will need to do:
+
+ $ cortex setup
+
+See the cortex README for more info or email us (or submit an issue for legitimate bugs).
+
+### Usage
+
+To get the full set of commands, try:
+
+ $ python scripts/deep_infomax.py --help
+
+For CIFAR10 on a DCGAN architecture, try:
+
+ $ python scripts/deep_infomax.py --d.source CIFAR10 -n DIM_CIFAR10 --d.copy_to_local --t.epochs 1000
+
+You should get over 71-72% in the pretraining step alone (this was included for monitoring purposes only).
+Note, this wont get you all the way towards reproducing results in the paper: for this the classifier needs to be retrained with the encoder held fixed.
+Support for training a classifier with the representations fixed is coming soon.
+
+For STL-10 on folded 64x64 Alexnet (strided crops) with multiple globals and the noise-contrastive estimation type loss, try:
+
+ $ python scripts/deep_infomax.py --d.sources STL10 --d.data_args "dict(stl_resize_only=True)" --d.n_workers 32 -n DIM_STL --t.epochs 200 --d.copy_to_local --encoder_config foldmultialex64x64 --mode nce --global_units 0
+
+### Deep Infomax
+
+TODO: visual guide to Deep Infomax.
\ No newline at end of file
diff --git a/cortex_DIM/__init__.py b/cortex_DIM/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/cortex_DIM/configs/__init__.py b/cortex_DIM/configs/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/cortex_DIM/configs/convnets.py b/cortex_DIM/configs/convnets.py
new file mode 100644
index 0000000..77c5e84
--- /dev/null
+++ b/cortex_DIM/configs/convnets.py
@@ -0,0 +1,98 @@
+'''Basic convnet hyperparameters.
+
+conv_args are in format (dim_h, f_size, stride, pad batch_norm, dropout, nonlinearity, pool)
+fc_args are in format (dim_h, batch_norm, dropout, nonlinearity)
+
+'''
+
+from cortex_DIM.nn_modules.encoder import ConvnetEncoder, FoldedConvnetEncoder
+
+
+# Basic DCGAN-like encoders
+
+_basic28x28 = dict(
+ Encoder=ConvnetEncoder,
+ conv_args=[(64, 5, 2, 2, True, False, 'ReLU', None),
+ (128, 5, 2, 2, True, False, 'ReLU', None)],
+ fc_args=[(1024, True, False, 'ReLU', None)],
+ local_idx=1,
+ fc_idx=0
+)
+
+_basic32x32 = dict(
+ Encoder=ConvnetEncoder,
+ conv_args=[(64, 4, 2, 1, True, False, 'ReLU', None),
+ (128, 4, 2, 1, True, False, 'ReLU', None),
+ (256, 4, 2, 1, True, False, 'ReLU', None)],
+ fc_args=[(1024, True, False, 'ReLU')],
+ local_idx=1,
+ conv_idx=2,
+ fc_idx=0
+)
+
+_basic64x64 = dict(
+ Encoder=ConvnetEncoder,
+ conv_args=[(64, 4, 2, 1, True, False, 'ReLU', None),
+ (128, 4, 2, 1, True, False, 'ReLU', None),
+ (256, 4, 2, 1, True, False, 'ReLU', None),
+ (512, 4, 2, 1, True, False, 'ReLU', None)],
+ fc_args=[(1024, True, False, 'ReLU')],
+ local_idx=2,
+ conv_idx=3,
+ fc_idx=0
+)
+
+# Alexnet-like encoders
+
+_alex64x64 = dict(
+ Encoder=ConvnetEncoder,
+ conv_args=[(96, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)),
+ (192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)),
+ (384, 3, 1, 1, True, False, 'ReLU', None),
+ (384, 3, 1, 1, True, False, 'ReLU', None),
+ (192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2))],
+ fc_args=[(4096, True, False, 'ReLU'),
+ (4096, True, False, 'ReLU')],
+ local_idx=2,
+ conv_idx=4,
+ fc_idx=1
+)
+
+_foldalex64x64 = dict(
+ Encoder=FoldedConvnetEncoder,
+ crop_size=16,
+ conv_args=[(96, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)),
+ (192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)),
+ (384, 3, 1, 1, True, False, 'ReLU', None),
+ (384, 3, 1, 1, True, False, 'ReLU', None),
+ (192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2))],
+ fc_args=[(4096, True, False, 'ReLU'),
+ (4096, True, False, 'ReLU')],
+ local_idx=4,
+ fc_idx=1
+)
+
+_foldmultialex64x64 = dict(
+ Encoder=FoldedConvnetEncoder,
+ crop_size=16,
+ conv_args=[(96, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)),
+ (192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)),
+ (384, 3, 1, 1, True, False, 'ReLU', None),
+ (384, 3, 1, 1, True, False, 'ReLU', None),
+ (192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)),
+ (192, 3, 1, 0, True, False, 'ReLU', None),
+ (192, 1, 1, 0, True, False, 'ReLU', None)],
+ fc_args=[(4096, True, False, 'ReLU')],
+ local_idx=4,
+ multi_idx=6,
+ fc_idx=1
+)
+
+configs = dict(
+ basic28x28=_basic28x28,
+ basic32x32=_basic32x32,
+ basic64x64=_basic64x64,
+ alex64x64=_alex64x64,
+ foldalex64x64=_foldalex64x64,
+ foldmultialex64x64=_foldmultialex64x64
+)
\ No newline at end of file
diff --git a/cortex_DIM/configs/resnets.py b/cortex_DIM/configs/resnets.py
new file mode 100644
index 0000000..46ef6fd
--- /dev/null
+++ b/cortex_DIM/configs/resnets.py
@@ -0,0 +1,151 @@
+"""Configurations for ResNets
+
+"""
+
+from cortex_DIM.networks.dim_encoders import DIMResnet, DIMFoldedResnet
+
+
+_resnet19_32x32 = dict(
+ Encoder=DIMResnet,
+ conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)],
+ res_args=[
+ ([(64, 1, 1, 0, True, False, 'ReLU', None),
+ (64, 3, 1, 1, True, False, 'ReLU', None),
+ (64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 1),
+ ([(64, 1, 1, 0, True, False, 'ReLU', None),
+ (64, 3, 1, 1, True, False, 'ReLU', None),
+ (64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 1),
+ ([(128, 1, 1, 0, True, False, 'ReLU', None),
+ (128, 3, 2, 1, True, False, 'ReLU', None),
+ (128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 1),
+ ([(128, 1, 1, 0, True, False, 'ReLU', None),
+ (128, 3, 1, 1, True, False, 'ReLU', None),
+ (128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 1),
+ ([(256, 1, 1, 0, True, False, 'ReLU', None),
+ (256, 3, 2, 1, True, False, 'ReLU', None),
+ (256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 1),
+ ([(256, 1, 1, 0, True, False, 'ReLU', None),
+ (256, 3, 1, 1, True, False, 'ReLU', None),
+ (256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 1)
+ ],
+ fc_args=[(1024, True, False, 'ReLU')],
+ local_idx=4,
+ fc_idx=0
+)
+
+_foldresnet19_32x32 = dict(
+ Encoder=DIMFoldedResnet,
+ crop_size=8,
+ conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)],
+ res_args=[
+ ([(64, 1, 1, 0, True, False, 'ReLU', None),
+ (64, 3, 1, 1, True, False, 'ReLU', None),
+ (64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 1),
+ ([(64, 1, 1, 0, True, False, 'ReLU', None),
+ (64, 3, 1, 1, True, False, 'ReLU', None),
+ (64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 1),
+ ([(128, 1, 1, 0, True, False, 'ReLU', None),
+ (128, 3, 2, 1, True, False, 'ReLU', None),
+ (128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 1),
+ ([(128, 1, 1, 0, True, False, 'ReLU', None),
+ (128, 3, 1, 1, True, False, 'ReLU', None),
+ (128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 1),
+ ([(256, 1, 1, 0, True, False, 'ReLU', None),
+ (256, 3, 2, 1, True, False, 'ReLU', None),
+ (256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 1),
+ ([(256, 1, 1, 0, True, False, 'ReLU', None),
+ (256, 3, 1, 1, True, False, 'ReLU', None),
+ (256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 1)
+ ],
+ fc_args=[(1024, True, False, 'ReLU')],
+ local_idx=6,
+ fc_idx=0
+)
+
+_resnet34_32x32 = dict(
+ Encoder=DIMResnet,
+ conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)],
+ res_args=[
+ ([(64, 1, 1, 0, True, False, 'ReLU', None),
+ (64, 3, 1, 1, True, False, 'ReLU', None),
+ (64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 1),
+ ([(64, 1, 1, 0, True, False, 'ReLU', None),
+ (64, 3, 1, 1, True, False, 'ReLU', None),
+ (64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 2),
+ ([(128, 1, 1, 0, True, False, 'ReLU', None),
+ (128, 3, 2, 1, True, False, 'ReLU', None),
+ (128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 1),
+ ([(128, 1, 1, 0, True, False, 'ReLU', None),
+ (128, 3, 1, 1, True, False, 'ReLU', None),
+ (128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 5),
+ ([(256, 1, 1, 0, True, False, 'ReLU', None),
+ (256, 3, 2, 1, True, False, 'ReLU', None),
+ (256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 1),
+ ([(256, 1, 1, 0, True, False, 'ReLU', None),
+ (256, 3, 1, 1, True, False, 'ReLU', None),
+ (256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 2)
+ ],
+ fc_args=[(1024, True, False, 'ReLU')],
+ local_idx=2,
+ fc_idx=0
+)
+
+_foldresnet34_32x32 = dict(
+ Encoder=DIMFoldedResnet,
+ crop_size=8,
+ conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)],
+ res_args=[
+ ([(64, 1, 1, 0, True, False, 'ReLU', None),
+ (64, 3, 1, 1, True, False, 'ReLU', None),
+ (64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 1),
+ ([(64, 1, 1, 0, True, False, 'ReLU', None),
+ (64, 3, 1, 1, True, False, 'ReLU', None),
+ (64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 2),
+ ([(128, 1, 1, 0, True, False, 'ReLU', None),
+ (128, 3, 2, 1, True, False, 'ReLU', None),
+ (128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 1),
+ ([(128, 1, 1, 0, True, False, 'ReLU', None),
+ (128, 3, 1, 1, True, False, 'ReLU', None),
+ (128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 5),
+ ([(256, 1, 1, 0, True, False, 'ReLU', None),
+ (256, 3, 2, 1, True, False, 'ReLU', None),
+ (256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 1),
+ ([(256, 1, 1, 0, True, False, 'ReLU', None),
+ (256, 3, 1, 1, True, False, 'ReLU', None),
+ (256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
+ 2)
+ ],
+ fc_args=[(1024, True, False, 'ReLU')],
+ local_idx=12,
+ fc_idx=0
+)
+
+configs = dict(
+ resnet19_32x32=_resnet19_32x32,
+ resnet34_32x32=_resnet34_32x32,
+ foldresnet19_32x32=_foldresnet19_32x32,
+ foldresnet34_32x32=_foldresnet34_32x32
+)
\ No newline at end of file
diff --git a/cortex_DIM/functions/__init__.py b/cortex_DIM/functions/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/cortex_DIM/functions/dim_losses.py b/cortex_DIM/functions/dim_losses.py
new file mode 100644
index 0000000..7de64ea
--- /dev/null
+++ b/cortex_DIM/functions/dim_losses.py
@@ -0,0 +1,224 @@
+'''cortex_DIM losses.
+
+'''
+
+import math
+
+import torch
+import torch.nn.functional as F
+
+from cortex_DIM.functions.gan_losses import get_positive_expectation, get_negative_expectation
+
+
+def fenchel_dual_loss(l, g, measure=None):
+ '''Computes the f-divergence distance between positive and negative joint distributions.
+
+ Divergences supported are Jensen-Shannon `JSD`, `GAN` (equivalent to JSD),
+ Squared Hellinger `H2`, Chi-squeared `X2`, `KL`, and reverse KL `RKL`.
+
+ Args:
+ l: Local feature map.
+ g: Global features.
+ measure: f-divergence measure.
+
+ Returns:
+ torch.Tensor: Loss.
+
+ '''
+ N, local_units, n_locs = l.size()
+ l = l.permute(0, 2, 1)
+ l = l.reshape(-1, local_units)
+
+ u = torch.mm(g, l.t())
+ u = u.reshape(N, N, -1)
+ mask = torch.eye(N).cuda()
+ n_mask = 1 - mask
+
+ E_pos = get_positive_expectation(u, measure, average=False).mean(2)
+ E_neg = get_negative_expectation(u, measure, average=False).mean(2)
+ E_pos = (E_pos * mask).sum() / mask.sum()
+ E_neg = (E_neg * n_mask).sum() / n_mask.sum()
+ loss = E_neg - E_pos
+ return loss
+
+
+def multi_fenchel_dual_loss(l, m, measure=None):
+ '''Computes the f-divergence distance between positive and negative joint distributions.
+
+ Used for multiple globals.
+
+ Divergences supported are Jensen-Shannon `JSD`, `GAN` (equivalent to JSD),
+ Squared Hellinger `H2`, Chi-squeared `X2`, `KL`, and reverse KL `RKL`.
+
+ Args:
+ l: Local feature map.
+ m: Multiple globals feature map.
+ measure: f-divergence measure.
+
+ Returns:
+ torch.Tensor: Loss.
+
+ '''
+ N, units, n_locals = l.size()
+ n_multis = m.size(2)
+
+ l = l.view(N, units, n_locals)
+ l = l.permute(0, 2, 1)
+ l = l.reshape(-1, units)
+
+ m = m.view(N, units, n_multis)
+ m = m.permute(0, 2, 1)
+ m = m.reshape(-1, units)
+
+ u = torch.mm(m, l.t())
+ u = u.reshape(N, n_multis, N, n_locals).permute(0, 2, 3, 1)
+
+ mask = torch.eye(N).cuda()
+ n_mask = 1 - mask
+
+ E_pos = get_positive_expectation(u, measure, average=False).mean(2).mean(2)
+ E_neg = get_negative_expectation(u, measure, average=False).mean(2).mean(2)
+ E_pos = (E_pos * mask).sum() / mask.sum()
+ E_neg = (E_neg * n_mask).sum() / n_mask.sum()
+ loss = E_neg - E_pos
+ return loss
+
+
+def nce_loss(l, g):
+ '''Computes the noise contrastive estimation-based loss.
+
+ Args:
+ l: Local feature map.
+ g: Global features.
+
+ Returns:
+ torch.Tensor: Loss.
+
+ '''
+ N, local_units, n_locs = l.size()
+ l_p = l.permute(0, 2, 1)
+ u_p = torch.matmul(l_p, g.unsqueeze(dim=2))
+
+ l_n = l_p.reshape(-1, local_units)
+ u_n = torch.mm(g, l_n.t())
+ u_n = u_n.reshape(N, N, n_locs)
+
+ mask = torch.eye(N).unsqueeze(dim=2).cuda()
+ n_mask = 1 - mask
+
+ u_n = (n_mask * u_n) - (10. * (1 - n_mask)) # mask out "self" examples
+ u_n = u_n.reshape(N, -1).unsqueeze(dim=1).expand(-1, n_locs, -1)
+
+ pred_lgt = torch.cat([u_p, u_n], dim=2)
+ pred_log = F.log_softmax(pred_lgt, dim=2)
+ loss = -pred_log[:, :, 0].mean()
+ return loss
+
+
+def multi_nce_loss(l, m):
+ '''
+
+ Used for multiple globals.
+
+ Args:
+ l: Local feature map.
+ m: Multiple globals feature map.
+
+ Returns:
+ torch.Tensor: Loss.
+
+ '''
+ N, units, n_locals = l.size()
+ _, _ , n_multis = m.size()
+
+ l = l.view(N, units, n_locals)
+ m = m.view(N, units, n_multis)
+ l_p = l.permute(0, 2, 1)
+ m_p = m.permute(0, 2, 1)
+ u_p = torch.matmul(l_p, m).unsqueeze(2)
+
+ l_n = l_p.reshape(-1, units)
+ m_n = m_p.reshape(-1, units)
+ u_n = torch.mm(m_n, l_n.t())
+ u_n = u_n.reshape(N, n_multis, N, n_locals).permute(0, 2, 3, 1)
+
+ mask = torch.eye(N)[:, :, None, None].cuda()
+ n_mask = 1 - mask
+
+ u_n = (n_mask * u_n) - (10. * (1 - n_mask)) # mask out "self" examples
+ u_n = u_n.reshape(N, N * n_locals, n_multis).unsqueeze(dim=1).expand(-1, n_locals, -1, -1)
+
+ pred_lgt = torch.cat([u_p, u_n], dim=2)
+ pred_log = F.log_softmax(pred_lgt, dim=2)
+ loss = -pred_log[:, :, 0].mean()
+
+ return loss
+
+
+def donsker_varadhan_loss(l, g):
+ '''
+
+ Args:
+ l: Local feature map.
+ g: Global features.
+
+ Returns:
+ torch.Tensor: Loss.
+
+ '''
+ N, local_units, n_locs = l.size()
+ l = l.permute(0, 2, 1)
+ l = l.reshape(-1, local_units)
+
+ u = torch.mm(g, l.t())
+ u = u.reshape(N, N, n_locs)
+
+ mask = torch.eye(N).cuda()
+ n_mask = (1 - mask)[:, :, None]
+
+ E_pos = (u.mean(2) * mask).sum() / mask.sum()
+
+ u -= 100 * (1 - n_mask)
+ u_max = torch.max(u)
+ E_neg = torch.log((n_mask * torch.exp(u - u_max)).sum() + 1e-6) + u_max - math.log(n_mask.sum())
+ loss = E_neg - E_pos
+ return loss
+
+
+def multi_donsker_varadhan_loss(l, m):
+ '''
+
+ Used for multiple globals.
+
+ Args:
+ l: Local feature map.
+ m: Multiple globals feature map.
+
+ Returns:
+ torch.Tensor: Loss.
+
+ '''
+ N, units, n_locals = l.size()
+ n_multis = m.size(2)
+
+ l = l.view(N, units, n_locals)
+ l = l.permute(0, 2, 1)
+ l = l.reshape(-1, units)
+
+ m = m.view(N, units, n_multis)
+ m = m.permute(0, 2, 1)
+ m = m.reshape(-1, units)
+
+ u = torch.mm(m, l.t())
+ u = u.reshape(N, n_multis, N, n_locals).permute(0, 2, 3, 1)
+
+ mask = torch.eye(N).cuda()
+ n_mask = 1 - mask
+
+ E_pos = (u.mean(2) * mask).sum() / mask.sum()
+
+ u -= 100 * (1 - n_mask)
+ u_max = torch.max(u)
+ E_neg = torch.log((n_mask * torch.exp(u - u_max)).sum() + 1e-6) + u_max - math.log(n_mask.sum())
+ loss = E_neg - E_pos
+ return loss
\ No newline at end of file
diff --git a/cortex_DIM/functions/gan_losses.py b/cortex_DIM/functions/gan_losses.py
new file mode 100644
index 0000000..55d1848
--- /dev/null
+++ b/cortex_DIM/functions/gan_losses.py
@@ -0,0 +1,95 @@
+"""
+
+"""
+
+import math
+
+import torch
+import torch.nn.functional as F
+
+from cortex_DIM.functions.misc import log_sum_exp
+
+
+def raise_measure_error(measure):
+ supported_measures = ['GAN', 'JSD', 'X2', 'KL', 'RKL', 'DV', 'H2', 'W1']
+ raise NotImplementedError(
+ 'Measure `{}` not supported. Supported: {}'.format(measure,
+ supported_measures))
+
+
+def get_positive_expectation(p_samples, measure, average=True):
+ """Computes the positive part of a divergence / difference.
+
+ Args:
+ p_samples: Positive samples.
+ measure: Measure to compute for.
+ average: Average the result over samples.
+
+ Returns:
+ torch.Tensor
+
+ """
+ log_2 = math.log(2.)
+
+ if measure == 'GAN':
+ Ep = - F.softplus(-p_samples)
+ elif measure == 'JSD':
+ Ep = log_2 - F.softplus(- p_samples)
+ elif measure == 'X2':
+ Ep = p_samples ** 2
+ elif measure == 'KL':
+ Ep = p_samples + 1.
+ elif measure == 'RKL':
+ Ep = -torch.exp(-p_samples)
+ elif measure == 'DV':
+ Ep = p_samples
+ elif measure == 'H2':
+ Ep = 1. - torch.exp(-p_samples)
+ elif measure == 'W1':
+ Ep = p_samples
+ else:
+ raise_measure_error(measure)
+
+ if average:
+ return Ep.mean()
+ else:
+ return Ep
+
+
+def get_negative_expectation(q_samples, measure, average=True):
+ """Computes the negative part of a divergence / difference.
+
+ Args:
+ q_samples: Negative samples.
+ measure: Measure to compute for.
+ average: Average the result over samples.
+
+ Returns:
+ torch.Tensor
+
+ """
+ log_2 = math.log(2.)
+
+ if measure == 'GAN':
+ Eq = F.softplus(-q_samples) + q_samples
+ elif measure == 'JSD':
+ Eq = F.softplus(-q_samples) + q_samples - log_2
+ elif measure == 'X2':
+ Eq = -0.5 * ((torch.sqrt(q_samples ** 2) + 1.) ** 2)
+ elif measure == 'KL':
+ Eq = torch.exp(q_samples)
+ elif measure == 'RKL':
+ Eq = q_samples - 1.
+ elif measure == 'DV':
+ Eq = log_sum_exp(q_samples, 0) - math.log(q_samples.size(0))
+ elif measure == 'H2':
+ Eq = torch.exp(q_samples) - 1.
+ elif measure == 'W1':
+ Eq = q_samples
+ else:
+ raise_measure_error(measure)
+
+ if average:
+ return Eq.mean()
+ else:
+ return Eq
\ No newline at end of file
diff --git a/cortex_DIM/functions/misc.py b/cortex_DIM/functions/misc.py
new file mode 100644
index 0000000..5ff72de
--- /dev/null
+++ b/cortex_DIM/functions/misc.py
@@ -0,0 +1,39 @@
+"""Miscilaneous functions.
+
+"""
+
+import torch
+
+
+def log_sum_exp(x, axis=None):
+ """Log sum exp function
+
+ Args:
+ x: Input.
+ axis: Axis over which to perform sum.
+
+ Returns:
+ torch.Tensor: log sum exp
+
+ """
+ x_max = torch.max(x, axis)[0]
+ y = torch.log((torch.exp(x - x_max)).sum(axis)) + x_max
+ return y
+
+
+def random_permute(X):
+ """Randomly permutes a tensor.
+
+ Args:
+ X: Input tensor.
+
+ Returns:
+ torch.Tensor
+
+ """
+ X = X.transpose(1, 2)
+ b = torch.rand((X.size(0), X.size(1))).cuda()
+ idx = b.sort(0)[1]
+ adx = torch.range(0, X.size(1) - 1).long()
+ X = X[idx, adx[None, :]].transpose(1, 2)
+ return X
diff --git a/cortex_DIM/nn_modules/__init__.py b/cortex_DIM/nn_modules/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/cortex_DIM/nn_modules/convnet.py b/cortex_DIM/nn_modules/convnet.py
new file mode 100644
index 0000000..6a5ac50
--- /dev/null
+++ b/cortex_DIM/nn_modules/convnet.py
@@ -0,0 +1,352 @@
+'''Convnet encoder module.
+
+'''
+
+import torch
+import torch.nn as nn
+
+from cortex.built_ins.networks.utils import get_nonlinearity
+
+from cortex_DIM.nn_modules.misc import Fold, Unfold, View
+
+
+def infer_conv_size(w, k, s, p):
+ '''Infers the next size after convolution.
+
+ Args:
+ w: Input size.
+ k: Kernel size.
+ s: Stride.
+ p: Padding.
+
+ Returns:
+ int: Output size.
+
+ '''
+ x = (w - k + 2 * p) // s + 1
+ return x
+
+
+class Convnet(nn.Module):
+ '''Basic convnet convenience class.
+
+ Attributes:
+ conv_layers: nn.Sequential of nn.Conv2d layers with batch norm,
+ dropout, nonlinearity.
+ fc_layers: nn.Sequential of nn.Linear layers with batch norm,
+ dropout, nonlinearity.
+ reshape: Simple reshape layer.
+ conv_shape: Shape of the convolutional output.
+
+ '''
+
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ self.create_layers(*args, **kwargs)
+
+ def create_layers(self, shape, conv_args=None, fc_args=None):
+ '''Creates layers
+
+ conv_args are in format (dim_h, f_size, stride, pad, batch_norm, dropout, nonlinearity, pool)
+ fc_args are in format (dim_h, batch_norm, dropout, nonlinearity)
+
+ Args:
+ shape: Shape of input.
+ conv_args: List of tuple of convolutional arguments.
+ fc_args: List of tuple of fully-connected arguments.
+ '''
+
+ self.conv_layers, self.conv_shape = self.create_conv_layers(shape, conv_args)
+
+ dim_x, dim_y, dim_out = self.conv_shape
+ dim_r = dim_x * dim_y * dim_out
+ self.reshape = View(-1, dim_r)
+ self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args)
+
+ def create_conv_layers(self, shape, conv_args):
+ '''Creates a set of convolutional layers.
+
+ Args:
+ shape: Input shape.
+ conv_args: List of tuple of convolutional arguments.
+
+ Returns:
+ nn.Sequential: a sequence of convolutional layers.
+
+ '''
+
+ conv_layers = nn.Sequential()
+ conv_args = conv_args or []
+
+ dim_x, dim_y, dim_in = shape
+
+ for i, (dim_out, f, s, p, batch_norm, dropout, nonlinearity, pool) in enumerate(conv_args):
+ name = '({}/{})_{}'.format(dim_in, dim_out, i + 1)
+ conv_block = nn.Sequential()
+
+ if dim_out is not None:
+ conv = nn.Conv2d(dim_in, dim_out, kernel_size=f, stride=s, padding=p, bias=not(batch_norm))
+ nn.init.kaiming_normal_(conv.weight, mode='fan_out', nonlinearity='relu')
+ conv_block.add_module(name + 'conv', conv)
+ dim_x, dim_y = self.next_size(dim_x, dim_y, f, s, p)
+ else:
+ dim_out = dim_in
+
+ if dropout:
+ conv_block.add_module(name + 'do', nn.Dropout2d(p=dropout))
+ if batch_norm:
+ bn = nn.BatchNorm2d(dim_out)
+ conv_block.add_module(name + 'bn', bn)
+
+ if nonlinearity:
+ nonlinearity = get_nonlinearity(nonlinearity)
+ conv_block.add_module(nonlinearity.__class__.__name__, nonlinearity)
+
+ if pool:
+ (pool_type, kernel, stride) = pool
+ Pool = getattr(nn, pool_type)
+ conv_block.add_module(name + 'pool', Pool(kernel_size=kernel, stride=stride))
+ dim_x, dim_y = self.next_size(dim_x, dim_y, kernel, stride, 0)
+
+ conv_layers.add_module(name, conv_block)
+
+ dim_in = dim_out
+
+ dim_out = dim_in
+
+ return conv_layers, (dim_x, dim_y, dim_out)
+
+ def create_linear_layers(self, dim_in, fc_args):
+ '''
+
+ Args:
+ dim_in: Number of input units.
+ fc_args: List of tuple of fully-connected arguments.
+
+ Returns:
+ nn.Sequential.
+
+ '''
+
+ fc_layers = nn.Sequential()
+ fc_args = fc_args or []
+
+ for i, (dim_out, batch_norm, dropout, nonlinearity) in enumerate(fc_args):
+ name = '({}/{})_{}'.format(dim_in, dim_out, i + 1)
+ fc_block = nn.Sequential()
+
+ if dim_out is not None:
+ fc_block.add_module(name + 'fc', nn.Linear(dim_in, dim_out))
+ else:
+ dim_out = dim_in
+
+ if dropout:
+ fc_block.add_module(name + 'do', nn.Dropout(p=dropout))
+ if batch_norm:
+ bn = nn.BatchNorm1d(dim_out)
+ fc_block.add_module(name + 'bn', bn)
+
+ if nonlinearity:
+ nonlinearity = get_nonlinearity(nonlinearity)
+ fc_block.add_module(nonlinearity.__class__.__name__, nonlinearity)
+
+ fc_layers.add_module(name, fc_block)
+
+ dim_in = dim_out
+
+ return fc_layers, dim_in
+
+ def next_size(self, dim_x, dim_y, k, s, p):
+ '''Infers the next size of a convolutional layer.
+
+ Args:
+ dim_x: First dimension.
+ dim_y: Second dimension.
+ k: Kernel size.
+ s: Stride.
+ p: Padding.
+
+ Returns:
+ (int, int): (First output dimension, Second output dimension)
+
+ '''
+ if isinstance(k, int):
+ kx, ky = (k, k)
+ else:
+ kx, ky = k
+
+ if isinstance(s, int):
+ sx, sy = (s, s)
+ else:
+ sx, sy = s
+
+ if isinstance(p, int):
+ px, py = (p, p)
+ else:
+ px, py = p
+ return (infer_conv_size(dim_x, kx, sx, px),
+ infer_conv_size(dim_y, ky, sy, py))
+
+ def forward(self, x: torch.Tensor, return_full_list=False):
+ '''Forward pass
+
+ Args:
+ x: Input.
+ return_full_list: Optional, returns all layer outputs.
+
+ Returns:
+ torch.Tensor or list of torch.Tensor.
+
+ '''
+ if return_full_list:
+ conv_out = []
+ for conv_layer in self.conv_layers:
+ x = conv_layer(x)
+ conv_out.append(x)
+ else:
+ conv_out = self.conv_layers(x)
+ x = conv_out
+
+ x = self.reshape(x)
+
+ if return_full_list:
+ fc_out = []
+ for fc_layer in self.fc_layers:
+ x = fc_layer(x)
+ fc_out.append(x)
+ else:
+ fc_out = self.fc_layers(x)
+
+ return conv_out, fc_out
+
+
+class FoldedConvnet(Convnet):
+ '''Convnet with strided crop input.
+
+ '''
+
+ def create_layers(self, shape, crop_size=8, conv_args=None, fc_args=None):
+ '''Creates layers
+
+ conv_args are in format (dim_h, f_size, stride, pad, batch_norm, dropout, nonlinearity, pool)
+ fc_args are in format (dim_h, batch_norm, dropout, nonlinearity)
+
+ Args:
+ shape: Shape of input.
+ crop_size: Size of crops
+ conv_args: List of tuple of convolutional arguments.
+ fc_args: List of tuple of fully-connected arguments.
+ '''
+
+ self.crop_size = crop_size
+
+ dim_x, dim_y, dim_in = shape
+ if dim_x != dim_y:
+ raise ValueError('x and y dimensions must be the same to use Folded encoders.')
+
+ self.final_size = 2 * (dim_x // self.crop_size) - 1
+
+ self.unfold = Unfold(dim_x, self.crop_size)
+ self.refold = Fold(dim_x, self.crop_size)
+
+ shape = (self.crop_size, self.crop_size, dim_in)
+
+ self.conv_layers, self.conv_shape = self.create_conv_layers(shape, conv_args)
+
+ dim_x, dim_y, dim_out = self.conv_shape
+ dim_r = dim_x * dim_y * dim_out
+ self.reshape = View(-1, dim_r)
+ self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args)
+
+ def create_conv_layers(self, shape, conv_args):
+ '''Creates a set of convolutional layers.
+
+ Args:
+ shape: Input shape.
+ conv_args: List of tuple of convolutional arguments.
+
+ Returns:
+ nn.Sequential: A sequence of convolutional layers.
+
+ '''
+
+ conv_layers = nn.Sequential()
+ conv_args = conv_args or []
+ dim_x, dim_y, dim_in = shape
+
+ for i, (dim_out, f, s, p, batch_norm, dropout, nonlinearity, pool) in enumerate(conv_args):
+ name = '({}/{})_{}'.format(dim_in, dim_out, i + 1)
+ conv_block = nn.Sequential()
+
+ if dim_out is not None:
+ conv_block.add_module(name + 'conv',
+ nn.Conv2d(dim_in, dim_out, kernel_size=f, stride=s, padding=p, bias=not(batch_norm)))
+ dim_x, dim_y = self.next_size(dim_x, dim_y, f, s, p)
+ else:
+ dim_out = dim_in
+
+ if dropout:
+ conv_block.add_module(name + 'do', nn.Dropout2d(p=dropout))
+ if batch_norm:
+ conv_block.add_module(name + 'bn', nn.BatchNorm2d(dim_out))
+
+ if nonlinearity:
+ nonlinearity = get_nonlinearity(nonlinearity)
+ conv_block.add_module(nonlinearity.__class__.__name__, nonlinearity)
+
+ if pool:
+ (pool_type, kernel, stride) = pool
+ Pool = getattr(nn, pool_type)
+ conv_block.add_module('pool', Pool(kernel_size=kernel, stride=stride))
+ dim_x, dim_y = self.next_size(dim_x, dim_y, kernel, stride, 0)
+
+ conv_layers.add_module(name, conv_block)
+
+ dim_in = dim_out
+
+ if dim_x != dim_y:
+ raise ValueError('dim_x and dim_y do not match.')
+
+ if dim_x == 1:
+ dim_x = self.final_size
+ dim_y = self.final_size
+
+ dim_out = dim_in
+
+ return conv_layers, (dim_x, dim_y, dim_out)
+
+ def forward(self, x: torch.Tensor, return_full_list=False):
+ '''Forward pass
+
+ Args:
+ x: Input.
+ return_full_list: Optional, returns all layer outputs.
+
+ Returns:
+ torch.Tensor or list of torch.Tensor.
+
+ '''
+
+ x = self.unfold(x)
+
+ conv_out = []
+ for conv_layer in self.conv_layers:
+ x = conv_layer(x)
+ if x.size(2) == 1:
+ x = self.refold(x)
+ conv_out.append(x)
+
+ x = self.reshape(x)
+
+ if return_full_list:
+ fc_out = []
+ for fc_layer in self.fc_layers:
+ x = fc_layer(x)
+ fc_out.append(x)
+ else:
+ fc_out = self.fc_layers(x)
+
+ if not return_full_list:
+ conv_out = conv_out[-1]
+
+ return conv_out, fc_out
\ No newline at end of file
diff --git a/cortex_DIM/nn_modules/encoder.py b/cortex_DIM/nn_modules/encoder.py
new file mode 100644
index 0000000..52db297
--- /dev/null
+++ b/cortex_DIM/nn_modules/encoder.py
@@ -0,0 +1,96 @@
+'''Basic cortex_DIM encoder.
+
+'''
+
+import torch
+
+from cortex_DIM.nn_modules.convnet import Convnet, FoldedConvnet
+#from cortex_DIM.nn_modules import ResNet, FoldedResNet
+
+
+def create_encoder(Module):
+ class Encoder(Module):
+ '''Encoder used for cortex_DIM.
+
+ '''
+
+ def __init__(self, *args, local_idx=None, multi_idx=None, conv_idx=None, fc_idx=None, **kwargs):
+ '''
+
+ Args:
+ args: Arguments for parent class.
+ local_idx: Index in list of convolutional layers for local features.
+ multi_idx: Index in list of convolutional layers for multiple globals.
+ conv_idx: Index in list of convolutional layers for intermediate features.
+ fc_idx: Index in list of fully-connected layers for intermediate features.
+ kwargs: Keyword arguments for the parent class.
+ '''
+
+ super().__init__(*args, **kwargs)
+
+ if local_idx is None:
+ raise ValueError('`local_idx` must be set')
+
+ conv_idx = conv_idx or local_idx
+
+ self.local_idx = local_idx
+ self.multi_idx = multi_idx
+ self.conv_idx = conv_idx
+ self.fc_idx = fc_idx
+
+ def forward(self, x: torch.Tensor):
+ '''
+
+ Args:
+ x: Input tensor.
+
+ Returns:
+ local_out, multi_out, hidden_out, global_out
+
+ '''
+
+ outs = super().forward(x, return_full_list=True)
+ if len(outs) == 2:
+ conv_out, fc_out = outs
+ else:
+ conv_before_out, res_out, conv_after_out, fc_out = outs
+ conv_out = conv_before_out + res_out + conv_after_out
+
+ local_out = conv_out[self.local_idx]
+
+ if self.multi_idx is not None:
+ multi_out = conv_out[self.multi_idx]
+ else:
+ multi_out = None
+
+ if len(fc_out) > 0:
+ if self.fc_idx is not None:
+ hidden_out = fc_out[self.fc_idx]
+ else:
+ hidden_out = None
+ global_out = fc_out[-1]
+ else:
+ hidden_out = None
+ global_out = None
+
+ conv_out = conv_out[self.conv_idx]
+
+ return local_out, conv_out, multi_out, hidden_out, global_out
+
+ return Encoder
+
+
+class ConvnetEncoder(create_encoder(Convnet)):
+ pass
+
+
+class FoldedConvnetEncoder(create_encoder(FoldedConvnet)):
+ pass
+
+
+#class DIMResnet(create_dim_encoder(ResNet)):
+# pass
+
+
+#class DIMFoldedResnet(create_dim_encoder(FoldedResNet)):
+# pass
diff --git a/cortex_DIM/nn_modules/mi_networks.py b/cortex_DIM/nn_modules/mi_networks.py
new file mode 100644
index 0000000..5f9de51
--- /dev/null
+++ b/cortex_DIM/nn_modules/mi_networks.py
@@ -0,0 +1,106 @@
+"""Module for networks used for computing MI.
+
+"""
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from cortex_DIM.nn_modules.misc import Permute
+
+
+class MIFCNet(nn.Module):
+ """Simple custom network for computing MI.
+
+ """
+ def __init__(self, n_input, n_units):
+ """
+
+ Args:
+ n_input: Number of input units.
+ n_units: Number of output units.
+ """
+ super().__init__()
+
+ assert(n_units >= n_input)
+
+ self.linear_shortcut = nn.Linear(n_input, n_units)
+ self.block_nonlinear = nn.Sequential(
+ nn.Linear(n_input, n_units),
+ nn.BatchNorm1d(n_units),
+ nn.ReLU(),
+ nn.Linear(n_units, n_units)
+ )
+
+ # initialize the initial projection to a sort of noisy copy
+ eye_mask = np.zeros((n_units, n_input), dtype=np.uint8)
+ for i in range(n_input):
+ eye_mask[i, i] = 1
+
+ self.linear_shortcut.weight.data.uniform_(-0.01, 0.01)
+ self.linear_shortcut.weight.data.masked_fill_(torch.tensor(eye_mask), 1.)
+
+ def forward(self, x):
+ """
+
+ Args:
+ x: Input tensor.
+
+ Returns:
+ torch.Tensor: network output.
+
+ """
+ h = self.block_nonlinear(x) + self.linear_shortcut(x)
+ return h
+
+
+class MI1x1ConvNet(nn.Module):
+ """Simple custorm 1x1 convnet.
+
+ """
+ def __init__(self, n_input, n_units):
+ """
+
+ Args:
+ n_input: Number of input units.
+ n_units: Number of output units.
+ """
+
+ super().__init__()
+
+ self.block_nonlinear = nn.Sequential(
+ nn.Conv2d(n_input, n_units, kernel_size=1, stride=1, padding=0, bias=False),
+ nn.BatchNorm2d(n_units),
+ nn.ReLU(),
+ nn.Conv2d(n_units, n_units, kernel_size=1, stride=1, padding=0, bias=True),
+ )
+
+ self.block_ln = nn.Sequential(
+ Permute(0, 2, 3, 1),
+ nn.LayerNorm(n_units),
+ Permute(0, 3, 1, 2)
+ )
+
+ self.linear_shortcut = nn.Conv2d(n_input, n_units, kernel_size=1,
+ stride=1, padding=0, bias=False)
+
+ # initialize shortcut to be like identity (if possible)
+ if n_units >= n_input:
+ eye_mask = np.zeros((n_units, n_input, 1, 1), dtype=np.uint8)
+ for i in range(n_input):
+ eye_mask[i, i, 0, 0] = 1
+ self.linear_shortcut.weight.data.uniform_(-0.01, 0.01)
+ self.linear_shortcut.weight.data.masked_fill_(torch.tensor(eye_mask), 1.)
+
+ def forward(self, x):
+ """
+
+ Args:
+ x: Input tensor.
+
+ Returns:
+ torch.Tensor: network output.
+
+ """
+ h = self.block_ln(self.block_nonlinear(x) + self.linear_shortcut(x))
+ return h
\ No newline at end of file
diff --git a/cortex_DIM/nn_modules/misc.py b/cortex_DIM/nn_modules/misc.py
new file mode 100644
index 0000000..9909808
--- /dev/null
+++ b/cortex_DIM/nn_modules/misc.py
@@ -0,0 +1,130 @@
+'''Various miscellaneous modules
+
+'''
+
+import torch
+
+
+class View(torch.nn.Module):
+ """Basic reshape module.
+
+ """
+ def __init__(self, *shape):
+ """
+
+ Args:
+ *shape: Input shape.
+ """
+ super().__init__()
+ self.shape = shape
+
+ def forward(self, input):
+ """Reshapes tensor.
+
+ Args:
+ input: Input tensor.
+
+ Returns:
+ torch.Tensor: Flattened tensor.
+
+ """
+ return input.view(*self.shape)
+
+
+class Unfold(torch.nn.Module):
+ """Module for unfolding tensor.
+
+ Performs strided crops on 2d (image) tensors. Stride is assumed to be half the crop size.
+
+ """
+ def __init__(self, img_size, fold_size):
+ """
+
+ Args:
+ img_size: Input size.
+ fold_size: Crop size.
+ """
+ super().__init__()
+
+ fold_stride = fold_size // 2
+ self.fold_size = fold_size
+ self.fold_stride = fold_stride
+ self.n_locs = 2 * (img_size // fold_size) - 1
+ self.unfold = torch.nn.Unfold((self.fold_size, self.fold_size),
+ stride=(self.fold_stride, self.fold_stride))
+
+ def forward(self, x):
+ """Unfolds tensor.
+
+ Args:
+ x: Input tensor.
+
+ Returns:
+ torch.Tensor: Unfolded tensor.
+
+ """
+ N = x.size(0)
+ x = self.unfold(x).reshape(N, -1, self.fold_size, self.fold_size, self.n_locs * self.n_locs)\
+ .permute(0, 4, 1, 2, 3)\
+ .reshape(N * self.n_locs * self.n_locs, -1, self.fold_size, self.fold_size)
+ return x
+
+
+class Fold(torch.nn.Module):
+ """Module (re)folding tensor.
+
+ Undoes the strided crops above. Works only on 1x1.
+
+ """
+ def __init__(self, img_size, fold_size):
+ """
+
+ Args:
+ img_size: Images size.
+ fold_size: Crop size.
+ """
+ super().__init__()
+ self.n_locs = 2 * (img_size // fold_size) - 1
+
+ def forward(self, x):
+ """(Re)folds tensor.
+
+ Args:
+ x: Input tensor.
+
+ Returns:
+ torch.Tensor: Refolded tensor.
+
+ """
+ dim_c, dim_x, dim_y = x.size()[1:]
+ x = x.reshape(-1, self.n_locs * self.n_locs, dim_c, dim_x * dim_y)
+ x = x.reshape(-1, self.n_locs * self.n_locs, dim_c, dim_x * dim_y)\
+ .permute(0, 2, 3, 1)\
+ .reshape(-1, dim_c * dim_x * dim_y, self.n_locs, self.n_locs).contiguous()
+ return x
+
+
+class Permute(torch.nn.Module):
+ """Module for permuting axes.
+
+ """
+ def __init__(self, *perm):
+ """
+
+ Args:
+ *perm: Permute axes.
+ """
+ super().__init__()
+ self.perm = perm
+
+ def forward(self, input):
+ """Permutes axes of tensor.
+
+ Args:
+ input: Input tensor.
+
+ Returns:
+ torch.Tensor: permuted tensor.
+
+ """
+ return input.permute(*self.perm)
diff --git a/scripts/deep_infomax.py b/scripts/deep_infomax.py
new file mode 100644
index 0000000..0d3732f
--- /dev/null
+++ b/scripts/deep_infomax.py
@@ -0,0 +1,235 @@
+'''Deep Implicit Infomax
+
+'''
+
+import logging
+
+from cortex.main import run
+from cortex.plugins import ModelPlugin
+from cortex.built_ins.models.classifier import SimpleClassifier
+
+from cortex_DIM.nn_modules.mi_networks import MIFCNet, MI1x1ConvNet
+from cortex_DIM.functions.dim_losses import fenchel_dual_loss, multi_donsker_varadhan_loss, nce_loss, \
+ multi_nce_loss, donsker_varadhan_loss, multi_fenchel_dual_loss
+from cortex_DIM.configs.convnets import configs as convnet_configs
+
+
+logger = logging.getLogger('cortex_DIM')
+
+
+class DIM(ModelPlugin):
+ '''Deep InfoMax
+
+ '''
+
+ defaults = dict(
+ data=dict(batch_size=dict(train=64, test=64),
+ inputs=dict(inputs='images'), skip_last_batch=True),
+ train=dict(save_on_lowest='losses.encoder', epochs=1000),
+ model=dict(
+ classifier_c_args=dict(dropout=0.1, dim_h=[200], batch_norm=True),
+ classifier_m_args=dict(dropout=0.1, dim_h=[200], batch_norm=True),
+ classifier_f_args=dict(dropout=0.1, dim_h=[200], batch_norm=True),
+ classifier_g_args=dict(dropout=0.1, dim_h=[200], batch_norm=True)),
+ optimizer=dict(learning_rate=1e-4)
+ )
+
+ def __init__(self, Classifier=SimpleClassifier):
+ super().__init__()
+
+ self.classifier_c = Classifier(
+ nets=dict(classifier='classifier_c'),
+ kwargs=dict(classifier_args='classifier_c_args'))
+ self.classifier_m = Classifier(
+ nets=dict(classifier='classifier_m'),
+ kwargs=dict(classifier_args='classifier_m_args'))
+ self.classifier_f = Classifier(
+ nets=dict(classifier='classifier_f'),
+ kwargs=dict(classifier_args='classifier_f_args'))
+ self.classifier_g = Classifier(
+ nets=dict(classifier='classifier_g'),
+ kwargs=dict(classifier_args='classifier_g_args'))
+
+ def build(self, global_units=64, mi_units=1024, encoder_config='basic32x32',
+ encoder_args={}):
+ '''
+
+ Args:
+ global_units: Number of global units.
+ mi_units: Number of units for MI estimation.
+ encoder_config: Config of encoder. See `cortex_DIM/configs` for details.
+ encoder_args: Additional dictionary to update encoder.
+
+ '''
+
+ # Draw data to help with shape inference.
+ self.data.reset(mode='test', make_pbar=False)
+ self.data.next()
+
+ dim_c, dim_x, dim_y = self.get_dims('images')
+ input_shape = (dim_x, dim_y, dim_c)
+
+ # Create encoder.
+ encoder_args_ = convnet_configs.get(encoder_config, None)
+ if encoder_args_ is None:
+ raise logger.warning('encoder_type `{}` not supported'.format(encoder_type))
+ encoder_args_ = {}
+ encoder_args_.update(**encoder_args)
+ encoder_args = encoder_args_
+
+ if global_units > 0:
+ if 'fc_args' in list(encoder_args.keys()):
+ encoder_args['fc_args'].append((global_units, False, False, None))
+ else:
+ encoder_args['fc_args'] = [(global_units, False, False, None)]
+ else:
+ if 'fc_args' in list(encoder_args.keys()):
+ encoder_args.pop('fc_args')
+
+ Encoder = encoder_args.pop('Encoder')
+ self.nets.encoder = Encoder(input_shape, **encoder_args)
+
+ # Create MI nn_modules and classifiers for monitoring.
+ S = self.inputs('images').cpu()
+ L, C, M, F, G = self.nets.encoder(S)
+
+ local_units, locals_x, locals_y = L.size()[1:]
+ self.nets.local_net = MI1x1ConvNet(local_units, mi_units)
+
+ conv_units, conv_x, conv_y = C.size()[1:]
+ self.classifier_c.build(dim_in=conv_units * conv_x * conv_y)
+
+ if M is not None:
+ multi_units, multis_x, multis_y = M.size()[1:]
+ self.nets.multi_net = MI1x1ConvNet(multi_units, mi_units)
+ self.classifier_m.build(dim_in=multi_units * multis_x * multis_y)
+
+ if F is not None:
+ fc_units = F.size(1)
+ self.classifier_f.build(dim_in=fc_units)
+
+ if G is not None:
+ self.nets.global_net = MIFCNet(global_units, mi_units)
+ self.classifier_g.build(dim_in=global_units)
+
+ def routine(self, measure='JSD', mode='fd'):
+ '''
+
+ Args:
+ measure: Type of f-divergence. For use with mode `fd`
+ mode: Loss mode. Fenchel-dual `fd`, NCE `nce`, or Donsker-Vadadhan `dv`.
+
+ '''
+ X, Y = self.inputs('images', 'targets')
+ L, C, M, F, G = self.nets.encoder(X)
+
+ if G is not None:
+ # Add a global-local loss.
+ local_global_loss = self.local_global_loss(L, G, measure, mode)
+ self.losses.global_net = local_global_loss
+ else:
+ local_global_loss = 0.
+
+ if M is not None:
+ # Add a multi-global local loss.
+ local_multi_loss = self.local_multi_loss(L, M, measure, mode)
+ self.losses.multi_net = local_multi_loss
+ else:
+ local_multi_loss = 0.
+
+ self.losses.encoder = local_global_loss + local_multi_loss
+ self.losses.local_net = local_global_loss + local_multi_loss
+
+ # Classifiers
+ units, dim_x, dim_y = C.size()[1:]
+ C = C.view(-1, units * dim_x * dim_y)
+ self.classifier_c.routine(C.detach(), Y)
+
+ if M is not None:
+ units, dim_x, dim_y = M.size()[1:]
+ M = M.view(-1, units * dim_x * dim_y)
+ self.classifier_m.routine(M.detach(), Y)
+
+ if F is not None:
+ self.classifier_f.routine(F.detach(), Y)
+
+ if G is not None:
+ self.classifier_g.routine(G.detach(), Y)
+
+ def local_global_loss(self, l, g, measure, mode):
+ '''
+
+ Args:
+ l: Local feature map.
+ g: Global features.
+ measure: Type of f-divergence. For use with mode `fd`
+ mode: Loss mode. Fenchel-dual `fd`, NCE `nce`, or Donsker-Vadadhan `dv`.
+
+ Returns:
+ torch.Tensor: Loss.
+
+ '''
+ l_enc = self.nets.local_net(l)
+ g_enc = self.nets.global_net(g)
+ N, local_units, dim_x, dim_y = l_enc.size()
+ l_enc = l_enc.view(N, local_units, -1)
+
+ if mode == 'fd':
+ loss = fenchel_dual_loss(l_enc, g_enc, measure=measure)
+ elif mode == 'nce':
+ loss = nce_loss(l_enc, g_enc)
+ elif mode == 'dv':
+ loss = donsker_varadhan_loss(l_enc, g_enc)
+ else:
+ raise NotImplementedError(mode)
+
+ return loss
+
+ def local_multi_loss(self, l, m, measure, mode):
+ '''
+
+ Args:
+ l: Local feature map.
+ m: Multiple globals feature map.
+ measure: Type of f-divergence. For use with mode `fd`
+ mode: Loss mode. Fenchel-dual `fd`, NCE `nce`, or Donsker-Vadadhan `dv`.
+
+ Returns:
+ torch.Tensor: Loss.
+
+ '''
+ l_enc = self.nets.local_net(l)
+ m_enc = self.nets.multi_net(m)
+ N, local_units, dim_x, dim_y = l_enc.size()
+ l_enc = l_enc.view(N, local_units, -1)
+ m_enc = m_enc.view(N, local_units, -1)
+
+ if mode == 'fd':
+ loss = multi_fenchel_dual_loss(l_enc, m_enc, measure=measure)
+ elif mode == 'nce':
+ loss = multi_nce_loss(l_enc, m_enc)
+ elif mode == 'dv':
+ loss = multi_donsker_varadhan_loss(l_enc, m_enc)
+ else:
+ raise NotImplementedError(mode)
+
+ return loss
+
+ def train_step(self):
+ """One step in training.
+
+ """
+ self.data.next()
+ self.routine()
+ self.optimizer_step()
+
+ def eval_step(self):
+ """One step in evaluation.
+
+ """
+ self.data.next()
+ self.routine()
+
+
+if __name__ == '__main__':
+ run(DIM())
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..23d6694
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,11 @@
+from setuptools import setup
+
+setup(name='cortex_DIM',
+ version='0.1',
+ description='The Deep InfoMax package',
+ author='R Devon Hjelm',
+ author_email='erroneus@gmail.com',
+ packages=['cortex_DIM', 'cortex_DIM.configs', 'cortex_DIM.functions', 'cortex_DIM.nn_modules'],
+ install_requires=['cortex==0.12'],
+ dependency_links=['git+https://github.com/rdevon/cortex.git@dev#egg=cortex-0.12'],
+ zip_safe=False)