diff --git a/torchbenchmark/models/functorch_maml_omniglot/__init__.py b/torchbenchmark/models/functorch_maml_omniglot/__init__.py new file mode 100644 index 0000000000..faf16d73d1 --- /dev/null +++ b/torchbenchmark/models/functorch_maml_omniglot/__init__.py @@ -0,0 +1,105 @@ +import torch +import torch.optim as optim +import torch.nn as nn +import torch.nn.functional as F +from functorch import make_functional_with_buffers, vmap, grad +import functools +from pathlib import Path +from typing import Tuple + +from ...util.model import BenchmarkModel +from torchbenchmark.tasks import OTHER + + +def loss_for_task(net, n_inner_iter, x_spt, y_spt, x_qry, y_qry): + params, buffers, fnet = net + querysz = x_qry.size(0) + + def compute_loss(new_params, buffers, x, y): + logits = fnet(new_params, buffers, x) + loss = F.cross_entropy(logits, y) + return loss + + new_params = params + for _ in range(n_inner_iter): + grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt) + new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)] + + # The final set of adapted parameters will induce some + # final loss and accuracy on the query dataset. + # These will be used to update the model's meta-parameters. + qry_logits = fnet(new_params, buffers, x_qry) + qry_loss = F.cross_entropy(qry_logits, y_qry) + qry_acc = (qry_logits.argmax( + dim=1) == y_qry).sum() / querysz + + return qry_loss, qry_acc + + +class Model(BenchmarkModel): + task = OTHER.OTHER_TASKS + DEFAULT_TRAIN_BSIZE = 1 + DEFAULT_EVAL_BSIZE = 1 + ALLOW_CUSTOMIZE_BSIZE = False + + def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]): + super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args) + + n_way = 5 + inplace_relu = True + net = nn.Sequential( + nn.Conv2d(1, 64, 3), + nn.BatchNorm2d(64, affine=True, track_running_stats=False), + nn.ReLU(inplace=inplace_relu), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, affine=True, track_running_stats=False), + nn.ReLU(inplace=inplace_relu), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, affine=True, track_running_stats=False), + nn.ReLU(inplace=inplace_relu), + nn.MaxPool2d(2, 2), + nn.Flatten(), + nn.Linear(64, n_way)).to(device) + + self.model = net + + root = str(Path(__file__).parent.parent) + self.meta_inputs = torch.load(f'{root}/maml_omniglot/batch.pt') + self.meta_inputs = tuple([torch.from_numpy(i).to(self.device) for i in self.meta_inputs]) + self.example_inputs = (self.meta_inputs[0][0],) + + def get_module(self): + return self.model, self.example_inputs + + def train(self): + model = self.model + model.train() + fnet, params, buffers = make_functional_with_buffers(self.model) + net = (params, buffers, fnet) + meta_opt = optim.Adam(params, lr=1e-3) + + # Sample a batch of support and query images and labels. + x_spt, y_spt, x_qry, y_qry = self.meta_inputs + task_num, setsz, c_, h, w = x_spt.size() + + n_inner_iter = 5 + meta_opt.zero_grad() + + # In parallel, trains one model per task. There is a support (x, y) + # for each task and a query (x, y) for each task. + compute_loss_for_task = functools.partial(loss_for_task, net, n_inner_iter) + qry_losses, qry_accs = vmap(compute_loss_for_task)(x_spt, y_spt, x_qry, y_qry) + + # Compute the maml loss by summing together the returned losses. + qry_losses.sum().backward() + + meta_opt.step() + + def eval(self) -> Tuple[torch.Tensor]: + model, (example_input,) = self.get_module() + model.eval() + with torch.no_grad(): + out = model(example_input) + return (out, ) diff --git a/torchbenchmark/models/functorch_maml_omniglot/install.py b/torchbenchmark/models/functorch_maml_omniglot/install.py new file mode 100644 index 0000000000..be308ead48 --- /dev/null +++ b/torchbenchmark/models/functorch_maml_omniglot/install.py @@ -0,0 +1,9 @@ +import subprocess +import sys + + +def pip_install_requirements(): + subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt']) + +if __name__ == '__main__': + pip_install_requirements() diff --git a/torchbenchmark/models/functorch_maml_omniglot/metadata.yaml b/torchbenchmark/models/functorch_maml_omniglot/metadata.yaml new file mode 100644 index 0000000000..24f57c317b --- /dev/null +++ b/torchbenchmark/models/functorch_maml_omniglot/metadata.yaml @@ -0,0 +1,7 @@ +eval_benchmark: false +eval_deterministic: false +eval_nograd: true +train_benchmark: false +train_deterministic: false +not_implemented: + - jit: true \ No newline at end of file diff --git a/torchbenchmark/models/functorch_maml_omniglot/requirements.txt b/torchbenchmark/models/functorch_maml_omniglot/requirements.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchbenchmark/models/maml_omniglot/__init__.py b/torchbenchmark/models/maml_omniglot/__init__.py index 92f512423d..5fc21c0e5a 100644 --- a/torchbenchmark/models/maml_omniglot/__init__.py +++ b/torchbenchmark/models/maml_omniglot/__init__.py @@ -30,8 +30,23 @@ class Model(BenchmarkModel): task = OTHER.OTHER_TASKS - DEFAULT_TRAIN_BSIZE = 1 - DEFAULT_EVAL_BSIZE = 1 + # batch size in the traditional sense doesn't apply to this maml model. + # Instead, there is a task number (32 in this case) and K-shot + # (5 in this case) and these were chosen to be representative of + # the training done in the original MAML paper + # (https://arxiv.org/pdf/1703.03400.pdf) + # + # The goal of MAML is to train a model in a way that if one brings along + # a new task and K data points, then the model generalizes well on the + # test set for that task. + # + # The task number (also known as the meta-batch size) is the number of + # independent tasks the model gets trained on. + # K-shot means that each task only sees K data points. + # + # We've set the following variables to be equal to the task number. + DEFAULT_TRAIN_BSIZE = 32 + DEFAULT_EVAL_BSIZE = 32 ALLOW_CUSTOMIZE_BSIZE = False def __init__(self, test, device, jit, batch_size=None, extra_args=[]):