diff --git a/.flake8 b/.flake8
index d2b9576..27d0132 100644
--- a/.flake8
+++ b/.flake8
@@ -6,6 +6,6 @@ max-line-length = 88
# D100-D107: Missing docstrings
# D200: One-line docstring should fit on one line with quotes.
extend-ignore = E203,E402,F401,D100,D101,D102,D103,D104,D105,D106,D107,D200
-docstring-convention = numpy
+; docstring-convention = numpy
# Ignore missing docstrings within unit testing functions.
per-file-ignores = **/tests/:D100,D101,D102,D103,D104,D105,D106,D107
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
index 7833a1a..cf5ef8d 100644
--- a/.github/pull_request_template.md
+++ b/.github/pull_request_template.md
@@ -1,5 +1,5 @@
-![Coverage Badge](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/vbadenas/9b54bd086e121233d2ad9a62d2136258/raw/frarch__pull_##.json)
+![Coverage Badge](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/victorbadenas/9b54bd086e121233d2ad9a62d2136258/raw/frarch__pull_##.json)
**Notes for reviewer:**
diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml
new file mode 100644
index 0000000..8363e5a
--- /dev/null
+++ b/.github/workflows/docs.yaml
@@ -0,0 +1,66 @@
+# Docs workflow
+#
+# Ensures that the docs can be built with sphinx.
+# - On every push and PR, checks the HTML documentation builds on linux.
+# - On every PR and tag, checks the documentation builds as a PDF on linux.
+# - If your repository is public, on pushes to the default branch (i.e. either
+# master or main), the HTML documentation is pushed to the gh-pages branch,
+# which is automatically rendered at the publicly accessible url
+# https://USER.github.io/PACKAGE/
+
+name: docs
+
+on: [push, pull_request]
+
+jobs:
+ docs-html:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: Build HTML docs
+ uses: ammaraskar/sphinx-action@master
+ with:
+ docs-folder: "docs/"
+
+ - name: Determine default branch
+ run: |
+ DEFAULT_BRANCH=$(git remote show origin | awk '/HEAD branch/ {print $NF}')
+ echo "default_branch=$DEFAULT_BRANCH" >> $GITHUB_ENV
+ echo "default_branch_ref=refs/heads/$DEFAULT_BRANCH" >> $GITHUB_ENV
+
+ - name: Determine whether repo is public
+ run: |
+ REMOTE_HTTP=$(git remote get-url origin | sed -e "s|:\([^/]\)|/\1|g" -e "s|^git@|https://|" -e "s|\.git$||")
+ echo "Probing $REMOTE_HTTP"
+ if wget -q --method=HEAD ${REMOTE_HTTP}; then IS_PUBLIC=1; else IS_PUBLIC=0; fi
+ echo "is_public=$IS_PUBLIC"
+ echo "is_public=$IS_PUBLIC" >> $GITHUB_ENV
+
+ - name: Deploy to GitHub Pages
+ uses: peaceiris/actions-gh-pages@v3
+ if: github.ref == env.default_branch_ref && env.is_public == 1
+ with:
+ github_token: ${{ secrets.GITHUB_TOKEN }}
+ publish_dir: "docs/_build/html/"
+
+ # docs-pdf:
+ # if: |
+ # github.event_name == 'pull_request' ||
+ # startsWith(github.ref, 'refs/tags/')
+ # runs-on: ubuntu-latest
+ # steps:
+ # - uses: actions/checkout@v2
+
+ # - name: Build PDF docs
+ # uses: ammaraskar/sphinx-action@master
+ # with:
+ # docs-folder: "docs/"
+ # pre-build-command: "apt-get update -y && apt-get install -y latexmk texlive-latex-recommended texlive-latex-extra texlive-fonts-recommended"
+ # build-command: "make latexpdf"
+
+ # - uses: actions/upload-artifact@v2
+ # if: startsWith(github.ref, 'refs/tags')
+ # with:
+ # name: Documentation
+ # path: docs/_build/latex/*.pdf
diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml
index 8b495e2..63bef86 100644
--- a/.github/workflows/python-app.yml
+++ b/.github/workflows/python-app.yml
@@ -1,23 +1,9 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
-name: frarch CI
-
-on:
- push:
- branches:
- - master
- - development
- - feature_*
- - fix_*
- tags:
- - v*
- pull_request:
- branches:
- - master
- - development
- - feature_*
- - fix_*
+name: CI
+
+on: [push, pull_request]
jobs:
build:
@@ -28,7 +14,7 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v2
with:
@@ -42,27 +28,30 @@ jobs:
unittest:
strategy:
matrix:
- os: [ubuntu-latest, macos-latest, windows-latest]
python: [3.7, 3.8, 3.9]
- runs-on: ${{ matrix.os }}
- needs: [ build ]
+ runs-on: ubuntu-latest
+ needs: [build]
+
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v3
+
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python }}
+
- name: install
run: |
python -m pip install --upgrade pip
pip install coverage
python setup.py -q install
+
- name: test
run: |
- coverage run -m unittest discover
- coverage report -m --omit='tests/*'
+ coverage run -m unittest discover -s tests/unit
+ coverage report -m --omit='tests/unit/*'
- - if: ${{ matrix.python == '3.9' && matrix.os == 'ubuntu-latest' }}
+ - if: ${{ matrix.python == '3.9' }}
name: Get Coverage for badge
run: |
coverage json
@@ -75,7 +64,7 @@ jobs:
echo $BRANCH_NAME
echo "BRANCH=$(echo ${BRANCH_NAME})" >> $GITHUB_ENV
- - if: ${{ matrix.python == '3.9' && matrix.os == 'ubuntu-latest' }}
+ - if: ${{ matrix.python == '3.9' }}
name: Create the Badge
uses: schneegans/dynamic-badges-action@v1.0.0
with:
@@ -85,14 +74,62 @@ jobs:
label: Test Coverage
message: ${{ env.COVERAGE }}
color: green
- namedLogo: Python
+ namedLogo: codecov
+
+ unittest-cross-platform:
+ strategy:
+ matrix:
+ os: [macos-latest, windows-latest]
+ runs-on: ${{ matrix.os }}
+ needs: [build]
+
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Set up Python 3.9
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.9
+
+ - name: install
+ run: |
+ python -m pip install --upgrade pip
+ pip install coverage
+ python setup.py -q install
+
+ - name: test
+ run: |
+ coverage run -m unittest discover -s tests/unit
+ coverage report -m --omit='tests/unit/*'
- release-and-package:
+ functional:
+ strategy:
+ matrix:
+ python: [3.7, 3.8, 3.9]
runs-on: ubuntu-latest
needs: [build, unittest]
steps:
+ - uses: actions/checkout@v3
+ - name: Set up Python ${{ matrix.python }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python }}
+ - name: install
+ run: |
+ python -m pip install --upgrade pip
+ pip install coverage
+ python setup.py -q install
+ - name: functional-test
+ run: |
+ cd tests/functional
+ ./runFunctionalTests.sh
+
+ package-and-release:
+ runs-on: ubuntu-latest
+ needs: [build, unittest, functional]
+ steps:
- name: checkout
- uses: actions/checkout@v2
+ uses: actions/checkout@v3
- name: Set up Python 3.7
uses: actions/setup-python@v2
diff --git a/LICENSE b/LICENSE
index f58ebef..2e0ec1d 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,22 +1,190 @@
-The MIT License (MIT)
+ Apache License
+ Version 2.0, January 2004
+ https://www.apache.org/licenses/
-Copyright (c) 2021 Scott Lowe
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
+ 1. Definitions.
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ Copyright 2013-2018 Docker, Inc.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index 42d92c3..a3178de 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,199 @@
-# **Frarch**
+# Frarch
-![Coverage Badge](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/vbadenas/9b54bd086e121233d2ad9a62d2136258/raw/frarch__heads_master.json)
+![Coverage Badge](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/victorbadenas/9b54bd086e121233d2ad9a62d2136258/raw/frarch__heads_master.json&style=flat)
+![Pytorch](https://img.shields.io/static/v1?label=PyTorch&message=v1.9.1&color=orange&style=flat&logo=pytorch)
+![python](https://img.shields.io/pypi/pyversions/frarch?logo=python&style=flat)
-**Fra**mework for Pyto**rch** experiments inspired in [speechbrain's](https://speechbrain.github.io/) workflow using [hyperpyyaml](https://github.com/speechbrain/HyperPyYAML) configuration files.
+![CI](https://github.com/victorbadenas/frarch/actions/workflows/python-app.yml/badge.svg?style=flat)
+[![PyPI version fury.io](https://badge.fury.io/py/frarch.svg?style=flat)](https://pypi.python.org/pypi/frarch/)
+![license](https://img.shields.io/github/license/victorbadenas/frarch?style=flat)
+
+Frarch is a **Fra**mework for Pyto**rch** experiments inspired by [speechbrain's](https://speechbrain.github.io/) workflow using [hyperpyyaml](https://github.com/speechbrain/HyperPyYAML) configuration files. Frarch aims to minimize the code needed to perform an experiment while organizing the output models and the log files for the experiment as well as the configuration files used to train them in an organised manner.
+
+## Features
+
+- `CPU` and `CUDA` computations. Note that CUDA must be installed for Pytorch and as such frarch to compute in an NVIDIA GPU. Multi-GPU is not supported at the moment, but will be supported in the future.
+- Minimize the size of training scripts.
+- Support for Python's 3.7, 3.8 and 3.9 versions
+- yaml definition of training hyperparameters.
+- organisation of output models and their hyperparameters, training scripts and logs.
+
+## Quick installation
+
+The frarch package is evolving and not yet in a stable release. Documentation will be added as the package progresses. The package can be installed via PyPI or via github for the users that what to modify the contents of the package.
+
+### PyPI installation
+
+Once the python environment has been created, you can install frarch by executing:
+
+```bash
+pip install frarch
+```
+
+Then frarch can be used in a python script using:
+
+```python
+import frarch as fr
+```
+
+### Github install
+
+Once the python environment has been created, you can install frarch by executing:
+
+```bash
+git clone https://github.com/victorbadenas/frarch.git
+cd frarch
+python setup.py install
+```
+
+for development instead of the last command, run `python setup.py develop` to be able to hot reload changes to the package.
+
+### Test
+
+To run the tests for the frarch package:
+
+```bash
+python setup.py install
+python -m unittest discover
+```
+
+### Documentation
+
+To create the documentation, run the following command:
+
+```bash
+make -C docs html
+sensible-browser docs/_build/html/index.html
+make -C docs latexpdf
+```
+
+## Running an experiment
+
+Frarch provides training classes such as [`ClassifierTrainer`](https://victorbadenas.github.io/frarch/source/packages/frarch.train.classifier_trainer.html) which provides methods to train a classifier model.
+
+### Example Python trainer script
+
+In this example we present a sample training script for training the MNIST dataset.
+
+```python
+from hyperpyyaml import load_hyperpyyaml
+
+import frarch as fr
+
+from frarch.utils.data import build_experiment_structure
+from frarch.utils.stages import Stage
+
+
+class MNISTTrainer(fr.train.ClassifierTrainer):
+ def forward(self, batch, stage):
+ inputs, _ = batch
+ inputs = inputs.to(self.device)
+ return self.modules.model(inputs)
+
+ def compute_loss(self, predictions, batch, stage):
+ _, labels = batch
+ labels = labels.to(self.device)
+ return self.hparams["loss"](predictions, labels)
+
+ def on_stage_end(self, stage, loss=None, epoch=None):
+ if stage == Stage.VALID:
+ if self.checkpointer is not None:
+ self.checkpointer.save(epoch=self.current_epoch, current_step=self.step)
+
+
+if __name__ == "__main__":
+ hparam_file, args = fr.parse_arguments()
+
+ with open(hparam_file, "r") as hparam_file_handler:
+ hparams = load_hyperpyyaml(
+ hparam_file_handler, args, overrides_must_match=False
+ )
+
+ build_experiment_structure(
+ hparam_file,
+ overrides=args,
+ experiment_folder=hparams["experiment_folder"],
+ debug=hparams["debug"],
+ )
+
+ trainer = MNISTTrainer(
+ modules=hparams["modules"],
+ opt_class=hparams["opt_class"],
+ hparams=hparams,
+ checkpointer=hparams["checkpointer"],
+ )
+
+ trainer.fit(
+ train_set=hparams["train_dataset"],
+ valid_set=hparams["valid_dataset"],
+ train_loader_kwargs=hparams["dataloader_options"],
+ valid_loader_kwargs=hparams["dataloader_options"],
+ )
+```
+
+And the hparams yaml file used to configure the experiment:
+
+```yaml
+# seeds
+seed: 42
+__set_seed: !apply:torch.manual_seed [!ref ]
+experiment_name: "mnist"
+experiment_folder: "results/mnist_demo/"
+device: "cpu"
+
+# data folder
+data_folder: /tmp/
+
+# training parameters
+epochs: 2
+batch_size: 128
+shuffle: True
+num_clases: 10
+
+transform_tensor: !new:torchvision.transforms.ToTensor
+preprocessing: !new:torchvision.transforms.Compose
+ transforms: [
+ !ref ,
+ ]
+
+# dataset object
+train_dataset: !new:torchvision.datasets.MNIST
+ root: !ref
+ train: true
+ download: true
+ transform: !ref
+
+valid_dataset: !new:torchvision.datasets.MNIST
+ root: !ref
+ train: false
+ download: true
+ transform: !ref
+
+# dataloader options
+dataloader_options:
+ batch_size: !ref
+ shuffle: !ref
+ num_workers: 8
+
+opt_class: !name:torch.optim.Adam
+ lr: 0.001
+
+loss: !new:torch.nn.CrossEntropyLoss
+
+model: !apply:torchvision.models.vgg11
+ pretrained: false
+
+modules:
+ model: !ref
+
+checkpointer: !new:frarch.modules.Checkpointer
+ save_path: !ref
+ modules: !ref
+
+```
+
+For the code execution run:
+
+```bash
+python train.py mnist.yaml
+```
diff --git a/docs/logo.png b/docs/logo.png
new file mode 100644
index 0000000..79610b7
Binary files /dev/null and b/docs/logo.png differ
diff --git a/experiments/caltech101/hparams/caltech101_vgg11_same.yaml b/experiments/caltech101/hparams/caltech101_vgg11.yaml
similarity index 100%
rename from experiments/caltech101/hparams/caltech101_vgg11_same.yaml
rename to experiments/caltech101/hparams/caltech101_vgg11.yaml
diff --git a/experiments/caltech101/hparams/caltech101_vgg11_same_normalized.yaml b/experiments/caltech101/hparams/caltech101_vgg11_same_normalized.yaml
deleted file mode 100644
index df72ec9..0000000
--- a/experiments/caltech101/hparams/caltech101_vgg11_same_normalized.yaml
+++ /dev/null
@@ -1,85 +0,0 @@
-################################################
-# #
-# Model: VGG11 Same Padding Normalized #
-# Author: vbadenas #
-# #
-################################################
-
-# seeds
-seed: 42
-__set_seed: !apply:torch.manual_seed [!ref ]
-experiment_name: caltech101_vgg11_normalized
-experiment_folder: "results/caltech101_vgg11_normalized/"
-device: "cuda:0"
-train_interval: 100
-
-# data folder
-data_folder: !PLACEHOLDER
-
-# training parameters
-epochs: 20
-batch_size: 16
-shuffle: True
-
-num_classes: 101
-
-transform_tensor: !new:torchvision.transforms.ToTensor
-transform_random_resized_crop: !new:torchvision.transforms.RandomResizedCrop
- size: 224
-transform_resize: !new:torchvision.transforms.Resize
- size: [224, 224]
-transform_normalize: !new:torchvision.transforms.Normalize
- mean: [0.5460, 0.5286, 0.5019]
- std: [0.2426, 0.2396, 0.2414]
-
-transform_train: !new:torchvision.transforms.Compose
- transforms: [
- !ref ,
- !ref ,
- !ref
- ]
-
-transform_valid: !new:torchvision.transforms.Compose
- transforms: [
- !ref ,
- !ref ,
- !ref
- ]
-
-train_dataset: !new:frarch.datasets.Caltech101
- subset: train
- root: !ref
- transform: !ref
-
-valid_dataset: !new:frarch.datasets.Caltech101
- subset: valid
- root: !ref
- transform: !ref
-
-# dataloader options
-dataloader_options:
- batch_size: !ref
- shuffle: !ref
- num_workers: 40
-
-opt_class: !name:torch.optim.Adam
- lr: 0.0001
-
-loss: !new:torch.nn.CrossEntropyLoss
-
-metrics: !new:frarch.modules.metrics.MetricsWrapper
- classification_error: !new:frarch.modules.metrics.ClassificationError
-
-model: !apply:torchvision.models.vgg11
- pretrained: False
- num_classes: !ref
-
-modules:
- model: !ref
-
-checkpointer: !new:frarch.modules.Checkpointer
- save_path: !ref
- reference_metric: "classification_error"
- save_best_only: True
- modules:
- model: !ref
diff --git a/experiments/caltech101/hparams/caltech101_vgg11_same_normalized_augmented.yaml b/experiments/caltech101/hparams/caltech101_vgg11_same_normalized_augmented.yaml
deleted file mode 100644
index b584862..0000000
--- a/experiments/caltech101/hparams/caltech101_vgg11_same_normalized_augmented.yaml
+++ /dev/null
@@ -1,88 +0,0 @@
-################################################
-# #
-# Model: VGG11 Same Padding Normalized Aug #
-# Author: vbadenas #
-# #
-################################################
-
-# seeds
-seed: 42
-__set_seed: !apply:torch.manual_seed [!ref ]
-experiment_name: caltech101_vgg11_normalized_aug
-experiment_folder: "results/caltech101_vgg11_normalized_aug/"
-device: "cuda:0"
-train_interval: 100
-
-# data folder
-data_folder: !PLACEHOLDER
-
-# training parameters
-epochs: 50
-batch_size: 16
-shuffle: True
-
-num_classes: 101
-
-transform_tensor: !new:torchvision.transforms.ToTensor
-transform_random_resized_crop: !new:torchvision.transforms.RandomResizedCrop
- size: 224
-transform_resize: !new:torchvision.transforms.Resize
- size: [224, 224]
-transform_normalize: !new:torchvision.transforms.Normalize
- mean: [0.5460, 0.5286, 0.5019]
- std: [0.2426, 0.2396, 0.2414]
-transform_powerpil: !new:frarch.datasets.transforms.PowerPIL
-
-transform_train: !new:torchvision.transforms.Compose
- transforms: [
- !ref ,
- !ref ,
- !ref ,
- !ref
- ]
-
-transform_valid: !new:torchvision.transforms.Compose
- transforms: [
- !ref ,
- !ref ,
- !ref ,
- !ref
- ]
-
-train_dataset: !new:frarch.datasets.Caltech101
- subset: train
- root: !ref
- transform: !ref
-
-valid_dataset: !new:frarch.datasets.Caltech101
- subset: valid
- root: !ref
- transform: !ref
-
-# dataloader options
-dataloader_options:
- batch_size: !ref
- shuffle: !ref
- num_workers: 40
-
-opt_class: !name:torch.optim.Adam
- lr: 0.0001
-
-loss: !new:torch.nn.CrossEntropyLoss
-
-metrics: !new:frarch.modules.metrics.MetricsWrapper
- classification_error: !new:frarch.modules.metrics.ClassificationError
-
-model: !apply:torchvision.models.vgg11
- pretrained: False
- num_classes: !ref
-
-modules:
- model: !ref
-
-checkpointer: !new:frarch.modules.Checkpointer
- save_path: !ref
- reference_metric: "classification_error"
- save_best_only: True
- modules:
- model: !ref
diff --git a/experiments/caltech101/hparams/caltech101_vgg11_valid.yaml b/experiments/caltech101/hparams/caltech101_vgg11_valid.yaml
deleted file mode 100644
index c56f598..0000000
--- a/experiments/caltech101/hparams/caltech101_vgg11_valid.yaml
+++ /dev/null
@@ -1,81 +0,0 @@
-################################################
-# #
-# Model: VGG11 Valid Padding #
-# Author: vbadenas #
-# #
-################################################
-
-# seeds
-seed: 42
-__set_seed: !apply:torch.manual_seed [!ref ]
-experiment_name: caltech101_vgg11
-experiment_folder: "results/caltech101_vgg11/"
-device: "cpu"
-train_interval: 100
-
-# data folder
-data_folder: !PLACEHOLDER
-
-# training parameters
-epochs: 20
-batch_size: 16
-shuffle: True
-padding: valid
-
-num_classes: 101
-
-transform_tensor: !new:torchvision.transforms.ToTensor
-transform_random_resized_crop: !new:torchvision.transforms.RandomResizedCrop
- size: 224
-transform_resize: !new:torchvision.transforms.Resize
- size: [224, 224]
-
-transform_train: !new:torchvision.transforms.Compose
- transforms: [
- !ref ,
- !ref ,
- ]
-
-transform_valid: !new:torchvision.transforms.Compose
- transforms: [
- !ref ,
- !ref ,
- ]
-
-train_dataset: !new:frarch.datasets.Caltech101
- subset: train
- root: !ref
- transform: !ref
-
-valid_dataset: !new:frarch.datasets.Caltech101
- subset: valid
- root: !ref
- transform: !ref
-
-# dataloader options
-dataloader_options:
- batch_size: !ref
- shuffle: !ref
- num_workers: 40
-
-opt_class: !name:torch.optim.Adam
- lr: 0.0001
-
-loss: !new:torch.nn.CrossEntropyLoss
-
-metrics: !new:frarch.modules.metrics.MetricsWrapper
- classification_error: !new:frarch.modules.metrics.ClassificationError
-
-model: !apply:torchvision.models.vgg11
- pretrained: False
- num_classes: !ref
-
-modules:
- model: !ref
-
-checkpointer: !new:frarch.modules.Checkpointer
- save_path: !ref
- reference_metric: "classification_error"
- save_best_only: True
- modules:
- model: !ref
diff --git a/experiments/caltech101/hparams/caltech101_vgg11_valid_normalized.yaml b/experiments/caltech101/hparams/caltech101_vgg11_valid_normalized.yaml
deleted file mode 100644
index 7945f59..0000000
--- a/experiments/caltech101/hparams/caltech101_vgg11_valid_normalized.yaml
+++ /dev/null
@@ -1,86 +0,0 @@
-################################################
-# #
-# Model: VGG11 Valid Padding Normalized #
-# Author: vbadenas #
-# #
-################################################
-
-# seeds
-seed: 42
-__set_seed: !apply:torch.manual_seed [!ref ]
-experiment_name: caltech101_vgg11_normalized
-experiment_folder: "results/caltech101_vgg11_normalized/"
-device: "cuda:0"
-train_interval: 100
-
-# data folder
-data_folder: !PLACEHOLDER
-
-# training parameters
-epochs: 20
-batch_size: 16
-shuffle: True
-padding: valid
-
-num_classes: 101
-
-transform_tensor: !new:torchvision.transforms.ToTensor
-transform_random_resized_crop: !new:torchvision.transforms.RandomResizedCrop
- size: 224
-transform_resize: !new:torchvision.transforms.Resize
- size: [224, 224]
-transform_normalize: !new:torchvision.transforms.Normalize
- mean: [0.5460, 0.5286, 0.5019]
- std: [0.2426, 0.2396, 0.2414]
-
-transform_train: !new:torchvision.transforms.Compose
- transforms: [
- !ref ,
- !ref ,
- !ref
- ]
-
-transform_valid: !new:torchvision.transforms.Compose
- transforms: [
- !ref ,
- !ref ,
- !ref
- ]
-
-train_dataset: !new:frarch.datasets.Caltech101
- subset: train
- root: !ref
- transform: !ref
-
-valid_dataset: !new:frarch.datasets.Caltech101
- subset: valid
- root: !ref
- transform: !ref
-
-# dataloader options
-dataloader_options:
- batch_size: !ref
- shuffle: !ref
- num_workers: 40
-
-opt_class: !name:torch.optim.Adam
- lr: 0.0001
-
-loss: !new:torch.nn.CrossEntropyLoss
-
-metrics: !new:frarch.modules.metrics.MetricsWrapper
- classification_error: !new:frarch.modules.metrics.ClassificationError
-
-model: !apply:torchvision.models.vgg11
- pretrained: False
- num_classes: !ref
-
-modules:
- model: !ref
-
-checkpointer: !new:frarch.modules.Checkpointer
- save_path: !ref
- reference_metric: "classification_error"
- save_best_only: True
- modules:
- model: !ref
diff --git a/experiments/caltech101/hparams/caltech101_vgg11_valid_normalized_augmented.yaml b/experiments/caltech101/hparams/caltech101_vgg11_valid_normalized_augmented.yaml
deleted file mode 100644
index 461c764..0000000
--- a/experiments/caltech101/hparams/caltech101_vgg11_valid_normalized_augmented.yaml
+++ /dev/null
@@ -1,89 +0,0 @@
-################################################
-# #
-# Model: VGG11 Valid Padding Normalized Aug #
-# Author: vbadenas #
-# #
-################################################
-
-# seeds
-seed: 42
-__set_seed: !apply:torch.manual_seed [!ref ]
-experiment_name: caltech101_vgg11_normalized_aug
-experiment_folder: "results/caltech101_vgg11_normalized_aug/"
-device: "cuda:0"
-train_interval: 100
-
-# data folder
-data_folder: !PLACEHOLDER
-
-# training parameters
-epochs: 50
-batch_size: 16
-shuffle: True
-padding: valid
-
-num_classes: 101
-
-transform_tensor: !new:torchvision.transforms.ToTensor
-transform_random_resized_crop: !new:torchvision.transforms.RandomResizedCrop
- size: 224
-transform_resize: !new:torchvision.transforms.Resize
- size: [224, 224]
-transform_normalize: !new:torchvision.transforms.Normalize
- mean: [0.5460, 0.5286, 0.5019]
- std: [0.2426, 0.2396, 0.2414]
-transform_powerpil: !new:frarch.datasets.transforms.PowerPIL
-
-transform_train: !new:torchvision.transforms.Compose
- transforms: [
- !ref ,
- !ref ,
- !ref ,
- !ref
- ]
-
-transform_valid: !new:torchvision.transforms.Compose
- transforms: [
- !ref ,
- !ref ,
- !ref ,
- !ref
- ]
-
-train_dataset: !new:frarch.datasets.Caltech101
- subset: train
- root: !ref
- transform: !ref
-
-valid_dataset: !new:frarch.datasets.Caltech101
- subset: valid
- root: !ref
- transform: !ref
-
-# dataloader options
-dataloader_options:
- batch_size: !ref
- shuffle: !ref
- num_workers: 40
-
-opt_class: !name:torch.optim.Adam
- lr: 0.0001
-
-loss: !new:torch.nn.CrossEntropyLoss
-
-metrics: !new:frarch.modules.metrics.MetricsWrapper
- classification_error: !new:frarch.modules.metrics.ClassificationError
-
-model: !apply:torchvision.models.vgg11
- pretrained: False
- num_classes: !ref
-
-modules:
- model: !ref
-
-checkpointer: !new:frarch.modules.Checkpointer
- save_path: !ref
- reference_metric: "classification_error"
- save_best_only: True
- modules:
- model: !ref
diff --git a/experiments/caltech101/train_caltech101.py b/experiments/caltech101/train_caltech101.py
index a7a0bab..9a26221 100644
--- a/experiments/caltech101/train_caltech101.py
+++ b/experiments/caltech101/train_caltech101.py
@@ -27,50 +27,25 @@
class Caltech101Trainer(fr.train.ClassifierTrainer):
- def __init__(self, *args, **kwargs):
- super(Caltech101Trainer, self).__init__(*args, **kwargs)
- if "padding" in self.hparams:
- if self.hparams["padding"] == "valid":
- self.change_model_padding()
- elif self.hparams["padding"] == "same":
- logger.info("padding not changed. Defaulting to same.")
- else:
- logger.warning(
- "padding configuration not understood. Defaulting to same."
- )
-
- def change_model_padding(self):
- for layer_name, layer in self.modules.model.named_modules():
- if isinstance(layer, torch.nn.Conv2d):
- padding_conf = self.hparams["padding"]
- logger.info(
- f"Changing {layer_name}'s padding from same to {padding_conf}"
- )
- layer._reversed_padding_repeated_twice = (0, 0, 0, 0)
- layer.padding = (0, 0)
- self.modules.model.avgpool.output_size = (3, 3)
- self.modules.model.classifier[0] = torch.nn.Linear(512 * 3 * 3, 4096)
- self.modules = self.modules.to(self.device)
-
- def forward(self, batch, stage):
+ def _forward(self, batch, stage):
inputs, _ = batch
inputs = inputs.to(self.device)
return self.modules.model(inputs)
- def compute_loss(self, predictions, batch, stage):
+ def _compute_loss(self, predictions, batch, stage):
_, labels = batch
labels = labels.to(self.device)
loss = self.hparams["loss"](predictions, labels)
self.hparams["metrics"].update(predictions, labels)
return loss
- def on_stage_start(self, stage, loss=None, epoch=None):
+ def _on_stage_start(self, stage, loss=None, epoch=None):
self.hparams["metrics"].reset()
if self.debug:
metrics = self.hparams["metrics"].get_metrics(mode="mean")
logger.debug(metrics)
- def on_stage_end(self, stage, loss=None, epoch=None):
+ def _on_stage_end(self, stage, loss=None, epoch=None):
metrics = self.hparams["metrics"].get_metrics(mode="mean")
metrics_string = "".join([f"{k}={v:.4f}" for k, v in metrics.items()])
diff --git a/experiments/fashion_mnist/train_fashion_mnist.py b/experiments/fashion_mnist/train_fashion_mnist.py
index 8e0fd0f..e7c68d9 100644
--- a/experiments/fashion_mnist/train_fashion_mnist.py
+++ b/experiments/fashion_mnist/train_fashion_mnist.py
@@ -31,13 +31,13 @@
class FMNISTTrainer(fr.train.ClassifierTrainer):
- def forward(self, batch, stage):
+ def _forward(self, batch, stage):
inputs, _ = batch
inputs = inputs.to(self.device)
embeddings = self.modules.model(inputs)
return self.modules.classifier(embeddings)
- def compute_loss(self, predictions, batch, stage):
+ def _compute_loss(self, predictions, batch, stage):
_, labels = batch
labels = labels.to(self.device)
loss = self.hparams["loss"](predictions, labels)
@@ -45,14 +45,14 @@ def compute_loss(self, predictions, batch, stage):
self.hparams["metrics"].update(predictions, labels)
return loss
- def on_stage_start(self, stage, loss=None, epoch=None):
+ def _on_stage_start(self, stage, loss=None, epoch=None):
if stage == Stage.VALID:
self.hparams["metrics"].reset()
if self.debug:
metrics = self.hparams["metrics"].get_metrics(mode="mean")
logger.debug(metrics)
- def on_stage_end(self, stage, loss=None, epoch=None):
+ def _on_stage_end(self, stage, loss=None, epoch=None):
if stage == Stage.VALID:
metrics = self.hparams["metrics"].get_metrics(mode="mean")
metrics_string = "".join([f"{k}=={v:.4f}" for k, v in metrics.items()])
diff --git a/experiments/mit67/trainmit67.py b/experiments/mit67/trainmit67.py
index 3ac38b5..73d2b94 100644
--- a/experiments/mit67/trainmit67.py
+++ b/experiments/mit67/trainmit67.py
@@ -26,23 +26,23 @@
class Mit67Trainer(fr.train.ClassifierTrainer):
- def forward(self, batch, stage):
+ def _forward(self, batch, stage):
inputs, _ = batch
inputs = inputs.to(self.device)
outputs = self.modules.model(inputs)
return self.modules.classifier(outputs)
- def compute_loss(self, predictions, batch, stage):
+ def _compute_loss(self, predictions, batch, stage):
_, labels = batch
labels = labels.to(self.device)
loss = self.hparams["loss"](predictions, labels)
self.hparams["metrics"].update(predictions, labels)
return loss
- def on_stage_start(self, stage, loss=None, epoch=None):
+ def _on_stage_start(self, stage, loss=None, epoch=None):
self.hparams["metrics"].reset()
- def on_stage_end(self, stage, loss=None, epoch=None):
+ def _on_stage_end(self, stage, loss=None, epoch=None):
metrics = self.hparams["metrics"].get_metrics(mode="mean")
metrics_string = "".join([f"{k}={v:.4f}" for k, v in metrics.items()])
diff --git a/experiments/mnist/train_mnist.py b/experiments/mnist/train_mnist.py
index 69f384c..0bd8531 100644
--- a/experiments/mnist/train_mnist.py
+++ b/experiments/mnist/train_mnist.py
@@ -26,13 +26,13 @@
class MNISTTrainer(fr.train.ClassifierTrainer):
- def forward(self, batch, stage):
+ def _forward(self, batch, stage):
inputs, _ = batch
inputs = inputs.to(self.device)
embeddings = self.modules.model(inputs)
return self.modules.classifier(embeddings)
- def compute_loss(self, predictions, batch, stage):
+ def _compute_loss(self, predictions, batch, stage):
_, labels = batch
labels = labels.to(self.device)
loss = self.hparams["loss"](predictions, labels)
@@ -40,14 +40,14 @@ def compute_loss(self, predictions, batch, stage):
self.hparams["metrics"].update(predictions, labels)
return loss
- def on_stage_start(self, stage, loss=None, epoch=None):
+ def _on_stage_start(self, stage, loss=None, epoch=None):
if stage == Stage.VALID:
self.hparams["metrics"].reset()
if self.debug:
metrics = self.hparams["metrics"].get_metrics(mode="mean")
logger.debug(metrics)
- def on_stage_end(self, stage, loss=None, epoch=None):
+ def _on_stage_end(self, stage, loss=None, epoch=None):
if stage == Stage.VALID:
metrics = self.hparams["metrics"].get_metrics(mode="mean")
metrics_string = "".join([f"{k}=={v:.4f}" for k, v in metrics.items()])
diff --git a/experiments/oxfordPets/train_oxfordPets.py b/experiments/oxfordPets/train_oxfordPets.py
index a1d210a..372691c 100644
--- a/experiments/oxfordPets/train_oxfordPets.py
+++ b/experiments/oxfordPets/train_oxfordPets.py
@@ -26,25 +26,25 @@
class OxfordPetsTrainer(fr.train.ClassifierTrainer):
- def forward(self, batch, stage):
+ def _forward(self, batch, stage):
inputs, _ = batch
inputs = inputs.to(self.device)
return self.modules.model(inputs)
- def compute_loss(self, predictions, batch, stage):
+ def _compute_loss(self, predictions, batch, stage):
_, labels = batch
labels = labels.to(self.device)
loss = self.hparams["loss"](predictions, labels)
self.hparams["metrics"].update(predictions, labels)
return loss
- def on_stage_start(self, stage, loss=None, epoch=None):
+ def _on_stage_start(self, stage, loss=None, epoch=None):
self.hparams["metrics"].reset()
if self.debug:
metrics = self.hparams["metrics"].get_metrics(mode="mean")
logger.debug(metrics)
- def on_stage_end(self, stage, loss=None, epoch=None):
+ def _on_stage_end(self, stage, loss=None, epoch=None):
metrics = self.hparams["metrics"].get_metrics(mode="mean")
metrics_string = "".join([f"{k}={v:.4f}" for k, v in metrics.items()])
diff --git a/experiments/places365/train_places365.py b/experiments/places365/train_places365.py
index d3e49d1..b7c632d 100644
--- a/experiments/places365/train_places365.py
+++ b/experiments/places365/train_places365.py
@@ -31,12 +31,12 @@
class PlacesTrainer(fr.train.ClassifierTrainer):
- def forward(self, batch, stage):
+ def _forward(self, batch, stage):
inputs, _ = batch
inputs = inputs.to(self.device)
return self.modules.model(inputs)
- def compute_loss(self, predictions, batch, stage):
+ def _compute_loss(self, predictions, batch, stage):
_, labels = batch
labels = labels.to(self.device)
loss = self.hparams["loss"](predictions, labels)
@@ -44,10 +44,10 @@ def compute_loss(self, predictions, batch, stage):
self.hparams["metrics"].update(predictions, labels)
return loss
- def on_stage_start(self, stage, loss=None, epoch=None):
+ def _on_stage_start(self, stage, loss=None, epoch=None):
self.hparams["metrics"].reset()
- def on_stage_end(self, stage, loss=None, epoch=None):
+ def _on_stage_end(self, stage, loss=None, epoch=None):
if stage == Stage.VALID:
metrics = self.hparams["metrics"].get_metrics(mode="mean")
metrics_string = "".join([f"{k}=={v:.4f}" for k, v in metrics.items()])
@@ -62,7 +62,7 @@ def on_stage_end(self, stage, loss=None, epoch=None):
**metrics, epoch=self.current_epoch, current_step=self.step
)
- def save_intra_epoch_ckpt(self):
+ def _save_intra_epoch_ckpt(self):
if self.checkpointer is not None:
self.checkpointer.save(
epoch=self.current_epoch, current_step=self.step, intra_epoch=True
diff --git a/frarch/__meta__.py b/frarch/__meta__.py
index 00a7fc4..355faba 100644
--- a/frarch/__meta__.py
+++ b/frarch/__meta__.py
@@ -1,12 +1,8 @@
-# `name` is the name of the package as used for `pip install package`
name = "frarch"
-# `path` is the name of the package for `import package`
path = name.lower().replace("-", "_").replace(" ", "_")
-# Version number should follow https://python.org/dev/peps/pep-0440 and
-# https://semver.org
-version = "0.1.3"
+version = "0.1.5"
author = "vbadenas"
author_email = "victor.badenas@gmail.com"
-description = "Training Framework for PyTorch projects" # One-liner
-url = "https://github.com/vbadenas/frarch" # your project homepage
-license = "MIT" # See https://choosealicense.com
+description = "Training Framework for PyTorch projects"
+url = "https://github.com/victorbadenas/frarch"
+license = "Apache 2.0"
diff --git a/frarch/datasets/caltech101.py b/frarch/datasets/caltech101.py
index 3033dea..9621e64 100644
--- a/frarch/datasets/caltech101.py
+++ b/frarch/datasets/caltech101.py
@@ -3,8 +3,9 @@
import random
from collections import Counter
from pathlib import Path
-from typing import Callable, Union
+from typing import Callable, Iterable, List, Mapping, Tuple, Union
+import torch
from PIL import Image
from torch.utils.data import Dataset
@@ -14,13 +15,43 @@
class Caltech101(Dataset):
+ """Caltech 101 dataset object.
+
+ Data loader for the Caltech 101 dataset for object classification. The dataset can
+ be obtained from
+ http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz.
+
+ Args:
+ subset (str): "train" or "valid". Subset to load. Defaults to "train".
+ transform (Callable): a callable object that takes an `PIL.Image` object and
+ returns a modified `PIL.Image` object. Defaults to None, which won't apply
+ any transformation.
+ target_transform (Callable): a callable object that the label data and returns
+ modified label data. Defaults to None, which won't apply any transformation.
+ root (Union[str, Path]): root directory for the dataset. Defaults to `./data/`.
+
+ References:
+ - http://www.vision.caltech.edu/Image_Datasets/Caltech101/
+
+ Examples:
+ Simple usage of the dataset class::
+
+ from frarch.datasets import Caltech101
+ from frarch.utils.data import create_dataloader
+ from torchvision.transforms import ToTensor
+ dataset = Caltech101("train", ToTensor, None, "./data/")
+ dataloader = create_dataloader(dataset)
+ for batch_idx, (batch, labels) in enumerate(dataloader):
+ # process batch
+ """
+
def __init__(
self,
subset: str = "train",
transform: Callable = None,
target_transform: Callable = None,
root: Union[str, Path] = "./data/",
- ):
+ ) -> None:
if subset not in ["train", "valid"]:
raise ValueError(f"set must be train or test not {subset}")
@@ -48,7 +79,7 @@ def __init__(
f" in {len(self.classes)} classes"
)
- def __getitem__(self, index):
+ def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
path, target = self.images[index]
img = Image.open(path).convert("RGB")
if self.transform is not None:
@@ -57,42 +88,47 @@ def __getitem__(self, index):
target = self.target_transform(target)
return img, target
- def __len__(self):
+ def __len__(self) -> int:
return len(self.images)
- def get_number_classes(self):
+ def get_number_classes(self) -> int:
+ """Get number of target labels.
+
+ Returns:
+ int: number of target labels.
+ """
return len(self.classes)
- def _detect_dataset(self):
+ def _detect_dataset(self) -> bool:
if not self.root.exists():
return False
else:
num_images = len(self._get_file_paths())
return num_images > 0
- def _build_and_load_lst(self):
+ def _build_and_load_lst(self) -> None:
all_paths = self._get_file_paths()
self._build_and_load_class_map(all_paths)
self._load_train_test_files(all_paths)
- def _build_and_load_class_map(self, all_paths):
+ def _build_and_load_class_map(self, all_paths: List[Path]) -> None:
if not self.mapper_path.exists():
self._build_class_mapper(all_paths)
self._load_class_map()
- def _build_class_mapper(self, all_paths):
+ def _build_class_mapper(self, all_paths: List[Path]) -> None:
classes_set = set(map(lambda path: path.parts[-2], all_paths))
logger.info(f"found {len(classes_set)} classes.")
class_mapper = dict(zip(classes_set, range(len(classes_set))))
logger.info(f"class mapper built: {class_mapper}")
self._dump_class_map(class_mapper)
- def _load_train_test_files(self, all_paths):
+ def _load_train_test_files(self, all_paths: List[Path]) -> None:
if not self.train_lst_path.exists() and not self.valid_lst_path.exists():
self._build_train_test_files(all_paths)
self._load_set()
- def _build_train_test_files(self, all_paths):
+ def _build_train_test_files(self, all_paths: List[Path]) -> None:
classes_list = list(map(lambda path: path.parts[-2], all_paths))
instance_counter = Counter(classes_list)
@@ -132,16 +168,16 @@ def _build_train_test_files(self, all_paths):
self._write_lst(self.valid_lst_path, valid_instances)
@staticmethod
- def _write_lst(path, instances):
+ def _write_lst(path: Path, instances: Iterable) -> None:
with path.open("w") as f:
for line in instances:
f.write(",".join(map(str, line)) + "\n")
- def _get_file_paths(self):
+ def _get_file_paths(self) -> List[Path]:
all_files = list(self.root.glob("*/*.jpg"))
return list(filter(lambda x: x.parts[-2] != "BACKGROUND_Google", all_files))
- def _load_set(self):
+ def _load_set(self) -> None:
path = self.train_lst_path if self.set == "train" else self.valid_lst_path
with path.open("r") as f:
self.images = []
@@ -149,10 +185,10 @@ def _load_set(self):
path, label = line.strip().split(",")
self.images.append((path, int(label)))
- def _dump_class_map(self, class_mapper):
+ def _dump_class_map(self, class_mapper: Mapping) -> None:
with self.mapper_path.open("w") as f:
json.dump(class_mapper, f)
- def _load_class_map(self):
+ def _load_class_map(self) -> None:
with self.mapper_path.open("r") as f:
self.classes = json.load(f)
diff --git a/frarch/datasets/mit67.py b/frarch/datasets/mit67.py
index e9be64a..31bfdbb 100644
--- a/frarch/datasets/mit67.py
+++ b/frarch/datasets/mit67.py
@@ -4,9 +4,10 @@
import tarfile
from collections import Counter
from pathlib import Path
-from typing import Callable, Union
+from typing import Callable, List, Union
from urllib.parse import urlparse
+import torch
from PIL import Image
from torch.utils.data import Dataset
@@ -21,6 +22,41 @@
class Mit67(Dataset):
+ """Mit 67 dataset object.
+
+ Data loader for the Mit 67 dataset for indoor scene recognition. The dataset can
+ be obtained from
+ http://groups.csail.mit.edu/vision/LabelMe/NewImages/indoorCVPR_09.tar.
+
+ Args:
+ train (bool): True for loading the train subset and False for valid. Defaults
+ to True.
+ transform (Callable): a callable object that takes an `PIL.Image` object and
+ returns a modified `PIL.Image` object. Defaults to None, which won't apply
+ any transformation.
+ target_transform (Callable): a callable object that the label data and returns
+ modified label data. Defaults to None, which won't apply any transformation.
+ download (bool): True for downloading and storing the dataset data in the `root`
+ directory if it's not present. Defaults to True.
+ root (Union[str, Path]): root directory for the dataset.
+ Defaults to `~/.cache/frarch/datasets/mit67/`.
+
+ References:
+ - http://web.mit.edu/torralba/www/indoor.html
+
+ Examples:
+ Simple usage of the dataset class::
+
+ from frarch.datasets import Mit67
+ from frarch.utils.data import create_dataloader
+ from torchvision.transforms import ToTensor
+
+ dataset = Mit67(True, ToTensor, None, True, "./data/")
+ dataloader = create_dataloader(dataset)
+ for batch_idx, (batch, labels) in enumerate(dataloader):
+ # process batch
+ """
+
def __init__(
self,
train: bool = True,
@@ -28,7 +64,7 @@ def __init__(
target_transform: Callable = None,
download: bool = True,
root: Union[str, Path] = "~/.cache/frarch/datasets/mit67/",
- ):
+ ) -> None:
self.root = Path(root).expanduser()
self.set = "train" if train else "test"
self.transform = transform
@@ -39,7 +75,7 @@ def __init__(
self.mapper_path = self.root / "class_map.json"
if download and not self._detect_dataset():
- self.download_mit_dataset()
+ self._download_mit_dataset()
if not self._detect_dataset():
raise DatasetNotFoundError(
f"download flag not set and dataset not present in {self.root}"
@@ -52,7 +88,7 @@ def __init__(
f" in {len(self.classes)} classes"
)
- def __getitem__(self, index):
+ def __getitem__(self, index: int) -> Union[torch.Tensor, int]:
path, target = self.images[index]
img = Image.open(path).convert("RGB")
if self.transform is not None:
@@ -61,13 +97,18 @@ def __getitem__(self, index):
target = self.target_transform(target)
return img, target
- def __len__(self):
+ def __len__(self) -> int:
return len(self.images)
- def get_number_classes(self):
+ def get_number_classes(self) -> int:
+ """Get number of target labels.
+
+ Returns:
+ int: number of target labels.
+ """
return len(self.classes)
- def download_mit_dataset(self):
+ def _download_mit_dataset(self) -> None:
self.root.mkdir(parents=True, exist_ok=True)
# download train/val images/annotations
@@ -87,28 +128,28 @@ def download_mit_dataset(self):
logger.info("[dataset] Done!")
cached_file.unlink()
- def _get_file_paths(self):
+ def _get_file_paths(self) -> List[Path]:
return list(self.root.glob("Images/*/*.jpg"))
- def _detect_dataset(self):
+ def _detect_dataset(self) -> bool:
if not self.root.exists():
return False
else:
num_images = len(self._get_file_paths())
return num_images > 0
- def _build_and_load_data_files(self):
+ def _build_and_load_data_files(self) -> None:
all_paths = self._get_file_paths()
self._load_class_map(all_paths)
self._load_train_test_files(all_paths)
- def _load_class_map(self, all_paths):
+ def _load_class_map(self, all_paths: List[Path]) -> None:
if not self.mapper_path.exists():
self._build_class_mapper(all_paths)
with self.mapper_path.open("r") as f:
self.classes = json.load(f)
- def _build_class_mapper(self, all_paths):
+ def _build_class_mapper(self, all_paths: List[Path]) -> None:
classes_set = set(map(lambda path: path.parts[-2], all_paths))
logger.info(f"found {len(classes_set)} classes.")
class_mapper = dict(zip(classes_set, range(len(classes_set))))
@@ -116,12 +157,12 @@ def _build_class_mapper(self, all_paths):
with self.mapper_path.open("w") as f:
json.dump(class_mapper, f)
- def _load_train_test_files(self, all_paths):
+ def _load_train_test_files(self, all_paths: List[Path]) -> None:
if not self.train_lst_path.exists() and not self.valid_lst_path.exists():
self._build_train_test_files(all_paths)
self._load_set(self.set)
- def _build_train_test_files(self, all_paths):
+ def _build_train_test_files(self, all_paths: List[Path]) -> None:
classes_list = list(map(lambda path: path.parts[-2], all_paths))
instance_counter = Counter(classes_list)
@@ -164,7 +205,7 @@ def _build_train_test_files(self, all_paths):
for line in valid_instances:
f.write(",".join(map(str, line)) + "\n")
- def _load_set(self, set):
+ def _load_set(self, set: str) -> None:
path = self.train_lst_path if set == "train" else self.valid_lst_path
with path.open("r") as f:
self.images = []
diff --git a/frarch/datasets/oxford_pets.py b/frarch/datasets/oxford_pets.py
index 4669e2a..5fdb337 100644
--- a/frarch/datasets/oxford_pets.py
+++ b/frarch/datasets/oxford_pets.py
@@ -1,9 +1,10 @@
import logging
import tarfile
from pathlib import Path
-from typing import Callable, Union
+from typing import Callable, List, Union
from urllib.parse import urlparse
+import torch
from PIL import Image
from torch.utils.data import Dataset
@@ -19,6 +20,41 @@
class OxfordPets(Dataset):
+ """Oxford Pets dataset object.
+
+ Data loader for the Oxford Pets dataset for pet recognition. The dataset can be
+ obtained from https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz and
+ their corresponding labels in
+ https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz.
+
+ Args:
+ subset (str): "train" or "valid". Subset to load. Defaults to "train".
+ transform (Callable): a callable object that takes an `PIL.Image` object and
+ returns a modified `PIL.Image` object. Defaults to None, which won't apply
+ any transformation.
+ target_transform (Callable): a callable object that the label data and returns
+ modified label data. Defaults to None, which won't apply any transformation.
+ download (bool): True for downloading and storing the dataset data in the `root`
+ directory if it's not present. Defaults to True.
+ root (Union[str, Path]): root directory for the dataset.
+ Defaults to `~/.cache/frarch/datasets/oxford_pets/`.
+
+ References:
+ - https://www.robots.ox.ac.uk/~vgg/data/pets/
+
+ Examples:
+ Simple usage of the dataset class::
+
+ from frarch.datasets import Mit67
+ from frarch.utils.data import create_dataloader
+ from torchvision.transforms import ToTensor
+
+ dataset = OxfordPets( "train", ToTensor, None, True, "./data/")
+ dataloader = create_dataloader(dataset)
+ for batch_idx, (batch, labels) in enumerate(dataloader):
+ # process batch
+ """
+
def __init__(
self,
subset: str = "train",
@@ -26,7 +62,7 @@ def __init__(
target_transform: Callable = None,
download: bool = True,
root: Union[str, Path] = "~/.cache/frarch/datasets/oxford_pets/",
- ):
+ ) -> None:
if subset not in ["train", "valid"]:
raise ValueError(f"set must be train or test not {subset}")
@@ -41,8 +77,8 @@ def __init__(
self.valid_lst_path = self.root / "annotations" / "test.txt"
if download and not self._detect_dataset():
- self.download_dataset()
- self.download_annotations()
+ self._download_dataset()
+ self._download_annotations()
if not self._detect_dataset():
raise DatasetNotFoundError(
f"download flag not set and dataset not present in {self.root}"
@@ -55,7 +91,7 @@ def __init__(
f" in {len(self.classes)} classes"
)
- def __getitem__(self, index):
+ def __getitem__(self, index: int) -> Union[torch.Tensor, int]:
path, target = self.images[index]
img = Image.open(path).convert("RGB")
if self.transform is not None:
@@ -64,19 +100,24 @@ def __getitem__(self, index):
target = self.target_transform(target)
return img, target
- def __len__(self):
+ def __len__(self) -> int:
return len(self.images)
- def get_number_classes(self):
+ def get_number_classes(self) -> int:
+ """Get number of target labels.
+
+ Returns:
+ int: number of target labels.
+ """
return len(self.classes)
- def download_annotations(self):
- self.download_file("images")
+ def _download_annotations(self) -> None:
+ self._download_file("images")
- def download_dataset(self):
- self.download_file("annotations")
+ def _download_dataset(self) -> None:
+ self._download_file("annotations")
- def download_file(self, url_key):
+ def _download_file(self, url_key: str) -> None:
self.root.mkdir(parents=True, exist_ok=True)
# download train/val images/annotations
@@ -96,10 +137,10 @@ def download_file(self, url_key):
logger.info(f"Done! Removing dached file {cached_file}...")
cached_file.unlink()
- def _get_file_paths(self):
+ def _get_file_paths(self) -> List[Path]:
return list(self.images_root.glob("*.jpg"))
- def _detect_dataset(self):
+ def _detect_dataset(self) -> bool:
if not self.root.exists():
return False
else:
@@ -109,7 +150,7 @@ def _detect_dataset(self):
)
return num_images > 0 and annotations_present
- def _load_set(self):
+ def _load_set(self) -> None:
path = self.train_lst_path if self.set == "train" else self.valid_lst_path
with path.open("r") as f:
self.images = []
diff --git a/frarch/datasets/transforms/pil_transforms.py b/frarch/datasets/transforms/pil_transforms.py
index 803de84..6d39cb3 100644
--- a/frarch/datasets/transforms/pil_transforms.py
+++ b/frarch/datasets/transforms/pil_transforms.py
@@ -1,5 +1,5 @@
import random
-from typing import Callable, Iterable, Union
+from typing import Callable, Iterable, Optional, Union
import PIL.Image as im
import PIL.ImageEnhance as ie
@@ -12,10 +12,17 @@
class RandomFlip(object):
"""Randomly flips the given PIL.Image.
- Probability of 0.25 horizontal, 0.25 vertical, 0.5 as is
+ Probability of 0.25 horizontal flip, 0.25 vertical flip, 0.5 no flip.
+
+ Example:
+ Simple usage of the class::
+
+ image = PIL.Image.open("testimage.jpg")
+ random_flip = RandomFlip()
+ processed_image = random_flip(image)
"""
- def __call__(self, img: Image):
+ def __call__(self, img: Image) -> Image:
if not isinstance(img, Image):
raise ValueError(f"img is {type(img)} not a PIL.Image object")
@@ -33,9 +40,16 @@ class RandomRotate(object):
"""Randomly rotate the given PIL.Image.
Probability of 1/6 90°, 1/6 180°, 1/6 270°, 1/2 as is.
+
+ Example:
+ Simple usage of the class::
+
+ image = PIL.Image.open("testimage.jpg")
+ random_rotate = RandomRotate()
+ processed_image = random_rotate(image)
"""
- def __call__(self, img: Image):
+ def __call__(self, img: Image) -> Image:
if not isinstance(img, Image):
raise ValueError(f"img is {type(img)} not a PIL.Image object")
@@ -52,7 +66,23 @@ def __call__(self, img: Image):
class PILColorBalance(object):
- def __init__(self, var: float):
+ """Randomly dim or enhance color of an image.
+
+ Randomly dim or enhance color of an image given as an input. Given var value,
+ enhance color from a 1-var factor to a 1+var factor.
+
+ Args:
+ var (float): float value to get random alpha value from [1-var, 1+var].
+
+ Example:
+ Simple usage of the class::
+
+ image = PIL.Image.open("testimage.jpg")
+ random_color_balance = PILColorBalance()
+ processed_image = random_color_balance(image)
+ """
+
+ def __init__(self, var: float) -> None:
if not isinstance(var, float):
raise ValueError(f"{self.__class__.__name__}.var must be a float value")
if var < 0 or var > 1:
@@ -61,7 +91,7 @@ def __init__(self, var: float):
)
self.var = var
- def __call__(self, img: Image):
+ def __call__(self, img: Image) -> Image:
if not isinstance(img, Image):
raise ValueError(f"img is {type(img)} not a PIL.Image object")
@@ -70,7 +100,23 @@ def __call__(self, img: Image):
class PILContrast(object):
- def __init__(self, var: float):
+ """Randomly dim or enhance contrast of an image.
+
+ Randomly dim or enhance contrast of an image given as an input. Given var value,
+ enhance contrast from a 1-var factor to a 1+var factor.
+
+ Args:
+ var (float): float value to get random alpha value from [1-var, 1+var].
+
+ Example:
+ Simple usage of the class::
+
+ image = PIL.Image.open("testimage.jpg")
+ random_contrast = PILContrast()
+ processed_image = random_contrast(image)
+ """
+
+ def __init__(self, var: float) -> None:
if not isinstance(var, float):
raise ValueError(f"{self.__class__.__name__}.var must be a float value")
if var < 0 or var > 1:
@@ -79,7 +125,7 @@ def __init__(self, var: float):
)
self.var = var
- def __call__(self, img: Image):
+ def __call__(self, img: Image) -> Image:
if not isinstance(img, Image):
raise ValueError(f"img is {type(img)} not a PIL.Image object")
@@ -88,7 +134,23 @@ def __call__(self, img: Image):
class PILBrightness(object):
- def __init__(self, var: float):
+ """Randomly dim or enhance brightness of an image.
+
+ Randomly dim or enhance brightness of an image given as an input. Given var value,
+ enhance brightness from a 1-var factor to a 1+var factor.
+
+ Args:
+ var (float): float value to get random alpha value from [1-var, 1+var].
+
+ Example:
+ Simple usage of the class::
+
+ image = PIL.Image.open("testimage.jpg")
+ random_brightness = PILBrightness()
+ processed_image = random_brightness(image)
+ """
+
+ def __init__(self, var: float) -> None:
if not isinstance(var, float):
raise ValueError(f"{self.__class__.__name__}.var must be a float value")
if var < 0 or var > 1:
@@ -97,7 +159,7 @@ def __init__(self, var: float):
)
self.var = var
- def __call__(self, img: Image):
+ def __call__(self, img: Image) -> Image:
if not isinstance(img, Image):
raise ValueError(f"img is {type(img)} not a PIL.Image object")
@@ -106,7 +168,23 @@ def __call__(self, img: Image):
class PILSharpness(object):
- def __init__(self, var: float):
+ """Randomly dim or enhance sharpness of an image.
+
+ Randomly dim or enhance sharpness of an image given as an input. Given var value,
+ enhance sharpness from a 1-var factor to a 1+var factor.
+
+ Args:
+ var (float): float value to get random alpha value from [1-var, 1+var].
+
+ Example:
+ Simple usage of the class::
+
+ image = PIL.Image.open("testimage.jpg")
+ random_sharpness = PILSharpness()
+ processed_image = random_sharpness(image)
+ """
+
+ def __init__(self, var: float) -> None:
if not isinstance(var, float):
raise ValueError(f"{self.__class__.__name__}.var must be a float value")
if var < 0 or var > 1:
@@ -115,7 +193,7 @@ def __init__(self, var: float):
)
self.var = var
- def __call__(self, img: Image):
+ def __call__(self, img: Image) -> Image:
if not isinstance(img, Image):
raise ValueError(f"img is {type(img)} not a PIL.Image object")
@@ -124,9 +202,30 @@ def __call__(self, img: Image):
class RandomOrder(object):
- """Composes several transforms together in random order."""
+ """Composes several transforms together in random order.
+
+ Args:
+ transforms (Iterable[Callable]): An iterable containing callable objects
+ which take PIL.Image as input and outputs the same object type.
+
+ Example:
+ Simple usage of the class::
+
+ image = PIL.Image.open("testimage.jpg")
+ random_sharpness = RandomOrder(
+ [
+ RandomFlip(),
+ RandomRotate(),
+ PILColorBalance(0.1),
+ PILContrast(0.1),
+ PILBrightness(0.1),
+ PILSharpness(0.1),
+ ]
+ )
+ processed_image = random_sharpness(image)
+ """
- def __init__(self, transforms: Union[Iterable[Callable], NoneType]):
+ def __init__(self, transforms: Optional[Iterable[Callable]]) -> None:
if not isinstance(transforms, (Iterable, NoneType)):
raise ValueError("transforms must be an iterable object")
if transforms is not None:
@@ -136,7 +235,7 @@ def __init__(self, transforms: Union[Iterable[Callable], NoneType]):
raise ValueError("all objects in transforms must be callable")
self.transforms = transforms
- def __call__(self, img: Image):
+ def __call__(self, img: Image) -> Image:
if not isinstance(img, Image):
raise ValueError(f"img is {type(img)} not a PIL.Image object")
@@ -149,6 +248,20 @@ def __call__(self, img: Image):
class PowerPIL(RandomOrder):
+ """Composes several transforms together in random order.
+
+ Args:
+ transforms (Iterable[Callable]): An iterable containing callable objects
+ which take PIL.Image as input and outputs the same object type.
+
+ Example:
+ Simple usage of the class::
+
+ image = PIL.Image.open("testimage.jpg")
+ powerpil = PowerPIL(True, True, 0.1, 0.1, 0.1, 0.1)
+ processed_image = powerpil(image)
+ """
+
def __init__(
self,
rotate: bool = True,
@@ -157,7 +270,7 @@ def __init__(
contrast: float = 0.4,
brightness: float = 0.4,
sharpness: float = 0.4,
- ):
+ ) -> None:
self._check_parameters(
rotate, flip, colorbalance, contrast, brightness, sharpness
)
@@ -176,7 +289,9 @@ def __init__(
self.transforms.append(PILSharpness(sharpness))
@staticmethod
- def _check_parameters(rotate, flip, colorbalance, contrast, brightness, sharpness):
+ def _check_parameters(
+ rotate, flip, colorbalance, contrast, brightness, sharpness
+ ) -> None:
if not isinstance(rotate, bool):
raise ValueError("rotate must be boolean")
if not isinstance(flip, bool):
diff --git a/frarch/models/classification/cnn/__init__.py b/frarch/models/classification/cnn/__init__.py
index 779569d..f532d2e 100644
--- a/frarch/models/classification/cnn/__init__.py
+++ b/frarch/models/classification/cnn/__init__.py
@@ -1,6 +1,6 @@
from .fashionCNN import FashionClassifier, FashionCNN
+from .mitCNNs import MitCNN, MitCNNClassifier
from .mnistCNN import MNISTCNN, MNISTClassifier
-from .smallCNNs import MitCNN, MitCNNClassifier
from .vgg import (
VGG,
VGGClassifier,
diff --git a/frarch/models/classification/cnn/fashionCNN.py b/frarch/models/classification/cnn/fashionCNN.py
index a2747f5..3c2b927 100644
--- a/frarch/models/classification/cnn/fashionCNN.py
+++ b/frarch/models/classification/cnn/fashionCNN.py
@@ -1,8 +1,16 @@
+import torch
import torch.nn as nn
class FashionCNN(nn.Module):
- def __init__(self, out_size=128):
+ """Small CNN network for FashionMNIST dataset.
+
+ Args:
+ out_size (int): Size of the output embedding for the feature extraction
+ network. Defaults to 128.
+ """
+
+ def __init__(self, out_size: int = 128) -> None:
super(FashionCNN, self).__init__()
self.layer1 = nn.Sequential(
@@ -24,7 +32,17 @@ def __init__(self, out_size=128):
self.fc2 = nn.Linear(in_features=512, out_features=out_size)
self.relu = nn.ReLU()
- def forward(self, x):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Define the computation performed at every call.
+
+ forward computation for FashionCNN.
+
+ Args:
+ x (torch.Tensor): input to the model.
+
+ Returns:
+ torch.Tensor: output of the model.
+ """
out = self.layer1(x)
out = self.layer2(out)
out = out.view(out.size(0), -1)
@@ -36,9 +54,26 @@ def forward(self, x):
class FashionClassifier(nn.Module):
- def __init__(self, embedding_size=128, classes=10):
+ """Classifier network for FashionCNN.
+
+ Args:
+ embedding_size (int): embedding size from FashionCNN network. Defaults to 128.
+ classes (int): number of output classes for the classifier. Defaults to 10.
+ """
+
+ def __init__(self, embedding_size: int = 128, classes: int = 10) -> None:
super(FashionClassifier, self).__init__()
self.fc = nn.Linear(in_features=embedding_size, out_features=classes)
- def forward(self, x):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Define the computation performed at every call.
+
+ forward computation for FashionCNN.
+
+ Args:
+ x (torch.Tensor): input to the model.
+
+ Returns:
+ torch.Tensor: output of the model.
+ """
return self.fc(x)
diff --git a/frarch/models/classification/cnn/smallCNNs.py b/frarch/models/classification/cnn/mitCNNs.py
similarity index 51%
rename from frarch/models/classification/cnn/smallCNNs.py
rename to frarch/models/classification/cnn/mitCNNs.py
index 01a65d2..9b6faf0 100644
--- a/frarch/models/classification/cnn/smallCNNs.py
+++ b/frarch/models/classification/cnn/mitCNNs.py
@@ -1,9 +1,19 @@
+import torch
import torch.nn as nn
import torch.nn.functional as F
class MitCNN(nn.Module):
- def __init__(self, input_channels=1, embedding_size=256):
+ """Small CNN network for Mit67 dataset.
+
+ Args:
+ input_channels (int): Number of input channels in the input tensor.
+ Defaults to 1.
+ embedding_size (int): Size of the output embedding for the feature extraction
+ network. Defaults to 256.
+ """
+
+ def __init__(self, input_channels: int = 1, embedding_size: int = 256) -> None:
super(MitCNN, self).__init__()
self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding="same")
self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding="same")
@@ -12,7 +22,17 @@ def __init__(self, input_channels=1, embedding_size=256):
self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding="same")
self.fc1 = nn.Linear(256, embedding_size)
- def forward(self, x):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Define the computation performed at every call.
+
+ forward computation for MitCNN.
+
+ Args:
+ x (torch.Tensor): input to the model.
+
+ Returns:
+ torch.Tensor: output of the model.
+ """
x = F.relu(self.conv1(x))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = F.dropout(x, p=0.5, training=self.training)
@@ -29,10 +49,27 @@ def forward(self, x):
class MitCNNClassifier(nn.Module):
- def __init__(self, embedding_size=256, num_classes=10):
+ """Classifier network for MitCNN.
+
+ Args:
+ embedding_size (int): embedding size from MitCNN network. Defaults to 256.
+ classes (int): number of output classes for the classifier. Defaults to 10.
+ """
+
+ def __init__(self, embedding_size: int = 256, num_classes: int = 10) -> None:
super(MitCNNClassifier, self).__init__()
self.fc2 = nn.Linear(embedding_size, num_classes)
- def forward(self, x):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Define the computation performed at every call.
+
+ forward computation for MitCNNClassifier.
+
+ Args:
+ x (torch.Tensor): input to the model.
+
+ Returns:
+ torch.Tensor: output of the model.
+ """
x = self.fc2(x)
return F.log_softmax(x, dim=1)
diff --git a/frarch/models/classification/cnn/mnistCNN.py b/frarch/models/classification/cnn/mnistCNN.py
index 5bc5dfb..576d69d 100644
--- a/frarch/models/classification/cnn/mnistCNN.py
+++ b/frarch/models/classification/cnn/mnistCNN.py
@@ -1,16 +1,36 @@
+import torch
import torch.nn as nn
import torch.nn.functional as F
class MNISTCNN(nn.Module):
- def __init__(self, input_channels=1, embedding_size=256):
+ """Small CNN network for MNIST dataset.
+
+ Args:
+ input_channels (int): Number of input channels in the input tensor.
+ Defaults to 1.
+ embedding_size (int): Size of the output embedding for the feature extraction
+ network. Defaults to 256.
+ """
+
+ def __init__(self, input_channels: int = 1, embedding_size: int = 256) -> None:
super(MNISTCNN, self).__init__()
self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=10, padding="same")
self.conv2 = nn.Conv2d(32, 32, kernel_size=10, padding="same")
self.conv3 = nn.Conv2d(32, 64, kernel_size=10, padding="same")
self.fc1 = nn.Linear(64 * 7 * 7, embedding_size)
- def forward(self, x):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Define the computation performed at every call.
+
+ forward computation for MNISTCNN.
+
+ Args:
+ x (torch.Tensor): input to the model.
+
+ Returns:
+ torch.Tensor: output of the model.
+ """
x = F.relu(self.conv1(x))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = F.dropout(x, p=0.5, training=self.training)
@@ -23,10 +43,27 @@ def forward(self, x):
class MNISTClassifier(nn.Module):
- def __init__(self, embedding_size=256, num_classes=10):
+ """Classifier network for MNISTCNN.
+
+ Args:
+ embedding_size (int): embedding size from MNISTCNN network. Defaults to 256.
+ classes (int): number of output classes for the classifier. Defaults to 10.
+ """
+
+ def __init__(self, embedding_size: int = 256, num_classes: int = 10) -> None:
super(MNISTClassifier, self).__init__()
self.fc2 = nn.Linear(embedding_size, num_classes)
- def forward(self, x):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Define the computation performed at every call.
+
+ forward computation for MNISTClassifier.
+
+ Args:
+ x (torch.Tensor): input to the model.
+
+ Returns:
+ torch.Tensor: output of the model.
+ """
x = self.fc2(x)
return F.log_softmax(x, dim=1)
diff --git a/frarch/models/classification/cnn/resnet.py b/frarch/models/classification/cnn/resnet.py
index f875d25..eb71a7f 100644
--- a/frarch/models/classification/cnn/resnet.py
+++ b/frarch/models/classification/cnn/resnet.py
@@ -158,6 +158,23 @@ def forward(self, x: Tensor) -> Tensor:
class ResNet(nn.Module):
+ """Residual network architecture model.
+
+ Args:
+ block (ResNetBlock): basic construction block for architecture.
+ layers (List[int]): layer configuration. I.e. [2, 2, 2, 2].
+ num_classes (int)): output classes. Defaults to 1000.
+ zero_init_residual (bool): True to initialize the residual layers with zeros.
+ Default False.
+ groups (int): groups for convolutional layers in residual blocks.
+ Defaults to 1.
+ width_per_group (int): Width for convolutional residual blocks. Defaults to 64.
+ replace_stride_with_dilation (Optional[List[bool]]): Optional list of boolean
+ values to replace stride with dilation. Defaults to None.
+ norm_layer (Optional[Callable[..., nn.Module]]): Norm layer. If None, defaults
+ to BatchNorm2d. Defautls to None.
+ """
+
def __init__(
self,
block: ResNetBlock,
@@ -273,7 +290,6 @@ def _make_layer(
return nn.Sequential(*layers)
def _forward_impl(self, x: Tensor) -> Tensor:
- # See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
@@ -291,6 +307,16 @@ def _forward_impl(self, x: Tensor) -> Tensor:
return x
def forward(self, x: Tensor) -> Tensor:
+ """Define the computation performed at every call.
+
+ forward computation for ResNet.
+
+ Args:
+ x (torch.Tensor): input to the model.
+
+ Returns:
+ torch.Tensor: output of the model.
+ """
return self._forward_impl(x)
@@ -315,8 +341,7 @@ def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
From `"Deep Residual Learning for Image Recognition"
`_.
- Args
- ----
+ Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
@@ -329,8 +354,7 @@ def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
From `"Deep Residual Learning for Image Recognition"
`_.
- Args
- ----
+ Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
@@ -343,8 +367,7 @@ def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
From `"Deep Residual Learning for Image Recognition"
`_.
- Args
- ----
+ Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
@@ -357,8 +380,7 @@ def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
From `"Deep Residual Learning for Image Recognition"
`_.
- Args
- ----
+ Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
@@ -373,8 +395,7 @@ def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
From `"Deep Residual Learning for Image Recognition"
`_.
- Args
- ----
+ Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
@@ -391,8 +412,7 @@ def resnext50_32x4d(
From `"Aggregated Residual Transformation for Deep Neural Networks"
`_.
- Args
- ----
+ Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
@@ -411,8 +431,7 @@ def resnext101_32x8d(
From `"Aggregated Residual Transformation for Deep Neural Networks"
`_.
- Args
- ----
+ Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
@@ -434,8 +453,7 @@ def wide_resnet50_2(
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
- Args
- ----
+ Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
@@ -456,8 +474,7 @@ def wide_resnet101_2(
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
- Args
- ----
+ Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
diff --git a/frarch/models/classification/cnn/vgg.py b/frarch/models/classification/cnn/vgg.py
index 3f0d542..c2e871d 100644
--- a/frarch/models/classification/cnn/vgg.py
+++ b/frarch/models/classification/cnn/vgg.py
@@ -12,15 +12,14 @@
import logging
from collections import OrderedDict
-from typing import Any, Dict, List, Union, cast
+from typing import Any, Dict, Iterable, List, Mapping, Union, cast
import torch
import torch.nn as nn
+import torchvision
from torch.hub import load_state_dict_from_url
-__all__ = [
- "VGG",
-]
+__all__ = torchvision.models.vgg.__all__
l11 = [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"]
l13 = [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"]
@@ -88,26 +87,48 @@
class VGG(nn.Module):
+ """VGG network definition.
+
+ From `"Very Deep Convolutional Networks for Large-Scale Image Recognition"
+ `_.
+
+ Args:
+ layers_cfg (List[Union[str, int]]): vgg layer configuration.
+ batch_norm (bool, optional): Boolean flag for doing batch normalization.
+ Defaults to True.
+ init_weights (bool, optional): Force weight initialization.
+ Defaults to True.
+ pretrained (bool, optional): Get pretrained model. Defaults to False.
+ """
+
def __init__(
self,
layers_cfg: List[Union[str, int]],
batch_norm: bool = True,
init_weights: bool = True,
pretrained: bool = False,
- padding: str = "same",
) -> None:
super(VGG, self).__init__()
self.layers_cfg = layers_cfg
self.batch_norm = batch_norm
- self.padding = padding
- self.features = self.make_layers(layers_cfg, batch_norm)
- self.avgpool = nn.AdaptiveAvgPool2d((7, 7) if padding == "same" else (3, 3))
+ self.features = self._make_layers(layers_cfg, batch_norm)
+ self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
if init_weights:
self._initialize_weights()
elif pretrained:
self._load_pretrained(layers_cfg, batch_norm)
def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Define the computation performed at every call.
+
+ forward computation for VGG.
+
+ Args:
+ x (torch.Tensor): input to the model.
+
+ Returns:
+ torch.Tensor: output of the model.
+ """
x = self.features(x)
x = self.avgpool(x)
return torch.flatten(x, 1)
@@ -125,7 +146,7 @@ def _initialize_weights(self) -> None:
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
- def make_layers(
+ def _make_layers(
self, layers_cfg: List[Union[str, int]], batch_norm: bool = False
) -> nn.Sequential:
layers: List[nn.Module] = []
@@ -135,7 +156,7 @@ def make_layers(
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
else:
v = cast(int, v)
- conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=self.padding)
+ conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding="same")
layers.append(conv2d)
if batch_norm:
layers.append(nn.BatchNorm2d(v))
@@ -167,11 +188,24 @@ def _load_pretrained(self, layers_cfg, batch_norm):
class VGGClassifier(nn.Module):
+ """VGG classifier network.
+
+ Args:
+ num_classes ([type]): number of output classes.
+ init_weights (bool, optional): Initialize weights. Defaults to True.
+ pretrained (bool, optional): Pretrained architecture. Defaults to False.
+ arch (str, optional): Architecture to load as pretrained. Defaults to "".
+ """
+
def __init__(
- self, num_classes, init_weights=True, pretrained=False, arch="", padding="same"
- ):
+ self,
+ num_classes: int,
+ init_weights: bool = True,
+ pretrained: bool = False,
+ arch: str = "",
+ ) -> None:
super(VGGClassifier, self).__init__()
- self.in_features = 25088 if padding == "same" else 4608
+ self.in_features = 25088
self.num_classes = num_classes
self.classifier = nn.Sequential(
nn.Linear(self.in_features, 4096),
@@ -200,19 +234,31 @@ def _initialize_weights(self) -> None:
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
- def _load_pretrained(self, arch):
+ def _load_pretrained(self, arch: str) -> None:
state_dict = load_state_dict_from_url(model_urls[arch], progress=True)
_, classifier_state_dict = split_state_dict(
state_dict, "features", "classifier"
)
self.load_state_dict(classifier_state_dict)
- def forward(self, x):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Define the computation performed at every call.
+
+ forward computation for VGG.
+
+ Args:
+ x (torch.Tensor): input to the model.
+
+ Returns:
+ torch.Tensor: output of the model.
+ """
x = self.classifier(x)
return x
-def split_state_dict(state_dict, *search_strings):
+def split_state_dict(
+ state_dict: Mapping, *search_strings: Iterable[str]
+) -> List[Mapping]:
results = [OrderedDict() for _ in range(len(search_strings))]
for i, string in enumerate(search_strings):
for k in state_dict.keys():
@@ -255,96 +301,104 @@ def vggclassifier(
def vgg11(pretrained: bool = False, **kwargs: Any) -> VGG:
- r"""VGG 11-layer model from `_.
+ """Create VGG 11-layer model.
- The required minimum input size of the model is 32x32.
+ VGG 11-layer model from `"Very Deep Convolutional Networks for Large-Scale Image
+ Recognition" `_. The required minimum input
+ size of the model is 32x32.
- Args
- ----
+ Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _vgg("11", False, pretrained, **kwargs)
def vgg11_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
- r"""VGG 11-layer model with batch normalization `_.
+ """Create VGG 11-layer with batch normalization model.
+ VGG 11-layer model with batch normalization from `"Very Deep Convolutional Networks
+ for Large-Scale Image Recognition" `_.
The required minimum input size of the model is 32x32.
- Args
- ----
+ Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _vgg("11", True, pretrained, **kwargs)
def vgg13(pretrained: bool = False, **kwargs: Any) -> VGG:
- r"""VGG 13-layer model `_.
+ """Create VGG 13-layer model.
- The required minimum input size of the model is 32x32.
+ VGG 13-layer model from `"Very Deep Convolutional Networks for Large-Scale Image
+ Recognition" `_. The required minimum input
+ size of the model is 32x32.
- Args
- ----
+ Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _vgg("13", False, pretrained, **kwargs)
def vgg13_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
- r"""VGG 13-layer model with batch normalization `_.
+ """Create VGG 13-layer with batch normalization model.
+ VGG 13-layer model with batch normalization from `"Very Deep Convolutional Networks
+ for Large-Scale Image Recognition" `_.
The required minimum input size of the model is 32x32.
- Args
- ----
+ Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _vgg("13", True, pretrained, **kwargs)
def vgg16(pretrained: bool = False, **kwargs: Any) -> VGG:
- r"""VGG 16-layer model `_.
+ """Create VGG 16-layer model.
- The required minimum input size of the model is 32x32.
+ VGG 16-layer model from `"Very Deep Convolutional Networks for Large-Scale Image
+ Recognition" `_. The required minimum input
+ size of the model is 32x32.
- Args
- ----
+ Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _vgg("16", False, pretrained, **kwargs)
def vgg16_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
- r"""VGG 16-layer model with batch normalization `_.
+ """Create VGG 16-layer with batch normalization model.
+ VGG 16-layer model with batch normalization from `"Very Deep Convolutional Networks
+ for Large-Scale Image Recognition" `_.
The required minimum input size of the model is 32x32.
- Args
- ----
+ Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _vgg("16", True, pretrained, **kwargs)
def vgg19(pretrained: bool = False, **kwargs: Any) -> VGG:
- r"""VGG 19-layer model `_.
+ """Create VGG 19-layer model.
- The required minimum input size of the model is 32x32.
+ VGG 19-layer model from `"Very Deep Convolutional Networks for Large-Scale Image
+ Recognition" `_. The required minimum input
+ size of the model is 32x32.
- Args
- ----
+ Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _vgg("19", False, pretrained, **kwargs)
def vgg19_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
- r"""VGG 19-layer modelith batch normalization `_.
+ """Create VGG 19-layer with batch normalization model.
+ VGG 19-layer model with batch normalization from `"Very Deep Convolutional Networks
+ for Large-Scale Image Recognition" `_.
The required minimum input size of the model is 32x32.
- Args
- ----
+ Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
return _vgg("19", True, pretrained, **kwargs)
diff --git a/frarch/models/mock.py b/frarch/models/mock.py
index 7f17664..219242d 100644
--- a/frarch/models/mock.py
+++ b/frarch/models/mock.py
@@ -2,8 +2,21 @@
class BypassModel(torch.nn.Module):
- def __init__(self):
+ """Bypass mock model.
+
+ Module that returns the same argument that is given.
+ """
+
+ def __init__(self) -> None:
super(BypassModel, self).__init__()
- def forward(self, x):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Compute forward for bypass model.
+
+ Args:
+ x (torch.Tensor): input to the model.
+
+ Returns:
+ torch.Tensor: same tensor as the input.
+ """
return x
diff --git a/frarch/modules/checkpointer.py b/frarch/modules/checkpointer.py
index 4dc31c5..55a113c 100644
--- a/frarch/modules/checkpointer.py
+++ b/frarch/modules/checkpointer.py
@@ -14,25 +14,51 @@
class Checkpointer:
+ """Class for managing checkpoints.
+
+ Args:
+ save_path (Union[str, Path]): folder to store the checkpoint and training
+ data to.
+ modules (Mapping[str, torch.nn.Module]): dict-like structure with modules.
+ save_best_only (bool, optional): If true, save only the best model according to
+ some metric. If True, reference metric should be specified. If False, save
+ all end of epoch checkpoints. Defaults to False.
+ reference_metric (str, optional): Metric to use to determine the best model when
+ save_best_only is True. Must be a string in the keys of the modules
+ dict-like structure. Defaults to None.
+ mode (str, optional): min if lower is better, max if higher is better.
+ Examples: min for error and max for accuracy. Defaults to "min".
+
+ Raises:
+ ValueError: modules are not a dict or torch.nn.ModuleDict instance
+ ValueError: modules in modules dict-like don't have string keys or
+ torch.nn.Module values
+ ValueError: path must be a string or Path object
+ ValueError: metadata key is reserved for the metadata.json object.
+ ValueError: metric mode is not min or max
+ ValueError: save_best_only is True and no metric is defined.
+ """
+
def __init__(
self,
save_path: Union[str, Path],
modules: Mapping[str, torch.nn.Module],
- save_best_only: bool = True,
+ save_best_only: bool = False,
reference_metric: str = None,
mode: str = "min",
- ):
+ ) -> None:
if not isinstance(modules, (dict, torch.nn.ModuleDict)):
raise ValueError("modules must be a dict or torch.nn.ModuleDict instance")
- elif not all(isinstance(k, str) and isinstance(v, torch.nn.Module) for k, v in modules.items()):
+ elif not all(
+ isinstance(k, str) and isinstance(v, torch.nn.Module)
+ for k, v in modules.items()
+ ):
raise ValueError("modules must have string keys and torch.nn.Module values")
if not isinstance(save_path, (str, Path)):
raise ValueError("path must be a string or Path object")
if "metadata" in modules:
- raise ValueError(
- "metadata in modules is reserved for metadata json object"
- )
+ raise ValueError("metadata in modules is reserved for metadata json object")
if mode not in METRIC_MODES:
raise ValueError(f"metric mode must be in {METRIC_MODES} not {mode}")
if save_best_only and reference_metric is None:
@@ -49,12 +75,13 @@ def __init__(
self.reference_metric = reference_metric
self.mode = mode
- def save_initial_weights(self):
+ def save_initial_weights(self) -> None:
+ """Save weights with which the model has been initialized."""
time_str = str(datetime.now())
ckpt_folder = "initial_weights"
paths = self._build_paths(ckpt_folder)
self._save_modules(paths)
- self._save_json(
+ self._save_metadata(
time_str,
paths["metadata"],
epoch=-1,
@@ -69,13 +96,23 @@ def save(
intra_epoch: bool = False,
extra_data: Dict = None,
**metrics: Dict[str, Any],
- ):
+ ) -> None:
+ """Save checkpoint.
+
+ Args:
+ epoch (int): current epoch index.
+ current_step (int): current batch index.
+ intra_epoch (bool, optional): boolean flag to indicate if the checkpoint is
+ intra epoch if true and end of epoch if false. Defaults to False.
+ extra_data (Dict, optional): extra metadata in json format to add to the
+ metadata.json file. Defaults to None.
+ """
time_str = str(datetime.now())
ckpt_folder = f"ckpt_{time_str.replace(' ', '_')}"
paths = self._build_paths(ckpt_folder)
self._save_modules(paths)
- self._save_json(
+ self._save_metadata(
time_str,
paths["metadata"],
epoch=epoch,
@@ -90,9 +127,56 @@ def save(
logger.info(f"Saved end_of_epoch model to {ckpt_folder}")
if not intra_epoch and self.reference_metric is not None:
- self.update_best_metric(**metrics)
+ self._update_best_metric(**metrics)
if self.save_best_only:
- self.remove_old_ckpts(ckpt_folder)
+ self._remove_old_ckpts(ckpt_folder)
+
+ def load(self, mode="last", **load_kwargs) -> None:
+ """Load checkpoint from folder.
+
+ Args:
+ mode (str, optional): last for loading the last checkpoint stored and best
+ to load the model with the mest metric. Defaults to "last".
+
+ Raises:
+ ValueError: mode is not best or last
+ """
+ if mode == "best":
+ return self._load_best_checkpoint(**load_kwargs)
+ elif mode == "last":
+ return self._load_last_checkpoint(**load_kwargs)
+ else:
+ raise ValueError('load\'s mode kwarg can be "best" or "last"')
+
+ def exists_checkpoint(self) -> bool:
+ """Check if save_path contains a checkpoint folder.
+
+ Returns:
+ bool: True if it contains a checkpoint folder, False if not.
+ """
+ for folder in self.base_path.iterdir():
+ if self._is_ckpt_dir(folder):
+ return True
+ return False
+
+ @property
+ def current_epoch(self) -> int:
+ if len(self.metadata) >= 0 and "epoch" in self.metadata:
+ return int(self.metadata["epoch"])
+ return 0
+
+ @property
+ def next_epoch(self) -> int:
+ if self._is_intraepoch():
+ return self.current_epoch
+ else:
+ return self.current_epoch + 1
+
+ @property
+ def step(self) -> int:
+ if self._is_intraepoch():
+ return self.metadata["step"]
+ return 0
def _build_paths(self, ckpt_folder_name: str) -> dict:
paths = {}
@@ -110,7 +194,7 @@ def _save_modules(self, paths: Dict[str, Path]):
paths[module_name],
)
- def _save_json(
+ def _save_metadata(
self,
time_str: str,
metadata_path: Union[Path, str],
@@ -119,7 +203,7 @@ def _save_json(
intra_epoch: bool = False,
extra_data: Dict = None,
**metrics,
- ):
+ ) -> None:
self.metadata = {
"intra_epoch": intra_epoch,
"step": step,
@@ -132,14 +216,14 @@ def _save_json(
with open(metadata_path, "w") as metadata_handler:
json.dump(self.metadata, metadata_handler, indent=4)
- def update_best_metric(self, **metrics):
+ def _update_best_metric(self, **metrics):
if self.best_metric is None:
self.best_metric = metrics[self.reference_metric]
else:
- if self.is_better(metrics[self.reference_metric], self.best_metric):
+ if self._is_better(metrics[self.reference_metric], self.best_metric):
self.best_metric = metrics[self.reference_metric]
- def remove_old_ckpts(self, curr_ckpt_folder):
+ def _remove_old_ckpts(self, curr_ckpt_folder):
for old_ckpt in self.base_path.iterdir():
if old_ckpt.name == curr_ckpt_folder or not str(old_ckpt.name).startswith(
"ckpt_"
@@ -150,56 +234,42 @@ def remove_old_ckpts(self, curr_ckpt_folder):
if self.reference_metric not in old_metadata:
shutil.rmtree(old_ckpt)
- elif not self.is_better(
+ elif not self._is_better(
old_metadata[self.reference_metric], self.best_metric
):
shutil.rmtree(old_ckpt)
- def is_better(self, new_metric, old_metric) -> bool:
+ def _is_better(self, new_metric, old_metric) -> bool:
if self.mode == "min":
return new_metric <= old_metric
elif self.mode == "max":
return new_metric >= old_metric
- def load(self, mode="last", **load_kwargs) -> bool:
- if mode == "best":
- return self.load_best_checkpoint(**load_kwargs)
- elif mode == "last":
- return self.load_last_checkpoint(**load_kwargs)
- else:
- raise ValueError('load\'s mode kwarg can be "best" or "last"')
-
- def exists_checkpoint(self) -> bool:
- for folder in self.base_path.iterdir():
- if self.is_ckpt_dir(folder):
- return True
- return False
-
- def load_best_checkpoint(self, **load_kwargs) -> bool:
- ckpts_meta = self.load_checkpoints_meta()
+ def _load_best_checkpoint(self, **load_kwargs) -> None:
+ ckpts_meta = self._load_checkpoints_meta()
ckpts_meta.pop("initial_weights")
cmp_fn = min if self.mode == "min" else max
best_ckpt_name = cmp_fn(
ckpts_meta, key=lambda i: ckpts_meta[i][self.reference_metric]
)
- return self.load_checkpoint_from_folder(
+ return self._load_checkpoint_from_folder(
best_ckpt_name, ckpts_meta[best_ckpt_name], **load_kwargs
)
- def load_last_checkpoint(self, **load_kwargs) -> bool:
- ckpts_meta = self.load_checkpoints_meta()
+ def _load_last_checkpoint(self, **load_kwargs) -> None:
+ ckpts_meta = self._load_checkpoints_meta()
latest_ckpt_name = max(ckpts_meta, key=lambda i: ckpts_meta[i]["time"])
- return self.load_checkpoint_from_folder(
+ return self._load_checkpoint_from_folder(
latest_ckpt_name, ckpts_meta[latest_ckpt_name], **load_kwargs
)
- def load_checkpoints_meta(self) -> dict:
+ def _load_checkpoints_meta(self) -> dict:
ckpts_meta = {}
for folder in self.base_path.iterdir():
if not folder.is_dir():
continue
- if self.is_ckpt_dir(folder):
+ if self._is_ckpt_dir(folder):
metadata_path = folder / "metadata.json"
with open(metadata_path, "r") as f:
ckpts_meta[folder.name] = json.load(f)
@@ -210,15 +280,14 @@ def load_checkpoints_meta(self) -> dict:
return ckpts_meta
- def load_checkpoint_from_folder(self, ckpt_folder_name, metadata, **load_kwargs) -> bool:
+ def _load_checkpoint_from_folder(
+ self, ckpt_folder_name, metadata, **load_kwargs
+ ) -> None:
paths = self._build_paths(ckpt_folder_name)
for module_name in self.modules:
try:
self.modules[module_name].load_state_dict(
- torch.load(
- paths[module_name],
- **load_kwargs
- )
+ torch.load(paths[module_name], **load_kwargs)
)
except Exception as e:
logger.error(f"Failed loading ckpt from {ckpt_folder_name}.")
@@ -228,30 +297,10 @@ def load_checkpoint_from_folder(self, ckpt_folder_name, metadata, **load_kwargs)
logger.info(
f"Loaded ckpt from epoch {self.current_epoch} from {ckpt_folder_name}"
)
- return True
@staticmethod
- def is_ckpt_dir(path: Union[str, Path]):
+ def _is_ckpt_dir(path: Union[str, Path]):
return str(path.name).startswith("ckpt_") or str(path.name) == "initial_weights"
- def is_intraepoch(self) -> bool:
+ def _is_intraepoch(self) -> bool:
return self.metadata["intra_epoch"]
-
- @property
- def current_epoch(self) -> int:
- if len(self.metadata) >= 0 and "epoch" in self.metadata:
- return int(self.metadata["epoch"])
- return 0
-
- @property
- def next_epoch(self) -> int:
- if self.is_intraepoch():
- return self.current_epoch
- else:
- return self.current_epoch + 1
-
- @property
- def step(self) -> int:
- if self.is_intraepoch():
- return self.metadata["step"]
- return 0
diff --git a/frarch/modules/metrics/base.py b/frarch/modules/metrics/base.py
index 49349e4..4669eb8 100644
--- a/frarch/modules/metrics/base.py
+++ b/frarch/modules/metrics/base.py
@@ -1,20 +1,71 @@
+import abc
+from typing import Any
+
+import torch
+
AGGREGATION_MODES = ["mean", "max", "min"]
-class Metric:
- def __init__(self):
+class Metric(metaclass=abc.ABCMeta):
+ """abstract class for Metric objects.
+
+ Example:
+ Simple usage of the Metric class::
+ class MyMetric(Metric):
+ def _update(self, predictions, truth):
+ # compute some metric
+ return metric_value
+
+ model = MyModel()
+ mymetric = MyMetric()
+ for batch, labels in dataset:
+ predictions = model(batch)
+ mymetric.update(predictions, labels)
+ print(mymetric.get_metric(mode="mean"))
+
+ """
+
+ def __init__(self) -> None:
self.reset()
- def reset(self):
+ def reset(self) -> None:
+ """Clear metrics from class."""
self.metrics = []
- def update(self):
- raise NotImplementedError
+ def update(self, predictions: torch.Tensor, truth: torch.Tensor) -> None:
+ """Compute metric value and append to the metrics array.
- def __len__(self):
+ Args:
+ predictions (torch.Tensor): output tensors from model.
+ truth (torch.Tensor): ground truth tensor.
+ """
+ self.metrics.append(self._update(predictions, truth))
+
+ @abc.abstractmethod
+ def _update(self, predictions: torch.Tensor, truth: torch.Tensor) -> Any:
+ """Compute the metric value.
+
+ Args:
+ predictions (torch.Tensor): output tensors from model.
+ truth (torch.Tensor): ground truth tensor.
+ """
+
+ def __len__(self) -> int:
return len(self.metrics)
- def get_metric(self, mode="mean"):
+ def get_metric(self, mode="mean") -> float:
+ """Aggregate all values stored in the metric class.
+
+ Args:
+ mode (str, optional): aggregation type. mean, max or min.
+ Defaults to "mean".
+
+ Raises:
+ ValueError: aggregation mode not supported
+
+ Returns:
+ float: aggregated metric.
+ """
if len(self) == 0:
return 0.0
diff --git a/frarch/modules/metrics/classification_error.py b/frarch/modules/metrics/classification_error.py
index 8548524..94de47d 100644
--- a/frarch/modules/metrics/classification_error.py
+++ b/frarch/modules/metrics/classification_error.py
@@ -4,8 +4,20 @@
class ClassificationError(Metric):
- def update(self, predictions, truth):
+ """Classification error metric.
+
+ Example:
+ Sample code for use of the ClassificationError metric class:
+ model = MyModel()
+ error = ClassificationError()
+ for batch, labels in dataset:
+ predictions = model(batch)
+ error.update(predictions, labels)
+ print(error.get_metric(mode="mean"))
+ """
+
+ def _update(self, predictions: torch.Tensor, truth: torch.Tensor) -> torch.Tensor:
if predictions.shape[0] != truth.shape[0]:
raise ValueError(f"mismatched shapes {predictions.shape} != {truth.shape}")
predictions = torch.argmax(predictions, dim=-1)
- self.metrics.append((predictions != truth).float().mean().item())
+ return (predictions != truth).float().mean().item()
diff --git a/frarch/modules/metrics/wrapper.py b/frarch/modules/metrics/wrapper.py
index 7ce09e0..fdaed84 100644
--- a/frarch/modules/metrics/wrapper.py
+++ b/frarch/modules/metrics/wrapper.py
@@ -1,22 +1,54 @@
+from typing import Any, Dict
+
from .base import Metric
class MetricsWrapper:
- def __init__(self, **kwargs):
+ """Store a set of metrics and perform operations in all of them simultaneously.
+
+ Example:
+ Sample code for metrics wrapper:
+ metrics_wrapper = MetricsWrapper(
+ metric_str0=Metric0(),
+ metric_str1=Metric1(),
+ )
+ model = Model()
+
+ for batch, labels in dataset:
+ predictions = model(batch)
+ metrics_wrapper.update(predictions, labels)
+ print(metrics_wrapper.get_metrics())
+ # prints {"metric_str0": 0.0, "metric_str1": 1.0}
+ """
+
+ def __init__(self, **kwargs: Metric) -> None:
+ """Initialize metrics in wrapper.
+
+ Raises:
+ ValueError: if any of the values in kwargs don't inherit from metric.
+ """
for k, v in kwargs.items():
if not isinstance(v, Metric):
raise ValueError(f"value for key {k} should inherit from Metric")
setattr(self, k, v)
- def reset(self):
+ def reset(self) -> None:
+ """Call reset in all metrics in the MetricsWrapper class."""
for _, v in self.__dict__.items():
v.reset()
- def update(self, *args, **kwargs):
+ def update(self, *args: Any, **kwargs: Any) -> None:
+ """Call update on all metrics in MetricsWrapper class."""
for _, v in self.__dict__.items():
v.update(*args, **kwargs)
- def get_metrics(self, *args, **kwargs):
+ def get_metrics(self, *args: Any, **kwargs: Any) -> Dict[str, Metric]:
+ """Build a dict with aggregated metrics.
+
+ Returns:
+ Dict[str, Metric]: dict with metric names as keys and aggregated metrics\
+ as values.
+ """
metrics = {}
for k, v in self.__dict__.items():
metrics[k] = v.get_metric(*args, **kwargs)
diff --git a/frarch/parser.py b/frarch/parser.py
index 2ea6a72..d1bc176 100644
--- a/frarch/parser.py
+++ b/frarch/parser.py
@@ -1,16 +1,14 @@
import argparse
from pathlib import Path
+from typing import Any, Dict, Tuple
-def parse_arguments():
- """Parse command-line arguments to the experiment.
+def parse_arguments() -> Tuple[str, Dict[str, Any]]:
+ """Parse arguments from command line.
- Returns
- -------
- param_file : str
- The location of the parameters file.
- parameters : argparse.NameSpace
- options
+ Returns:
+ params_file (str): hyperparams file path.
+ args (Dict[str, Any]]): arguments from argparse.Namespace.
"""
parser = argparse.ArgumentParser(
description="Run an experiment",
diff --git a/frarch/train/base_trainer.py b/frarch/train/base_trainer.py
index 2e9a7ce..f1cd28c 100644
--- a/frarch/train/base_trainer.py
+++ b/frarch/train/base_trainer.py
@@ -14,14 +14,17 @@ class definition of a base trainer.
import logging
import sys
+from typing import Any, Mapping, Optional, Type, Union
import torch
+from torch.utils.data import DataLoader, Dataset
+from frarch.modules import Checkpointer
from frarch.utils.stages import Stage
logger = logging.getLogger(__name__)
PYTHON_VERSION_MAJOR = 3
-PYTHON_VERSION_MINOR = 6
+PYTHON_VERSION_MINOR = 7
default_values = {
"debug": False,
@@ -35,7 +38,28 @@ class definition of a base trainer.
class BaseTrainer:
- def __init__(self, modules, opt_class, hparams, checkpointer=None):
+ """Abstract class for trainer managers.
+
+ Args:
+ modules (Mapping[str, torch.nn.Module]): trainable modules in the training.
+ opt_class (Type[torch.optim.Optimizer]): optimizer class for training.
+ hparams (Mapping[str, Any]): hparams dict-like structure from hparams file.
+ checkpointer (Optional[Checkpointer], optional): Checkpointer class for saving
+ the model and the hyperparameters needed. If None, no checkpoints are saved.
+ Defaults to None.
+
+ Raises:
+ ValueError: ckpt_interval_minutes must be > 0 or None
+ SystemError: Python version not supported. Python version must be >= 3.7
+ """
+
+ def __init__(
+ self,
+ modules: Mapping[str, torch.nn.Module],
+ opt_class: Type[torch.optim.Optimizer],
+ hparams: Mapping[str, Any],
+ checkpointer: Optional[Checkpointer] = None,
+ ) -> None:
self.hparams = hparams
self.opt_class = opt_class
self.checkpointer = checkpointer
@@ -73,12 +97,34 @@ def __init__(self, modules, opt_class, hparams, checkpointer=None):
self.avg_train_loss = 0.0
self.step = 0
- def __call__(self, *args, **kwargs):
+ def __call__(self, *args, **kwargs) -> None:
+ """Alias for fit."""
return self.fit(*args, **kwargs)
- def on_fit_start(self):
+ def fit(
+ self,
+ train_set: Union[Dataset, DataLoader],
+ valid_set: Optional[Union[Dataset, DataLoader]] = None,
+ train_loader_kwargs: dict = None,
+ valid_loader_kwargs: dict = None,
+ ) -> None:
+ """Fit the modules to the dataset. Main function of the Trainer class.
+
+ Args:
+ train_set (Union[Dataset, DataLoader]): dataset for training.
+ valid_set (Optional[Union[Dataset, DataLoader]], optional): dataset for
+ validation. If not provided, validation will not be performed.
+ Defaults to None.
+ train_loader_kwargs (dict, optional): optional kwargs for train dataloader.
+ Defaults to None.
+ valid_loader_kwargs (dict, optional): optional kwargs for valid dataloader.
+ Defaults to None.
+ """
+ raise NotImplementedError
+
+ def _on_fit_start(self) -> None:
# Initialize optimizers
- self.init_optimizers()
+ self._init_optimizers()
# set first epoch index
self.start_epoch = 0
@@ -96,61 +142,58 @@ def on_fit_start(self):
)
if self.start_epoch == 0:
- self.save_initial_weights()
+ self._save_initial_weights()
- def save_initial_weights(self):
+ def _save_initial_weights(self) -> None:
if self.checkpointer is not None:
self.checkpointer.save_initial_weights()
- def init_optimizers(self):
+ def _init_optimizers(self) -> None:
if self.opt_class is not None:
self.optimizer = self.opt_class(self.modules.parameters())
- def on_fit_end(self, epoch=None):
+ def _on_fit_end(self, epoch: Optional[int] = None) -> None:
pass
- def on_stage_start(self, stage, epoch=None):
+ def _on_stage_start(self, stage: Stage, epoch: Optional[int] = None) -> None:
pass
- def on_stage_end(self, stage, loss=None, epoch=None):
+ def _on_stage_end(
+ self, stage: Stage, loss=None, epoch: Optional[int] = None
+ ) -> None:
pass
- def on_train_interval(self, epoch=None):
+ def _on_train_interval(self, epoch: Optional[int] = None) -> None:
pass
- def save_intra_epoch_ckpt(self):
+ def _save_intra_epoch_ckpt(self) -> None:
raise NotImplementedError
- def forward(self, batch, stage):
+ def _forward(self, batch: torch.Tensor, stage: Stage) -> torch.Tensor:
raise NotImplementedError
- def evaluate_batch(self, batch, stage):
- out = self.forward(batch, stage=stage)
- loss = self.compute_loss(out, batch, stage=stage)
+ def _evaluate_batch(self, batch: torch.Tensor, stage: Stage) -> torch.Tensor:
+ out = self._forward(batch, stage=stage)
+ loss = self._compute_loss(out, batch, stage=stage)
return loss.detach().cpu()
- def compute_loss(self, predictions, batch, stage):
+ def _compute_loss(
+ self, predictions: torch.Tensor, batch: torch.Tensor, stage: Stage
+ ) -> torch.Tensor:
raise NotImplementedError
- def fit_batch(self, batch):
+ def _fit_batch(self, batch: torch.Tensor) -> torch.Tensor:
self.optimizer.zero_grad()
- outputs = self.forward(batch, Stage.TRAIN)
- loss = self.compute_loss(outputs, batch, Stage.TRAIN)
+ outputs = self._forward(batch, Stage.TRAIN)
+ loss = self._compute_loss(outputs, batch, Stage.TRAIN)
loss.backward()
self.optimizer.step()
return loss.detach().cpu()
- def update_average(self, loss, avg_loss):
+ def _update_average(
+ self, loss: torch.Tensor, avg_loss: torch.Tensor
+ ) -> torch.Tensor:
if torch.isfinite(loss):
avg_loss -= avg_loss / self.step
avg_loss += float(loss) / self.step
return avg_loss
-
- def fit(
- self,
- train_set,
- valid_set=None,
- train_loader_kwargs: dict = None,
- valid_loader_kwargs: dict = None,
- ):
- raise NotImplementedError
diff --git a/frarch/train/classifier_trainer.py b/frarch/train/classifier_trainer.py
index 1067b3a..b5ad0a7 100644
--- a/frarch/train/classifier_trainer.py
+++ b/frarch/train/classifier_trainer.py
@@ -14,8 +14,10 @@
import logging
import time
+from typing import Iterable, Optional, Union
import torch
+from torch.utils.data import DataLoader, Dataset
from frarch.utils.data import create_dataloader
from frarch.utils.stages import Stage
@@ -26,13 +28,40 @@
class ClassifierTrainer(BaseTrainer):
+ """Trainer class for classifiers.
+
+ Args:
+ modules (Mapping[str, torch.nn.Module]): trainable modules in the training.
+ opt_class (Type[torch.optim.Optimizer]): optimizer class for training.
+ hparams (Mapping[str, Any]): hparams dict-like structure from hparams file.
+ checkpointer (Optional[Checkpointer], optional): Checkpointer class for saving
+ the model and the hyperparameters needed. If None, no checkpoints are saved.
+ Defaults to None.
+
+ Raises:
+ ValueError: ckpt_interval_minutes must be > 0 or None
+ SystemError: Python version not supported. Python version must be >= 3.7
+ """
+
def fit(
self,
- train_set,
- valid_set=None,
+ train_set: Union[Dataset, DataLoader],
+ valid_set: Optional[Union[Dataset, DataLoader]] = None,
train_loader_kwargs: dict = None,
valid_loader_kwargs: dict = None,
- ):
+ ) -> None:
+ """Fit the modules to the dataset. Main function of the Trainer class.
+
+ Args:
+ train_set (Union[Dataset, DataLoader]): dataset for training.
+ valid_set (Optional[Union[Dataset, DataLoader]], optional): dataset for
+ validation. If not provided, validation will not be performed.
+ Defaults to None.
+ train_loader_kwargs (dict, optional): optional kwargs for train dataloader.
+ Defaults to None.
+ valid_loader_kwargs (dict, optional): optional kwargs for valid dataloader.
+ Defaults to None.
+ """
if train_loader_kwargs is None:
train_loader_kwargs = {}
if valid_loader_kwargs is None:
@@ -45,16 +74,16 @@ def fit(
):
valid_set = create_dataloader(valid_set, **valid_loader_kwargs)
- self.on_fit_start()
+ self._on_fit_start()
for self.current_epoch in range(self.start_epoch, self.hparams["epochs"]):
- self.on_stage_start(Stage.TRAIN, self.current_epoch)
+ self._on_stage_start(Stage.TRAIN, self.current_epoch)
self.modules.train()
last_ckpt_time = time.time()
- t = self.get_iterable(
+ t = self._get_iterable(
train_set,
desc=f"Epoch {self.current_epoch} train",
initial=self.step,
@@ -62,14 +91,14 @@ def fit(
)
for batch in t:
self.step += 1
- loss = self.fit_batch(batch)
- self.avg_train_loss = self.update_average(loss, self.avg_train_loss)
- self.update_progress(
+ loss = self._fit_batch(batch)
+ self.avg_train_loss = self._update_average(loss, self.avg_train_loss)
+ self._update_progress(
t, self.step, stage="train", train_loss=self.avg_train_loss
)
if not (self.step % self.train_interval):
- self.on_train_interval(self.current_epoch)
+ self._on_train_interval(self.current_epoch)
if self.debug and self.step >= self.debug_batches:
break
@@ -80,45 +109,47 @@ def fit(
and time.time() - last_ckpt_time
>= self.ckpt_interval_minutes * 60.0
):
- self.save_intra_epoch_ckpt()
+ self._save_intra_epoch_ckpt()
last_ckpt_time = time.time()
if not self.noprogressbar:
t.close()
# Run train "on_stage_end" on all processes
- self.on_stage_end(Stage.TRAIN, self.avg_train_loss, self.current_epoch)
+ self._on_stage_end(Stage.TRAIN, self.avg_train_loss, self.current_epoch)
# Validation stage
if valid_set is not None:
- self.on_stage_start(Stage.VALID, self.current_epoch)
+ self._on_stage_start(Stage.VALID, self.current_epoch)
self.modules.eval()
avg_valid_loss = 0.0
valid_step = 0
with torch.no_grad():
- t = self.get_iterable(
+ t = self._get_iterable(
valid_set,
desc=f"Epoch {self.current_epoch} valid",
dynamic_ncols=True,
)
for batch in t:
valid_step += 1
- loss = self.evaluate_batch(batch, stage=Stage.VALID)
- avg_valid_loss = self.update_average(loss, avg_valid_loss)
- self.update_progress(
+ loss = self._evaluate_batch(batch, stage=Stage.VALID)
+ avg_valid_loss = self._update_average(loss, avg_valid_loss)
+ self._update_progress(
t, valid_step, stage="valid", valid_loss=avg_valid_loss
)
+ if self.debug and self.step >= self.debug_batches:
+ break
if not self.noprogressbar:
t.close()
# Only run validation "on_stage_end" on main process
- self.on_stage_end(Stage.VALID, avg_valid_loss, self.current_epoch)
+ self._on_stage_end(Stage.VALID, avg_valid_loss, self.current_epoch)
self.step = 0
self.avg_train_loss = 0.0
- self.on_fit_end()
+ self._on_fit_end()
- def get_iterable(self, dataset: torch.utils.data.DataLoader, **kwargs):
+ def _get_iterable(self, dataset: torch.utils.data.DataLoader, **kwargs) -> Iterable:
if not self.noprogressbar:
from tqdm import tqdm
@@ -128,14 +159,16 @@ def get_iterable(self, dataset: torch.utils.data.DataLoader, **kwargs):
logger.warning(f"Running {self.__class__.__name__} without tqdm")
return dataset
- def update_progress(self, iterable, step, stage=None, **kwargs):
+ def _update_progress(
+ self, iterable: Iterable, step: int, stage: Stage, **kwargs
+ ) -> None:
if self.noprogressbar:
if not step % self.train_interval or (step >= len(iterable)) or (step == 1):
- self.update_progress_console(iterable, step, stage=stage, **kwargs)
+ self._update_progress_console(iterable, step, stage=stage, **kwargs)
else:
- self.update_progress_tqdm(iterable, **kwargs)
+ self._update_progress_tqdm(iterable, **kwargs)
- def update_progress_tqdm(self, iterable, **kwargs):
+ def _update_progress_tqdm(self, iterable: Iterable, **kwargs) -> None:
if "metrics" in self.hparams:
iterable.set_postfix(
**kwargs,
@@ -144,7 +177,9 @@ def update_progress_tqdm(self, iterable, **kwargs):
else:
iterable.set_postfix(**kwargs)
- def update_progress_console(self, iterable, step, stage=None, **kwargs):
+ def _update_progress_console(
+ self, iterable: Iterable, step: int, stage: Stage, **kwargs
+ ) -> None:
kwargs_string = ", ".join([f"{k}={v:.4f}" for k, v in kwargs.items()])
if "metrics" in self.hparams:
metrics = self.hparams["metrics"].get_metrics(mode="mean")
@@ -152,7 +187,7 @@ def update_progress_console(self, iterable, step, stage=None, **kwargs):
else:
metrics_string = ""
- print(
+ logger.info(
f"Epoch {self.current_epoch} {stage}: step {step}/{len(iterable)}"
f" -> {kwargs_string}, {metrics_string}"
)
diff --git a/frarch/utils/data/__init__.py b/frarch/utils/data/__init__.py
index c4aef32..63317d0 100644
--- a/frarch/utils/data/__init__.py
+++ b/frarch/utils/data/__init__.py
@@ -1,10 +1,11 @@
import logging
from pathlib import Path
-from typing import Union
+from typing import Any, Mapping, Optional, Union
from urllib.request import urlretrieve
import torch
from hyperpyyaml import resolve_references
+from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from frarch.utils.logging import create_logger_file
@@ -12,18 +13,39 @@
logger = logging.getLogger(__name__)
-def create_dataloader(dataset: torch.utils.data.Dataset, **dataloader_kwargs):
- if not isinstance(dataset, torch.utils.data.Dataset):
+def create_dataloader(dataset: Dataset, **dataloader_kwargs) -> DataLoader:
+ """Create dataloader from dataset.
+
+ Args:
+ dataset (torch.utils.data.Dataset): dataset object to feed onto DataLoader.
+
+ Raises:
+ ValueError: dataset is not a Dataset or does not inherit from it.
+
+ Returns:
+ torch.utils.data.DataLoader: dataloader with the dataset object.
+ """
+ if not isinstance(dataset, Dataset):
raise ValueError("dataset needs to be a child or torch.utils.data.Dataset")
- return torch.utils.data.DataLoader(dataset, **dataloader_kwargs)
+ return DataLoader(dataset, **dataloader_kwargs)
def build_experiment_structure(
- hparams_file,
- overrides=None,
- experiment_folder: str = "results/debug/",
+ hparams_file: Union[str, Path],
+ experiment_folder: Union[str, Path],
+ overrides: Mapping = None,
debug: bool = False,
):
+ """Construct experiment folder hierarchy on experiment_folder.
+
+ Args:
+ hparams_file (Union[str, Path]): hparams configuration file path.
+ experiment_folder (Union[str, Path]): Folder where to store experiment files.
+ Defaults to "results/debug/".
+ overrides (Mapping, optional): Parameters to override on hparams file.
+ Defaults to None.
+ debug (bool, optional): debug flag. Defaults to False.
+ """
if overrides is None:
overrides = {}
@@ -44,28 +66,20 @@ def build_experiment_structure(
logger.info(f"experiment folder {str(base_path)} created successfully")
-def download_url(url, destination=None, progress_bar=True):
+def download_url(
+ url: str, destination: Optional[str] = None, progress_bar: bool = True
+) -> str:
"""Download a URL to a local file.
- Parameters
- ----------
- url : str
- The URL to download.
- destination : str, None
- The destination of the file. If None is given the file is saved to
- a temporary directory.
- progress_bar : bool
- Whether to show a command-line progress bar while downloading.
-
- Returns
- -------
- filename : str
- The location of the downloaded file.
-
- Notes
- -----
- Progress bar use/example adapted from tqdm documentation:
- https://github.com/tqdm/tqdm
+ Args:
+ url (str): The URL to download.
+ destination (Optional[str], optional): The destination of the file. If None is
+ given the file is saved to a temporary directory. Defaults to None.
+ progress_bar (bool, optional): The destination of the file. If None is given
+ the file is saved to a temporary directory. Defaults to True.
+
+ Returns:
+ str: filename of downloaded file.
"""
def update_progressbar(t):
@@ -87,11 +101,30 @@ def inner(b=1, bsize=1, tsize=None):
)
else:
filename, _ = urlretrieve(url, filename=destination)
+ return filename
-def tensorInDevice(data, device="cpu", **kwargs):
+def tensor_in_device(data: Any, device: str = "cpu", **kwargs) -> torch.Tensor:
+ """Create tensor in device.
+
+ Args:
+ data (Any): data on the tensor.
+ device (str, optional): string of the device to be created in.
+ Defaults to "cpu".
+
+ Returns:
+ torch.Tensor: tensor in device.
+ """
return torch.Tensor(data, **kwargs).to(device)
-def read_file(filepath: Union[str, Path]):
+def read_file(filepath: Union[str, Path]) -> str:
+ """Read contents of file.
+
+ Args:
+ filepath (Union[str, Path]): path to file.
+
+ Returns:
+ str: contents of the file
+ """
return Path(filepath).read_text()
diff --git a/frarch/utils/exceptions.py b/frarch/utils/exceptions.py
index 6b1871a..df3d52e 100644
--- a/frarch/utils/exceptions.py
+++ b/frarch/utils/exceptions.py
@@ -5,12 +5,13 @@
class DatasetNotFoundError(Exception):
"""Exception raised for OS dataset errors.
- Args
- ----
- path([Path, str]): [path where the dataset should be]
+ Args:
+ path ([Path, str]): path where the dataset should be
"""
- def __init__(self, path: Union[str, Path], msg="Dataset not found in path {path}"):
+ def __init__(
+ self, path: Union[str, Path], msg: str = "Dataset not found in path {path}"
+ ) -> None:
self.path = path
self.msg = msg
super().__init__(self.msg.format(path=path))
diff --git a/frarch/utils/logging/create_logger.py b/frarch/utils/logging/create_logger.py
index 254827b..e7e5ca1 100644
--- a/frarch/utils/logging/create_logger.py
+++ b/frarch/utils/logging/create_logger.py
@@ -4,7 +4,22 @@
from typing import Union
-def create_logger_file(log_file_path: Union[str, Path], debug=False, stdout=False):
+def create_logger_file(
+ log_file_path: Union[str, Path], debug: bool = False, stdout: bool = False
+) -> None:
+ """Create logger file in file path.
+
+ Args:
+ log_file_path (Union[str, Path]): path to logger file.
+ debug (bool, optional): log debug statements. Defaults to False.
+ stdout (bool, optional): log to stdout as well as to the log file.
+ Defaults to False.
+
+ Raises:
+ ValueError: log_file_path is not a str or pathlib.Path object.
+ ValueError: debug is not boolean.
+ ValueError: stdout is not boolean.
+ """
if not isinstance(log_file_path, (str, Path)):
raise ValueError("path must be a string or Path object")
if not isinstance(debug, bool):
diff --git a/renovate.json b/renovate.json
new file mode 100644
index 0000000..f45d8f1
--- /dev/null
+++ b/renovate.json
@@ -0,0 +1,5 @@
+{
+ "extends": [
+ "config:base"
+ ]
+}
diff --git a/tests/functional/runFunctionalTests.sh b/tests/functional/runFunctionalTests.sh
new file mode 100755
index 0000000..de8dcaa
--- /dev/null
+++ b/tests/functional/runFunctionalTests.sh
@@ -0,0 +1,19 @@
+#!/usr/bin/env bash
+
+set -e
+
+test_training () {
+ experimentFolder=$1
+ cd ${experimentFolder}
+ resultsFolder=$(cat train.yaml | grep -Po '^experiment_folder: "\K.*(?=")')
+ python train.py train.yaml --device cpu --debug --debug_batches 10
+ [ $(ls ${resultsFolder}/save/*/*.pt | wc -l) -gt 0 ]
+ [ $(ls ${resultsFolder}/save/*/*.json | wc -l) -gt 0 ]
+ [ $(ls ${resultsFolder}/train.yaml | wc -l) -gt 0 ]
+ cd -
+}
+
+[ $(ls -d */results/ | wc -l) -gt 0 ] && rm -rf */results/
+test_training train_mnist
+
+set +e
diff --git a/tests/functional/train_mnist/train.py b/tests/functional/train_mnist/train.py
new file mode 100644
index 0000000..131addc
--- /dev/null
+++ b/tests/functional/train_mnist/train.py
@@ -0,0 +1,52 @@
+from hyperpyyaml import load_hyperpyyaml
+
+import frarch as fr
+from frarch.utils.data import build_experiment_structure
+from frarch.utils.stages import Stage
+
+
+class MNISTTrainer(fr.train.ClassifierTrainer):
+ def _forward(self, batch, stage):
+ inputs, _ = batch
+ inputs = inputs.to(self.device)
+ return self.modules.model(inputs)
+
+ def _compute_loss(self, predictions, batch, stage):
+ _, labels = batch
+ labels = labels.to(self.device)
+ return self.hparams["loss"](predictions, labels)
+
+ def _on_stage_end(self, stage, loss=None, epoch=None):
+ if stage == Stage.VALID:
+ if self.checkpointer is not None:
+ self.checkpointer.save(epoch=self.current_epoch, current_step=self.step)
+
+
+if __name__ == "__main__":
+ hparam_file, args = fr.parse_arguments()
+
+ with open(hparam_file, "r") as hparam_file_handler:
+ hparams = load_hyperpyyaml(
+ hparam_file_handler, args, overrides_must_match=False
+ )
+
+ build_experiment_structure(
+ hparam_file,
+ overrides=args,
+ experiment_folder=hparams["experiment_folder"],
+ debug=hparams["debug"],
+ )
+
+ trainer = MNISTTrainer(
+ modules=hparams["modules"],
+ opt_class=hparams["opt_class"],
+ hparams=hparams,
+ checkpointer=hparams["checkpointer"],
+ )
+
+ trainer.fit(
+ train_set=hparams["train_dataset"],
+ valid_set=hparams["valid_dataset"],
+ train_loader_kwargs=hparams["dataloader_options"],
+ valid_loader_kwargs=hparams["dataloader_options"],
+ )
diff --git a/tests/functional/train_mnist/train.yaml b/tests/functional/train_mnist/train.yaml
new file mode 100644
index 0000000..7f6816f
--- /dev/null
+++ b/tests/functional/train_mnist/train.yaml
@@ -0,0 +1,67 @@
+# seeds
+seed: 42
+__set_seed: !apply:torch.manual_seed [!ref ]
+experiment_name: "mnist"
+experiment_folder: "results/mnist_functional/"
+device: "cpu"
+
+# data folder
+data_folder: /tmp/
+
+# training parameters
+epochs: 2
+batch_size: 16
+shuffle: True
+num_clases: 10
+
+transform_resize: !new:torchvision.transforms.Resize
+ size: 32
+transform_grayscale: !new:torchvision.transforms.Grayscale
+ num_output_channels: 3
+transform_tensor: !new:torchvision.transforms.ToTensor
+transform_normalize: !new:torchvision.transforms.Normalize
+ mean: 0.1307
+ std: 0.3081
+
+preprocessing: !new:torchvision.transforms.Compose
+ transforms: [
+ !ref ,
+ !ref ,
+ !ref ,
+ !ref
+ ]
+
+# dataset object
+train_dataset: !new:torchvision.datasets.MNIST
+ root: !ref
+ train: true
+ download: true
+ transform: !ref
+
+valid_dataset: !new:torchvision.datasets.MNIST
+ root: !ref
+ train: false
+ download: true
+ transform: !ref
+
+# dataloader options
+dataloader_options:
+ batch_size: !ref
+ shuffle: !ref
+ num_workers: 8
+
+opt_class: !name:torch.optim.Adam
+ lr: 0.001
+
+loss: !new:torch.nn.CrossEntropyLoss
+
+model: !apply:torchvision.models.vgg11
+ pretrained: false
+
+modules:
+ model: !ref
+
+checkpointer: !new:frarch.modules.Checkpointer
+ save_path: !ref
+ save_best_only: false
+ modules: !ref
diff --git a/tests/__init__.py b/tests/unit/__init__.py
similarity index 100%
rename from tests/__init__.py
rename to tests/unit/__init__.py
diff --git a/tests/test_checkpointer.py b/tests/unit/test_checkpointer.py
similarity index 61%
rename from tests/test_checkpointer.py
rename to tests/unit/test_checkpointer.py
index 0e11a19..8496be3 100644
--- a/tests/test_checkpointer.py
+++ b/tests/unit/test_checkpointer.py
@@ -1,12 +1,14 @@
-import torch
import copy
-import unittest
-import shutil
import json
+import shutil
+import unittest
from pathlib import Path
+
+import torch
+
from frarch.modules.checkpointer import Checkpointer
-DATA_FOLDER = Path("./tests/data/")
+DATA_FOLDER = Path(__file__).resolve().parent.parent / "data"
class MockModel(torch.nn.Module):
@@ -29,9 +31,7 @@ def tearDown(self):
def test_init_ok(self):
ckpter = Checkpointer(
- save_path=self.TMP_CKPT_PATH,
- modules=self.modules,
- save_best_only=False
+ save_path=self.TMP_CKPT_PATH, modules=self.modules, save_best_only=False
)
self.assertEqual(self.modules, ckpter.modules)
self.assertEqual(self.TMP_CKPT_PATH / "save", ckpter.base_path)
@@ -43,73 +43,58 @@ def test_init_ok(self):
def test_init_path_not_string(self):
with self.assertRaises(ValueError):
- Checkpointer(
- save_path=0.0,
- modules=self.modules
- )
+ Checkpointer(save_path=0.0, modules=self.modules)
def test_init_modules_not_ok(self):
with self.assertRaises(ValueError):
- Checkpointer(
- save_path=self.TMP_CKPT_PATH,
- modules=[MockModel()]
- )
+ Checkpointer(save_path=self.TMP_CKPT_PATH, modules=[MockModel()])
def test_metadata_in_modules(self):
nok_modules = copy.deepcopy(self.modules)
nok_modules["metadata"] = MockModel()
with self.assertRaises(ValueError):
- Checkpointer(
- save_path=self.TMP_CKPT_PATH,
- modules=nok_modules
- )
+ Checkpointer(save_path=self.TMP_CKPT_PATH, modules=nok_modules)
def test_key_not_string_in_modules(self):
nok_modules = dict(copy.deepcopy(self.modules))
nok_modules[0] = MockModel()
with self.assertRaises(ValueError):
- Checkpointer(
- save_path=self.TMP_CKPT_PATH,
- modules=nok_modules
- )
+ Checkpointer(save_path=self.TMP_CKPT_PATH, modules=nok_modules)
def test_value_not_module_in_modules(self):
nok_modules = dict(copy.deepcopy(self.modules))
nok_modules["module3"] = "not-a-module"
with self.assertRaises(ValueError):
- Checkpointer(
- save_path=self.TMP_CKPT_PATH,
- modules=nok_modules
- )
+ Checkpointer(save_path=self.TMP_CKPT_PATH, modules=nok_modules)
def test_mode_not_valid(self):
with self.assertRaises(ValueError):
Checkpointer(
- save_path=self.TMP_CKPT_PATH,
- modules=self.modules,
- mode="not-valid"
+ save_path=self.TMP_CKPT_PATH, modules=self.modules, mode="not-valid"
)
def test_save_best_only_no_reference_metric(self):
with self.assertRaises(ValueError):
Checkpointer(
- save_path=self.TMP_CKPT_PATH,
- modules=self.modules,
- save_best_only=True
+ save_path=self.TMP_CKPT_PATH, modules=self.modules, save_best_only=True
)
def test_save_initial_weights(self):
ckpter = Checkpointer(
- save_path=self.TMP_CKPT_PATH,
- modules=self.modules,
- save_best_only=False
+ save_path=self.TMP_CKPT_PATH, modules=self.modules, save_best_only=False
)
ckpter.save_initial_weights()
- self.assertTrue((self.TMP_CKPT_PATH/"save"/"initial_weights").exists())
- metadata = read_json(self.TMP_CKPT_PATH/"save"/"initial_weights"/"metadata.json")
+ self.assertTrue((self.TMP_CKPT_PATH / "save" / "initial_weights").exists())
+ metadata = read_json(
+ self.TMP_CKPT_PATH / "save" / "initial_weights" / "metadata.json"
+ )
self.assertEqual(metadata["epoch"], -1)
- self.assertTrue((self.TMP_CKPT_PATH/"save"/"initial_weights"/"model.pt").exists())
- self.assertTrue((self.TMP_CKPT_PATH/"save"/"initial_weights"/"model2.pt").exists())
+ self.assertTrue(
+ (self.TMP_CKPT_PATH / "save" / "initial_weights" / "model.pt").exists()
+ )
+ self.assertTrue(
+ (self.TMP_CKPT_PATH / "save" / "initial_weights" / "model2.pt").exists()
+ )
def test_save_end_of_epoch(self):
Checkpointer(
@@ -117,15 +102,12 @@ def test_save_end_of_epoch(self):
modules=self.modules,
save_best_only=False,
reference_metric=None,
- mode="min"
+ mode="min",
).save(
- epoch=1,
- current_step=1000,
- intra_epoch=False,
- extra_data={"test": "test"}
+ epoch=1, current_step=1000, intra_epoch=False, extra_data={"test": "test"}
)
- pt_paths = list(self.TMP_CKPT_PATH.glob('**/*.pt'))
- metadata_paths = list(self.TMP_CKPT_PATH.glob('**/metadata.json'))
+ pt_paths = list(self.TMP_CKPT_PATH.glob("**/*.pt"))
+ metadata_paths = list(self.TMP_CKPT_PATH.glob("**/metadata.json"))
self.assertEqual(len(pt_paths), len(self.modules))
metadata = read_json(metadata_paths[0])
self.assertFalse(metadata["intra_epoch"])
@@ -139,14 +121,14 @@ def test_save_intra_epoch(self):
modules=self.modules,
save_best_only=False,
reference_metric=None,
- mode="min"
+ mode="min",
).save(
epoch=1,
current_step=1000,
intra_epoch=True,
)
- pt_paths = list(self.TMP_CKPT_PATH.glob('**/*.pt'))
- metadata_paths = list(self.TMP_CKPT_PATH.glob('**/metadata.json'))
+ pt_paths = list(self.TMP_CKPT_PATH.glob("**/*.pt"))
+ metadata_paths = list(self.TMP_CKPT_PATH.glob("**/metadata.json"))
self.assertEqual(len(pt_paths), len(self.modules))
metadata = read_json(metadata_paths[0])
self.assertTrue(metadata["intra_epoch"])
@@ -159,14 +141,11 @@ def test_save_extradata(self):
modules=self.modules,
save_best_only=False,
reference_metric=None,
- mode="min"
+ mode="min",
).save(
- epoch=1,
- current_step=1000,
- intra_epoch=True,
- extra_data={"test": "test"}
+ epoch=1, current_step=1000, intra_epoch=True, extra_data={"test": "test"}
)
- metadata_paths = list(self.TMP_CKPT_PATH.glob('**/metadata.json'))
+ metadata_paths = list(self.TMP_CKPT_PATH.glob("**/metadata.json"))
metadata = read_json(metadata_paths[0])
self.assertDictEqual(metadata["extra_info"], {"test": "test"})
@@ -176,15 +155,10 @@ def test_save_metric(self):
modules=self.modules,
save_best_only=True,
reference_metric="metric",
- mode="min"
- )
- ckpter.save(
- epoch=1,
- current_step=1000,
- intra_epoch=False,
- metric=.5
+ mode="min",
)
- self.assertEqual(ckpter.best_metric, .5)
+ ckpter.save(epoch=1, current_step=1000, intra_epoch=False, metric=0.5)
+ self.assertEqual(ckpter.best_metric, 0.5)
def test_save_metric_update_min(self):
ckpter = Checkpointer(
@@ -192,21 +166,11 @@ def test_save_metric_update_min(self):
modules=self.modules,
save_best_only=True,
reference_metric="metric",
- mode="min"
- )
- ckpter.save(
- epoch=1,
- current_step=1000,
- intra_epoch=False,
- metric=.5
- )
- ckpter.save(
- epoch=1,
- current_step=1000,
- intra_epoch=False,
- metric=.1
+ mode="min",
)
- self.assertEqual(ckpter.best_metric, .1)
+ ckpter.save(epoch=1, current_step=1000, intra_epoch=False, metric=0.5)
+ ckpter.save(epoch=1, current_step=1000, intra_epoch=False, metric=0.1)
+ self.assertEqual(ckpter.best_metric, 0.1)
def test_save_metric_update_max(self):
ckpter = Checkpointer(
@@ -214,21 +178,11 @@ def test_save_metric_update_max(self):
modules=self.modules,
save_best_only=True,
reference_metric="metric",
- mode="max"
- )
- ckpter.save(
- epoch=1,
- current_step=1000,
- intra_epoch=False,
- metric=.5
- )
- ckpter.save(
- epoch=1,
- current_step=1000,
- intra_epoch=False,
- metric=.1
+ mode="max",
)
- self.assertEqual(ckpter.best_metric, .5)
+ ckpter.save(epoch=1, current_step=1000, intra_epoch=False, metric=0.5)
+ ckpter.save(epoch=1, current_step=1000, intra_epoch=False, metric=0.1)
+ self.assertEqual(ckpter.best_metric, 0.5)
def test_load_checkpoint(self):
modules = copy.deepcopy(self.modules)
@@ -237,14 +191,9 @@ def test_load_checkpoint(self):
modules=modules,
save_best_only=True,
reference_metric="metric",
- mode="max"
- )
- ckpter.save(
- epoch=1,
- current_step=1000,
- intra_epoch=False,
- metric=.5
+ mode="max",
)
+ ckpter.save(epoch=1, current_step=1000, intra_epoch=False, metric=0.5)
modules.model.fc = torch.nn.Linear(2, 1)
ckpter.load(mode="last")
self.assertTrue((self.modules.model.fc.weight == modules.model.fc.weight).all())
@@ -256,47 +205,33 @@ def test_load_checkpoint_map_location(self):
modules=modules,
save_best_only=True,
reference_metric="metric",
- mode="max"
- )
- ckpter.save(
- epoch=1,
- current_step=1000,
- intra_epoch=False,
- metric=.5
+ mode="max",
)
+ ckpter.save(epoch=1, current_step=1000, intra_epoch=False, metric=0.5)
modules.model.fc = torch.nn.Linear(2, 1)
- ckpter.load(mode="last", map_location='cpu')
+ ckpter.load(mode="last", map_location="cpu")
self.assertTrue((self.modules.model.fc.weight == modules.model.fc.weight).all())
def test_properies_end_of_epoch(self):
ckpter = Checkpointer(
- save_path=self.TMP_CKPT_PATH,
- modules=self.modules,
- save_best_only=False
- )
- ckpter.save(
- epoch=1,
- current_step=1000,
- intra_epoch=False,
- metric=.5
+ save_path=self.TMP_CKPT_PATH, modules=self.modules, save_best_only=False
)
- self.assertEqual(ckpter.is_intraepoch(), False)
+ ckpter.save(epoch=1, current_step=1000, intra_epoch=False, metric=0.5)
+ self.assertEqual(ckpter._is_intraepoch(), False)
self.assertEqual(ckpter.current_epoch, 1)
self.assertEqual(ckpter.next_epoch, 2)
self.assertEqual(ckpter.step, 0)
def test_properies_intra_epoch(self):
ckpter = Checkpointer(
- save_path=self.TMP_CKPT_PATH,
- modules=self.modules,
- save_best_only=False
+ save_path=self.TMP_CKPT_PATH, modules=self.modules, save_best_only=False
)
ckpter.save(
epoch=1,
current_step=1000,
intra_epoch=True,
)
- self.assertEqual(ckpter.is_intraepoch(), True)
+ self.assertEqual(ckpter._is_intraepoch(), True)
self.assertEqual(ckpter.current_epoch, 1)
self.assertEqual(ckpter.next_epoch, 1)
self.assertEqual(ckpter.step, 1000)
diff --git a/tests/unit/test_cnn_classifiers.py b/tests/unit/test_cnn_classifiers.py
new file mode 100644
index 0000000..dee2638
--- /dev/null
+++ b/tests/unit/test_cnn_classifiers.py
@@ -0,0 +1,115 @@
+import unittest
+from typing import Tuple
+
+import torch
+
+from frarch.models.classification.cnn import (
+ MNISTCNN,
+ FashionClassifier,
+ FashionCNN,
+ MitCNN,
+ MitCNNClassifier,
+ MNISTClassifier,
+ vgg11,
+ vgg11_bn,
+ vgg13,
+ vgg13_bn,
+ vgg16,
+ vgg16_bn,
+ vgg19,
+ vgg19_bn,
+ vggclassifier,
+)
+
+VGG_CONFIGS = {
+ "vgg11": vgg11,
+ "vgg11_bn": vgg11_bn,
+ "vgg13": vgg13,
+ "vgg13_bn": vgg13_bn,
+ "vgg16": vgg16,
+ "vgg16_bn": vgg16_bn,
+ "vgg19": vgg19,
+ "vgg19_bn": vgg19_bn,
+}
+
+
+def forward_model(model: torch.nn.Module, shape: Tuple):
+ return model(torch.rand(shape))
+
+
+class TestCNNClassifiers(unittest.TestCase):
+ def test_MNISTCNN_init(self):
+ MNISTCNN(input_channels=1, embedding_size=256)
+
+ def test_MNISTCNN_forward(self):
+ model = MNISTCNN(input_channels=1, embedding_size=256)
+ out = forward_model(model, (2, 1, 28, 28))
+ self.assertEqual(out.shape, (2, 256))
+
+ def test_MNISTClassifier_init(self):
+ MNISTClassifier(embedding_size=256, num_classes=10)
+
+ def test_MNISTClassifier_forward(self):
+ model = MNISTClassifier(embedding_size=256, num_classes=10)
+ out = forward_model(model, (2, 256))
+ self.assertEqual(out.shape, (2, 10))
+
+ def test_FashionCNN_init(self):
+ FashionCNN(out_size=256)
+
+ def test_FashionCNN_forward(self):
+ model = FashionCNN(out_size=256)
+ out = forward_model(model, (2, 1, 32, 32))
+ self.assertEqual(out.shape, (2, 256))
+
+ def test_FashionClassifier_init(self):
+ FashionClassifier(embedding_size=256, classes=10)
+
+ def test_FashionClassifier_forward(self):
+ model = FashionClassifier(embedding_size=256, classes=10)
+ out = forward_model(model, (2, 256))
+ self.assertEqual(out.shape, (2, 10))
+
+ def test_MitCNN_init(self):
+ MitCNN(input_channels=1, embedding_size=256)
+
+ def test_MitCNN_forward(self):
+ model = MitCNN(input_channels=1, embedding_size=256)
+ out = forward_model(model, (2, 1, 32, 32))
+ self.assertEqual(out.shape, (2, 256))
+
+ def test_MitCNNClassifier_init(self):
+ MitCNNClassifier(embedding_size=256, num_classes=10)
+
+ def test_MitCNNClassifier_forward(self):
+ model = MitCNNClassifier(embedding_size=256, num_classes=10)
+ out = forward_model(model, (2, 256))
+ self.assertEqual(out.shape, (2, 10))
+
+ def test_VGGClassifier_init(self):
+ vggclassifier(False, num_classes=100)
+
+ def test_VGGClassifier_forward(self):
+ model = vggclassifier(False, 100)
+ out = forward_model(model, (2, 25088))
+ self.assertEqual(out.shape, (2, 100))
+
+ def test_vggconfigurations_init(self):
+ for conf_name, conf_fn in VGG_CONFIGS.items():
+ try:
+ conf_fn()
+ except Exception as e:
+ raise Exception(f"Exception thrown for {conf_name}") from e
+
+ def test_vggconfigurations_forward(self):
+ for conf_name, conf_fn in VGG_CONFIGS.items():
+ model = conf_fn()
+ out = forward_model(model, (2, 3, 244, 244))
+
+ self.assertEqual(
+ out.shape, (2, 25088), msg=f"shape mismatch for {conf_name} arch"
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_dataset.py b/tests/unit/test_dataset.py
similarity index 99%
rename from tests/test_dataset.py
rename to tests/unit/test_dataset.py
index 6cf44d2..49e790b 100644
--- a/tests/test_dataset.py
+++ b/tests/unit/test_dataset.py
@@ -7,7 +7,7 @@
from frarch import datasets
from frarch.utils.exceptions import DatasetNotFoundError
-DATA_FOLDER = Path("./tests/data/")
+DATA_FOLDER = Path(__file__).resolve().parent.parent / "data"
class TestCaltech101(unittest.TestCase):
diff --git a/tests/test_image_transforms.py b/tests/unit/test_image_transforms.py
similarity index 99%
rename from tests/test_image_transforms.py
rename to tests/unit/test_image_transforms.py
index 2594d8b..e9cfe19 100644
--- a/tests/test_image_transforms.py
+++ b/tests/unit/test_image_transforms.py
@@ -7,7 +7,7 @@
import frarch.datasets.transforms as t
-DATA_FOLDER = Path("./tests/data/")
+DATA_FOLDER = Path(__file__).resolve().parent.parent / "data"
class TestImageTransforms(unittest.TestCase):
diff --git a/tests/test_metrics.py b/tests/unit/test_metrics.py
similarity index 69%
rename from tests/test_metrics.py
rename to tests/unit/test_metrics.py
index 9ca9665..27fd0f8 100644
--- a/tests/test_metrics.py
+++ b/tests/unit/test_metrics.py
@@ -1,50 +1,49 @@
-import torch
import unittest
+from typing import Any
+
+import torch
+
from frarch.modules import metrics
from frarch.modules.metrics.base import Metric
class MockMetric(Metric):
- def update(self, data):
- self.metrics.append(data)
+ def _update(self, data: Any, truth: Any) -> Any:
+ return data
class TestMetrics(unittest.TestCase):
def test_metricBase_constructor(self):
- m = Metric()
+ m = MockMetric()
self.assertEqual(len(m.metrics), 0)
def test_metricBase_length(self):
- m = Metric()
+ m = MockMetric()
self.assertEqual(len(m.metrics), len(m))
- def test_metricBase_virtual_update(self):
- with self.assertRaises(NotImplementedError):
- Metric().update()
-
def test_metricBase_get_metric_mean(self):
- m = Metric()
+ m = MockMetric()
m.metrics = [0, 1, 2]
self.assertEqual(m.get_metric(mode="mean"), 1)
def test_metricBase_get_metric_max(self):
- m = Metric()
+ m = MockMetric()
m.metrics = [0, 1, 2]
self.assertEqual(m.get_metric(mode="max"), 2)
def test_metricBase_get_metric_min(self):
- m = Metric()
+ m = MockMetric()
m.metrics = [0, 1, 2]
self.assertEqual(m.get_metric(mode="min"), 0)
def test_metricBase_agg_mode_not_valid(self):
- m = Metric()
+ m = MockMetric()
m.metrics = [0, 1, 2]
with self.assertRaises(ValueError):
m.get_metric(mode="not-valid")
def test_metricBase_empty_metric(self):
- m = Metric()
+ m = MockMetric()
self.assertEquals(m.get_metric(), 0.0)
def test_classification_error_update_accurate(self):
@@ -69,25 +68,16 @@ def test_classification_error_mismatch(self):
m.update(predictions, truth)
def test_metricsWrapper_init(self):
- mw = metrics.MetricsWrapper(
- metric0=MockMetric(),
- metric1=MockMetric()
- )
+ mw = metrics.MetricsWrapper(metric0=MockMetric(), metric1=MockMetric())
self.assertTrue(hasattr(mw, "metric0"))
self.assertTrue(hasattr(mw, "metric1"))
def test_metricsWrapper_not_metric(self):
with self.assertRaises(ValueError):
- metrics.MetricsWrapper(
- metric0=MockMetric(),
- metric1="not-a-metric"
- )
+ metrics.MetricsWrapper(metric0=MockMetric(), metric1="not-a-metric")
def test_metricsWrapper_reset(self):
- mw = metrics.MetricsWrapper(
- metric0=MockMetric(),
- metric1=MockMetric()
- )
+ mw = metrics.MetricsWrapper(metric0=MockMetric(), metric1=MockMetric())
mw.metric0.metrics = [0, 1, 2]
mw.metric1.metrics = [0]
mw.reset()
@@ -95,41 +85,29 @@ def test_metricsWrapper_reset(self):
self.assertEqual(len(mw.metric1), 0)
def test_metricsWrapper_update(self):
- mw = metrics.MetricsWrapper(
- metric0=MockMetric(),
- metric1=MockMetric()
- )
- mw.update(1)
+ mw = metrics.MetricsWrapper(metric0=MockMetric(), metric1=MockMetric())
+ mw.update(1, 1)
self.assertEquals(mw.metric0.metrics, [1])
self.assertEquals(mw.metric1.metrics, [1])
def test_metricsWrapper_get_metrics_mean(self):
- mw = metrics.MetricsWrapper(
- metric0=MockMetric(),
- metric1=MockMetric()
- )
- mw.update(0)
- mw.update(1)
+ mw = metrics.MetricsWrapper(metric0=MockMetric(), metric1=MockMetric())
+ mw.update(0, 0)
+ mw.update(1, 1)
m = mw.get_metrics(mode="mean")
self.assertDictEqual(m, {"metric0": 0.5, "metric1": 0.5})
def test_metricsWrapper_get_metrics_max(self):
- mw = metrics.MetricsWrapper(
- metric0=MockMetric(),
- metric1=MockMetric()
- )
- mw.update(0)
- mw.update(1)
+ mw = metrics.MetricsWrapper(metric0=MockMetric(), metric1=MockMetric())
+ mw.update(0, 0)
+ mw.update(1, 1)
m = mw.get_metrics(mode="max")
self.assertDictEqual(m, {"metric0": 1, "metric1": 1})
def test_metricsWrapper_get_metrics_min(self):
- mw = metrics.MetricsWrapper(
- metric0=MockMetric(),
- metric1=MockMetric()
- )
- mw.update(0)
- mw.update(1)
+ mw = metrics.MetricsWrapper(metric0=MockMetric(), metric1=MockMetric())
+ mw.update(0, 0)
+ mw.update(1, 1)
m = mw.get_metrics(mode="min")
self.assertDictEqual(m, {"metric0": 0, "metric1": 0})
diff --git a/tests/unit/test_train.py b/tests/unit/test_train.py
new file mode 100644
index 0000000..4204f09
--- /dev/null
+++ b/tests/unit/test_train.py
@@ -0,0 +1,99 @@
+import unittest
+from pathlib import Path
+
+import torch
+
+from frarch.train import BaseTrainer, ClassifierTrainer
+
+DATA_FOLDER = Path(__file__).resolve().parent.parent / "data"
+
+
+class MockModel(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.fc = torch.nn.Linear(10, 2)
+
+ def forward(self, inputs):
+ return self.fc(inputs)
+
+
+class MockDataset(torch.utils.data.Dataset):
+ def __init__(self, n_classes=2):
+ super().__init__()
+ self.n_classes = n_classes
+
+ def __getitem__(self, idx):
+ return torch.rand((10,)), 0
+
+ def __len__(self):
+ return 10
+
+
+class TestClassifierTrainer(ClassifierTrainer):
+ def _forward(self, batch, stage):
+ inputs, _ = batch
+ inputs = inputs.to(self.device)
+ return self.modules.model(inputs)
+
+ def _compute_loss(self, predictions, batch, stage):
+ _, labels = batch
+ labels = labels.to(self.device)
+ return self.hparams["loss"](predictions, labels)
+
+
+class TestTrainers(unittest.TestCase):
+ model = MockModel()
+ train_dataset = MockDataset()
+ test_dataset = MockDataset()
+ opt_class = torch.optim.Adam
+
+ def test_init(self):
+ BaseTrainer({"model": self.model}, self.opt_class, {"noprogressbar": True})
+
+ def test_init_ckpt_interval_negative(self):
+ with self.assertRaises(ValueError):
+ BaseTrainer(
+ {"model": self.model}, self.opt_class, {"ckpt_interval_minutes": -1}
+ )
+
+ def test_ClassifierTrainer_init(self):
+ TestClassifierTrainer(
+ {"model": self.model}, self.opt_class, {"noprogressbar": True}
+ )
+
+ def test_ClassifierTrainer_fit(self):
+ trainer = TestClassifierTrainer(
+ {"model": self.model},
+ self.opt_class,
+ {"epochs": 1, "loss": torch.nn.CrossEntropyLoss(), "noprogressbar": True},
+ )
+ trainer.fit(
+ train_set=self.train_dataset,
+ valid_set=self.test_dataset,
+ )
+
+ def test_ClassifierTrainer_fit_epochs_not_specified(self):
+ trainer = TestClassifierTrainer(
+ {"model": self.model},
+ self.opt_class,
+ {"loss": torch.nn.CrossEntropyLoss(), "noprogressbar": True},
+ )
+ with self.assertRaises(KeyError):
+ trainer.fit(
+ train_set=self.train_dataset,
+ valid_set=self.test_dataset,
+ )
+
+ def test_ClassifierTrainer_fit_loss_not_specified(self):
+ trainer = TestClassifierTrainer(
+ {"model": self.model}, self.opt_class, {"epochs": 1, "noprogressbar": True}
+ )
+ with self.assertRaises(KeyError):
+ trainer.fit(
+ train_set=self.train_dataset,
+ valid_set=self.test_dataset,
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_utils.py b/tests/unit/test_utils.py
similarity index 95%
rename from tests/test_utils.py
rename to tests/unit/test_utils.py
index f1f3666..36da2b7 100644
--- a/tests/test_utils.py
+++ b/tests/unit/test_utils.py
@@ -7,7 +7,7 @@
from frarch.utils import data, exceptions, logging
-DATA_FOLDER = Path("./tests/data/")
+DATA_FOLDER = Path(__file__).resolve().parent.parent / "data"
class MockDataset(torch.utils.data.Dataset):
@@ -51,13 +51,13 @@ def test_create_dataloader_no_dataset(self):
def test_tensorInDevice(self):
tensor_data = list(range(10))
- tensor = data.tensorInDevice(tensor_data)
+ tensor = data.tensor_in_device(tensor_data)
self.assertIsInstance(tensor, torch.Tensor)
self.assertTupleEqual(tensor.shape, (10,))
def test_tensorInDevice_cpu(self):
tensor_data = list(range(10))
- tensor = data.tensorInDevice(tensor_data, device="cpu")
+ tensor = data.tensor_in_device(tensor_data, device="cpu")
self.assertEqual(str(tensor.device), "cpu")
def test_read_file(self):