Skip to content

Commit

Permalink
first major update to DIM
Browse files Browse the repository at this point in the history
  • Loading branch information
Devon committed Jan 15, 2019
1 parent 7cc5de0 commit 91f46d3
Show file tree
Hide file tree
Showing 18 changed files with 1,627 additions and 1 deletion.
14 changes: 14 additions & 0 deletions .idea/deployment.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions .idea/webServers.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

62 changes: 61 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,62 @@
# DIM
# Deep InfoMax (DIM)

[UPDATE]: this work has been accepted as an oral presentation at ICLR 2019.
We are gradually updating the repository over the next few weeks to reflect experiments in the camera-ready version.

Learning Deep Representations by Mutual Information Estimation and Maximization
Sample code to do the local-only objective in
https://openreview.net/forum?id=Bklr3j0cKX
https://arxiv.org/abs/1808.06670

### Completed
[Updated 1/15/2019]
* Latest code for dot-product style scoring function for local DIM (single or multiple globals).
* NCE / DV losses.
* Convnet and folded convnet (strided crops) architectures.

### TODO
* Resnet and folded resnet architectures and training classifiers keeping the encoder fixed (evaluation).
* NDM, MINE, SVM, and MS-SSIM evaluation.
* Global DIM and prior matching.
* Coordinate and occlusion tasks.
* Add nearest neighbor analysis.

### Installation / requirements

This is a package, so to install just run:

$ pip install .

This package installs the dev branch cortex: https://github.com/rdevon/cortex

Which requires Python 3.5+ (Not tested on higher than 3.7). Note that cortex is in early beta stages, but it is usable for this demo.

cortex optionally requires visdom: https://github.com/pytorch/vision

You will need to do:

$ cortex setup

See the cortex README for more info or email us (or submit an issue for legitimate bugs).

### Usage

To get the full set of commands, try:

$ python scripts/deep_infomax.py --help

For CIFAR10 on a DCGAN architecture, try:

$ python scripts/deep_infomax.py --d.source CIFAR10 -n DIM_CIFAR10 --d.copy_to_local --t.epochs 1000

You should get over 71-72% in the pretraining step alone (this was included for monitoring purposes only).
Note, this wont get you all the way towards reproducing results in the paper: for this the classifier needs to be retrained with the encoder held fixed.
Support for training a classifier with the representations fixed is coming soon.

For STL-10 on folded 64x64 Alexnet (strided crops) with multiple globals and the noise-contrastive estimation type loss, try:

$ python scripts/deep_infomax.py --d.sources STL10 --d.data_args "dict(stl_resize_only=True)" --d.n_workers 32 -n DIM_STL --t.epochs 200 --d.copy_to_local --encoder_config foldmultialex64x64 --mode nce --global_units 0

### Deep Infomax

TODO: visual guide to Deep Infomax.
Empty file added cortex_DIM/__init__.py
Empty file.
Empty file added cortex_DIM/configs/__init__.py
Empty file.
98 changes: 98 additions & 0 deletions cortex_DIM/configs/convnets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
'''Basic convnet hyperparameters.
conv_args are in format (dim_h, f_size, stride, pad batch_norm, dropout, nonlinearity, pool)
fc_args are in format (dim_h, batch_norm, dropout, nonlinearity)
'''

from cortex_DIM.nn_modules.encoder import ConvnetEncoder, FoldedConvnetEncoder


# Basic DCGAN-like encoders

_basic28x28 = dict(
Encoder=ConvnetEncoder,
conv_args=[(64, 5, 2, 2, True, False, 'ReLU', None),
(128, 5, 2, 2, True, False, 'ReLU', None)],
fc_args=[(1024, True, False, 'ReLU', None)],
local_idx=1,
fc_idx=0
)

_basic32x32 = dict(
Encoder=ConvnetEncoder,
conv_args=[(64, 4, 2, 1, True, False, 'ReLU', None),
(128, 4, 2, 1, True, False, 'ReLU', None),
(256, 4, 2, 1, True, False, 'ReLU', None)],
fc_args=[(1024, True, False, 'ReLU')],
local_idx=1,
conv_idx=2,
fc_idx=0
)

_basic64x64 = dict(
Encoder=ConvnetEncoder,
conv_args=[(64, 4, 2, 1, True, False, 'ReLU', None),
(128, 4, 2, 1, True, False, 'ReLU', None),
(256, 4, 2, 1, True, False, 'ReLU', None),
(512, 4, 2, 1, True, False, 'ReLU', None)],
fc_args=[(1024, True, False, 'ReLU')],
local_idx=2,
conv_idx=3,
fc_idx=0
)

# Alexnet-like encoders

_alex64x64 = dict(
Encoder=ConvnetEncoder,
conv_args=[(96, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)),
(192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)),
(384, 3, 1, 1, True, False, 'ReLU', None),
(384, 3, 1, 1, True, False, 'ReLU', None),
(192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2))],
fc_args=[(4096, True, False, 'ReLU'),
(4096, True, False, 'ReLU')],
local_idx=2,
conv_idx=4,
fc_idx=1
)

_foldalex64x64 = dict(
Encoder=FoldedConvnetEncoder,
crop_size=16,
conv_args=[(96, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)),
(192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)),
(384, 3, 1, 1, True, False, 'ReLU', None),
(384, 3, 1, 1, True, False, 'ReLU', None),
(192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2))],
fc_args=[(4096, True, False, 'ReLU'),
(4096, True, False, 'ReLU')],
local_idx=4,
fc_idx=1
)

_foldmultialex64x64 = dict(
Encoder=FoldedConvnetEncoder,
crop_size=16,
conv_args=[(96, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)),
(192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)),
(384, 3, 1, 1, True, False, 'ReLU', None),
(384, 3, 1, 1, True, False, 'ReLU', None),
(192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)),
(192, 3, 1, 0, True, False, 'ReLU', None),
(192, 1, 1, 0, True, False, 'ReLU', None)],
fc_args=[(4096, True, False, 'ReLU')],
local_idx=4,
multi_idx=6,
fc_idx=1
)

configs = dict(
basic28x28=_basic28x28,
basic32x32=_basic32x32,
basic64x64=_basic64x64,
alex64x64=_alex64x64,
foldalex64x64=_foldalex64x64,
foldmultialex64x64=_foldmultialex64x64
)
151 changes: 151 additions & 0 deletions cortex_DIM/configs/resnets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""Configurations for ResNets
"""

from cortex_DIM.networks.dim_encoders import DIMResnet, DIMFoldedResnet


_resnet19_32x32 = dict(
Encoder=DIMResnet,
conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)],
res_args=[
([(64, 1, 1, 0, True, False, 'ReLU', None),
(64, 3, 1, 1, True, False, 'ReLU', None),
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
1),
([(64, 1, 1, 0, True, False, 'ReLU', None),
(64, 3, 1, 1, True, False, 'ReLU', None),
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
1),
([(128, 1, 1, 0, True, False, 'ReLU', None),
(128, 3, 2, 1, True, False, 'ReLU', None),
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
1),
([(128, 1, 1, 0, True, False, 'ReLU', None),
(128, 3, 1, 1, True, False, 'ReLU', None),
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
1),
([(256, 1, 1, 0, True, False, 'ReLU', None),
(256, 3, 2, 1, True, False, 'ReLU', None),
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
1),
([(256, 1, 1, 0, True, False, 'ReLU', None),
(256, 3, 1, 1, True, False, 'ReLU', None),
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
1)
],
fc_args=[(1024, True, False, 'ReLU')],
local_idx=4,
fc_idx=0
)

_foldresnet19_32x32 = dict(
Encoder=DIMFoldedResnet,
crop_size=8,
conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)],
res_args=[
([(64, 1, 1, 0, True, False, 'ReLU', None),
(64, 3, 1, 1, True, False, 'ReLU', None),
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
1),
([(64, 1, 1, 0, True, False, 'ReLU', None),
(64, 3, 1, 1, True, False, 'ReLU', None),
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
1),
([(128, 1, 1, 0, True, False, 'ReLU', None),
(128, 3, 2, 1, True, False, 'ReLU', None),
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
1),
([(128, 1, 1, 0, True, False, 'ReLU', None),
(128, 3, 1, 1, True, False, 'ReLU', None),
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
1),
([(256, 1, 1, 0, True, False, 'ReLU', None),
(256, 3, 2, 1, True, False, 'ReLU', None),
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
1),
([(256, 1, 1, 0, True, False, 'ReLU', None),
(256, 3, 1, 1, True, False, 'ReLU', None),
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
1)
],
fc_args=[(1024, True, False, 'ReLU')],
local_idx=6,
fc_idx=0
)

_resnet34_32x32 = dict(
Encoder=DIMResnet,
conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)],
res_args=[
([(64, 1, 1, 0, True, False, 'ReLU', None),
(64, 3, 1, 1, True, False, 'ReLU', None),
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
1),
([(64, 1, 1, 0, True, False, 'ReLU', None),
(64, 3, 1, 1, True, False, 'ReLU', None),
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
2),
([(128, 1, 1, 0, True, False, 'ReLU', None),
(128, 3, 2, 1, True, False, 'ReLU', None),
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
1),
([(128, 1, 1, 0, True, False, 'ReLU', None),
(128, 3, 1, 1, True, False, 'ReLU', None),
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
5),
([(256, 1, 1, 0, True, False, 'ReLU', None),
(256, 3, 2, 1, True, False, 'ReLU', None),
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
1),
([(256, 1, 1, 0, True, False, 'ReLU', None),
(256, 3, 1, 1, True, False, 'ReLU', None),
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
2)
],
fc_args=[(1024, True, False, 'ReLU')],
local_idx=2,
fc_idx=0
)

_foldresnet34_32x32 = dict(
Encoder=DIMFoldedResnet,
crop_size=8,
conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)],
res_args=[
([(64, 1, 1, 0, True, False, 'ReLU', None),
(64, 3, 1, 1, True, False, 'ReLU', None),
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
1),
([(64, 1, 1, 0, True, False, 'ReLU', None),
(64, 3, 1, 1, True, False, 'ReLU', None),
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
2),
([(128, 1, 1, 0, True, False, 'ReLU', None),
(128, 3, 2, 1, True, False, 'ReLU', None),
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
1),
([(128, 1, 1, 0, True, False, 'ReLU', None),
(128, 3, 1, 1, True, False, 'ReLU', None),
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
5),
([(256, 1, 1, 0, True, False, 'ReLU', None),
(256, 3, 2, 1, True, False, 'ReLU', None),
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
1),
([(256, 1, 1, 0, True, False, 'ReLU', None),
(256, 3, 1, 1, True, False, 'ReLU', None),
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
2)
],
fc_args=[(1024, True, False, 'ReLU')],
local_idx=12,
fc_idx=0
)

configs = dict(
resnet19_32x32=_resnet19_32x32,
resnet34_32x32=_resnet34_32x32,
foldresnet19_32x32=_foldresnet19_32x32,
foldresnet34_32x32=_foldresnet34_32x32
)
Empty file.
Loading

0 comments on commit 91f46d3

Please sign in to comment.