Skip to content

Commit

Permalink
maml_omniglot model with functorch (#1179)
Browse files Browse the repository at this point in the history
Summary:
Very similar to the maml_omniglot model in torchbench (which is what we originally based this model off of). We also load some sample inputs from there because the model takes the same inputs.

Pull Request resolved: #1179

Test Plan: - python test.py -k TestBenchmark.test_functorch_maml_omniglot_train_cuda

Reviewed By: xuzhao9

Differential Revision: D39671689

Pulled By: zou3519

fbshipit-source-id: 2c1c6b92de5d32c0b37de5d069caa0f7bbb61e7d
  • Loading branch information
zou3519 authored and facebook-github-bot committed Sep 21, 2022
1 parent e64fc73 commit c80d22b
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 2 deletions.
105 changes: 105 additions & 0 deletions torchbenchmark/models/functorch_maml_omniglot/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from functorch import make_functional_with_buffers, vmap, grad
import functools
from pathlib import Path
from typing import Tuple

from ...util.model import BenchmarkModel
from torchbenchmark.tasks import OTHER


def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry):
params, buffers, fnet = net
querysz = x_qry.size(0)

def compute_loss(new_params, buffers, x, y):
logits = fnet(new_params, buffers, x)
loss = F.cross_entropy(logits, y)
return loss

new_params = params
for _ in range(n_inner_iter):
grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt)
new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]

# The final set of adapted parameters will induce some
# final loss and accuracy on the query dataset.
# These will be used to update the model's meta-parameters.
qry_logits = fnet(new_params, buffers, x_qry)
qry_loss = F.cross_entropy(qry_logits, y_qry)
qry_acc = (qry_logits.argmax(
dim=1) == y_qry).sum() / querysz

return qry_loss, qry_acc


class Model(BenchmarkModel):
task = OTHER.OTHER_TASKS
DEFAULT_TRAIN_BSIZE = 1
DEFAULT_EVAL_BSIZE = 1
ALLOW_CUSTOMIZE_BSIZE = False

def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args)

n_way = 5
inplace_relu = True
net = nn.Sequential(
nn.Conv2d(1, 64, 3),
nn.BatchNorm2d(64, affine=True, track_running_stats=False),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, affine=True, track_running_stats=False),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, affine=True, track_running_stats=False),
nn.ReLU(inplace=inplace_relu),
nn.MaxPool2d(2, 2),
nn.Flatten(),
nn.Linear(64, n_way)).to(device)

self.model = net

root = str(Path(__file__).parent.parent)
self.meta_inputs = torch.load(f'{root}/maml_omniglot/batch.pt')
self.meta_inputs = tuple([torch.from_numpy(i).to(self.device) for i in self.meta_inputs])
self.example_inputs = (self.meta_inputs[0][0],)

def get_module(self):
return self.model, self.example_inputs

def train(self):
model = self.model
model.train()
fnet, params, buffers = make_functional_with_buffers(self.model)
net = (params, buffers, fnet)
meta_opt = optim.Adam(params, lr=1e-3)

# Sample a batch of support and query images and labels.
x_spt, y_spt, x_qry, y_qry = self.meta_inputs
task_num, setsz, c_, h, w = x_spt.size()

n_inner_iter = 5
meta_opt.zero_grad()

# In parallel, trains one model per task. There is a support (x, y)
# for each task and a query (x, y) for each task.
compute_loss_for_task = functools.partial(loss_for_task, net, n_inner_iter)
qry_losses, qry_accs = vmap(compute_loss_for_task)(x_spt, y_spt, x_qry, y_qry)

# Compute the maml loss by summing together the returned losses.
qry_losses.sum().backward()

meta_opt.step()

def eval(self) -> Tuple[torch.Tensor]:
model, (example_input,) = self.get_module()
model.eval()
with torch.no_grad():
out = model(example_input)
return (out, )
9 changes: 9 additions & 0 deletions torchbenchmark/models/functorch_maml_omniglot/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import subprocess
import sys


def pip_install_requirements():
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])

if __name__ == '__main__':
pip_install_requirements()
7 changes: 7 additions & 0 deletions torchbenchmark/models/functorch_maml_omniglot/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
eval_benchmark: false
eval_deterministic: false
eval_nograd: true
train_benchmark: false
train_deterministic: false
not_implemented:
- jit: true
Empty file.
19 changes: 17 additions & 2 deletions torchbenchmark/models/maml_omniglot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,23 @@

class Model(BenchmarkModel):
task = OTHER.OTHER_TASKS
DEFAULT_TRAIN_BSIZE = 1
DEFAULT_EVAL_BSIZE = 1
# batch size in the traditional sense doesn't apply to this maml model.
# Instead, there is a task number (32 in this case) and K-shot
# (5 in this case) and these were chosen to be representative of
# the training done in the original MAML paper
# (https://arxiv.org/pdf/1703.03400.pdf)
#
# The goal of MAML is to train a model in a way that if one brings along
# a new task and K data points, then the model generalizes well on the
# test set for that task.
#
# The task number (also known as the meta-batch size) is the number of
# independent tasks the model gets trained on.
# K-shot means that each task only sees K data points.
#
# We've set the following variables to be equal to the task number.
DEFAULT_TRAIN_BSIZE = 32
DEFAULT_EVAL_BSIZE = 32
ALLOW_CUSTOMIZE_BSIZE = False

def __init__(self, test, device, jit, batch_size=None, extra_args=[]):
Expand Down

0 comments on commit c80d22b

Please sign in to comment.