-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b74edec
commit 30c69f8
Showing
16 changed files
with
1,655 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
# Data and scratch files | ||
data/ | ||
scratch/ | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,36 @@ | ||
# ammi | ||
# Adversarial Maximal Mutual Information (AMMI) | ||
This is a PyTorch implementation of AMMI [1]. Install dependencies (Python 3) and download data by running | ||
``` | ||
pip install -r requirements.txt | ||
./get_data.sh | ||
``` | ||
Specifically, the experiments were run with Python `3.8.3` and PyTorch `1.5.3` using NVIDIA Quadro RTX 6000s (CUDA version `10.2`). | ||
|
||
### Quick start | ||
**Unsupervised document hashing** on Reuters using 16 bits | ||
```bash | ||
python ammi.py reuters16_ammi data/document_hashing/reuters.tfidf.mat --train --raw_prior | ||
``` | ||
Output logged in file `reuters16_ammi.log`. You can simply switch the dataset to do **predictive document hashing**, for instance, | ||
```bash | ||
python ammi.py toy data/related_articles/article_pairs_tfidf_small.p --train --raw_prior --num_retrieve 10 | ||
``` | ||
The VAE and DVQ baselines can be run similarly by switching `ammi.py` with `vae.py` or `dvq.py`. | ||
|
||
### Reproducibility | ||
See [`commands.txt`](commands.txt) for the hyperparameters used in the paper. They were optimized by random grid search on validation data, for instance | ||
```bash | ||
python ammi.py tmc64_ammi data/document_hashing/reuters.tfidf.mat --train --num_features 64 --num_runs 100 --cuda | ||
python ammi.py wdw128_ammi data/related_articles/article_pairs_tfidf.p --train --num_features 128 --num_runs 100 --cuda --num_workers 8 | ||
``` | ||
|
||
### References | ||
[1] [Learning Discrete Structured Representations by Adversarially Maximizing Mutual Information (Stratos and Wiseman, 2020)](https://arxiv.org/abs/2004.03991) | ||
``` | ||
@article{stratos2020learning, | ||
title={Learning Discrete Structured Representations by Adversarially Maximizing Mutual Information}, | ||
author={Stratos, Karl and Wiseman, Sam}, | ||
journal={arXiv preprint arXiv:2004.03991}, | ||
year={2020} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
import argparse | ||
import entropy as ent | ||
import torch | ||
import torch.nn as nn | ||
|
||
from model import Model | ||
from pytorch_helper import get_init_function, FF | ||
|
||
|
||
class AMMI(Model): | ||
|
||
def __init__(self, hparams): | ||
super().__init__(hparams=hparams) | ||
|
||
def define_parameters(self): | ||
self.entropy = EntropyHelper(self.hparams) | ||
self.pZ_Y = Posterior(self.hparams.num_features, | ||
self.hparams.order_posterior, | ||
self.data.vocab_size, | ||
self.hparams.num_layers_posterior, | ||
self.hparams.dim_hidden) | ||
|
||
if not self.hparams.brute: | ||
self.qZ = Prior(self.hparams.num_features, | ||
self.hparams.order_prior, | ||
self.hparams.num_layers, # Using general num_layers | ||
self.hparams.dim_hidden, # Using general dim_hidden | ||
self.hparams.raw_prior) | ||
|
||
if self.multiview: | ||
self.qZ_X = Posterior(self.hparams.num_features, | ||
self.hparams.order_posterior, | ||
self.data.vocab_size, | ||
self.hparams.num_layers_posterior, | ||
self.hparams.dim_hidden) | ||
|
||
self.apply(get_init_function(self.hparams.init)) | ||
self.lr_prior = self.hparams.lr if self.hparams.lr_prior < 0 else \ | ||
self.hparams.lr_prior | ||
|
||
def forward(self, Y, X=None): | ||
P_ = self.pZ_Y(Y) | ||
P = torch.sigmoid(P_) | ||
|
||
Q_ = self.qZ_X(X) if self.multiview else P_ | ||
hZ_cond = self.entropy.hZ_X(P, Q_) | ||
|
||
if self.hparams.brute: | ||
hZ = self.entropy.hZ(P) | ||
else: | ||
optimizer_prior = torch.optim.Adam(self.qZ.parameters(), | ||
lr=self.lr_prior) | ||
for _ in range(self.hparams.num_steps_prior): | ||
optimizer_prior.zero_grad() | ||
hZ = self.entropy.hZ_X(P.detach(), self.qZ()) | ||
hZ.backward() | ||
nn.utils.clip_grad_norm_(self.qZ.parameters(), | ||
self.hparams.clip) | ||
optimizer_prior.step() | ||
|
||
hZ = self.entropy.hZ_X(P, self.qZ()) | ||
|
||
loss = hZ_cond - self.hparams.entropy_weight * hZ | ||
|
||
return {'loss': loss, 'hZ_cond': hZ_cond, 'hZ': hZ} | ||
|
||
def configure_optimizers(self): | ||
params = list(self.pZ_Y.parameters()) | ||
if self.multiview: | ||
params += list(self.qZ_X.parameters()) | ||
return [torch.optim.Adam(params, lr=self.hparams.lr)] | ||
|
||
def configure_gradient_clippers(self): | ||
clippers = [(self.pZ_Y.parameters(), self.hparams.clip)] | ||
if self.multiview: | ||
clippers.append((self.qZ_X.parameters(), self.hparams.clip)) | ||
return clippers | ||
|
||
def encode_discrete(self, Y): | ||
P = torch.sigmoid(self.pZ_Y(Y)) | ||
encodings = self.entropy.viterbi(P)[0] | ||
return encodings # {0,1}^{B x m} | ||
|
||
def get_hparams_grid(self): | ||
grid = Model.get_general_hparams_grid() | ||
grid.update({ | ||
'lr_prior': [0.1, 0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001], | ||
'entropy_weight': [1, 1.5, 2, 2.5, 3, 3.5], | ||
'num_steps_prior': [1, 2, 4], | ||
'dim_hidden': [8, 12, 16, 20, 24, 28], | ||
'num_layers': [0, 1, 2], | ||
'raw_prior': [False, False, False, True], | ||
}) | ||
return grid | ||
|
||
@staticmethod | ||
def get_model_specific_argparser(): | ||
parser = Model.get_general_argparser() | ||
|
||
parser.add_argument('--order_posterior', type=int, default=0, | ||
help='Markov order of posterior [%(default)d]') | ||
parser.add_argument('--order_prior', type=int, default=3, | ||
help='Markov order of prior [%(default)d]') | ||
parser.add_argument('--num_layers_posterior', type=int, default=0, | ||
help='num layers in posterior [%(default)d]') | ||
parser.add_argument('--num_steps_prior', type=int, default=4, | ||
help='num gradient steps on prior per loss ' | ||
'[%(default)d]') | ||
parser.add_argument('--raw_prior', action='store_true', | ||
help='raw logit embeddings for prior encoder?') | ||
parser.add_argument('--lr_prior', type=float, default=-1, | ||
help='initial learning rate for prior (same as lr ' | ||
' if -1) [%(default)g]') | ||
parser.add_argument('--brute', action='store_true', | ||
help='brute-force entropy calculation?') | ||
parser.add_argument('--entropy_weight', type=float, default=2, | ||
help='entropy weight in MI [%(default)g]') | ||
|
||
return parser | ||
|
||
|
||
class EntropyHelper(nn.Module): | ||
|
||
def __init__(self, hparams): | ||
super().__init__() | ||
self.register_buffer('quads', | ||
ent.precompute_quads(hparams.order_posterior)) | ||
assert hparams.order_prior >= hparams.order_posterior | ||
device = torch.device('cuda' if hparams.cuda else 'cpu') | ||
self.buffs = ent.precompute_buffers(hparams.batch_size, | ||
hparams.order_posterior, | ||
hparams.order_prior, | ||
device) | ||
if hparams.brute: | ||
self.register_buffer('I', ent.precompute_I(hparams.num_features, | ||
hparams.order_posterior)) | ||
|
||
def hZ_X(self, P, Q_): | ||
if len(Q_.size()) == 2: | ||
Q_ = Q_.repeat(P.size(0), 1, 1) | ||
return ent.estimate_hZ_X(P, Q_, quads=self.quads, buffers=self.buffs) | ||
|
||
def hZ(self, P): | ||
return ent.estimate_hZ(P, I=self.I.repeat(P.size(0), 1, 1)) | ||
|
||
def viterbi(self, P): | ||
return ent.compute_viterbi(P, quads=self.quads) | ||
|
||
|
||
class Posterior(nn.Module): | ||
|
||
def __init__(self, num_features, markov_order, dim_input, num_layers, | ||
dim_hidden): | ||
super(Posterior, self).__init__() | ||
self.num_features = num_features | ||
|
||
num_logits = num_features * pow(2, markov_order) | ||
self.ff = FF(dim_input, dim_hidden, num_logits, num_layers) | ||
|
||
def forward(self, inputs): | ||
logits = self.ff(inputs).view(inputs.size(0), self.num_features, -1) | ||
P_ = torch.cat([-logits, logits], dim=2) # B x m x 2^(o+1) | ||
return P_ | ||
|
||
|
||
class Prior(nn.Module): | ||
|
||
def __init__(self, num_features, markov_order, num_layers, dim_hidden, | ||
raw=False): | ||
super(Prior, self).__init__() | ||
self.raw = raw | ||
|
||
if raw: | ||
self.theta = nn.Embedding(num_features, pow(2, markov_order)) | ||
else: | ||
self.theta = nn.Embedding(num_features, dim_hidden) | ||
self.ff = FF(dim_hidden, dim_hidden, pow(2, markov_order), | ||
num_layers) | ||
|
||
def forward(self): | ||
logits = self.theta.weight if self.raw else self.ff(self.theta.weight) | ||
R_ = torch.cat([-logits, logits], dim=1) # m x 2^(r+1) | ||
return R_ | ||
|
||
|
||
if __name__ == '__main__': | ||
argparser = AMMI.get_model_specific_argparser() | ||
hparams = argparser.parse_args() | ||
model = AMMI(hparams) | ||
if hparams.train: | ||
model.run_training_sessions() | ||
else: | ||
model.load() | ||
print('Loaded model with: %s' % model.flag_hparams()) | ||
|
||
val_perf, test_perf = model.run_test() | ||
print('Val: {:8.2f}'.format(val_perf)) | ||
print('Test: {:8.2f}'.format(test_perf)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
python ammi.py scratch/tmc16_ammi data/document_hashing/tmc.tfidf.mat --train --num_features 16 --dim_hidden 24 --num_layers 1 --batch_size 64 --lr 0.01 --init 0.5 --clip 1 --seed 4997 --cuda --num_steps_prior 4 --raw_prior --entropy_weight 3 # 70.96 | ||
python ammi.py scratch/tmc32_ammi data/document_hashing/tmc.tfidf.mat --train --num_features 32 --dim_hidden 16 --num_layers 0 --batch_size 256 --lr 0.003 --init 0.5 --clip 10 --seed 87751 --cuda --num_steps_prior 1 --entropy_weight 2 # 74.16 | ||
python ammi.py scratch/tmc64_ammi data/document_hashing/tmc.tfidf.mat --train --num_features 64 --dim_hidden 20 --num_layers 0 --batch_size 128 --lr 0.03 --init 0.5 --clip 5 --seed 24547 --cuda --num_steps_prior 2 --entropy_weight 3.5 # 75.22 | ||
python ammi.py scratch/tmc128_ammi data/document_hashing/tmc.tfidf.mat --train --num_features 128 --dim_hidden 12 --num_layers 1 --batch_size 128 --lr 0.03 --init 0.5 --clip 1 --seed 69075 --cuda --num_steps_prior 1 --entropy_weight 3.5 # 76.27 | ||
python ammi.py scratch/ng16_ammi data/document_hashing/ng20.tfidf.mat --train --num_features 16 --dim_hidden 28 --num_layers 1 --batch_size 64 --lr 0.003 --init 0.1 --clip 5 --seed 57000 --cuda --num_steps_prior 4 --raw_prior --entropy_weight 2.5 # 55.18 | ||
python ammi.py scratch/ng32_ammi data/document_hashing/ng20.tfidf.mat --train --num_features 32 --dim_hidden 28 --num_layers 2 --batch_size 16 --lr 0.01 --init 0.05 --clip 1 --seed 28082 --cuda --num_steps_prior 4 --raw_prior --entropy_weight 3.5 # 59.56 | ||
python ammi.py scratch/ng64_ammi data/document_hashing/ng20.tfidf.mat --train --num_features 64 --dim_hidden 20 --num_layers 0 --batch_size 128 --lr 0.03 --init 0.5 --clip 5 --seed 24547 --cuda --num_steps_prior 2 --entropy_weight 3.5 # 63.98 | ||
python ammi.py scratch/ng128_ammi data/document_hashing/ng20.tfidf.mat --train --num_features 128 --dim_hidden 8 --num_layers 0 --batch_size 128 --lr 0.01 --init 0.5 --clip 5 --seed 38414 --cuda --num_steps_prior 4 --entropy_weight 3.5 # 66.18 | ||
python ammi.py scratch/reuters16_ammi data/document_hashing/reuters.tfidf.mat --train --num_features 16 --dim_hidden 16 --num_layers 0 --batch_size 16 --lr 0.01 --init 0.1 --clip 10 --seed 9061 --cuda --num_steps_prior 4 --raw_prior --entropy_weight 2 # 81.73 | ||
python ammi.py scratch/reuters32_ammi data/document_hashing/reuters.tfidf.mat --train --num_features 32 --dim_hidden 20 --num_layers 1 --batch_size 64 --lr 0.003 --init 0.5 --clip 1 --seed 67425 --cuda --num_steps_prior 1 --entropy_weight 2 # 84.46 | ||
python ammi.py scratch/reuters64_ammi data/document_hashing/reuters.tfidf.mat --train --num_features 64 --dim_hidden 28 --num_layers 1 --batch_size 16 --lr 0.003 --init 0.5 --clip 1 --seed 86088 --cuda --num_steps_prior 1 --entropy_weight 3 # 85.06 | ||
python ammi.py scratch/reuters128_ammi data/document_hashing/reuters.tfidf.mat --train --num_features 128 --dim_hidden 24 --num_layers 1 --batch_size 16 --lr 0.03 --init 0.1 --clip 5 --seed 36782 --cuda --num_steps_prior 2 --raw_prior --entropy_weight 2.5 # 86.02 | ||
|
||
python ammi.py scratch/wdw128_ammi data/related_articles/article_pairs_tfidf.p --train --num_features 128 --dim_hidden 24 --num_layers 0 --batch_size 64 --lr 0.001 --init 0 --clip 10 --seed 31552 --cuda --num_steps_prior 2 --raw_prior --lr_prior 0.01 --entropy_weight 3.5 --num_workers 8 # 79.93 | ||
|
||
python vae.py scratch/wdw128_vae data/related_articles/article_pairs_tfidf.p --train --num_features 128 --dim_hidden 600 --num_layers 0 --batch_size 16 --lr 0.0001 --init 0.01 --clip 1 --seed 22325 --cuda --num_components 80 --beta 2 --num_workers 8 # 76.41 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import numpy as np | ||
import pickle | ||
import scipy.io | ||
import torch | ||
|
||
from torch.utils.data import Dataset, DataLoader, TensorDataset | ||
|
||
|
||
class Data: | ||
|
||
def __init__(self, file_path): | ||
self.file_path = file_path | ||
self.load_datasets() | ||
|
||
def get_loaders(self, batch_size, num_workers, shuffle_train=False, | ||
get_test=True): | ||
train_loader = DataLoader(self.train_dataset, batch_size=batch_size, | ||
num_workers=num_workers, | ||
shuffle=shuffle_train) | ||
val_loader = DataLoader(self.val_dataset, batch_size=batch_size, | ||
num_workers=num_workers, shuffle=False) | ||
test_loader = DataLoader(self.test_dataset, batch_size=batch_size, | ||
num_workers=num_workers, shuffle=False) \ | ||
if get_test else None | ||
return train_loader, val_loader, test_loader | ||
|
||
def load_datasets(self): | ||
raise NotImplementedError | ||
|
||
|
||
class LabeledDocuments(Data): | ||
|
||
def __init__(self, file_path): | ||
super().__init__(file_path=file_path) | ||
|
||
def load_datasets(self): | ||
dataset = scipy.io.loadmat(self.file_path) | ||
|
||
# (num documents) x (vocab size) tensors containing tf-idf values | ||
Y_train = torch.from_numpy(dataset['train'].toarray()).float() | ||
Y_val = torch.from_numpy(dataset['cv'].toarray()).float() | ||
Y_test = torch.from_numpy(dataset['test'].toarray()).float() | ||
|
||
# (num documents) x (num labels) tensors containing {0,1} | ||
L_train = torch.from_numpy(dataset['gnd_train']).float() | ||
L_val = torch.from_numpy(dataset['gnd_cv']).float() | ||
L_test = torch.from_numpy(dataset['gnd_test']).float() | ||
|
||
self.train_dataset = TensorDataset(Y_train, L_train) | ||
self.val_dataset = TensorDataset(Y_val, L_val) | ||
self.test_dataset = TensorDataset(Y_test, L_test) | ||
|
||
self.vocab_size = self.train_dataset[0][0].size(0) | ||
self.num_labels = self.train_dataset[0][1].size(0) | ||
|
||
|
||
class ArticlePairs(Data): | ||
|
||
def __init__(self, file_path): | ||
super().__init__(file_path=file_path) | ||
|
||
def load_datasets(self): | ||
# Each pair are dicts of form {word index: tf-idf value} | ||
(train_pairs, val_pairs, test_pairs, self.vocab) \ | ||
= pickle.load(open(self.file_path, 'rb')) | ||
|
||
self.train_dataset = ArticlePairDataset(train_pairs, len(self.vocab)) | ||
self.val_dataset = ArticlePairDataset(val_pairs, len(self.vocab)) | ||
self.test_dataset = ArticlePairDataset(test_pairs, len(self.vocab)) | ||
self.vocab_size = len(self.vocab) | ||
|
||
|
||
class ArticlePairDataset(Dataset): | ||
def __init__(self, dict_pairs, vocab_size): | ||
self.dict_pairs = dict_pairs | ||
self.vocab_size = vocab_size | ||
|
||
def __len__(self): | ||
return len(self.dict_pairs) | ||
|
||
def __getitem__(self, index): | ||
article_X, article_Y = self.dict_pairs[index] | ||
return self.vectorize(article_X), self.vectorize(article_Y) | ||
|
||
def vectorize(self, article): | ||
u = torch.zeros(self.vocab_size) | ||
for i, weight in article.items(): | ||
u[i] = weight | ||
return u |
Oops, something went wrong.