-
Notifications
You must be signed in to change notification settings - Fork 102
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Devon
committed
Jan 15, 2019
1 parent
7cc5de0
commit 91f46d3
Showing
18 changed files
with
1,627 additions
and
1 deletion.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,62 @@ | ||
# DIM | ||
# Deep InfoMax (DIM) | ||
|
||
[UPDATE]: this work has been accepted as an oral presentation at ICLR 2019. | ||
We are gradually updating the repository over the next few weeks to reflect experiments in the camera-ready version. | ||
|
||
Learning Deep Representations by Mutual Information Estimation and Maximization | ||
Sample code to do the local-only objective in | ||
https://openreview.net/forum?id=Bklr3j0cKX | ||
https://arxiv.org/abs/1808.06670 | ||
|
||
### Completed | ||
[Updated 1/15/2019] | ||
* Latest code for dot-product style scoring function for local DIM (single or multiple globals). | ||
* NCE / DV losses. | ||
* Convnet and folded convnet (strided crops) architectures. | ||
|
||
### TODO | ||
* Resnet and folded resnet architectures and training classifiers keeping the encoder fixed (evaluation). | ||
* NDM, MINE, SVM, and MS-SSIM evaluation. | ||
* Global DIM and prior matching. | ||
* Coordinate and occlusion tasks. | ||
* Add nearest neighbor analysis. | ||
|
||
### Installation / requirements | ||
|
||
This is a package, so to install just run: | ||
|
||
$ pip install . | ||
|
||
This package installs the dev branch cortex: https://github.com/rdevon/cortex | ||
|
||
Which requires Python 3.5+ (Not tested on higher than 3.7). Note that cortex is in early beta stages, but it is usable for this demo. | ||
|
||
cortex optionally requires visdom: https://github.com/pytorch/vision | ||
|
||
You will need to do: | ||
|
||
$ cortex setup | ||
|
||
See the cortex README for more info or email us (or submit an issue for legitimate bugs). | ||
|
||
### Usage | ||
|
||
To get the full set of commands, try: | ||
|
||
$ python scripts/deep_infomax.py --help | ||
|
||
For CIFAR10 on a DCGAN architecture, try: | ||
|
||
$ python scripts/deep_infomax.py --d.source CIFAR10 -n DIM_CIFAR10 --d.copy_to_local --t.epochs 1000 | ||
|
||
You should get over 71-72% in the pretraining step alone (this was included for monitoring purposes only). | ||
Note, this wont get you all the way towards reproducing results in the paper: for this the classifier needs to be retrained with the encoder held fixed. | ||
Support for training a classifier with the representations fixed is coming soon. | ||
|
||
For STL-10 on folded 64x64 Alexnet (strided crops) with multiple globals and the noise-contrastive estimation type loss, try: | ||
|
||
$ python scripts/deep_infomax.py --d.sources STL10 --d.data_args "dict(stl_resize_only=True)" --d.n_workers 32 -n DIM_STL --t.epochs 200 --d.copy_to_local --encoder_config foldmultialex64x64 --mode nce --global_units 0 | ||
|
||
### Deep Infomax | ||
|
||
TODO: visual guide to Deep Infomax. |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
'''Basic convnet hyperparameters. | ||
conv_args are in format (dim_h, f_size, stride, pad batch_norm, dropout, nonlinearity, pool) | ||
fc_args are in format (dim_h, batch_norm, dropout, nonlinearity) | ||
''' | ||
|
||
from cortex_DIM.nn_modules.encoder import ConvnetEncoder, FoldedConvnetEncoder | ||
|
||
|
||
# Basic DCGAN-like encoders | ||
|
||
_basic28x28 = dict( | ||
Encoder=ConvnetEncoder, | ||
conv_args=[(64, 5, 2, 2, True, False, 'ReLU', None), | ||
(128, 5, 2, 2, True, False, 'ReLU', None)], | ||
fc_args=[(1024, True, False, 'ReLU', None)], | ||
local_idx=1, | ||
fc_idx=0 | ||
) | ||
|
||
_basic32x32 = dict( | ||
Encoder=ConvnetEncoder, | ||
conv_args=[(64, 4, 2, 1, True, False, 'ReLU', None), | ||
(128, 4, 2, 1, True, False, 'ReLU', None), | ||
(256, 4, 2, 1, True, False, 'ReLU', None)], | ||
fc_args=[(1024, True, False, 'ReLU')], | ||
local_idx=1, | ||
conv_idx=2, | ||
fc_idx=0 | ||
) | ||
|
||
_basic64x64 = dict( | ||
Encoder=ConvnetEncoder, | ||
conv_args=[(64, 4, 2, 1, True, False, 'ReLU', None), | ||
(128, 4, 2, 1, True, False, 'ReLU', None), | ||
(256, 4, 2, 1, True, False, 'ReLU', None), | ||
(512, 4, 2, 1, True, False, 'ReLU', None)], | ||
fc_args=[(1024, True, False, 'ReLU')], | ||
local_idx=2, | ||
conv_idx=3, | ||
fc_idx=0 | ||
) | ||
|
||
# Alexnet-like encoders | ||
|
||
_alex64x64 = dict( | ||
Encoder=ConvnetEncoder, | ||
conv_args=[(96, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)), | ||
(192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)), | ||
(384, 3, 1, 1, True, False, 'ReLU', None), | ||
(384, 3, 1, 1, True, False, 'ReLU', None), | ||
(192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2))], | ||
fc_args=[(4096, True, False, 'ReLU'), | ||
(4096, True, False, 'ReLU')], | ||
local_idx=2, | ||
conv_idx=4, | ||
fc_idx=1 | ||
) | ||
|
||
_foldalex64x64 = dict( | ||
Encoder=FoldedConvnetEncoder, | ||
crop_size=16, | ||
conv_args=[(96, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)), | ||
(192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)), | ||
(384, 3, 1, 1, True, False, 'ReLU', None), | ||
(384, 3, 1, 1, True, False, 'ReLU', None), | ||
(192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2))], | ||
fc_args=[(4096, True, False, 'ReLU'), | ||
(4096, True, False, 'ReLU')], | ||
local_idx=4, | ||
fc_idx=1 | ||
) | ||
|
||
_foldmultialex64x64 = dict( | ||
Encoder=FoldedConvnetEncoder, | ||
crop_size=16, | ||
conv_args=[(96, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)), | ||
(192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)), | ||
(384, 3, 1, 1, True, False, 'ReLU', None), | ||
(384, 3, 1, 1, True, False, 'ReLU', None), | ||
(192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)), | ||
(192, 3, 1, 0, True, False, 'ReLU', None), | ||
(192, 1, 1, 0, True, False, 'ReLU', None)], | ||
fc_args=[(4096, True, False, 'ReLU')], | ||
local_idx=4, | ||
multi_idx=6, | ||
fc_idx=1 | ||
) | ||
|
||
configs = dict( | ||
basic28x28=_basic28x28, | ||
basic32x32=_basic32x32, | ||
basic64x64=_basic64x64, | ||
alex64x64=_alex64x64, | ||
foldalex64x64=_foldalex64x64, | ||
foldmultialex64x64=_foldmultialex64x64 | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
"""Configurations for ResNets | ||
""" | ||
|
||
from cortex_DIM.networks.dim_encoders import DIMResnet, DIMFoldedResnet | ||
|
||
|
||
_resnet19_32x32 = dict( | ||
Encoder=DIMResnet, | ||
conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)], | ||
res_args=[ | ||
([(64, 1, 1, 0, True, False, 'ReLU', None), | ||
(64, 3, 1, 1, True, False, 'ReLU', None), | ||
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
1), | ||
([(64, 1, 1, 0, True, False, 'ReLU', None), | ||
(64, 3, 1, 1, True, False, 'ReLU', None), | ||
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
1), | ||
([(128, 1, 1, 0, True, False, 'ReLU', None), | ||
(128, 3, 2, 1, True, False, 'ReLU', None), | ||
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
1), | ||
([(128, 1, 1, 0, True, False, 'ReLU', None), | ||
(128, 3, 1, 1, True, False, 'ReLU', None), | ||
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
1), | ||
([(256, 1, 1, 0, True, False, 'ReLU', None), | ||
(256, 3, 2, 1, True, False, 'ReLU', None), | ||
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
1), | ||
([(256, 1, 1, 0, True, False, 'ReLU', None), | ||
(256, 3, 1, 1, True, False, 'ReLU', None), | ||
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
1) | ||
], | ||
fc_args=[(1024, True, False, 'ReLU')], | ||
local_idx=4, | ||
fc_idx=0 | ||
) | ||
|
||
_foldresnet19_32x32 = dict( | ||
Encoder=DIMFoldedResnet, | ||
crop_size=8, | ||
conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)], | ||
res_args=[ | ||
([(64, 1, 1, 0, True, False, 'ReLU', None), | ||
(64, 3, 1, 1, True, False, 'ReLU', None), | ||
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
1), | ||
([(64, 1, 1, 0, True, False, 'ReLU', None), | ||
(64, 3, 1, 1, True, False, 'ReLU', None), | ||
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
1), | ||
([(128, 1, 1, 0, True, False, 'ReLU', None), | ||
(128, 3, 2, 1, True, False, 'ReLU', None), | ||
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
1), | ||
([(128, 1, 1, 0, True, False, 'ReLU', None), | ||
(128, 3, 1, 1, True, False, 'ReLU', None), | ||
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
1), | ||
([(256, 1, 1, 0, True, False, 'ReLU', None), | ||
(256, 3, 2, 1, True, False, 'ReLU', None), | ||
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
1), | ||
([(256, 1, 1, 0, True, False, 'ReLU', None), | ||
(256, 3, 1, 1, True, False, 'ReLU', None), | ||
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
1) | ||
], | ||
fc_args=[(1024, True, False, 'ReLU')], | ||
local_idx=6, | ||
fc_idx=0 | ||
) | ||
|
||
_resnet34_32x32 = dict( | ||
Encoder=DIMResnet, | ||
conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)], | ||
res_args=[ | ||
([(64, 1, 1, 0, True, False, 'ReLU', None), | ||
(64, 3, 1, 1, True, False, 'ReLU', None), | ||
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
1), | ||
([(64, 1, 1, 0, True, False, 'ReLU', None), | ||
(64, 3, 1, 1, True, False, 'ReLU', None), | ||
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
2), | ||
([(128, 1, 1, 0, True, False, 'ReLU', None), | ||
(128, 3, 2, 1, True, False, 'ReLU', None), | ||
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
1), | ||
([(128, 1, 1, 0, True, False, 'ReLU', None), | ||
(128, 3, 1, 1, True, False, 'ReLU', None), | ||
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
5), | ||
([(256, 1, 1, 0, True, False, 'ReLU', None), | ||
(256, 3, 2, 1, True, False, 'ReLU', None), | ||
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
1), | ||
([(256, 1, 1, 0, True, False, 'ReLU', None), | ||
(256, 3, 1, 1, True, False, 'ReLU', None), | ||
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
2) | ||
], | ||
fc_args=[(1024, True, False, 'ReLU')], | ||
local_idx=2, | ||
fc_idx=0 | ||
) | ||
|
||
_foldresnet34_32x32 = dict( | ||
Encoder=DIMFoldedResnet, | ||
crop_size=8, | ||
conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)], | ||
res_args=[ | ||
([(64, 1, 1, 0, True, False, 'ReLU', None), | ||
(64, 3, 1, 1, True, False, 'ReLU', None), | ||
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
1), | ||
([(64, 1, 1, 0, True, False, 'ReLU', None), | ||
(64, 3, 1, 1, True, False, 'ReLU', None), | ||
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
2), | ||
([(128, 1, 1, 0, True, False, 'ReLU', None), | ||
(128, 3, 2, 1, True, False, 'ReLU', None), | ||
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
1), | ||
([(128, 1, 1, 0, True, False, 'ReLU', None), | ||
(128, 3, 1, 1, True, False, 'ReLU', None), | ||
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
5), | ||
([(256, 1, 1, 0, True, False, 'ReLU', None), | ||
(256, 3, 2, 1, True, False, 'ReLU', None), | ||
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
1), | ||
([(256, 1, 1, 0, True, False, 'ReLU', None), | ||
(256, 3, 1, 1, True, False, 'ReLU', None), | ||
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)], | ||
2) | ||
], | ||
fc_args=[(1024, True, False, 'ReLU')], | ||
local_idx=12, | ||
fc_idx=0 | ||
) | ||
|
||
configs = dict( | ||
resnet19_32x32=_resnet19_32x32, | ||
resnet34_32x32=_resnet34_32x32, | ||
foldresnet19_32x32=_foldresnet19_32x32, | ||
foldresnet34_32x32=_foldresnet34_32x32 | ||
) |
Empty file.
Oops, something went wrong.