This is pedagogical implementation of MAML Algorithm.
To run MAML on Omnliglot data do
sh runs
python \
--data-folder <path to omniglot data> \
--num-classes 5 \
--num-meta-test-classes 5 \
--num-samples-per-class 1 \
--num-meta-test-samples-per-class 1 \
--batch-size 32 \
--inner-update-lr 0.4 \
--meta-lr 0.001 \
--num-meta-train-iterations 2000 \
--num-meta-test-iterations 200 \
--num-meta-validation-iterations 4 \
--num-inner-updates 1 \
--meta-test-num-inner-updates 1 \
--dim-hidden 16 \
--validation-frequency 100 \
--device cpu
The below plot, compares validation accuracy with and without MAML on 1-Shot, 5-way classification.
The core idea is to have pytorch modules, which can take parameters in the forward function. This code is inspired from
class MetaLinear(nn.Linear):
def forward(self, input, params):
if params is None:
params = OrderedDict(self.named_parameters())
weight = params.get("weight", None)
bias = params.get("bias", None)
weight = params.get(self.module_prefix + ".weight", None)
bias = params.get(self.module_prefix + ".bias", None)
return F.linear(input, weight, bias)
These modules help us to easily keep track of model parameters and meta parameters.
def apply_inner_update(
model, loss_fn, inner_update_lr, num_inner_updates, inps, labels
Perform inner gradient update 'num_inner_updates' time. The whole batch is composed of a single task
model: Model
loss_fn: Loss function
inner_update_lr: Learning rate for the inner update
num_inner_updates: An integer specifying the number of times,
inner gradient descent update should be done
inps: Input for the inner gradient loop. Inner update loss has to be calculated on this data
Size = [batch_size, 1, 28, 28]. 28 X 28 is the width and height of omniglot data
labels: Labels corresponding to inpa
Size = [batch_size,]
params: Updated params of the model after `num_inner_updates`
# Start from the intial paramters of the model.
params = None
for step in range(num_inner_updates):
# Do one step of inner gradient update
# each gradient update is done with the updated parameter
logits = model(inps, params)
loss = loss_fn(logits, labels)
updated_params = get_updated_params(
loss, model, params, inner_update_lr
# Next iteration uses new value of params
params = updated_params
return params
Please use issue tracker to raise any issues.