diff --git a/trainer.ipynb b/trainer.ipynb new file mode 100644 index 00000000..0bfac9c2 --- /dev/null +++ b/trainer.ipynb @@ -0,0 +1,190 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "trainer.ipynb", + "provenance": [], + "collapsed_sections": [], + "authorship_tag": "ABX9TyMcgoPC4IPbj3zqM6RAGqkW", + "include_colab_link": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "background_save": true + }, + "id": "qZttVqkV83fp" + }, + "source": [ + "import torch\n", + "import numpy as np" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "17hmBjqM9JcR" + }, + "source": [ + "def fit(train_loader, val_loader, model, loss_fn, optimizer, scheduler, n_epochs, cuda, log_interval, metrics=[],\n", + " start_epoch=0):\n", + " \"\"\"\n", + " Loaders, model, loss function and metrics should work together for a given task,\n", + " i.e. The model should be able to process data output of loaders,\n", + " loss function should process target output of loaders and outputs from the model\n", + "\n", + " Examples: Classification: batch loader, classification model, NLL loss, accuracy metric\n", + " Siamese network: Siamese loader, siamese model, contrastive loss\n", + " Online triplet learning: batch loader, embedding model, online triplet loss\n", + " \"\"\"\n", + " for epoch in range(0, start_epoch):\n", + " scheduler.step()\n", + "\n", + " for epoch in range(start_epoch, n_epochs):\n", + " scheduler.step()\n", + "\n", + " # Train stage\n", + " train_loss, metrics = train_epoch(train_loader, model, loss_fn, optimizer, cuda, log_interval, metrics)\n", + "\n", + " message = 'Epoch: {}/{}. Train set: Average loss: {:.4f}'.format(epoch + 1, n_epochs, train_loss)\n", + " for metric in metrics:\n", + " message += '\\t{}: {}'.format(metric.name(), metric.value())\n", + "\n", + " val_loss, metrics = test_epoch(val_loader, model, loss_fn, cuda, metrics)\n", + " val_loss /= len(val_loader)\n", + "\n", + " message += '\\nEpoch: {}/{}. Validation set: Average loss: {:.4f}'.format(epoch + 1, n_epochs,\n", + " val_loss)\n", + " for metric in metrics:\n", + " message += '\\t{}: {}'.format(metric.name(), metric.value())\n", + "\n", + " print(message)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "lUHBriPj9eDw" + }, + "source": [ + "def train_epoch(train_loader, model, loss_fn, optimizer, cuda, log_interval, metrics):\n", + " for metric in metrics:\n", + " metric.reset()\n", + "\n", + " model.train()\n", + " losses = []\n", + " total_loss = 0\n", + "\n", + " for batch_idx, (data, target) in enumerate(train_loader):\n", + " target = target if len(target) > 0 else None\n", + " if not type(data) in (tuple, list):\n", + " data = (data,)\n", + " if cuda:\n", + " data = tuple(d.cuda() for d in data)\n", + " if target is not None:\n", + " target = target.cuda()\n", + "\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = model(*data)\n", + "\n", + " if type(outputs) not in (tuple, list):\n", + " outputs = (outputs,)\n", + "\n", + " loss_inputs = outputs\n", + " if target is not None:\n", + " target = (target,)\n", + " loss_inputs += target\n", + "\n", + " loss_outputs = loss_fn(*loss_inputs)\n", + " loss = loss_outputs[0] if type(loss_outputs) in (tuple, list) else loss_outputs\n", + " losses.append(loss.item())\n", + " total_loss += loss.item()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " for metric in metrics:\n", + " metric(outputs, target, loss_outputs)\n", + "\n", + " if batch_idx % log_interval == 0:\n", + " message = 'Train: [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", + " batch_idx * len(data[0]), len(train_loader.dataset),\n", + " 100. * batch_idx / len(train_loader), np.mean(losses))\n", + " for metric in metrics:\n", + " message += '\\t{}: {}'.format(metric.name(), metric.value())\n", + "\n", + " print(message)\n", + " losses = []\n", + "\n", + " total_loss /= (batch_idx + 1)\n", + " return total_loss, metrics" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "75jrvpdx9oNG" + }, + "source": [ + "def test_epoch(val_loader, model, loss_fn, cuda, metrics):\n", + " with torch.no_grad():\n", + " for metric in metrics:\n", + " metric.reset()\n", + " model.eval()\n", + " val_loss = 0\n", + " for batch_idx, (data, target) in enumerate(val_loader):\n", + " target = target if len(target) > 0 else None\n", + " if not type(data) in (tuple, list):\n", + " data = (data,)\n", + " if cuda:\n", + " data = tuple(d.cuda() for d in data)\n", + " if target is not None:\n", + " target = target.cuda()\n", + "\n", + " outputs = model(*data)\n", + "\n", + " if type(outputs) not in (tuple, list):\n", + " outputs = (outputs,)\n", + " loss_inputs = outputs\n", + " if target is not None:\n", + " target = (target,)\n", + " loss_inputs += target\n", + "\n", + " loss_outputs = loss_fn(*loss_inputs)\n", + " loss = loss_outputs[0] if type(loss_outputs) in (tuple, list) else loss_outputs\n", + " val_loss += loss.item()\n", + "\n", + " for metric in metrics:\n", + " metric(outputs, target, loss_outputs)\n", + "\n", + " return val_loss, metrics\n" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file