-
Notifications
You must be signed in to change notification settings - Fork 293
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
maml_omniglot model with functorch (#1179)
Summary: Very similar to the maml_omniglot model in torchbench (which is what we originally based this model off of). We also load some sample inputs from there because the model takes the same inputs. Pull Request resolved: #1179 Test Plan: - python test.py -k TestBenchmark.test_functorch_maml_omniglot_train_cuda Reviewed By: xuzhao9 Differential Revision: D39671689 Pulled By: zou3519 fbshipit-source-id: 2c1c6b92de5d32c0b37de5d069caa0f7bbb61e7d
- Loading branch information
1 parent
e64fc73
commit c80d22b
Showing
5 changed files
with
138 additions
and
2 deletions.
There are no files selected for viewing
105 changes: 105 additions & 0 deletions
105
torchbenchmark/models/functorch_maml_omniglot/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import torch | ||
import torch.optim as optim | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from functorch import make_functional_with_buffers, vmap, grad | ||
import functools | ||
from pathlib import Path | ||
from typing import Tuple | ||
|
||
from ...util.model import BenchmarkModel | ||
from torchbenchmark.tasks import OTHER | ||
|
||
|
||
def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry): | ||
params, buffers, fnet = net | ||
querysz = x_qry.size(0) | ||
|
||
def compute_loss(new_params, buffers, x, y): | ||
logits = fnet(new_params, buffers, x) | ||
loss = F.cross_entropy(logits, y) | ||
return loss | ||
|
||
new_params = params | ||
for _ in range(n_inner_iter): | ||
grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt) | ||
new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)] | ||
|
||
# The final set of adapted parameters will induce some | ||
# final loss and accuracy on the query dataset. | ||
# These will be used to update the model's meta-parameters. | ||
qry_logits = fnet(new_params, buffers, x_qry) | ||
qry_loss = F.cross_entropy(qry_logits, y_qry) | ||
qry_acc = (qry_logits.argmax( | ||
dim=1) == y_qry).sum() / querysz | ||
|
||
return qry_loss, qry_acc | ||
|
||
|
||
class Model(BenchmarkModel): | ||
task = OTHER.OTHER_TASKS | ||
DEFAULT_TRAIN_BSIZE = 1 | ||
DEFAULT_EVAL_BSIZE = 1 | ||
ALLOW_CUSTOMIZE_BSIZE = False | ||
|
||
def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]): | ||
super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args) | ||
|
||
n_way = 5 | ||
inplace_relu = True | ||
net = nn.Sequential( | ||
nn.Conv2d(1, 64, 3), | ||
nn.BatchNorm2d(64, affine=True, track_running_stats=False), | ||
nn.ReLU(inplace=inplace_relu), | ||
nn.MaxPool2d(2, 2), | ||
nn.Conv2d(64, 64, 3), | ||
nn.BatchNorm2d(64, affine=True, track_running_stats=False), | ||
nn.ReLU(inplace=inplace_relu), | ||
nn.MaxPool2d(2, 2), | ||
nn.Conv2d(64, 64, 3), | ||
nn.BatchNorm2d(64, affine=True, track_running_stats=False), | ||
nn.ReLU(inplace=inplace_relu), | ||
nn.MaxPool2d(2, 2), | ||
nn.Flatten(), | ||
nn.Linear(64, n_way)).to(device) | ||
|
||
self.model = net | ||
|
||
root = str(Path(__file__).parent.parent) | ||
self.meta_inputs = torch.load(f'{root}/maml_omniglot/batch.pt') | ||
self.meta_inputs = tuple([torch.from_numpy(i).to(self.device) for i in self.meta_inputs]) | ||
self.example_inputs = (self.meta_inputs[0][0],) | ||
|
||
def get_module(self): | ||
return self.model, self.example_inputs | ||
|
||
def train(self): | ||
model = self.model | ||
model.train() | ||
fnet, params, buffers = make_functional_with_buffers(self.model) | ||
net = (params, buffers, fnet) | ||
meta_opt = optim.Adam(params, lr=1e-3) | ||
|
||
# Sample a batch of support and query images and labels. | ||
x_spt, y_spt, x_qry, y_qry = self.meta_inputs | ||
task_num, setsz, c_, h, w = x_spt.size() | ||
|
||
n_inner_iter = 5 | ||
meta_opt.zero_grad() | ||
|
||
# In parallel, trains one model per task. There is a support (x, y) | ||
# for each task and a query (x, y) for each task. | ||
compute_loss_for_task = functools.partial(loss_for_task, net, n_inner_iter) | ||
qry_losses, qry_accs = vmap(compute_loss_for_task)(x_spt, y_spt, x_qry, y_qry) | ||
|
||
# Compute the maml loss by summing together the returned losses. | ||
qry_losses.sum().backward() | ||
|
||
meta_opt.step() | ||
|
||
def eval(self) -> Tuple[torch.Tensor]: | ||
model, (example_input,) = self.get_module() | ||
model.eval() | ||
with torch.no_grad(): | ||
out = model(example_input) | ||
return (out, ) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import subprocess | ||
import sys | ||
|
||
|
||
def pip_install_requirements(): | ||
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt']) | ||
|
||
if __name__ == '__main__': | ||
pip_install_requirements() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
eval_benchmark: false | ||
eval_deterministic: false | ||
eval_nograd: true | ||
train_benchmark: false | ||
train_deterministic: false | ||
not_implemented: | ||
- jit: true |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters