Skip to content

Commit

Permalink
Add example of maml training on omniglot
Browse files Browse the repository at this point in the history
This is related to pytorch#328. This PR adds an actually correct
implementation of maml to the repo. The previous implementation doesn't
actually compute higher order gradients where it is supposed to.

I'm not familiar with how torchbench works so please let me know if
there are additional files that need to be modified.

Test Plan:

Ran the following:
```
python test.py -k test_maml_omniglot_example_cpu
python test.py -k test_maml_omniglot_eval_cpu
python test.py -k test_maml_omniglot_train_cpu
```

Future work:
- Delete the maml example that is currently in this repo (or rename it
to make it clear that it's doing something different from the paper that
it is trying to reproduce).
  • Loading branch information
zou3519 committed May 24, 2021
1 parent baa48a1 commit 98898c2
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 0 deletions.
110 changes: 110 additions & 0 deletions torchbenchmark/models/maml_omniglot/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# This file was adapted from
# https://github.com/facebookresearch/higher/blob/master/examples/maml-omniglot.py
# It comes with the following license.
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import higher

from ...util.model import BenchmarkModel


class Model(BenchmarkModel):
def __init__(self, device=None, jit=False):
super().__init__()
self.device = device
self.jit = jit

n_way = 5
net = nn.Sequential(
nn.Conv2d(1, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Flatten(),
nn.Linear(64, n_way)).to(device)
self.model = net

root = str(Path(__file__).parent)
self.meta_inputs = torch.load(f'{root}/batch.pt')
self.meta_inputs = tuple([torch.from_numpy(i).to(self.device) for i in self.meta_inputs])

self.example_inputs = (self.meta_inputs[0][0],)

def get_module(self):
if self.jit:
raise NotImplementedError()

return self.model, self.example_inputs

def train(self, niter=3):
if self.jit:
raise NotImplementedError()

net, _ = self.get_module()
x_spt, y_spt, x_qry, y_qry = self.meta_inputs
meta_opt = optim.Adam(net.parameters(), lr=1e-3)

for _ in range(niter):
task_num, setsz, c_, h, w = x_spt.size()
querysz = x_qry.size(1)

n_inner_iter = 5
inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)

meta_opt.zero_grad()
for i in range(task_num):
with higher.innerloop_ctx(
net, inner_opt, copy_initial_weights=False
) as (fnet, diffopt):
for _ in range(n_inner_iter):
spt_logits = fnet(x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
diffopt.step(spt_loss)

qry_logits = fnet(x_qry[i])
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
qry_loss.backward()

meta_opt.step()

def eval(self, niter=1):
if self.jit:
raise NotImplementedError()

model, (example_input,) = self.get_module()
for i in range(niter):
model(example_input)


if __name__ == "__main__":
m = Model(device="cuda", jit=False)
module, example_inputs = m.get_module()
module(*example_inputs)
m.train(niter=1)
m.eval(niter=1)
Binary file added torchbenchmark/models/maml_omniglot/batch.pt
Binary file not shown.
9 changes: 9 additions & 0 deletions torchbenchmark/models/maml_omniglot/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import subprocess
import sys


def pip_install_requirements():
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt'])

if __name__ == '__main__':
pip_install_requirements()
1 change: 1 addition & 0 deletions torchbenchmark/models/maml_omniglot/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
higher

0 comments on commit 98898c2

Please sign in to comment.