Skip to content

Commit

Permalink
public version
Browse files Browse the repository at this point in the history
  • Loading branch information
karlstratos committed Jul 15, 2020
1 parent b74edec commit 30c69f8
Show file tree
Hide file tree
Showing 16 changed files with 1,655 additions and 1 deletion.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Data and scratch files
data/
scratch/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
37 changes: 36 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,36 @@
# ammi
# Adversarial Maximal Mutual Information (AMMI)
This is a PyTorch implementation of AMMI [1]. Install dependencies (Python 3) and download data by running
```
pip install -r requirements.txt
./get_data.sh
```
Specifically, the experiments were run with Python `3.8.3` and PyTorch `1.5.3` using NVIDIA Quadro RTX 6000s (CUDA version `10.2`).

### Quick start
**Unsupervised document hashing** on Reuters using 16 bits
```bash
python ammi.py reuters16_ammi data/document_hashing/reuters.tfidf.mat --train --raw_prior
```
Output logged in file `reuters16_ammi.log`. You can simply switch the dataset to do **predictive document hashing**, for instance,
```bash
python ammi.py toy data/related_articles/article_pairs_tfidf_small.p --train --raw_prior --num_retrieve 10
```
The VAE and DVQ baselines can be run similarly by switching `ammi.py` with `vae.py` or `dvq.py`.

### Reproducibility
See [`commands.txt`](commands.txt) for the hyperparameters used in the paper. They were optimized by random grid search on validation data, for instance
```bash
python ammi.py tmc64_ammi data/document_hashing/reuters.tfidf.mat --train --num_features 64 --num_runs 100 --cuda
python ammi.py wdw128_ammi data/related_articles/article_pairs_tfidf.p --train --num_features 128 --num_runs 100 --cuda --num_workers 8
```

### References
[1] [Learning Discrete Structured Representations by Adversarially Maximizing Mutual Information (Stratos and Wiseman, 2020)](https://arxiv.org/abs/2004.03991)
```
@article{stratos2020learning,
title={Learning Discrete Structured Representations by Adversarially Maximizing Mutual Information},
author={Stratos, Karl and Wiseman, Sam},
journal={arXiv preprint arXiv:2004.03991},
year={2020}
}
```
198 changes: 198 additions & 0 deletions ammi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import argparse
import entropy as ent
import torch
import torch.nn as nn

from model import Model
from pytorch_helper import get_init_function, FF


class AMMI(Model):

def __init__(self, hparams):
super().__init__(hparams=hparams)

def define_parameters(self):
self.entropy = EntropyHelper(self.hparams)
self.pZ_Y = Posterior(self.hparams.num_features,
self.hparams.order_posterior,
self.data.vocab_size,
self.hparams.num_layers_posterior,
self.hparams.dim_hidden)

if not self.hparams.brute:
self.qZ = Prior(self.hparams.num_features,
self.hparams.order_prior,
self.hparams.num_layers, # Using general num_layers
self.hparams.dim_hidden, # Using general dim_hidden
self.hparams.raw_prior)

if self.multiview:
self.qZ_X = Posterior(self.hparams.num_features,
self.hparams.order_posterior,
self.data.vocab_size,
self.hparams.num_layers_posterior,
self.hparams.dim_hidden)

self.apply(get_init_function(self.hparams.init))
self.lr_prior = self.hparams.lr if self.hparams.lr_prior < 0 else \
self.hparams.lr_prior

def forward(self, Y, X=None):
P_ = self.pZ_Y(Y)
P = torch.sigmoid(P_)

Q_ = self.qZ_X(X) if self.multiview else P_
hZ_cond = self.entropy.hZ_X(P, Q_)

if self.hparams.brute:
hZ = self.entropy.hZ(P)
else:
optimizer_prior = torch.optim.Adam(self.qZ.parameters(),
lr=self.lr_prior)
for _ in range(self.hparams.num_steps_prior):
optimizer_prior.zero_grad()
hZ = self.entropy.hZ_X(P.detach(), self.qZ())
hZ.backward()
nn.utils.clip_grad_norm_(self.qZ.parameters(),
self.hparams.clip)
optimizer_prior.step()

hZ = self.entropy.hZ_X(P, self.qZ())

loss = hZ_cond - self.hparams.entropy_weight * hZ

return {'loss': loss, 'hZ_cond': hZ_cond, 'hZ': hZ}

def configure_optimizers(self):
params = list(self.pZ_Y.parameters())
if self.multiview:
params += list(self.qZ_X.parameters())
return [torch.optim.Adam(params, lr=self.hparams.lr)]

def configure_gradient_clippers(self):
clippers = [(self.pZ_Y.parameters(), self.hparams.clip)]
if self.multiview:
clippers.append((self.qZ_X.parameters(), self.hparams.clip))
return clippers

def encode_discrete(self, Y):
P = torch.sigmoid(self.pZ_Y(Y))
encodings = self.entropy.viterbi(P)[0]
return encodings # {0,1}^{B x m}

def get_hparams_grid(self):
grid = Model.get_general_hparams_grid()
grid.update({
'lr_prior': [0.1, 0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001],
'entropy_weight': [1, 1.5, 2, 2.5, 3, 3.5],
'num_steps_prior': [1, 2, 4],
'dim_hidden': [8, 12, 16, 20, 24, 28],
'num_layers': [0, 1, 2],
'raw_prior': [False, False, False, True],
})
return grid

@staticmethod
def get_model_specific_argparser():
parser = Model.get_general_argparser()

parser.add_argument('--order_posterior', type=int, default=0,
help='Markov order of posterior [%(default)d]')
parser.add_argument('--order_prior', type=int, default=3,
help='Markov order of prior [%(default)d]')
parser.add_argument('--num_layers_posterior', type=int, default=0,
help='num layers in posterior [%(default)d]')
parser.add_argument('--num_steps_prior', type=int, default=4,
help='num gradient steps on prior per loss '
'[%(default)d]')
parser.add_argument('--raw_prior', action='store_true',
help='raw logit embeddings for prior encoder?')
parser.add_argument('--lr_prior', type=float, default=-1,
help='initial learning rate for prior (same as lr '
' if -1) [%(default)g]')
parser.add_argument('--brute', action='store_true',
help='brute-force entropy calculation?')
parser.add_argument('--entropy_weight', type=float, default=2,
help='entropy weight in MI [%(default)g]')

return parser


class EntropyHelper(nn.Module):

def __init__(self, hparams):
super().__init__()
self.register_buffer('quads',
ent.precompute_quads(hparams.order_posterior))
assert hparams.order_prior >= hparams.order_posterior
device = torch.device('cuda' if hparams.cuda else 'cpu')
self.buffs = ent.precompute_buffers(hparams.batch_size,
hparams.order_posterior,
hparams.order_prior,
device)
if hparams.brute:
self.register_buffer('I', ent.precompute_I(hparams.num_features,
hparams.order_posterior))

def hZ_X(self, P, Q_):
if len(Q_.size()) == 2:
Q_ = Q_.repeat(P.size(0), 1, 1)
return ent.estimate_hZ_X(P, Q_, quads=self.quads, buffers=self.buffs)

def hZ(self, P):
return ent.estimate_hZ(P, I=self.I.repeat(P.size(0), 1, 1))

def viterbi(self, P):
return ent.compute_viterbi(P, quads=self.quads)


class Posterior(nn.Module):

def __init__(self, num_features, markov_order, dim_input, num_layers,
dim_hidden):
super(Posterior, self).__init__()
self.num_features = num_features

num_logits = num_features * pow(2, markov_order)
self.ff = FF(dim_input, dim_hidden, num_logits, num_layers)

def forward(self, inputs):
logits = self.ff(inputs).view(inputs.size(0), self.num_features, -1)
P_ = torch.cat([-logits, logits], dim=2) # B x m x 2^(o+1)
return P_


class Prior(nn.Module):

def __init__(self, num_features, markov_order, num_layers, dim_hidden,
raw=False):
super(Prior, self).__init__()
self.raw = raw

if raw:
self.theta = nn.Embedding(num_features, pow(2, markov_order))
else:
self.theta = nn.Embedding(num_features, dim_hidden)
self.ff = FF(dim_hidden, dim_hidden, pow(2, markov_order),
num_layers)

def forward(self):
logits = self.theta.weight if self.raw else self.ff(self.theta.weight)
R_ = torch.cat([-logits, logits], dim=1) # m x 2^(r+1)
return R_


if __name__ == '__main__':
argparser = AMMI.get_model_specific_argparser()
hparams = argparser.parse_args()
model = AMMI(hparams)
if hparams.train:
model.run_training_sessions()
else:
model.load()
print('Loaded model with: %s' % model.flag_hparams())

val_perf, test_perf = model.run_test()
print('Val: {:8.2f}'.format(val_perf))
print('Test: {:8.2f}'.format(test_perf))
16 changes: 16 additions & 0 deletions commands.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
python ammi.py scratch/tmc16_ammi data/document_hashing/tmc.tfidf.mat --train --num_features 16 --dim_hidden 24 --num_layers 1 --batch_size 64 --lr 0.01 --init 0.5 --clip 1 --seed 4997 --cuda --num_steps_prior 4 --raw_prior --entropy_weight 3 # 70.96
python ammi.py scratch/tmc32_ammi data/document_hashing/tmc.tfidf.mat --train --num_features 32 --dim_hidden 16 --num_layers 0 --batch_size 256 --lr 0.003 --init 0.5 --clip 10 --seed 87751 --cuda --num_steps_prior 1 --entropy_weight 2 # 74.16
python ammi.py scratch/tmc64_ammi data/document_hashing/tmc.tfidf.mat --train --num_features 64 --dim_hidden 20 --num_layers 0 --batch_size 128 --lr 0.03 --init 0.5 --clip 5 --seed 24547 --cuda --num_steps_prior 2 --entropy_weight 3.5 # 75.22
python ammi.py scratch/tmc128_ammi data/document_hashing/tmc.tfidf.mat --train --num_features 128 --dim_hidden 12 --num_layers 1 --batch_size 128 --lr 0.03 --init 0.5 --clip 1 --seed 69075 --cuda --num_steps_prior 1 --entropy_weight 3.5 # 76.27
python ammi.py scratch/ng16_ammi data/document_hashing/ng20.tfidf.mat --train --num_features 16 --dim_hidden 28 --num_layers 1 --batch_size 64 --lr 0.003 --init 0.1 --clip 5 --seed 57000 --cuda --num_steps_prior 4 --raw_prior --entropy_weight 2.5 # 55.18
python ammi.py scratch/ng32_ammi data/document_hashing/ng20.tfidf.mat --train --num_features 32 --dim_hidden 28 --num_layers 2 --batch_size 16 --lr 0.01 --init 0.05 --clip 1 --seed 28082 --cuda --num_steps_prior 4 --raw_prior --entropy_weight 3.5 # 59.56
python ammi.py scratch/ng64_ammi data/document_hashing/ng20.tfidf.mat --train --num_features 64 --dim_hidden 20 --num_layers 0 --batch_size 128 --lr 0.03 --init 0.5 --clip 5 --seed 24547 --cuda --num_steps_prior 2 --entropy_weight 3.5 # 63.98
python ammi.py scratch/ng128_ammi data/document_hashing/ng20.tfidf.mat --train --num_features 128 --dim_hidden 8 --num_layers 0 --batch_size 128 --lr 0.01 --init 0.5 --clip 5 --seed 38414 --cuda --num_steps_prior 4 --entropy_weight 3.5 # 66.18
python ammi.py scratch/reuters16_ammi data/document_hashing/reuters.tfidf.mat --train --num_features 16 --dim_hidden 16 --num_layers 0 --batch_size 16 --lr 0.01 --init 0.1 --clip 10 --seed 9061 --cuda --num_steps_prior 4 --raw_prior --entropy_weight 2 # 81.73
python ammi.py scratch/reuters32_ammi data/document_hashing/reuters.tfidf.mat --train --num_features 32 --dim_hidden 20 --num_layers 1 --batch_size 64 --lr 0.003 --init 0.5 --clip 1 --seed 67425 --cuda --num_steps_prior 1 --entropy_weight 2 # 84.46
python ammi.py scratch/reuters64_ammi data/document_hashing/reuters.tfidf.mat --train --num_features 64 --dim_hidden 28 --num_layers 1 --batch_size 16 --lr 0.003 --init 0.5 --clip 1 --seed 86088 --cuda --num_steps_prior 1 --entropy_weight 3 # 85.06
python ammi.py scratch/reuters128_ammi data/document_hashing/reuters.tfidf.mat --train --num_features 128 --dim_hidden 24 --num_layers 1 --batch_size 16 --lr 0.03 --init 0.1 --clip 5 --seed 36782 --cuda --num_steps_prior 2 --raw_prior --entropy_weight 2.5 # 86.02

python ammi.py scratch/wdw128_ammi data/related_articles/article_pairs_tfidf.p --train --num_features 128 --dim_hidden 24 --num_layers 0 --batch_size 64 --lr 0.001 --init 0 --clip 10 --seed 31552 --cuda --num_steps_prior 2 --raw_prior --lr_prior 0.01 --entropy_weight 3.5 --num_workers 8 # 79.93

python vae.py scratch/wdw128_vae data/related_articles/article_pairs_tfidf.p --train --num_features 128 --dim_hidden 600 --num_layers 0 --batch_size 16 --lr 0.0001 --init 0.01 --clip 1 --seed 22325 --cuda --num_components 80 --beta 2 --num_workers 8 # 76.41
89 changes: 89 additions & 0 deletions data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import numpy as np
import pickle
import scipy.io
import torch

from torch.utils.data import Dataset, DataLoader, TensorDataset


class Data:

def __init__(self, file_path):
self.file_path = file_path
self.load_datasets()

def get_loaders(self, batch_size, num_workers, shuffle_train=False,
get_test=True):
train_loader = DataLoader(self.train_dataset, batch_size=batch_size,
num_workers=num_workers,
shuffle=shuffle_train)
val_loader = DataLoader(self.val_dataset, batch_size=batch_size,
num_workers=num_workers, shuffle=False)
test_loader = DataLoader(self.test_dataset, batch_size=batch_size,
num_workers=num_workers, shuffle=False) \
if get_test else None
return train_loader, val_loader, test_loader

def load_datasets(self):
raise NotImplementedError


class LabeledDocuments(Data):

def __init__(self, file_path):
super().__init__(file_path=file_path)

def load_datasets(self):
dataset = scipy.io.loadmat(self.file_path)

# (num documents) x (vocab size) tensors containing tf-idf values
Y_train = torch.from_numpy(dataset['train'].toarray()).float()
Y_val = torch.from_numpy(dataset['cv'].toarray()).float()
Y_test = torch.from_numpy(dataset['test'].toarray()).float()

# (num documents) x (num labels) tensors containing {0,1}
L_train = torch.from_numpy(dataset['gnd_train']).float()
L_val = torch.from_numpy(dataset['gnd_cv']).float()
L_test = torch.from_numpy(dataset['gnd_test']).float()

self.train_dataset = TensorDataset(Y_train, L_train)
self.val_dataset = TensorDataset(Y_val, L_val)
self.test_dataset = TensorDataset(Y_test, L_test)

self.vocab_size = self.train_dataset[0][0].size(0)
self.num_labels = self.train_dataset[0][1].size(0)


class ArticlePairs(Data):

def __init__(self, file_path):
super().__init__(file_path=file_path)

def load_datasets(self):
# Each pair are dicts of form {word index: tf-idf value}
(train_pairs, val_pairs, test_pairs, self.vocab) \
= pickle.load(open(self.file_path, 'rb'))

self.train_dataset = ArticlePairDataset(train_pairs, len(self.vocab))
self.val_dataset = ArticlePairDataset(val_pairs, len(self.vocab))
self.test_dataset = ArticlePairDataset(test_pairs, len(self.vocab))
self.vocab_size = len(self.vocab)


class ArticlePairDataset(Dataset):
def __init__(self, dict_pairs, vocab_size):
self.dict_pairs = dict_pairs
self.vocab_size = vocab_size

def __len__(self):
return len(self.dict_pairs)

def __getitem__(self, index):
article_X, article_Y = self.dict_pairs[index]
return self.vectorize(article_X), self.vectorize(article_Y)

def vectorize(self, article):
u = torch.zeros(self.vocab_size)
for i, weight in article.items():
u[i] = weight
return u
Loading

0 comments on commit 30c69f8

Please sign in to comment.