Skip to content

Commit

Permalink
adding resnets
Browse files Browse the repository at this point in the history
  • Loading branch information
Devon committed Jan 15, 2019
1 parent 593ae61 commit 76a443c
Show file tree
Hide file tree
Showing 6 changed files with 323 additions and 16 deletions.
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Deep InfoMax (DIM)

[UPDATE]: this work has been accepted as an oral presentation at ICLR 2019.
This work has been accepted as an oral presentation at ICLR 2019.
We are gradually updating the repository over the next few weeks to reflect experiments in the camera-ready version.

Learning Deep Representations by Mutual Information Estimation and Maximization
Expand All @@ -13,9 +13,10 @@ https://arxiv.org/abs/1808.06670
* Latest code for dot-product style scoring function for local DIM (single or multiple globals).
* JSD / NCE / DV losses (In addition, f-divergences: KL, reverse KL, squared Hellinger, chi squared).
* Convnet and folded convnet (strided crops) architectures.
* Resnet and folded resnet architectures.

### TODO
* Resnet and folded resnet architectures and training classifiers keeping the encoder fixed (evaluation).
* Training classifiers keeping the encoder fixed (evaluation).
* NDM, MINE, SVM, and MS-SSIM evaluation.
* Global DIM and prior matching.
* Coordinate and occlusion tasks.
Expand Down Expand Up @@ -52,8 +53,12 @@ For CIFAR10 on a DCGAN architecture, try:
You should get over 71-72% in the pretraining step alone (this was included for monitoring purposes only).
Note, this wont get you all the way towards reproducing results in the paper: for this the classifier needs to be retrained with the encoder held fixed.
Support for training a classifier with the representations fixed is coming soon.

For a folded Resnet (strided crops) and the noise-contrastive estimation (NCE) type loss, one could do:

$ python scripts/deep_infomax.py --d.source CIFAR10 --encoder_config foldresnet19_32x32 --mode nce -n DIM_CIFAR10_FoldedResnet --d.copy_to_local --t.epochs 1000

For STL-10 on folded 64x64 Alexnet (strided crops) with multiple globals and the noise-contrastive estimation type loss, try:
For STL-10 on folded 64x64 Alexnet with multiple globals and the NCE-type loss, try:

$ python scripts/deep_infomax.py --d.sources STL10 --d.data_args "dict(stl_resize_only=True)" --d.n_workers 32 -n DIM_STL --t.epochs 200 --d.copy_to_local --encoder_config foldmultialex64x64 --mode nce --global_units 0

Expand Down
10 changes: 5 additions & 5 deletions cortex_DIM/configs/resnets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
"""

from cortex_DIM.networks.dim_encoders import DIMResnet, DIMFoldedResnet
from cortex_DIM.nn_modules.encoder import ResnetEncoder, FoldedResnetEncoder


_resnet19_32x32 = dict(
Encoder=DIMResnet,
Encoder=ResnetEncoder,
conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)],
res_args=[
([(64, 1, 1, 0, True, False, 'ReLU', None),
Expand Down Expand Up @@ -40,7 +40,7 @@
)

_foldresnet19_32x32 = dict(
Encoder=DIMFoldedResnet,
Encoder=FoldedResnetEncoder,
crop_size=8,
conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)],
res_args=[
Expand Down Expand Up @@ -75,7 +75,7 @@
)

_resnet34_32x32 = dict(
Encoder=DIMResnet,
Encoder=ResnetEncoder,
conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)],
res_args=[
([(64, 1, 1, 0, True, False, 'ReLU', None),
Expand Down Expand Up @@ -109,7 +109,7 @@
)

_foldresnet34_32x32 = dict(
Encoder=DIMFoldedResnet,
Encoder=FoldedResnetEncoder,
crop_size=8,
conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)],
res_args=[
Expand Down
10 changes: 5 additions & 5 deletions cortex_DIM/nn_modules/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from cortex_DIM.nn_modules.convnet import Convnet, FoldedConvnet
#from cortex_DIM.nn_modules import ResNet, FoldedResNet
from cortex_DIM.nn_modules.resnet import ResNet, FoldedResNet


def create_encoder(Module):
Expand Down Expand Up @@ -88,9 +88,9 @@ class FoldedConvnetEncoder(create_encoder(FoldedConvnet)):
pass


#class DIMResnet(create_dim_encoder(ResNet)):
# pass
class ResnetEncoder(create_encoder(ResNet)):
pass


#class DIMFoldedResnet(create_dim_encoder(FoldedResNet)):
# pass
class FoldedResnetEncoder(create_encoder(FoldedResNet)):
pass
297 changes: 297 additions & 0 deletions cortex_DIM/nn_modules/resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
'''Module for making resnet encoders.
'''

import torch
import torch.nn as nn

from cortex_DIM.nn_modules.convnet import Convnet
from cortex_DIM.nn_modules.misc import Fold, Unfold, View


_nonlin_idx = 6


class ResBlock(Convnet):
'''Residual block for ResNet
'''

def create_layers(self, shape, conv_args=None):
'''Creates layers
Args:
shape: Shape of input.
conv_args: Layer arguments for block.
'''

# Move nonlinearity to a separate step for residual.
final_nonlin = conv_args[-1][_nonlin_idx]
conv_args[-1] = list(conv_args[-1])
conv_args[-1][_nonlin_idx] = None
conv_args.append((None, 0, 0, 0, False, False, final_nonlin, None))

super().create_layers(shape, conv_args=conv_args)

if self.conv_shape != shape:
dim_x, dim_y, dim_in = shape
dim_x_, dim_y_, dim_out = self.conv_shape
stride = dim_x // dim_x_
next_x, _ = self.next_size(dim_x, dim_y, 1, stride, 0)
assert next_x == dim_x_, (self.conv_shape, shape)

self.downsample = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=stride, padding=0, bias=False),
nn.BatchNorm2d(dim_out),
)
else:
self.downsample = None

def forward(self, x: torch.Tensor):
'''Forward pass
Args:
x: Input.
Returns:
torch.Tensor or list of torch.Tensor.
'''

if self.downsample is not None:
residual = self.downsample(x)
else:
residual = x

x = self.conv_layers[-1](self.conv_layers[:-1](x) + residual)

return x


class ResNet(Convnet):
def create_layers(self, shape, conv_before_args=None, res_args=None, conv_after_args=None, fc_args=None):
'''Creates layers
Args:
shape: Shape of the input.
conv_before_args: Arguments for convolutional layers before residuals.
res_args: Residual args.
conv_after_args: Arguments for convolutional layers after residuals.
fc_args: Fully-connected arguments.
'''

dim_x, dim_y, dim_in = shape
shape = (dim_x, dim_y, dim_in)
self.conv_before_layers, self.conv_before_shape = self.create_conv_layers(shape, conv_before_args)
self.res_layers, self.res_shape = self.create_res_layers(self.conv_before_shape, res_args)
self.conv_after_layers, self.conv_after_shape = self.create_conv_layers(self.res_shape, conv_after_args)

dim_x, dim_y, dim_out = self.conv_after_shape
dim_r = dim_x * dim_y * dim_out
self.reshape = View(-1, dim_r)
self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args)

def create_res_layers(self, shape, block_args=None):
'''Creates a set of residual blocks.
Args:
shape: input shape.
block_args: Arguments for blocks.
Returns:
nn.Sequential: sequence of residual blocks.
'''

res_layers = nn.Sequential()
block_args = block_args or []

for i, (conv_args, n_blocks) in enumerate(block_args):
block = ResBlock(shape, conv_args=conv_args)
res_layers.add_module('block_{}_0'.format(i), block)

for j in range(1, n_blocks):
shape = block.conv_shape
block = ResBlock(shape, conv_args=conv_args)
res_layers.add_module('block_{}_{}'.format(i, j), block)
shape = block.conv_shape

return res_layers, shape

def forward(self, x: torch.Tensor, return_full_list=False):
'''Forward pass
Args:
x: Input.
return_full_list: Optional, returns all layer outputs.
Returns:
torch.Tensor or list of torch.Tensor.
'''

if return_full_list:
conv_before_out = []
for conv_layer in self.conv_before_layers:
x = conv_layer(x)
conv_before_out.append(x)
else:
conv_before_out = self.conv_layers(x)
x = conv_before_out

if return_full_list:
res_out = []
for res_layer in self.res_layers:
x = res_layer(x)
res_out.append(x)
else:
res_out = self.res_layers(x)
x = res_out

if return_full_list:
conv_after_out = []
for conv_layer in self.conv_after_layers:
x = conv_layer(x)
conv_after_out.append(x)
else:
conv_after_out = self.conv_after_layers(x)
x = conv_after_out

x = self.reshape(x)

if return_full_list:
fc_out = []
for fc_layer in self.fc_layers:
x = fc_layer(x)
fc_out.append(x)
else:
fc_out = self.fc_layers(x)

return conv_before_out, res_out, conv_after_out, fc_out


class FoldedResNet(ResNet):
'''Resnet with strided crop input.
'''

def create_layers(self, shape, crop_size=8, conv_before_args=None, res_args=None,
conv_after_args=None, fc_args=None):
'''Creates layers
Args:
shape: Shape of the input.
crop_size: Size of the crops.
conv_before_args: Arguments for convolutional layers before residuals.
res_args: Residual args.
conv_after_args: Arguments for convolutional layers after residuals.
fc_args: Fully-connected arguments.
'''
self.crop_size = crop_size

dim_x, dim_y, dim_in = shape
self.final_size = 2 * (dim_x // self.crop_size) - 1

self.unfold = Unfold(dim_x, self.crop_size)
self.refold = Fold(dim_x, self.crop_size)

shape = (self.crop_size, self.crop_size, dim_in)
self.conv_before_layers, self.conv_before_shape = self.create_conv_layers(shape, conv_before_args)

self.res_layers, self.res_shape = self.create_res_layers(self.conv_before_shape, res_args)
self.conv_after_layers, self.conv_after_shape = self.create_conv_layers(self.res_shape, conv_after_args)
self.conv_after_shape = self.res_shape

dim_x, dim_y, dim_out = self.conv_after_shape
dim_r = dim_x * dim_y * dim_out
self.reshape = View(-1, dim_r)
self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args)

def create_res_layers(self, shape, block_args=None):
'''Creates a set of residual blocks.
Args:
shape: input shape.
block_args: Arguments for blocks.
Returns:
nn.Sequential: sequence of residual blocks.
'''

res_layers = nn.Sequential()
block_args = block_args or []

for i, (conv_args, n_blocks) in enumerate(block_args):
block = ResBlock(shape, conv_args=conv_args)
res_layers.add_module('block_{}_0'.format(i), block)

for j in range(1, n_blocks):
shape = block.conv_shape
block = ResBlock(shape, conv_args=conv_args)
res_layers.add_module('block_{}_{}'.format(i, j), block)
shape = block.conv_shape
dim_x, dim_y = shape[:2]

if dim_x != dim_y:
raise ValueError('dim_x and dim_y do not match.')

if dim_x == 1:
shape = (self.final_size, self.final_size, shape[2])

return res_layers, shape

def forward(self, x: torch.Tensor, return_full_list=False):
'''Forward pass
Args:
x: Input.
return_full_list: Optional, returns all layer outputs.
Returns:
torch.Tensor or list of torch.Tensor.
'''
x = self.unfold(x)

conv_before_out = []
for conv_layer in self.conv_before_layers:
x = conv_layer(x)
if x.size(2) == 1:
x = self.refold(x)
conv_before_out.append(x)

res_out = []
for res_layer in self.res_layers:
x = res_layer(x)
res_out.append(x)

if x.size(2) == 1:
x = self.refold(x)
res_out[-1] = x

conv_after_out = []
for conv_layer in self.conv_after_layers:
x = conv_layer(x)
if x.size(2) == 1:
x = self.refold(x)
conv_after_out.append(x)

x = self.reshape(x)

if return_full_list:
fc_out = []
for fc_layer in self.fc_layers:
x = fc_layer(x)
fc_out.append(x)
else:
fc_out = self.fc_layers(x)

if not return_full_list:
conv_before_out = conv_before_out[-1]
res_out = res_out[-1]
conv_after_out = conv_after_out[-1]

return conv_before_out, res_out, conv_after_out, fc_out
Loading

0 comments on commit 76a443c

Please sign in to comment.