forked from pytorch/benchmark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add example of maml training on omniglot
This is related to pytorch#328. This PR adds an actually correct implementation of maml to the repo. The previous implementation doesn't actually compute higher order gradients where it is supposed to. I'm not familiar with how torchbench works so please let me know if there are additional files that need to be modified. Test Plan: Ran the following: ``` python test.py -k test_maml_omniglot_example_cpu python test.py -k test_maml_omniglot_eval_cpu python test.py -k test_maml_omniglot_train_cpu ``` Future work: - Delete the maml example that is currently in this repo (or rename it to make it clear that it's doing something different from the paper that it is trying to reproduce).
- Loading branch information
Showing
4 changed files
with
120 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# This file was adapted from | ||
# https://github.com/facebookresearch/higher/blob/master/examples/maml-omniglot.py | ||
# It comes with the following license. | ||
# | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
import torch.optim as optim | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from pathlib import Path | ||
import higher | ||
|
||
from ...util.model import BenchmarkModel | ||
|
||
|
||
class Model(BenchmarkModel): | ||
def __init__(self, device=None, jit=False): | ||
super().__init__() | ||
self.device = device | ||
self.jit = jit | ||
|
||
n_way = 5 | ||
net = nn.Sequential( | ||
nn.Conv2d(1, 64, 3), | ||
nn.BatchNorm2d(64, momentum=1, affine=True), | ||
nn.ReLU(inplace=True), | ||
nn.MaxPool2d(2, 2), | ||
nn.Conv2d(64, 64, 3), | ||
nn.BatchNorm2d(64, momentum=1, affine=True), | ||
nn.ReLU(inplace=True), | ||
nn.MaxPool2d(2, 2), | ||
nn.Conv2d(64, 64, 3), | ||
nn.BatchNorm2d(64, momentum=1, affine=True), | ||
nn.ReLU(inplace=True), | ||
nn.MaxPool2d(2, 2), | ||
nn.Flatten(), | ||
nn.Linear(64, n_way)).to(device) | ||
self.model = net | ||
|
||
root = str(Path(__file__).parent) | ||
self.meta_inputs = torch.load(f'{root}/batch.pt') | ||
self.meta_inputs = tuple([torch.from_numpy(i).to(self.device) for i in self.meta_inputs]) | ||
|
||
self.example_inputs = (self.meta_inputs[0][0],) | ||
|
||
def get_module(self): | ||
if self.jit: | ||
raise NotImplementedError() | ||
|
||
return self.model, self.example_inputs | ||
|
||
def train(self, niter=3): | ||
if self.jit: | ||
raise NotImplementedError() | ||
|
||
net, _ = self.get_module() | ||
x_spt, y_spt, x_qry, y_qry = self.meta_inputs | ||
meta_opt = optim.Adam(net.parameters(), lr=1e-3) | ||
|
||
for _ in range(niter): | ||
task_num, setsz, c_, h, w = x_spt.size() | ||
querysz = x_qry.size(1) | ||
|
||
n_inner_iter = 5 | ||
inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1) | ||
|
||
meta_opt.zero_grad() | ||
for i in range(task_num): | ||
with higher.innerloop_ctx( | ||
net, inner_opt, copy_initial_weights=False | ||
) as (fnet, diffopt): | ||
for _ in range(n_inner_iter): | ||
spt_logits = fnet(x_spt[i]) | ||
spt_loss = F.cross_entropy(spt_logits, y_spt[i]) | ||
diffopt.step(spt_loss) | ||
|
||
qry_logits = fnet(x_qry[i]) | ||
qry_loss = F.cross_entropy(qry_logits, y_qry[i]) | ||
qry_loss.backward() | ||
|
||
meta_opt.step() | ||
|
||
def eval(self, niter=1): | ||
if self.jit: | ||
raise NotImplementedError() | ||
|
||
model, (example_input,) = self.get_module() | ||
for i in range(niter): | ||
model(example_input) | ||
|
||
|
||
if __name__ == "__main__": | ||
m = Model(device="cuda", jit=False) | ||
module, example_inputs = m.get_module() | ||
module(*example_inputs) | ||
m.train(niter=1) | ||
m.eval(niter=1) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import subprocess | ||
import sys | ||
|
||
|
||
def pip_install_requirements(): | ||
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt']) | ||
|
||
if __name__ == '__main__': | ||
pip_install_requirements() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
higher |