This example implements a popular meta-learning method called prototypical networks (ProtoNets) for few-shot learning on the Omniglot dataset. ProtoNets often outperform MAML on few-shot learning problems. Please consult the original paper and code published by the author for more details.
- model_def.py: The core code for the model. This includes building and compiling the model.
- data.py: The code used to generate few-shot learning tasks.
- startup-hook.sh: Script that will be run automatically by Determined in each container launched for this experiment. The script installs an additional dependency and invokes
fetch_data.sh
to download the training data. - fetch_data.sh: Script to download the Omniglot data.
- 20way1shot.yaml: The 20-way 1-shot configuration for the experiment.
- 20way5shot.yaml: The 20-way 5-shot configuration for the experiment.
We use the Omniglot download script from the meta-blocks repo.
This script is called in startup-hook.sh
so the data folder is available prior to
starting the trial and should take less than 15 seconds to run.
If you have not yet installed Determined, installation instructions can be found
under docs/install-admin.html
or at https://docs.determined.ai/latest/index.html
Run the following command: det -m <master host:port> experiment create -f 20way1shot.yaml .
. The other configurations can be run by specifying the appropriate
configuration file in place of 20way1shot.yaml
.
For 20-way 1-shot classification on Omniglot, this implementation should reach ~96% test accuracy in 20k batches and converge to over 97% (beating the published 96% result in the original paper). See an example learning curve below: