From 5e81ec34318cbe958830ac58938fc6a6a79f190a Mon Sep 17 00:00:00 2001 From: aerodynamic-sauce-pan Date: Fri, 23 Feb 2024 18:17:46 +0100 Subject: [PATCH 01/12] Updated project metadata & _pipreqs_ python package for project packaging purposes. --- .gitignore | 11 ++++++--- requirements_python3.10.txt | 1 + setup.py | 46 ++++++++++++++++++++++++++++++++++++- 3 files changed, 54 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 1e6cfe2e..78670368 100644 --- a/.gitignore +++ b/.gitignore @@ -7,13 +7,18 @@ checkMyCode.sh !.gitignore +# Build utilities docs/_build + +# Data examples/ecologists/micro_geolifeclef*/dataset/patches examples/ecologists/micro_geolifeclef*/dataset/rasters +examples/ecologists/sentinel-2a/dataset/*.tif examples/inference/micro_geolifeclef*/dataset/patches examples/inference/micro_geolifeclef*/dataset/rasters - -examples/ecologists/sentinel-2a/dataset/*.tif examples/inference/sentinel-2a/dataset/*.tif - examples/kaggle/geolifeclef2022/dataset + +# Packaging +dist/ +build/ diff --git a/requirements_python3.10.txt b/requirements_python3.10.txt index d830d0c7..1d333669 100644 --- a/requirements_python3.10.txt +++ b/requirements_python3.10.txt @@ -22,6 +22,7 @@ omegaconf==2.3.0 opencv-python==4.7.0.72 pandas==2.1.1 Pillow==10.2.0 +pipreqs==0.5.0 planetary-computer==0.4.9 pydocstyle==6.3.0 pyflakes==3.0.1 diff --git a/setup.py b/setup.py index 915509b4..30d2cfc3 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,47 @@ from setuptools import setup, find_packages -setup(name="malpolon", packages=find_packages()) +setup(name="malpolon", + version="1.0.0", + description="Malpolon v1.0.0.0", + author="Theo Larcher, Titouan Lorieul", + author_email="theo.larcher@inria.fr, titouan.lorieul@inria.fr", + url="https://github.com/plantnet/malpolon", + classifiers=[ + "Intended Audience :: Developpers, Ecologists, Researchers", + "License :: MIT License", + "Programming Language :: Python :: 3.10"], + package_dir={"": "malpolon"}, + packages=find_packages(where="malpolon", exclude="malpolon.tests"), + python_requires=">=3.10, <4", + install_requires=[ + "Cartopy>=0.21.1", + "kaggle>=1.5.16", + "matplotlib>=3.8.0", + "numpy>=1.26.4", + "omegaconf>=2.3.0", + "pandas>=2.2.1", + "Pillow>=10.0.1", + "Pillow>=10.2.0", + "planetary_computer>=0.4.9", + "pyproj>=3.6.1", + "pystac>=1.6.1", + "pytest>=7.2.2", + "pytorch_lightning>=2.1.0", + "rasterio>=1.3.8.post1", + "scikit_learn>=1.1.3", + "Shapely>=2.0.3", + "tifffile>=2022.10.10", + "timm>=0.9.2", + "torch>=2.1.0", + "torchgeo>=0.5.0", + "torchmetrics>=1.2.0", + "torchvision>=0.16.0", + "tqdm>=4.66.1" + ], + project_urls={ + "Bug Reports": "https://github.com/plantnet/malpolon/issues/new?assignees=aerodynamic-sauce-pan&labels=bug&projects=&template=bug_report.md&title=%5BBUG%5D", + "Feature request": "https://github.com/plantnet/malpolon/issues/new?assignees=aerodynamic-sauce-pan&labels=enhancement&projects=&template=enhancement.md&title=%5BEnhancement%5D", + "Host organizer": "https://plantnet.org/", + "Source": "https://github.com/plantnet/malpolon", + }, +) From b00d3e627f73c15775dda5ffccea29ab9a0d473c Mon Sep 17 00:00:00 2001 From: aerodynamic-sauce-pan Date: Fri, 23 Feb 2024 18:27:35 +0100 Subject: [PATCH 02/12] Try: fix pipreqs package conflict --- environment_python3.10.yml | 1 + requirements_python3.10.txt | 1 - setup.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/environment_python3.10.yml b/environment_python3.10.yml index b4124cb7..ec10bf22 100644 --- a/environment_python3.10.yml +++ b/environment_python3.10.yml @@ -7,5 +7,6 @@ dependencies: - cmake==3.26.3 - pip==23.2.1 - pip: + - --no-deps pipreqs==0.5.0 - -r requirements_python3.10.txt prefix: ~/envs/malpolon diff --git a/requirements_python3.10.txt b/requirements_python3.10.txt index 1d333669..d830d0c7 100644 --- a/requirements_python3.10.txt +++ b/requirements_python3.10.txt @@ -22,7 +22,6 @@ omegaconf==2.3.0 opencv-python==4.7.0.72 pandas==2.1.1 Pillow==10.2.0 -pipreqs==0.5.0 planetary-computer==0.4.9 pydocstyle==6.3.0 pyflakes==3.0.1 diff --git a/setup.py b/setup.py index 30d2cfc3..ebcee92a 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup(name="malpolon", version="1.0.0", - description="Malpolon v1.0.0.0", + description="Malpolon v1.0.0", author="Theo Larcher, Titouan Lorieul", author_email="theo.larcher@inria.fr, titouan.lorieul@inria.fr", url="https://github.com/plantnet/malpolon", From 8f8e36d86124ef66c592d1115b48ec2c2ab887f8 Mon Sep 17 00:00:00 2001 From: aerodynamic-sauce-pan Date: Fri, 23 Feb 2024 18:49:30 +0100 Subject: [PATCH 03/12] Updated setup.py for PyPi package publication --- setup.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index ebcee92a..ae2e854a 100644 --- a/setup.py +++ b/setup.py @@ -7,9 +7,20 @@ author_email="theo.larcher@inria.fr, titouan.lorieul@inria.fr", url="https://github.com/plantnet/malpolon", classifiers=[ - "Intended Audience :: Developpers, Ecologists, Researchers", - "License :: MIT License", - "Programming Language :: Python :: 3.10"], + "Development Status :: 3 - Alpha", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.10", + "Environment :: GPU :: NVIDIA CUDA", + "Operating System :: Unix", + "Operating System :: MacOS, + "Typing :: Typed", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Image Processing", + "Topic :: Scientific/Engineering :: Image Recognition", + "Topic :: Scientific/Engineering :: GIS" + ], package_dir={"": "malpolon"}, packages=find_packages(where="malpolon", exclude="malpolon.tests"), python_requires=">=3.10, <4", From 9b658eb4b0955e4acbfb7e2b88e5e9b20b3ec850 Mon Sep 17 00:00:00 2001 From: aerodynamic-sauce-pan Date: Fri, 23 Feb 2024 18:50:09 +0100 Subject: [PATCH 04/12] Updated setup.py for PyPi package publication --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ae2e854a..d0ff5888 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ "Programming Language :: Python :: 3.10", "Environment :: GPU :: NVIDIA CUDA", "Operating System :: Unix", - "Operating System :: MacOS, + "Operating System :: MacOS", "Typing :: Typed", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: Artificial Intelligence", From 7e8f6486a27ebf0d8ffe6100532e63265a027dc3 Mon Sep 17 00:00:00 2001 From: aerodynamic-sauce-pan Date: Fri, 23 Feb 2024 19:23:00 +0100 Subject: [PATCH 05/12] Updated credits info and added file docstrings when missing. --- .../micro_geolifeclef2022/cnn_on_rgb_nir_patches.py | 8 ++++++++ .../micro_geolifeclef2022/cnn_on_rgb_patches.py | 8 ++++++++ .../ecologists/micro_geolifeclef2022/transforms.py | 3 ++- .../ecologists/sentinel-2a/cnn_on_rgbnir_torchgeo.py | 3 ++- examples/ecologists/sentinel-2a/transforms.py | 10 +++++++++- .../micro_geolifeclef2022/cnn_on_rgb_nir_patches.py | 11 ++++++++--- .../micro_geolifeclef2022/cnn_on_rgb_patches.py | 10 ++++++++-- .../inference/micro_geolifeclef2022/transforms.py | 3 ++- .../inference/sentinel-2a/cnn_on_rgbnir_torchgeo.py | 5 ++--- examples/inference/sentinel-2a/transforms.py | 10 +++++++++- .../kaggle/geolifeclef2022/cnn_on_rgb_patches.py | 12 +++++++++--- .../cnn_on_rgb_temperature_patches.py | 10 +++++++++- .../geolifeclef2022/cnn_on_temperature_patches.py | 10 +++++++++- examples/kaggle/geolifeclef2022/transforms.py | 9 +++++++++ .../geolifeclef2023/cnn_on_rgbnir_glc23_patches.py | 2 +- malpolon/data/data_module.py | 3 +-- malpolon/data/datasets/geolifeclef2022.py | 2 +- malpolon/data/environmental_raster.py | 9 +++++++-- malpolon/data/get_jpeg_patches_stats.py | 2 ++ malpolon/data/utils.py | 6 +++++- malpolon/models/model_builder.py | 11 +++++++++++ malpolon/models/multi_modal.py | 9 ++++++++- malpolon/models/standard_prediction_systems.py | 2 +- malpolon/models/utils.py | 6 +++++- malpolon/plot/history.py | 7 +++++++ malpolon/plot/map.py | 9 ++++++++- malpolon/tests/test_environmental_raster.py | 5 +++++ malpolon/tests/test_geolifeclef2022_dataset.py | 9 +++++++-- malpolon/tests/test_models.py | 6 ++++++ malpolon/tests/test_standard_prediction_systems.py | 5 +++++ malpolon/tests/test_torchgeo_datasets.py | 5 +++++ setup.py | 2 +- 32 files changed, 180 insertions(+), 32 deletions(-) diff --git a/examples/ecologists/micro_geolifeclef2022/cnn_on_rgb_nir_patches.py b/examples/ecologists/micro_geolifeclef2022/cnn_on_rgb_nir_patches.py index 971114cc..4eca4423 100644 --- a/examples/ecologists/micro_geolifeclef2022/cnn_on_rgb_nir_patches.py +++ b/examples/ecologists/micro_geolifeclef2022/cnn_on_rgb_nir_patches.py @@ -1,3 +1,11 @@ +"""Main script to run training or inference on microlifeclef2022 dataset. + +Uses RGB and Near infra-red pre-extracted patches from the dataset. + +Author: Titouan Lorieul + Theo Larcher +""" + from __future__ import annotations import os diff --git a/examples/ecologists/micro_geolifeclef2022/cnn_on_rgb_patches.py b/examples/ecologists/micro_geolifeclef2022/cnn_on_rgb_patches.py index 0e0a1734..ae5d7642 100644 --- a/examples/ecologists/micro_geolifeclef2022/cnn_on_rgb_patches.py +++ b/examples/ecologists/micro_geolifeclef2022/cnn_on_rgb_patches.py @@ -1,3 +1,11 @@ +"""Main script to run training or inference on microlifeclef2022 dataset. + +Uses RGB pre-extracted patches from the dataset. + +Author: Titouan Lorieul + Theo Larcher +""" + from __future__ import annotations import os diff --git a/examples/ecologists/micro_geolifeclef2022/transforms.py b/examples/ecologists/micro_geolifeclef2022/transforms.py index 71ad830c..da97c7ea 100644 --- a/examples/ecologists/micro_geolifeclef2022/transforms.py +++ b/examples/ecologists/micro_geolifeclef2022/transforms.py @@ -3,9 +3,10 @@ These transform classes can be called during training loops to perform data augmentation. -Author: Titouan Lorieul +Author: Titouan Lorieul Theo Larcher """ + import numpy as np import torch from torchvision import transforms diff --git a/examples/ecologists/sentinel-2a/cnn_on_rgbnir_torchgeo.py b/examples/ecologists/sentinel-2a/cnn_on_rgbnir_torchgeo.py index 97e27ac9..1bfe2722 100644 --- a/examples/ecologists/sentinel-2a/cnn_on_rgbnir_torchgeo.py +++ b/examples/ecologists/sentinel-2a/cnn_on_rgbnir_torchgeo.py @@ -3,8 +3,9 @@ This script runs the RasterSentinel2 dataset class by default. Author: Theo Larcher - Titouan Lorieul + Titouan Lorieul """ + from __future__ import annotations import os diff --git a/examples/ecologists/sentinel-2a/transforms.py b/examples/ecologists/sentinel-2a/transforms.py index 266eaa78..1cb661c3 100644 --- a/examples/ecologists/sentinel-2a/transforms.py +++ b/examples/ecologists/sentinel-2a/transforms.py @@ -1,5 +1,13 @@ -import numpy as np +"""Collection of custom PyTorch friendly transform classes. + +These transform classes can be called during training loops to perform +data augmentation. +Author: Titouan Lorieul + Theo Larcher +""" + +import numpy as np import torch from torchvision import transforms diff --git a/examples/inference/micro_geolifeclef2022/cnn_on_rgb_nir_patches.py b/examples/inference/micro_geolifeclef2022/cnn_on_rgb_nir_patches.py index 2ec35dd3..e1a51773 100644 --- a/examples/inference/micro_geolifeclef2022/cnn_on_rgb_nir_patches.py +++ b/examples/inference/micro_geolifeclef2022/cnn_on_rgb_nir_patches.py @@ -1,9 +1,14 @@ -from __future__ import annotations +"""Main script to run inference on microlifeclef2022 dataset. + +Uses RGB and Near infra-red pre-extracted patches from the dataset. -import os +Author: Titouan Lorieul + Theo Larcher +""" + +from __future__ import annotations import hydra -import numpy as np import pytorch_lightning as pl import torch from omegaconf import DictConfig diff --git a/examples/inference/micro_geolifeclef2022/cnn_on_rgb_patches.py b/examples/inference/micro_geolifeclef2022/cnn_on_rgb_patches.py index 85227df4..97049cdc 100644 --- a/examples/inference/micro_geolifeclef2022/cnn_on_rgb_patches.py +++ b/examples/inference/micro_geolifeclef2022/cnn_on_rgb_patches.py @@ -1,6 +1,12 @@ -from __future__ import annotations +"""Main script to run inference on microlifeclef2022 dataset. + +Uses RGB pre-extracted patches from the dataset. -import os +Author: Titouan Lorieul + Theo Larcher +""" + +from __future__ import annotations import hydra import pytorch_lightning as pl diff --git a/examples/inference/micro_geolifeclef2022/transforms.py b/examples/inference/micro_geolifeclef2022/transforms.py index 71ad830c..da97c7ea 100644 --- a/examples/inference/micro_geolifeclef2022/transforms.py +++ b/examples/inference/micro_geolifeclef2022/transforms.py @@ -3,9 +3,10 @@ These transform classes can be called during training loops to perform data augmentation. -Author: Titouan Lorieul +Author: Titouan Lorieul Theo Larcher """ + import numpy as np import torch from torchvision import transforms diff --git a/examples/inference/sentinel-2a/cnn_on_rgbnir_torchgeo.py b/examples/inference/sentinel-2a/cnn_on_rgbnir_torchgeo.py index e1e2cc5d..ee7ad15c 100644 --- a/examples/inference/sentinel-2a/cnn_on_rgbnir_torchgeo.py +++ b/examples/inference/sentinel-2a/cnn_on_rgbnir_torchgeo.py @@ -3,11 +3,10 @@ This script runs the RasterSentinel2 dataset class by default. Author: Theo Larcher - Titouan Lorieul + Titouan Lorieul """ -from __future__ import annotations -import os +from __future__ import annotations import hydra import pytorch_lightning as pl diff --git a/examples/inference/sentinel-2a/transforms.py b/examples/inference/sentinel-2a/transforms.py index 266eaa78..1cb661c3 100644 --- a/examples/inference/sentinel-2a/transforms.py +++ b/examples/inference/sentinel-2a/transforms.py @@ -1,5 +1,13 @@ -import numpy as np +"""Collection of custom PyTorch friendly transform classes. + +These transform classes can be called during training loops to perform +data augmentation. +Author: Titouan Lorieul + Theo Larcher +""" + +import numpy as np import torch from torchvision import transforms diff --git a/examples/kaggle/geolifeclef2022/cnn_on_rgb_patches.py b/examples/kaggle/geolifeclef2022/cnn_on_rgb_patches.py index 4167ce52..49597851 100644 --- a/examples/kaggle/geolifeclef2022/cnn_on_rgb_patches.py +++ b/examples/kaggle/geolifeclef2022/cnn_on_rgb_patches.py @@ -1,8 +1,14 @@ -import os +"""Main script to run training on microlifeclef2022 dataset. + +Uses RGB pre-extracted patches from the dataset. +This script was created for Kaggle participants of the GeoLifeCLEF 2022 +challenge. + +Author: Titouan Lorieul +""" import hydra -import pytorch_lightning as pl -import torchmetrics.functional as Fmetrics +import pytorch_lightning as p from omegaconf import DictConfig from pytorch_lightning.callbacks import ModelCheckpoint from torchvision import transforms diff --git a/examples/kaggle/geolifeclef2022/cnn_on_rgb_temperature_patches.py b/examples/kaggle/geolifeclef2022/cnn_on_rgb_temperature_patches.py index b32fbd79..80bc888e 100644 --- a/examples/kaggle/geolifeclef2022/cnn_on_rgb_temperature_patches.py +++ b/examples/kaggle/geolifeclef2022/cnn_on_rgb_temperature_patches.py @@ -1,4 +1,12 @@ -import os +"""Main script to run training on microlifeclef2022 dataset. + +Uses RGB pre-extracted patches and temperature rasters from the dataset. +This script was created for Kaggle participants of the GeoLifeCLEF 2022 +challenge. + +Author: Titouan Lorieul +""" + from pathlib import Path import hydra diff --git a/examples/kaggle/geolifeclef2022/cnn_on_temperature_patches.py b/examples/kaggle/geolifeclef2022/cnn_on_temperature_patches.py index 4a22bb00..410d82ed 100644 --- a/examples/kaggle/geolifeclef2022/cnn_on_temperature_patches.py +++ b/examples/kaggle/geolifeclef2022/cnn_on_temperature_patches.py @@ -1,4 +1,12 @@ -import os +"""Main script to run training on microlifeclef2022 dataset. + +Uses temperature rasters from the dataset. +This script was created for Kaggle participants of the GeoLifeCLEF 2022 +challenge. + +Author: Titouan Lorieul +""" + from pathlib import Path import hydra diff --git a/examples/kaggle/geolifeclef2022/transforms.py b/examples/kaggle/geolifeclef2022/transforms.py index 0f155988..ff60e0d3 100644 --- a/examples/kaggle/geolifeclef2022/transforms.py +++ b/examples/kaggle/geolifeclef2022/transforms.py @@ -1,3 +1,12 @@ +"""Collection of custom PyTorch friendly transform classes. + +These transform classes can be called during training loops to perform +data augmentation. + +Author: Titouan Lorieul + Theo Larcher +""" + import numpy as np import torch from torchvision import transforms diff --git a/examples/kaggle/geolifeclef2023/cnn_on_rgbnir_glc23_patches.py b/examples/kaggle/geolifeclef2023/cnn_on_rgbnir_glc23_patches.py index bcb4c401..e03ead44 100644 --- a/examples/kaggle/geolifeclef2023/cnn_on_rgbnir_glc23_patches.py +++ b/examples/kaggle/geolifeclef2023/cnn_on_rgbnir_glc23_patches.py @@ -3,8 +3,8 @@ This script runs the RasterSentinel2 dataset class by default. Author: Theo Larcher - Titouan Lorieul """ + from __future__ import annotations import hydra diff --git a/malpolon/data/data_module.py b/malpolon/data/data_module.py index f74e2321..e093d335 100644 --- a/malpolon/data/data_module.py +++ b/malpolon/data/data_module.py @@ -1,8 +1,7 @@ """This module provides a base class for data modules. Author: Theo Larcher - Titouan Lorieul - + Titouan Lorieul """ from __future__ import annotations diff --git a/malpolon/data/datasets/geolifeclef2022.py b/malpolon/data/datasets/geolifeclef2022.py index f6f8eba3..2aae7896 100644 --- a/malpolon/data/datasets/geolifeclef2022.py +++ b/malpolon/data/datasets/geolifeclef2022.py @@ -3,7 +3,7 @@ This module has since been updated for GeoLifeCLEF2023 Author: Benjamin Deneu - Titouan Lorieul + Titouan Lorieul License: GPLv3 Python version: 3.8 diff --git a/malpolon/data/environmental_raster.py b/malpolon/data/environmental_raster.py index 2d9c1e1d..7fd52b68 100644 --- a/malpolon/data/environmental_raster.py +++ b/malpolon/data/environmental_raster.py @@ -1,13 +1,18 @@ +"""Custom classes to handle environmental rasters without torchgeo. + +Author: Titouan Lorieul +""" + from __future__ import annotations + import warnings from pathlib import Path -from typing import Any, Optional, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional, Union import matplotlib.pyplot as plt import numpy as np import rasterio - if TYPE_CHECKING: import numpy.typing as npt diff --git a/malpolon/data/get_jpeg_patches_stats.py b/malpolon/data/get_jpeg_patches_stats.py index 8365c689..8b57da8d 100644 --- a/malpolon/data/get_jpeg_patches_stats.py +++ b/malpolon/data/get_jpeg_patches_stats.py @@ -3,6 +3,8 @@ When dealing with a large amount of files it should be run only once, and the statistics should be stored in a separate .csv for later use. + +Author: Theo Larcher """ import argparse diff --git a/malpolon/data/utils.py b/malpolon/data/utils.py index df342e96..7597e803 100644 --- a/malpolon/data/utils.py +++ b/malpolon/data/utils.py @@ -1,4 +1,8 @@ -"""This file compiles useful functions related to data and file handling.""" +"""This file compiles useful functions related to data and file handling. + +Author: Theo Larcher +""" + from __future__ import annotations import os diff --git a/malpolon/models/model_builder.py b/malpolon/models/model_builder.py index 05621ef4..7e04e9dc 100644 --- a/malpolon/models/model_builder.py +++ b/malpolon/models/model_builder.py @@ -1,3 +1,14 @@ +"""This module provides classes to build your PyTorch models. + +Classes listed in this module allow to select a model from your +provider (timm, torchvision...), retrieve it with or without +pre-trained weights, and modify it by adding or removing layers. + +Author: Titouan Lorieul + Theo Larcher + +""" + from __future__ import annotations from typing import TYPE_CHECKING diff --git a/malpolon/models/multi_modal.py b/malpolon/models/multi_modal.py index 7f395ae6..63d415e1 100644 --- a/malpolon/models/multi_modal.py +++ b/malpolon/models/multi_modal.py @@ -1,10 +1,17 @@ +"""This module provides classes for advanced model building. + +Author: Titouan Lorieul + +""" + from __future__ import annotations + from typing import TYPE_CHECKING import torch -from torch import nn from pytorch_lightning.strategies import SingleDeviceStrategy, StrategyRegistry from pytorch_lightning.utilities import move_data_to_device +from torch import nn from .utils import check_model diff --git a/malpolon/models/standard_prediction_systems.py b/malpolon/models/standard_prediction_systems.py index dca873c4..02ba8848 100644 --- a/malpolon/models/standard_prediction_systems.py +++ b/malpolon/models/standard_prediction_systems.py @@ -1,6 +1,6 @@ """This module provides classes wrapping pytorchlightning training modules. -Author: Titouan Lorieul +Author: Titouan Lorieul Theo Larcher """ diff --git a/malpolon/models/utils.py b/malpolon/models/utils.py index 4182837e..4c934e08 100644 --- a/malpolon/models/utils.py +++ b/malpolon/models/utils.py @@ -1,4 +1,8 @@ -"""This file compiles useful functions related to models.""" +"""This file compiles useful functions related to models. + +Author: Theo Larcher + Titouan Lorieul +""" from __future__ import annotations diff --git a/malpolon/plot/history.py b/malpolon/plot/history.py index 3ca08c18..ffe94cc2 100644 --- a/malpolon/plot/history.py +++ b/malpolon/plot/history.py @@ -1,4 +1,10 @@ +"""Utilities used for plotting purposes. + +Author: Titouan Lorieul +""" + from __future__ import annotations + from typing import Optional import matplotlib.pyplot as plt @@ -109,6 +115,7 @@ def plot_history( if __name__ == "__main__": import argparse + import pandas as pd parser = argparse.ArgumentParser(description="plots the training curves") diff --git a/malpolon/plot/map.py b/malpolon/plot/map.py index 314ef7b4..f89d8cfe 100644 --- a/malpolon/plot/map.py +++ b/malpolon/plot/map.py @@ -1,5 +1,12 @@ +"""Utilities for plotting maps. + +Author: Titouan Lorieul + +""" + from __future__ import annotations -from typing import Optional, TYPE_CHECKING + +from typing import TYPE_CHECKING, Optional import matplotlib.pyplot as plt diff --git a/malpolon/tests/test_environmental_raster.py b/malpolon/tests/test_environmental_raster.py index 9719c981..9cfbf1a6 100644 --- a/malpolon/tests/test_environmental_raster.py +++ b/malpolon/tests/test_environmental_raster.py @@ -1,3 +1,8 @@ +"""This script tests the environmental raster module. + +Author: Titouan Lorieul +""" + from pathlib import Path import numpy as np diff --git a/malpolon/tests/test_geolifeclef2022_dataset.py b/malpolon/tests/test_geolifeclef2022_dataset.py index 054025da..3f045a02 100644 --- a/malpolon/tests/test_geolifeclef2022_dataset.py +++ b/malpolon/tests/test_geolifeclef2022_dataset.py @@ -1,11 +1,16 @@ +"""This script tests the GeoLifeCLEF2022 dataset module. + +Author: Titouan Lorieul +""" + from pathlib import Path import numpy as np import pytest +from malpolon.data.datasets.geolifeclef2022 import ( + GeoLifeCLEF2022Dataset, load_patch, visualize_observation_patch) from malpolon.data.environmental_raster import PatchExtractor -from malpolon.data.datasets.geolifeclef2022 import load_patch, GeoLifeCLEF2022Dataset, visualize_observation_patch - DATA_PATH = Path("malpolon/tests/data/glc22") diff --git a/malpolon/tests/test_models.py b/malpolon/tests/test_models.py index d9b9cab7..d621b36a 100644 --- a/malpolon/tests/test_models.py +++ b/malpolon/tests/test_models.py @@ -1,3 +1,9 @@ +"""This script tests the models module. + +Author: Titouan Lorieul + Theo Larcher +""" + import timm import torch from torchvision import models diff --git a/malpolon/tests/test_standard_prediction_systems.py b/malpolon/tests/test_standard_prediction_systems.py index 33269770..e4e5cb71 100644 --- a/malpolon/tests/test_standard_prediction_systems.py +++ b/malpolon/tests/test_standard_prediction_systems.py @@ -1,3 +1,8 @@ +"""This script tests the standard prediction systems module. + +Author: Theo Larcher +""" + import numpy as np import timm diff --git a/malpolon/tests/test_torchgeo_datasets.py b/malpolon/tests/test_torchgeo_datasets.py index 3256fc8b..b0aa362f 100644 --- a/malpolon/tests/test_torchgeo_datasets.py +++ b/malpolon/tests/test_torchgeo_datasets.py @@ -1,3 +1,8 @@ +"""This script tests the torchgeo datasets module. + +Author: Theo Larcher +""" + from pathlib import Path import numpy as np diff --git a/setup.py b/setup.py index d0ff5888..fa467ccd 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ version="1.0.0", description="Malpolon v1.0.0", author="Theo Larcher, Titouan Lorieul", - author_email="theo.larcher@inria.fr, titouan.lorieul@inria.fr", + author_email="theo.larcher@inria.fr, titouan.lorieul@gmail.com", url="https://github.com/plantnet/malpolon", classifiers=[ "Development Status :: 3 - Alpha", From e8c6b82bd3db94efd775c86ad41e68510af90d3a Mon Sep 17 00:00:00 2001 From: aerodynamic-sauce-pan Date: Fri, 23 Feb 2024 19:58:38 +0100 Subject: [PATCH 06/12] Added missing docstrings --- malpolon/models/model_builder.py | 98 ++++++++++++++++++++++++++++++++ malpolon/models/multi_modal.py | 44 +++++++++++++- 2 files changed, 141 insertions(+), 1 deletion(-) diff --git a/malpolon/models/model_builder.py b/malpolon/models/model_builder.py index 7e04e9dc..18b3ecf3 100644 --- a/malpolon/models/model_builder.py +++ b/malpolon/models/model_builder.py @@ -25,6 +25,7 @@ class _ModelBuilder: + """General class to build models.""" providers: dict[str, Provider] = {} modifiers: dict[str, Modifier] = {} @@ -36,6 +37,28 @@ def build_model( model_kwargs: dict = {}, modifiers: dict[str, Optional[dict[str, Any]]] = {}, ) -> nn.Module: + """Return a built model with the given provider and modifiers. + + Parameters + ---------- + provider_name : str + source of the model's provider, valid values are: + [`timm`, `torchvision`] + model_name : str + name of the model to retrieve from the provider + model_args : list, optional + model arguments to pass on when building it, by default [] + model_kwargs : dict, optional + model kwargs, by default {} + modifiers : dict[str, Optional[dict[str, Any]]], optional + modifiers to call on the model after it is built, + by default {} + + Returns + ------- + nn.Module + built and mofified model + """ provider = self.providers[provider_name] model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None} model = provider(model_name, *model_args, **model_kwargs) @@ -48,24 +71,77 @@ def build_model( return model def register_provider(self, provider_name: str, provider: Provider) -> None: + """Register a provider to the model builder. + + Parameters + ---------- + provider_name : str + name of the provider, valid values are: + [`timm`, `torchvision`] + provider : Provider + callable provider function + """ self.providers[provider_name] = provider def register_modifier(self, modifier_name: str, modifier: Modifier) -> None: + """Register a modifier to the model builder. + + Parameters + ---------- + modifier_name : str + name of the modifier, valid values are: + [`change_first_convolutional_layer`, `change_last_layer`, `change_last_layer_to_identity`] + modifier : Modifier + modifier callable function + """ self.modifiers[modifier_name] = modifier def torchvision_model_provider( model_name: str, *model_args: Any, **model_kwargs: Any ) -> nn.Module: + """Return a model from torchvision's library. + This method uses tochvision's API to retrieve a model from its + library. + + Parameters + ---------- + model_name : str + name of the model to retrieve from torchvision's library + + Returns + ------- + nn.Module + model object + """ model = getattr(models, model_name) model = model(*model_args, **model_kwargs) return model + def timm_model_provider( model_name: str, *model_args: Any, **model_kwargs: Any ) -> nn.Module: + """Return a model from timm's library. + + This method uses timm's API to retrieve a model from its library. + + Parameters + ---------- + model_name : str + name of the model to retrieve from timm's library + Returns + ------- + nn.Module + model object + + Raises + ------ + ValueError + if the model name is not listed in TIMM's library + """ available_models = timm.list_models() if model_name in available_models: model = timm.create_model(model_name, *model_args, **model_kwargs) @@ -80,6 +156,28 @@ def timm_model_provider( def _find_module_of_type( module: nn.Module, module_type: type, order: str ) -> tuple[nn.Module, str]: + """Find the first or last module of a given type in a module. + + Parameters + ---------- + module : nn.Module + torch module to search in (_e.g.: torch model_) + module_type : type + module type to search for (_e.g.: nn.Conv2d_) + order : str + order to search for the module, valid values are: + [`first`, `last`] + + Returns + ------- + tuple[nn.Module, str] + module and its name + + Raises + ------ + ValueError + if the order is not valid + """ if order == "first": modules = module.named_children() elif order == "last": diff --git a/malpolon/models/multi_modal.py b/malpolon/models/multi_modal.py index 63d415e1..38379c87 100644 --- a/malpolon/models/multi_modal.py +++ b/malpolon/models/multi_modal.py @@ -1,7 +1,7 @@ """This module provides classes for advanced model building. Author: Titouan Lorieul - + Theo Larcher """ from __future__ import annotations @@ -20,11 +20,32 @@ class MultiModalModel(nn.Module): + """Base multi-modal model. + + This class builds an aggregation of multiple models from the passed + on config file values, one for each modality, splits the training + routine per modality and then aggregates the features from each + modality after each forward pass. + """ def __init__( self, modality_models: Union[nn.Module, Mapping], aggregator_model: Union[nn.Module, Mapping], ): + """Class constructor. + + Parameters + ---------- + modality_models : Union[nn.Module, Mapping] + dictionary of modality names and their respective models to + pass on to the model builder + aggregator_model : Union[nn.Module, Mapping] + Model strategy to aggregate the features from each modality. + Can either be a PyTorch module directly (in this case, the + module will be directly called), or a mapping in the same + fashion as for buiding the modality models, in which case + the model builder will be called again. + """ super().__init__() for modality_name, model in modality_models.items(): @@ -46,12 +67,29 @@ def forward(self, x: list[Any]) -> Any: class HomogeneousMultiModalModel(MultiModalModel): + """Straightforward multi-modal model.""" def __init__( self, modality_names: list, modalities_model: dict, aggregator_model: Union[nn.Module, Mapping], ): + """Class constructor. + + Parameters + ---------- + modality_names : list + list of modalities names + modalities_model : dict + dictionary of modality names and their respective models to + pass on to the model builder + aggregator_model : Union[nn.Module, Mapping] + Model strategy to aggregate the features from each modality. + Can either be a PyTorch module directly (in this case, the + module will be directly called), or a mapping in the same + fashion as for buiding the modality models, in which case + the model builder will be called again. + """ self.modality_names = modality_names self.modalities_model = modalities_model @@ -62,6 +100,10 @@ def __init__( class ParallelMultiModalModelStrategy(SingleDeviceStrategy): + """Model parallelism strategy for multi-modal models. + + WARNING: STILL UNDER DEVELOPMENT. + """ strategy_name = "parallel_multi_modal_model" def __init__( From bc629b8c8d0a0f93f72d330e68eec2f2e2d8645f Mon Sep 17 00:00:00 2001 From: aerodynamic-sauce-pan Date: Mon, 26 Feb 2024 12:38:45 +0100 Subject: [PATCH 07/12] Fixed docstrings and linting for v1.0.0 --- malpolon/check_install.py | 6 ++ malpolon/data/datasets/geolifeclef2022.py | 4 +- malpolon/data/datasets/torchgeo_sentinel2.py | 2 +- malpolon/data/environmental_raster.py | 71 ++++++++----------- malpolon/data/utils.py | 6 +- malpolon/logging.py | 50 ++++++------- malpolon/models/model_builder.py | 15 ++-- malpolon/models/multi_modal.py | 2 + .../models/standard_prediction_systems.py | 4 +- malpolon/models/utils.py | 33 ++++++--- malpolon/plot/history.py | 13 ++-- malpolon/plot/map.py | 9 ++- 12 files changed, 114 insertions(+), 101 deletions(-) diff --git a/malpolon/check_install.py b/malpolon/check_install.py index 41ef4df6..8240914b 100644 --- a/malpolon/check_install.py +++ b/malpolon/check_install.py @@ -1,9 +1,15 @@ +"""This module checks the installation of PyTorch and GPU libraries. + +Author: Titouan Lorieul +""" + import os import torch def print_cuda_info(): + """Print information about the CUDA/PyTorch installation.""" print(f"Using PyTorch version {torch.__version__}") print(f"CUDA available: {torch.cuda.is_available()} (version: {torch.version.cuda})") print(f"cuDNN available: {torch.backends.cudnn.enabled} (version: {torch.backends.cudnn.version()})") diff --git a/malpolon/data/datasets/geolifeclef2022.py b/malpolon/data/datasets/geolifeclef2022.py index 2aae7896..73980b2b 100644 --- a/malpolon/data/datasets/geolifeclef2022.py +++ b/malpolon/data/datasets/geolifeclef2022.py @@ -347,9 +347,9 @@ def download(self): return try: - import kaggle + import kaggle # pylint: disable=C0415,W0611 # noqa: F401 except OSError as error: - raise OSError("Have you properly set up your Kaggle API token ? For more information, please refer to section 'Authentication' of the kaggle documentation : https://www.kaggle.com/docs/api"+msg) from error + raise OSError("Have you properly set up your Kaggle API token ? For more information, please refer to section 'Authentication' of the kaggle documentation : https://www.kaggle.com/docs/api") from error answer = input("You are about to download the GeoLifeClef2022 dataset which weighs ~62 GB. Do you want to continue ? [y/n]") if answer.lower() in ["y", "yes"]: diff --git a/malpolon/data/datasets/torchgeo_sentinel2.py b/malpolon/data/datasets/torchgeo_sentinel2.py index 463ff254..c34b3256 100644 --- a/malpolon/data/datasets/torchgeo_sentinel2.py +++ b/malpolon/data/datasets/torchgeo_sentinel2.py @@ -254,7 +254,7 @@ def plot( class RasterSentinel2GLC23(RasterSentinel2): - """Adaptation of RasterSentinel2 for new GLC23 observations""" + """Adaptation of RasterSentinel2 for new GLC23 observations.""" filename_glob = "*.tif" filename_regex = r"(?Pred|green|blue|nir)_2021" all_bands = ["red", "green", "blue", "nir"] diff --git a/malpolon/data/environmental_raster.py b/malpolon/data/environmental_raster.py index 7fd52b68..30bbe788 100644 --- a/malpolon/data/environmental_raster.py +++ b/malpolon/data/environmental_raster.py @@ -35,8 +35,8 @@ # fmt: on -class Raster(object): - """Loads a GeoTIFF file and extract patches for a single environmental raster +class Raster(): + """Loads a GeoTIFF file and extract patches for a single environmental raster. Parameters ---------- @@ -62,11 +62,7 @@ def __init__( ): path = Path(path) if not path.exists(): - raise ValueError( - "path should be the path to a raster, given non-existant path: {}".format( - path - ) - ) + raise ValueError(f"path should be the path to a raster, given non-existant path: {path}") self.path = path self.name = path.name @@ -75,7 +71,7 @@ def __init__( self.nan = nan # Loading the raster - filename = path / "{}_{}.tif".format(self.name, country) + filename = path / f"{self.name}_{country}.tif" with rasterio.open(filename) as dataset: self.dataset = dataset raster = dataset.read(1, masked=True, out_dtype=np.float32) @@ -93,7 +89,8 @@ def __init__( self.shape = self.raster.shape def _extract_patch(self, coordinates: Coordinates) -> Patch: - """Extracts the patch around the given GPS coordinates. + """Extract the patch around the given GPS coordinates. + Avoid using this method directly. Parameter @@ -139,7 +136,9 @@ def _extract_patch(self, coordinates: Coordinates) -> Patch: return patch def __len__(self) -> int: - """Number of bands in the raster (should always be equal to 1). + """Return the number of bands in the raster. + + Should always be equal to 1. Returns ------- @@ -149,7 +148,7 @@ def __len__(self) -> int: return self.dataset.count def __getitem__(self, coordinates: Coordinates) -> Patch: - """Extracts the patch around the given GPS coordinates. + """Extract the patch around the given GPS coordinates. Parameters ---------- @@ -166,20 +165,17 @@ def __getitem__(self, coordinates: Coordinates) -> Patch: except IndexError as e: if self.out_of_bounds == "error": raise e - else: - if self.out_of_bounds == "warn": - warnings.warn( - "GPS coordinates ({}, {}) out of bounds".format(*coordinates) - ) + if self.out_of_bounds == "warn": + warnings.warn(f"GPS coordinates ({coordinates[0]}, {coordinates[1]}) out of bounds") - if self.size == 1: - patch = np.array([self.nan], dtype=np.float32) - else: - patch = np.full( - (1, self.size, self.size), fill_value=self.nan, dtype=np.float32 - ) + if self.size == 1: + patch = np.array([self.nan], dtype=np.float32) + else: + patch = np.full( + (1, self.size, self.size), fill_value=self.nan, dtype=np.float32 + ) - return patch + return patch def __repr__(self) -> str: return str(self) @@ -188,7 +184,7 @@ def __str__(self) -> str: return "name: " + self.name + "\n" -class PatchExtractor(object): +class PatchExtractor(): """Handles the loading and extraction of an environmental tensor from multiple rasters given GPS coordinates. Parameters @@ -202,11 +198,7 @@ class PatchExtractor(object): def __init__(self, root_path: Union[str, Path], size: int = 256): self.root_path = Path(root_path) if not self.root_path.exists(): - raise ValueError( - "root_path should be the directory containing the rasters, given a non-existant path: {}".format( - root_path - ) - ) + raise ValueError("root_path should be the directory containing the rasters, given a non-existant path: {root_path}") self.size = size @@ -214,7 +206,7 @@ def __init__(self, root_path: Union[str, Path], size: int = 256): self.rasters_us: list[Raster] = [] def add_all_rasters(self, **kwargs: Any) -> None: - """Add all variables (rasters) available + """Add all variables (rasters) available. Parameters ---------- @@ -225,7 +217,7 @@ def add_all_rasters(self, **kwargs: Any) -> None: self.append(raster_name, **kwargs) def add_all_bioclimatic_rasters(self, **kwargs: Any) -> None: - """Add all bioclimatic variables (rasters) available + """Add all bioclimatic variables (rasters) available. Parameters ---------- @@ -236,7 +228,7 @@ def add_all_bioclimatic_rasters(self, **kwargs: Any) -> None: self.append(raster_name, **kwargs) def add_all_pedologic_rasters(self, **kwargs: Any) -> None: - """Add all pedologic variables (rasters) available + """Add all pedologic variables (rasters) available. Parameters ---------- @@ -247,7 +239,7 @@ def add_all_pedologic_rasters(self, **kwargs: Any) -> None: self.append(raster_name, **kwargs) def append(self, raster_name: str, **kwargs: Any) -> None: - """Loads and appends a single raster to the rasters already loaded. + """Load and append a single raster to the rasters already loaded. Can be useful to load only a subset of rasters or to pass configurations specific to each raster. @@ -270,7 +262,7 @@ def clean(self) -> None: self.rasters_us = [] def _get_rasters_list(self, coordinates: Coordinates) -> list[Raster]: - """Returns the list of rasters from the appropriate country + """Return the list of rasters from the appropriate country. Parameters ---------- @@ -284,8 +276,7 @@ def _get_rasters_list(self, coordinates: Coordinates) -> list[Raster]: """ if coordinates[1] > -10.0: return self.rasters_fr - else: - return self.rasters_us + return self.rasters_us def __repr__(self) -> str: return str(self) @@ -301,7 +292,7 @@ def __str__(self) -> str: return result def __getitem__(self, coordinates: Coordinates) -> npt.NDArray[np.float32]: - """Extracts the patches around the given GPS coordinates for all the previously loaded rasters. + """Extract the patches around the given GPS coordinates for all the previously loaded rasters. Parameters ---------- @@ -317,7 +308,7 @@ def __getitem__(self, coordinates: Coordinates) -> npt.NDArray[np.float32]: return np.concatenate([r[coordinates] for r in rasters]) def __len__(self) -> int: - """Number of variables/rasters loaded. + """Return the number of variables/rasters loaded. Returns ------- @@ -334,7 +325,7 @@ def plot( fig: Optional[plt.Figure] = None, resolution: float = 1.0, ) -> Optional[plt.Figure]: - """Plot an environmental tensor (only works if size > 1) + """Plot an environmental tensor (only works if size > 1). Parameters ---------- @@ -394,7 +385,7 @@ def plot( ax.set_title(k[0], fontsize=20) fig.colorbar(im, ax=ax) - for ax in axes[len(metadata) :]: + for ax in axes[len(metadata):]: ax.axis("off") fig.tight_layout() diff --git a/malpolon/data/utils.py b/malpolon/data/utils.py index 7597e803..9f94605c 100644 --- a/malpolon/data/utils.py +++ b/malpolon/data/utils.py @@ -17,7 +17,7 @@ def is_bbox_contained( bbox1: Union[Iterable, BoundingBox], bbox2: Union[Iterable, BoundingBox], - method: str['shapely', 'manual', 'torchgeo'] = 'shapely' + method: str = 'shapely' ) -> bool: """Determine if a 2D bbox in included inside of another. @@ -62,7 +62,7 @@ def is_bbox_contained( def is_point_in_bbox( point: Iterable, bbox: Iterable, - method: str['shapely', 'manual'] = 'shapely' + method: str = 'shapely' ) -> bool: """Determine if a 2D point in included inside of a 2D bounding box. @@ -119,7 +119,7 @@ def to_one_hot_encoding( list One-hot encoded labels. """ - labels_predict = [labels_predict] if type(labels_predict) is int else labels_predict + labels_predict = [labels_predict] if isinstance(labels_predict, int) else labels_predict n_classes = len(labels_target) one_hot_labels = np.zeros(n_classes, dtype=np.float32) one_hot_labels[np.in1d(labels_target, labels_predict)] = 1 diff --git a/malpolon/logging.py b/malpolon/logging.py index 6a513f54..40215a9b 100644 --- a/malpolon/logging.py +++ b/malpolon/logging.py @@ -1,18 +1,26 @@ +"""Thie module defines custom methods for model logging purposes. + +Author: Titouan Lorieul +""" + from __future__ import annotations -from typing import TYPE_CHECKING import logging +from typing import TYPE_CHECKING from pytorch_lightning.callbacks import Callback if TYPE_CHECKING: from typing import Any + import pytorch_lightning as pl + from .models.standard_prediction_systems import GenericPredictionSystem def str_object(obj: Any) -> str: - """ + """Format an object to string. + Formats an object to printing by returning a string containing the class name and attributes (both name and values) @@ -24,7 +32,6 @@ class name and attributes (both name and values) ------- str: string containing class name and attributes. """ - class_name = obj.__class__.__name__ attributes = obj.__dict__ @@ -38,16 +45,16 @@ class name and attributes (both name and values) filtered_attributes.append((key, val)) formatted_attributes = ", ".join( - map(lambda x: "{}={}".format(*x), filtered_attributes) + map(lambda x: f"{*x,}={filtered_attributes}") ) - return "{}(\n {}\n)".format(class_name, formatted_attributes) + return f"{class_name}(\n {formatted_attributes}\n)".format(class_name, formatted_attributes) class Summary(Callback): - """ + """Log model summary at the beginning of training. + FIXME handle multi validation data loaders, combined datasets """ - def __init__(self) -> None: self.logger = logging.getLogger("malpolon") @@ -59,47 +66,40 @@ def _log_data_loading_summary(self, data_loader, split: str) -> None: else: dataset = data_loader.dataset - from torch.utils.data import Subset + from torch.utils.data import Subset # pylint: disable=C0415 if isinstance(dataset, Subset): dataset = dataset.dataset - logger.info("{} dataset: {}".format(split, dataset)) - logger.info("{} set size: {}".format(split, len(dataset))) + logger.info("%s dataset: %s", split, dataset) + logger.info("%s set size: %s", split, len(dataset)) if split == "Train" and hasattr(dataset, "n_classes"): - logger.info("Number of classes: {}".format(dataset.n_classes)) + logger.info("Number of classes: %s", dataset.n_classes) if hasattr(dataset, "transform"): - logger.info("{} data transformations: {}".format(split, dataset.transform)) + logger.info("%s data transformations: %s", split, dataset.transform) if hasattr(dataset, "target_transform"): - logger.info( - "{} data target transformations: {}".format( - split, dataset.target_transform - ) - ) + logger.info("%s data target transformations: %s", split, dataset.target_transform) - logger.info( - "{} data sampler: {}".format(split, str_object(data_loader.sampler)) - ) + logger.info("%s data sampler: %s", split, str_object(data_loader.sampler)) if hasattr(data_loader, "loaders"): batch_sampler = data_loader.loaders.batch_sampler else: batch_sampler = data_loader.batch_sampler - logger.info( - "{} data batch sampler: {}".format(split, str_object(batch_sampler)) - ) + logger.info("%s data batch sampler: %s", split, str_object(batch_sampler)) - def on_train_start(self, trainer: pl.Trainer, model: GenericPredictionSystem) -> None: + def on_train_start(self, trainer: pl.Trainer, pl_module: GenericPredictionSystem) -> None: logger = self.logger + model = pl_module logger.info("\n# Model specification") logger.info(model.model) logger.info(model.loss) logger.info(model.optimizer) - logger.info("Metrics: {}".format(model.metrics)) + logger.info("Metrics: %s", model.metrics) logger.info("\n# Data loading information") logger.info("\n## Training data") diff --git a/malpolon/models/model_builder.py b/malpolon/models/model_builder.py index 18b3ecf3..a2aa0fe0 100644 --- a/malpolon/models/model_builder.py +++ b/malpolon/models/model_builder.py @@ -184,16 +184,15 @@ def _find_module_of_type( modules = reversed(list(module.named_children())) else: raise ValueError( - "order must be either 'first' or 'last', given {}".format(order) + f"order must be either 'first' or 'last', given {order}" ) for child_name, child in modules: if isinstance(child, module_type): return module, child_name - else: - res = _find_module_of_type(child, module_type, order) - if res[1] != "": - return res + res = _find_module_of_type(child, module_type, order) + if res[1] != "": + return res return module, "" @@ -203,8 +202,7 @@ def change_first_convolutional_layer_modifier( num_input_channels: int, new_conv_layer_init_func: Optional[Callable[[nn.Conv2d, nn.Conv2d], None]] = None, ) -> nn.Module: - """ - Removes the first registered convolutional layer of a model and replaces it by a new convolutional layer with the provided number of input channels. + """Remove the first registered convolutional layer of a model and replaces it by a new convolutional layer with the provided number of input channels. Parameters ---------- @@ -284,8 +282,7 @@ def change_last_layer_modifier( def change_last_layer_to_identity_modifier(model: nn.Module) -> nn.Module: - """ - Removes the last linear layer of a model and replaces it by an nn.Identity layer. + """Remove the last linear layer of a model and replaces it by an nn.Identity layer. Parameters ---------- diff --git a/malpolon/models/multi_modal.py b/malpolon/models/multi_modal.py index 38379c87..8a0fdad9 100644 --- a/malpolon/models/multi_modal.py +++ b/malpolon/models/multi_modal.py @@ -116,6 +116,7 @@ def __init__( super().__init__("cuda:0", accelerator, checkpoint_io, precision_plugin) def model_to_device(self) -> None: + """TODO: Docstring.""" model = self.model.model self.modalites_names = model.modalities_models.keys() num_modalities = len(self.modalities_names) @@ -137,6 +138,7 @@ def model_to_device(self) -> None: def batch_to_device( self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0 ) -> Any: + """TODO: Docstring.""" x, target = batch for modality_name in self.modalities_models: diff --git a/malpolon/models/standard_prediction_systems.py b/malpolon/models/standard_prediction_systems.py index 02ba8848..664f0c57 100644 --- a/malpolon/models/standard_prediction_systems.py +++ b/malpolon/models/standard_prediction_systems.py @@ -2,7 +2,6 @@ Author: Titouan Lorieul Theo Larcher - """ from __future__ import annotations @@ -82,7 +81,7 @@ def _step( x, y = batch y_hat = self(x) - loss = self.loss(y_hat, self._cast_type_to_loss(y)) # Shape mismatch for binary: need to 'y = y.unsqueeze(1)' (or use .reshape(2)) to cast from [2] to [2,1] and cast y to float with .float() + loss = self.loss(y_hat, self._cast_type_to_loss(y)) # Shape mismatch for binary: need to 'y = y.unsqueeze(1)' (or use .reshape(2)) to cast from [2] to [2,1] and cast y to float with .float() self.log(f"{split}_loss", loss, **log_kwargs) for metric_name, metric_func in self.metrics.items(): @@ -285,7 +284,6 @@ def __init__( if True performs preprocessing operations on the hyperparameters, by default True """ - if hparams_preprocess: task = task.split('classification_')[1] metrics = check_metric(metrics) diff --git a/malpolon/models/utils.py b/malpolon/models/utils.py index 4c934e08..29134cbd 100644 --- a/malpolon/models/utils.py +++ b/malpolon/models/utils.py @@ -7,6 +7,7 @@ from __future__ import annotations import signal +import sys from pathlib import Path from typing import Mapping, Union @@ -29,28 +30,42 @@ def __init__(self, trainer): signal.signal(signal.SIGINT, self.signal_handler) def save_checkpoint(self): + """Save the latest checkpoint.""" print("Saving lastest checkpoint...") self.trainer.save_checkpoint(self.ckpt_dir_path) def signal_handler(self, sig, frame): + """Attempt to save the latest checkpoint in case of crash.""" print(f"Received signal {sig}. Performing cleanup...") self.save_checkpoint() - exit(0) + sys.exit(0) -def check_metric(metrics: OmegaConf) -> bool: - """_summary_ + +def check_metric(metrics: OmegaConf) -> OmegaConf: + """Ensure user's model metrics are valid. + + Users can either choose from a list of predefined metrics or + define their own custom metrics. This function binds the user's + metrics with their corresponding callable function from + torchmetrics, by reading the values in `metrics` which is a + dict-like structure returned by hydra when reading the config + file. + If the user chose predefined metrics, the function will + automatically bind the corresponding callable function from + torchmetrics. + If the user chose custom metrics, the function checks that they + also provided the callable function to compute the metric. Parameters ---------- - metric_name : str - _description_ - metric_type : str, optional - _description_, by default 'classification' + metrics: OmegaConf + user's input metrics, read from the config file via hydra, in + a dict-like structure Returns ------- - bool - _description_ + OmegaConf + user's metrics with their corresponding callable function """ try: metrics = OmegaConf.to_container(metrics) diff --git a/malpolon/plot/history.py b/malpolon/plot/history.py index ffe94cc2..949ea827 100644 --- a/malpolon/plot/history.py +++ b/malpolon/plot/history.py @@ -12,6 +12,7 @@ def escape_tex(s: str) -> str: + """Escape special characters for LaTeX rendering.""" if not plt.rcParams["text.usetex"]: return s @@ -22,7 +23,7 @@ def escape_tex(s: str) -> str: def plot_metric(df_metrics: pd.DataFrame, metric: str, ax: plt.Axis) -> plt.Axis: - """Plot specific metric monitored during model training history + """Plot specific metric monitored during model training history. Parameters ---------- @@ -76,7 +77,7 @@ def plot_history( fig: Optional[plt.Figure] = None, axes: Optional[list[plt.Axis]] = None, ) -> tuple[plt.Figure, list[plt.Axis]]: - """Plot model training history + """Plot model training history. Parameters ---------- @@ -102,7 +103,7 @@ def plot_history( axes = fig.subplots(nrows=nrows, ncols=ncols) - empty_axes = axes.ravel()[-(len(axes) - len(base_metrics) + 1) :] + empty_axes = axes.ravel()[-(len(axes) - len(base_metrics) + 1):] for ax in empty_axes: ax.axis("off") @@ -141,10 +142,10 @@ def plot_history( title = args.title[i] if i < len(args.title) else "" df = pd.read_csv(filename, index_col=args.index_column) - fig, axes = plot_history(df) - fig.canvas.manager.set_window_title(filename) + fig_hist, _ = plot_history(df) + fig_hist.canvas.manager.set_window_title(filename) if title: - fig.suptitle(title) + fig_hist.suptitle(title) plt.show() diff --git a/malpolon/plot/map.py b/malpolon/plot/map.py index f89d8cfe..a42e5427 100644 --- a/malpolon/plot/map.py +++ b/malpolon/plot/map.py @@ -20,7 +20,7 @@ def plot_map( extent: Optional[npt.ArrayLike] = None, ax: Optional[plt.Axes] = None, ) -> plt.Axes: - """Plots a map to show the observations on + """Plot a map on which to show the observations. Parameters ---------- @@ -43,8 +43,11 @@ def plot_map( elif region is None and extent is None: raise ValueError("Either region or extent must be set") - import cartopy.crs as ccrs - import cartopy.feature as cfeature + # Import outside toplevel to ease package management, especially + # when working on a computing cluster because cartopy requires + # binaries to be installed. + import cartopy.crs as ccrs # pylint: disable=C0415 + import cartopy.feature as cfeature # pylint: disable=C0415 if ax is None: ax = plt.axes(projection=ccrs.PlateCarree()) From 8f4f2739397d47cb02d17d97849b45d985a94182 Mon Sep 17 00:00:00 2001 From: Theo Larcher <42494948+tlarcher@users.noreply.github.com> Date: Mon, 26 Feb 2024 12:45:05 +0100 Subject: [PATCH 08/12] v1.0.0 Code documentation (#43) * Updated credits info and added file docstrings when missing. * Added missing docstrings * Fixed docstrings and linting for v1.0.0 --- .../cnn_on_rgb_nir_patches.py | 8 ++ .../cnn_on_rgb_patches.py | 8 ++ .../micro_geolifeclef2022/transforms.py | 3 +- .../sentinel-2a/cnn_on_rgbnir_torchgeo.py | 3 +- examples/ecologists/sentinel-2a/transforms.py | 10 +- .../cnn_on_rgb_nir_patches.py | 11 +- .../cnn_on_rgb_patches.py | 10 +- .../micro_geolifeclef2022/transforms.py | 3 +- .../sentinel-2a/cnn_on_rgbnir_torchgeo.py | 5 +- examples/inference/sentinel-2a/transforms.py | 10 +- .../geolifeclef2022/cnn_on_rgb_patches.py | 12 +- .../cnn_on_rgb_temperature_patches.py | 10 +- .../cnn_on_temperature_patches.py | 10 +- examples/kaggle/geolifeclef2022/transforms.py | 9 ++ .../cnn_on_rgbnir_glc23_patches.py | 2 +- malpolon/check_install.py | 6 + malpolon/data/data_module.py | 3 +- malpolon/data/datasets/geolifeclef2022.py | 6 +- malpolon/data/datasets/torchgeo_sentinel2.py | 2 +- malpolon/data/environmental_raster.py | 80 ++++++----- malpolon/data/get_jpeg_patches_stats.py | 2 + malpolon/data/utils.py | 12 +- malpolon/logging.py | 50 +++---- malpolon/models/model_builder.py | 124 ++++++++++++++++-- malpolon/models/multi_modal.py | 53 +++++++- .../models/standard_prediction_systems.py | 6 +- malpolon/models/utils.py | 39 ++++-- malpolon/plot/history.py | 20 ++- malpolon/plot/map.py | 18 ++- malpolon/tests/test_environmental_raster.py | 5 + .../tests/test_geolifeclef2022_dataset.py | 9 +- malpolon/tests/test_models.py | 6 + .../tests/test_standard_prediction_systems.py | 5 + malpolon/tests/test_torchgeo_datasets.py | 5 + setup.py | 2 +- 35 files changed, 434 insertions(+), 133 deletions(-) diff --git a/examples/ecologists/micro_geolifeclef2022/cnn_on_rgb_nir_patches.py b/examples/ecologists/micro_geolifeclef2022/cnn_on_rgb_nir_patches.py index 971114cc..4eca4423 100644 --- a/examples/ecologists/micro_geolifeclef2022/cnn_on_rgb_nir_patches.py +++ b/examples/ecologists/micro_geolifeclef2022/cnn_on_rgb_nir_patches.py @@ -1,3 +1,11 @@ +"""Main script to run training or inference on microlifeclef2022 dataset. + +Uses RGB and Near infra-red pre-extracted patches from the dataset. + +Author: Titouan Lorieul + Theo Larcher +""" + from __future__ import annotations import os diff --git a/examples/ecologists/micro_geolifeclef2022/cnn_on_rgb_patches.py b/examples/ecologists/micro_geolifeclef2022/cnn_on_rgb_patches.py index 0e0a1734..ae5d7642 100644 --- a/examples/ecologists/micro_geolifeclef2022/cnn_on_rgb_patches.py +++ b/examples/ecologists/micro_geolifeclef2022/cnn_on_rgb_patches.py @@ -1,3 +1,11 @@ +"""Main script to run training or inference on microlifeclef2022 dataset. + +Uses RGB pre-extracted patches from the dataset. + +Author: Titouan Lorieul + Theo Larcher +""" + from __future__ import annotations import os diff --git a/examples/ecologists/micro_geolifeclef2022/transforms.py b/examples/ecologists/micro_geolifeclef2022/transforms.py index 71ad830c..da97c7ea 100644 --- a/examples/ecologists/micro_geolifeclef2022/transforms.py +++ b/examples/ecologists/micro_geolifeclef2022/transforms.py @@ -3,9 +3,10 @@ These transform classes can be called during training loops to perform data augmentation. -Author: Titouan Lorieul +Author: Titouan Lorieul Theo Larcher """ + import numpy as np import torch from torchvision import transforms diff --git a/examples/ecologists/sentinel-2a/cnn_on_rgbnir_torchgeo.py b/examples/ecologists/sentinel-2a/cnn_on_rgbnir_torchgeo.py index 97e27ac9..1bfe2722 100644 --- a/examples/ecologists/sentinel-2a/cnn_on_rgbnir_torchgeo.py +++ b/examples/ecologists/sentinel-2a/cnn_on_rgbnir_torchgeo.py @@ -3,8 +3,9 @@ This script runs the RasterSentinel2 dataset class by default. Author: Theo Larcher - Titouan Lorieul + Titouan Lorieul """ + from __future__ import annotations import os diff --git a/examples/ecologists/sentinel-2a/transforms.py b/examples/ecologists/sentinel-2a/transforms.py index 266eaa78..1cb661c3 100644 --- a/examples/ecologists/sentinel-2a/transforms.py +++ b/examples/ecologists/sentinel-2a/transforms.py @@ -1,5 +1,13 @@ -import numpy as np +"""Collection of custom PyTorch friendly transform classes. + +These transform classes can be called during training loops to perform +data augmentation. +Author: Titouan Lorieul + Theo Larcher +""" + +import numpy as np import torch from torchvision import transforms diff --git a/examples/inference/micro_geolifeclef2022/cnn_on_rgb_nir_patches.py b/examples/inference/micro_geolifeclef2022/cnn_on_rgb_nir_patches.py index 2ec35dd3..e1a51773 100644 --- a/examples/inference/micro_geolifeclef2022/cnn_on_rgb_nir_patches.py +++ b/examples/inference/micro_geolifeclef2022/cnn_on_rgb_nir_patches.py @@ -1,9 +1,14 @@ -from __future__ import annotations +"""Main script to run inference on microlifeclef2022 dataset. + +Uses RGB and Near infra-red pre-extracted patches from the dataset. -import os +Author: Titouan Lorieul + Theo Larcher +""" + +from __future__ import annotations import hydra -import numpy as np import pytorch_lightning as pl import torch from omegaconf import DictConfig diff --git a/examples/inference/micro_geolifeclef2022/cnn_on_rgb_patches.py b/examples/inference/micro_geolifeclef2022/cnn_on_rgb_patches.py index 85227df4..97049cdc 100644 --- a/examples/inference/micro_geolifeclef2022/cnn_on_rgb_patches.py +++ b/examples/inference/micro_geolifeclef2022/cnn_on_rgb_patches.py @@ -1,6 +1,12 @@ -from __future__ import annotations +"""Main script to run inference on microlifeclef2022 dataset. + +Uses RGB pre-extracted patches from the dataset. -import os +Author: Titouan Lorieul + Theo Larcher +""" + +from __future__ import annotations import hydra import pytorch_lightning as pl diff --git a/examples/inference/micro_geolifeclef2022/transforms.py b/examples/inference/micro_geolifeclef2022/transforms.py index 71ad830c..da97c7ea 100644 --- a/examples/inference/micro_geolifeclef2022/transforms.py +++ b/examples/inference/micro_geolifeclef2022/transforms.py @@ -3,9 +3,10 @@ These transform classes can be called during training loops to perform data augmentation. -Author: Titouan Lorieul +Author: Titouan Lorieul Theo Larcher """ + import numpy as np import torch from torchvision import transforms diff --git a/examples/inference/sentinel-2a/cnn_on_rgbnir_torchgeo.py b/examples/inference/sentinel-2a/cnn_on_rgbnir_torchgeo.py index e1e2cc5d..ee7ad15c 100644 --- a/examples/inference/sentinel-2a/cnn_on_rgbnir_torchgeo.py +++ b/examples/inference/sentinel-2a/cnn_on_rgbnir_torchgeo.py @@ -3,11 +3,10 @@ This script runs the RasterSentinel2 dataset class by default. Author: Theo Larcher - Titouan Lorieul + Titouan Lorieul """ -from __future__ import annotations -import os +from __future__ import annotations import hydra import pytorch_lightning as pl diff --git a/examples/inference/sentinel-2a/transforms.py b/examples/inference/sentinel-2a/transforms.py index 266eaa78..1cb661c3 100644 --- a/examples/inference/sentinel-2a/transforms.py +++ b/examples/inference/sentinel-2a/transforms.py @@ -1,5 +1,13 @@ -import numpy as np +"""Collection of custom PyTorch friendly transform classes. + +These transform classes can be called during training loops to perform +data augmentation. +Author: Titouan Lorieul + Theo Larcher +""" + +import numpy as np import torch from torchvision import transforms diff --git a/examples/kaggle/geolifeclef2022/cnn_on_rgb_patches.py b/examples/kaggle/geolifeclef2022/cnn_on_rgb_patches.py index 4167ce52..49597851 100644 --- a/examples/kaggle/geolifeclef2022/cnn_on_rgb_patches.py +++ b/examples/kaggle/geolifeclef2022/cnn_on_rgb_patches.py @@ -1,8 +1,14 @@ -import os +"""Main script to run training on microlifeclef2022 dataset. + +Uses RGB pre-extracted patches from the dataset. +This script was created for Kaggle participants of the GeoLifeCLEF 2022 +challenge. + +Author: Titouan Lorieul +""" import hydra -import pytorch_lightning as pl -import torchmetrics.functional as Fmetrics +import pytorch_lightning as p from omegaconf import DictConfig from pytorch_lightning.callbacks import ModelCheckpoint from torchvision import transforms diff --git a/examples/kaggle/geolifeclef2022/cnn_on_rgb_temperature_patches.py b/examples/kaggle/geolifeclef2022/cnn_on_rgb_temperature_patches.py index b32fbd79..80bc888e 100644 --- a/examples/kaggle/geolifeclef2022/cnn_on_rgb_temperature_patches.py +++ b/examples/kaggle/geolifeclef2022/cnn_on_rgb_temperature_patches.py @@ -1,4 +1,12 @@ -import os +"""Main script to run training on microlifeclef2022 dataset. + +Uses RGB pre-extracted patches and temperature rasters from the dataset. +This script was created for Kaggle participants of the GeoLifeCLEF 2022 +challenge. + +Author: Titouan Lorieul +""" + from pathlib import Path import hydra diff --git a/examples/kaggle/geolifeclef2022/cnn_on_temperature_patches.py b/examples/kaggle/geolifeclef2022/cnn_on_temperature_patches.py index 4a22bb00..410d82ed 100644 --- a/examples/kaggle/geolifeclef2022/cnn_on_temperature_patches.py +++ b/examples/kaggle/geolifeclef2022/cnn_on_temperature_patches.py @@ -1,4 +1,12 @@ -import os +"""Main script to run training on microlifeclef2022 dataset. + +Uses temperature rasters from the dataset. +This script was created for Kaggle participants of the GeoLifeCLEF 2022 +challenge. + +Author: Titouan Lorieul +""" + from pathlib import Path import hydra diff --git a/examples/kaggle/geolifeclef2022/transforms.py b/examples/kaggle/geolifeclef2022/transforms.py index 0f155988..ff60e0d3 100644 --- a/examples/kaggle/geolifeclef2022/transforms.py +++ b/examples/kaggle/geolifeclef2022/transforms.py @@ -1,3 +1,12 @@ +"""Collection of custom PyTorch friendly transform classes. + +These transform classes can be called during training loops to perform +data augmentation. + +Author: Titouan Lorieul + Theo Larcher +""" + import numpy as np import torch from torchvision import transforms diff --git a/examples/kaggle/geolifeclef2023/cnn_on_rgbnir_glc23_patches.py b/examples/kaggle/geolifeclef2023/cnn_on_rgbnir_glc23_patches.py index bcb4c401..e03ead44 100644 --- a/examples/kaggle/geolifeclef2023/cnn_on_rgbnir_glc23_patches.py +++ b/examples/kaggle/geolifeclef2023/cnn_on_rgbnir_glc23_patches.py @@ -3,8 +3,8 @@ This script runs the RasterSentinel2 dataset class by default. Author: Theo Larcher - Titouan Lorieul """ + from __future__ import annotations import hydra diff --git a/malpolon/check_install.py b/malpolon/check_install.py index 41ef4df6..8240914b 100644 --- a/malpolon/check_install.py +++ b/malpolon/check_install.py @@ -1,9 +1,15 @@ +"""This module checks the installation of PyTorch and GPU libraries. + +Author: Titouan Lorieul +""" + import os import torch def print_cuda_info(): + """Print information about the CUDA/PyTorch installation.""" print(f"Using PyTorch version {torch.__version__}") print(f"CUDA available: {torch.cuda.is_available()} (version: {torch.version.cuda})") print(f"cuDNN available: {torch.backends.cudnn.enabled} (version: {torch.backends.cudnn.version()})") diff --git a/malpolon/data/data_module.py b/malpolon/data/data_module.py index f74e2321..e093d335 100644 --- a/malpolon/data/data_module.py +++ b/malpolon/data/data_module.py @@ -1,8 +1,7 @@ """This module provides a base class for data modules. Author: Theo Larcher - Titouan Lorieul - + Titouan Lorieul """ from __future__ import annotations diff --git a/malpolon/data/datasets/geolifeclef2022.py b/malpolon/data/datasets/geolifeclef2022.py index f6f8eba3..73980b2b 100644 --- a/malpolon/data/datasets/geolifeclef2022.py +++ b/malpolon/data/datasets/geolifeclef2022.py @@ -3,7 +3,7 @@ This module has since been updated for GeoLifeCLEF2023 Author: Benjamin Deneu - Titouan Lorieul + Titouan Lorieul License: GPLv3 Python version: 3.8 @@ -347,9 +347,9 @@ def download(self): return try: - import kaggle + import kaggle # pylint: disable=C0415,W0611 # noqa: F401 except OSError as error: - raise OSError("Have you properly set up your Kaggle API token ? For more information, please refer to section 'Authentication' of the kaggle documentation : https://www.kaggle.com/docs/api"+msg) from error + raise OSError("Have you properly set up your Kaggle API token ? For more information, please refer to section 'Authentication' of the kaggle documentation : https://www.kaggle.com/docs/api") from error answer = input("You are about to download the GeoLifeClef2022 dataset which weighs ~62 GB. Do you want to continue ? [y/n]") if answer.lower() in ["y", "yes"]: diff --git a/malpolon/data/datasets/torchgeo_sentinel2.py b/malpolon/data/datasets/torchgeo_sentinel2.py index 463ff254..c34b3256 100644 --- a/malpolon/data/datasets/torchgeo_sentinel2.py +++ b/malpolon/data/datasets/torchgeo_sentinel2.py @@ -254,7 +254,7 @@ def plot( class RasterSentinel2GLC23(RasterSentinel2): - """Adaptation of RasterSentinel2 for new GLC23 observations""" + """Adaptation of RasterSentinel2 for new GLC23 observations.""" filename_glob = "*.tif" filename_regex = r"(?Pred|green|blue|nir)_2021" all_bands = ["red", "green", "blue", "nir"] diff --git a/malpolon/data/environmental_raster.py b/malpolon/data/environmental_raster.py index 2d9c1e1d..30bbe788 100644 --- a/malpolon/data/environmental_raster.py +++ b/malpolon/data/environmental_raster.py @@ -1,13 +1,18 @@ +"""Custom classes to handle environmental rasters without torchgeo. + +Author: Titouan Lorieul +""" + from __future__ import annotations + import warnings from pathlib import Path -from typing import Any, Optional, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional, Union import matplotlib.pyplot as plt import numpy as np import rasterio - if TYPE_CHECKING: import numpy.typing as npt @@ -30,8 +35,8 @@ # fmt: on -class Raster(object): - """Loads a GeoTIFF file and extract patches for a single environmental raster +class Raster(): + """Loads a GeoTIFF file and extract patches for a single environmental raster. Parameters ---------- @@ -57,11 +62,7 @@ def __init__( ): path = Path(path) if not path.exists(): - raise ValueError( - "path should be the path to a raster, given non-existant path: {}".format( - path - ) - ) + raise ValueError(f"path should be the path to a raster, given non-existant path: {path}") self.path = path self.name = path.name @@ -70,7 +71,7 @@ def __init__( self.nan = nan # Loading the raster - filename = path / "{}_{}.tif".format(self.name, country) + filename = path / f"{self.name}_{country}.tif" with rasterio.open(filename) as dataset: self.dataset = dataset raster = dataset.read(1, masked=True, out_dtype=np.float32) @@ -88,7 +89,8 @@ def __init__( self.shape = self.raster.shape def _extract_patch(self, coordinates: Coordinates) -> Patch: - """Extracts the patch around the given GPS coordinates. + """Extract the patch around the given GPS coordinates. + Avoid using this method directly. Parameter @@ -134,7 +136,9 @@ def _extract_patch(self, coordinates: Coordinates) -> Patch: return patch def __len__(self) -> int: - """Number of bands in the raster (should always be equal to 1). + """Return the number of bands in the raster. + + Should always be equal to 1. Returns ------- @@ -144,7 +148,7 @@ def __len__(self) -> int: return self.dataset.count def __getitem__(self, coordinates: Coordinates) -> Patch: - """Extracts the patch around the given GPS coordinates. + """Extract the patch around the given GPS coordinates. Parameters ---------- @@ -161,20 +165,17 @@ def __getitem__(self, coordinates: Coordinates) -> Patch: except IndexError as e: if self.out_of_bounds == "error": raise e - else: - if self.out_of_bounds == "warn": - warnings.warn( - "GPS coordinates ({}, {}) out of bounds".format(*coordinates) - ) + if self.out_of_bounds == "warn": + warnings.warn(f"GPS coordinates ({coordinates[0]}, {coordinates[1]}) out of bounds") - if self.size == 1: - patch = np.array([self.nan], dtype=np.float32) - else: - patch = np.full( - (1, self.size, self.size), fill_value=self.nan, dtype=np.float32 - ) + if self.size == 1: + patch = np.array([self.nan], dtype=np.float32) + else: + patch = np.full( + (1, self.size, self.size), fill_value=self.nan, dtype=np.float32 + ) - return patch + return patch def __repr__(self) -> str: return str(self) @@ -183,7 +184,7 @@ def __str__(self) -> str: return "name: " + self.name + "\n" -class PatchExtractor(object): +class PatchExtractor(): """Handles the loading and extraction of an environmental tensor from multiple rasters given GPS coordinates. Parameters @@ -197,11 +198,7 @@ class PatchExtractor(object): def __init__(self, root_path: Union[str, Path], size: int = 256): self.root_path = Path(root_path) if not self.root_path.exists(): - raise ValueError( - "root_path should be the directory containing the rasters, given a non-existant path: {}".format( - root_path - ) - ) + raise ValueError("root_path should be the directory containing the rasters, given a non-existant path: {root_path}") self.size = size @@ -209,7 +206,7 @@ def __init__(self, root_path: Union[str, Path], size: int = 256): self.rasters_us: list[Raster] = [] def add_all_rasters(self, **kwargs: Any) -> None: - """Add all variables (rasters) available + """Add all variables (rasters) available. Parameters ---------- @@ -220,7 +217,7 @@ def add_all_rasters(self, **kwargs: Any) -> None: self.append(raster_name, **kwargs) def add_all_bioclimatic_rasters(self, **kwargs: Any) -> None: - """Add all bioclimatic variables (rasters) available + """Add all bioclimatic variables (rasters) available. Parameters ---------- @@ -231,7 +228,7 @@ def add_all_bioclimatic_rasters(self, **kwargs: Any) -> None: self.append(raster_name, **kwargs) def add_all_pedologic_rasters(self, **kwargs: Any) -> None: - """Add all pedologic variables (rasters) available + """Add all pedologic variables (rasters) available. Parameters ---------- @@ -242,7 +239,7 @@ def add_all_pedologic_rasters(self, **kwargs: Any) -> None: self.append(raster_name, **kwargs) def append(self, raster_name: str, **kwargs: Any) -> None: - """Loads and appends a single raster to the rasters already loaded. + """Load and append a single raster to the rasters already loaded. Can be useful to load only a subset of rasters or to pass configurations specific to each raster. @@ -265,7 +262,7 @@ def clean(self) -> None: self.rasters_us = [] def _get_rasters_list(self, coordinates: Coordinates) -> list[Raster]: - """Returns the list of rasters from the appropriate country + """Return the list of rasters from the appropriate country. Parameters ---------- @@ -279,8 +276,7 @@ def _get_rasters_list(self, coordinates: Coordinates) -> list[Raster]: """ if coordinates[1] > -10.0: return self.rasters_fr - else: - return self.rasters_us + return self.rasters_us def __repr__(self) -> str: return str(self) @@ -296,7 +292,7 @@ def __str__(self) -> str: return result def __getitem__(self, coordinates: Coordinates) -> npt.NDArray[np.float32]: - """Extracts the patches around the given GPS coordinates for all the previously loaded rasters. + """Extract the patches around the given GPS coordinates for all the previously loaded rasters. Parameters ---------- @@ -312,7 +308,7 @@ def __getitem__(self, coordinates: Coordinates) -> npt.NDArray[np.float32]: return np.concatenate([r[coordinates] for r in rasters]) def __len__(self) -> int: - """Number of variables/rasters loaded. + """Return the number of variables/rasters loaded. Returns ------- @@ -329,7 +325,7 @@ def plot( fig: Optional[plt.Figure] = None, resolution: float = 1.0, ) -> Optional[plt.Figure]: - """Plot an environmental tensor (only works if size > 1) + """Plot an environmental tensor (only works if size > 1). Parameters ---------- @@ -389,7 +385,7 @@ def plot( ax.set_title(k[0], fontsize=20) fig.colorbar(im, ax=ax) - for ax in axes[len(metadata) :]: + for ax in axes[len(metadata):]: ax.axis("off") fig.tight_layout() diff --git a/malpolon/data/get_jpeg_patches_stats.py b/malpolon/data/get_jpeg_patches_stats.py index 8365c689..8b57da8d 100644 --- a/malpolon/data/get_jpeg_patches_stats.py +++ b/malpolon/data/get_jpeg_patches_stats.py @@ -3,6 +3,8 @@ When dealing with a large amount of files it should be run only once, and the statistics should be stored in a separate .csv for later use. + +Author: Theo Larcher """ import argparse diff --git a/malpolon/data/utils.py b/malpolon/data/utils.py index df342e96..9f94605c 100644 --- a/malpolon/data/utils.py +++ b/malpolon/data/utils.py @@ -1,4 +1,8 @@ -"""This file compiles useful functions related to data and file handling.""" +"""This file compiles useful functions related to data and file handling. + +Author: Theo Larcher +""" + from __future__ import annotations import os @@ -13,7 +17,7 @@ def is_bbox_contained( bbox1: Union[Iterable, BoundingBox], bbox2: Union[Iterable, BoundingBox], - method: str['shapely', 'manual', 'torchgeo'] = 'shapely' + method: str = 'shapely' ) -> bool: """Determine if a 2D bbox in included inside of another. @@ -58,7 +62,7 @@ def is_bbox_contained( def is_point_in_bbox( point: Iterable, bbox: Iterable, - method: str['shapely', 'manual'] = 'shapely' + method: str = 'shapely' ) -> bool: """Determine if a 2D point in included inside of a 2D bounding box. @@ -115,7 +119,7 @@ def to_one_hot_encoding( list One-hot encoded labels. """ - labels_predict = [labels_predict] if type(labels_predict) is int else labels_predict + labels_predict = [labels_predict] if isinstance(labels_predict, int) else labels_predict n_classes = len(labels_target) one_hot_labels = np.zeros(n_classes, dtype=np.float32) one_hot_labels[np.in1d(labels_target, labels_predict)] = 1 diff --git a/malpolon/logging.py b/malpolon/logging.py index 6a513f54..40215a9b 100644 --- a/malpolon/logging.py +++ b/malpolon/logging.py @@ -1,18 +1,26 @@ +"""Thie module defines custom methods for model logging purposes. + +Author: Titouan Lorieul +""" + from __future__ import annotations -from typing import TYPE_CHECKING import logging +from typing import TYPE_CHECKING from pytorch_lightning.callbacks import Callback if TYPE_CHECKING: from typing import Any + import pytorch_lightning as pl + from .models.standard_prediction_systems import GenericPredictionSystem def str_object(obj: Any) -> str: - """ + """Format an object to string. + Formats an object to printing by returning a string containing the class name and attributes (both name and values) @@ -24,7 +32,6 @@ class name and attributes (both name and values) ------- str: string containing class name and attributes. """ - class_name = obj.__class__.__name__ attributes = obj.__dict__ @@ -38,16 +45,16 @@ class name and attributes (both name and values) filtered_attributes.append((key, val)) formatted_attributes = ", ".join( - map(lambda x: "{}={}".format(*x), filtered_attributes) + map(lambda x: f"{*x,}={filtered_attributes}") ) - return "{}(\n {}\n)".format(class_name, formatted_attributes) + return f"{class_name}(\n {formatted_attributes}\n)".format(class_name, formatted_attributes) class Summary(Callback): - """ + """Log model summary at the beginning of training. + FIXME handle multi validation data loaders, combined datasets """ - def __init__(self) -> None: self.logger = logging.getLogger("malpolon") @@ -59,47 +66,40 @@ def _log_data_loading_summary(self, data_loader, split: str) -> None: else: dataset = data_loader.dataset - from torch.utils.data import Subset + from torch.utils.data import Subset # pylint: disable=C0415 if isinstance(dataset, Subset): dataset = dataset.dataset - logger.info("{} dataset: {}".format(split, dataset)) - logger.info("{} set size: {}".format(split, len(dataset))) + logger.info("%s dataset: %s", split, dataset) + logger.info("%s set size: %s", split, len(dataset)) if split == "Train" and hasattr(dataset, "n_classes"): - logger.info("Number of classes: {}".format(dataset.n_classes)) + logger.info("Number of classes: %s", dataset.n_classes) if hasattr(dataset, "transform"): - logger.info("{} data transformations: {}".format(split, dataset.transform)) + logger.info("%s data transformations: %s", split, dataset.transform) if hasattr(dataset, "target_transform"): - logger.info( - "{} data target transformations: {}".format( - split, dataset.target_transform - ) - ) + logger.info("%s data target transformations: %s", split, dataset.target_transform) - logger.info( - "{} data sampler: {}".format(split, str_object(data_loader.sampler)) - ) + logger.info("%s data sampler: %s", split, str_object(data_loader.sampler)) if hasattr(data_loader, "loaders"): batch_sampler = data_loader.loaders.batch_sampler else: batch_sampler = data_loader.batch_sampler - logger.info( - "{} data batch sampler: {}".format(split, str_object(batch_sampler)) - ) + logger.info("%s data batch sampler: %s", split, str_object(batch_sampler)) - def on_train_start(self, trainer: pl.Trainer, model: GenericPredictionSystem) -> None: + def on_train_start(self, trainer: pl.Trainer, pl_module: GenericPredictionSystem) -> None: logger = self.logger + model = pl_module logger.info("\n# Model specification") logger.info(model.model) logger.info(model.loss) logger.info(model.optimizer) - logger.info("Metrics: {}".format(model.metrics)) + logger.info("Metrics: %s", model.metrics) logger.info("\n# Data loading information") logger.info("\n## Training data") diff --git a/malpolon/models/model_builder.py b/malpolon/models/model_builder.py index 05621ef4..a2aa0fe0 100644 --- a/malpolon/models/model_builder.py +++ b/malpolon/models/model_builder.py @@ -1,3 +1,14 @@ +"""This module provides classes to build your PyTorch models. + +Classes listed in this module allow to select a model from your +provider (timm, torchvision...), retrieve it with or without +pre-trained weights, and modify it by adding or removing layers. + +Author: Titouan Lorieul + Theo Larcher + +""" + from __future__ import annotations from typing import TYPE_CHECKING @@ -14,6 +25,7 @@ class _ModelBuilder: + """General class to build models.""" providers: dict[str, Provider] = {} modifiers: dict[str, Modifier] = {} @@ -25,6 +37,28 @@ def build_model( model_kwargs: dict = {}, modifiers: dict[str, Optional[dict[str, Any]]] = {}, ) -> nn.Module: + """Return a built model with the given provider and modifiers. + + Parameters + ---------- + provider_name : str + source of the model's provider, valid values are: + [`timm`, `torchvision`] + model_name : str + name of the model to retrieve from the provider + model_args : list, optional + model arguments to pass on when building it, by default [] + model_kwargs : dict, optional + model kwargs, by default {} + modifiers : dict[str, Optional[dict[str, Any]]], optional + modifiers to call on the model after it is built, + by default {} + + Returns + ------- + nn.Module + built and mofified model + """ provider = self.providers[provider_name] model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None} model = provider(model_name, *model_args, **model_kwargs) @@ -37,24 +71,77 @@ def build_model( return model def register_provider(self, provider_name: str, provider: Provider) -> None: + """Register a provider to the model builder. + + Parameters + ---------- + provider_name : str + name of the provider, valid values are: + [`timm`, `torchvision`] + provider : Provider + callable provider function + """ self.providers[provider_name] = provider def register_modifier(self, modifier_name: str, modifier: Modifier) -> None: + """Register a modifier to the model builder. + + Parameters + ---------- + modifier_name : str + name of the modifier, valid values are: + [`change_first_convolutional_layer`, `change_last_layer`, `change_last_layer_to_identity`] + modifier : Modifier + modifier callable function + """ self.modifiers[modifier_name] = modifier def torchvision_model_provider( model_name: str, *model_args: Any, **model_kwargs: Any ) -> nn.Module: + """Return a model from torchvision's library. + + This method uses tochvision's API to retrieve a model from its + library. + + Parameters + ---------- + model_name : str + name of the model to retrieve from torchvision's library + Returns + ------- + nn.Module + model object + """ model = getattr(models, model_name) model = model(*model_args, **model_kwargs) return model + def timm_model_provider( model_name: str, *model_args: Any, **model_kwargs: Any ) -> nn.Module: + """Return a model from timm's library. + + This method uses timm's API to retrieve a model from its library. + + Parameters + ---------- + model_name : str + name of the model to retrieve from timm's library + + Returns + ------- + nn.Module + model object + Raises + ------ + ValueError + if the model name is not listed in TIMM's library + """ available_models = timm.list_models() if model_name in available_models: model = timm.create_model(model_name, *model_args, **model_kwargs) @@ -69,22 +156,43 @@ def timm_model_provider( def _find_module_of_type( module: nn.Module, module_type: type, order: str ) -> tuple[nn.Module, str]: + """Find the first or last module of a given type in a module. + + Parameters + ---------- + module : nn.Module + torch module to search in (_e.g.: torch model_) + module_type : type + module type to search for (_e.g.: nn.Conv2d_) + order : str + order to search for the module, valid values are: + [`first`, `last`] + + Returns + ------- + tuple[nn.Module, str] + module and its name + + Raises + ------ + ValueError + if the order is not valid + """ if order == "first": modules = module.named_children() elif order == "last": modules = reversed(list(module.named_children())) else: raise ValueError( - "order must be either 'first' or 'last', given {}".format(order) + f"order must be either 'first' or 'last', given {order}" ) for child_name, child in modules: if isinstance(child, module_type): return module, child_name - else: - res = _find_module_of_type(child, module_type, order) - if res[1] != "": - return res + res = _find_module_of_type(child, module_type, order) + if res[1] != "": + return res return module, "" @@ -94,8 +202,7 @@ def change_first_convolutional_layer_modifier( num_input_channels: int, new_conv_layer_init_func: Optional[Callable[[nn.Conv2d, nn.Conv2d], None]] = None, ) -> nn.Module: - """ - Removes the first registered convolutional layer of a model and replaces it by a new convolutional layer with the provided number of input channels. + """Remove the first registered convolutional layer of a model and replaces it by a new convolutional layer with the provided number of input channels. Parameters ---------- @@ -175,8 +282,7 @@ def change_last_layer_modifier( def change_last_layer_to_identity_modifier(model: nn.Module) -> nn.Module: - """ - Removes the last linear layer of a model and replaces it by an nn.Identity layer. + """Remove the last linear layer of a model and replaces it by an nn.Identity layer. Parameters ---------- diff --git a/malpolon/models/multi_modal.py b/malpolon/models/multi_modal.py index 7f395ae6..8a0fdad9 100644 --- a/malpolon/models/multi_modal.py +++ b/malpolon/models/multi_modal.py @@ -1,10 +1,17 @@ +"""This module provides classes for advanced model building. + +Author: Titouan Lorieul + Theo Larcher +""" + from __future__ import annotations + from typing import TYPE_CHECKING import torch -from torch import nn from pytorch_lightning.strategies import SingleDeviceStrategy, StrategyRegistry from pytorch_lightning.utilities import move_data_to_device +from torch import nn from .utils import check_model @@ -13,11 +20,32 @@ class MultiModalModel(nn.Module): + """Base multi-modal model. + + This class builds an aggregation of multiple models from the passed + on config file values, one for each modality, splits the training + routine per modality and then aggregates the features from each + modality after each forward pass. + """ def __init__( self, modality_models: Union[nn.Module, Mapping], aggregator_model: Union[nn.Module, Mapping], ): + """Class constructor. + + Parameters + ---------- + modality_models : Union[nn.Module, Mapping] + dictionary of modality names and their respective models to + pass on to the model builder + aggregator_model : Union[nn.Module, Mapping] + Model strategy to aggregate the features from each modality. + Can either be a PyTorch module directly (in this case, the + module will be directly called), or a mapping in the same + fashion as for buiding the modality models, in which case + the model builder will be called again. + """ super().__init__() for modality_name, model in modality_models.items(): @@ -39,12 +67,29 @@ def forward(self, x: list[Any]) -> Any: class HomogeneousMultiModalModel(MultiModalModel): + """Straightforward multi-modal model.""" def __init__( self, modality_names: list, modalities_model: dict, aggregator_model: Union[nn.Module, Mapping], ): + """Class constructor. + + Parameters + ---------- + modality_names : list + list of modalities names + modalities_model : dict + dictionary of modality names and their respective models to + pass on to the model builder + aggregator_model : Union[nn.Module, Mapping] + Model strategy to aggregate the features from each modality. + Can either be a PyTorch module directly (in this case, the + module will be directly called), or a mapping in the same + fashion as for buiding the modality models, in which case + the model builder will be called again. + """ self.modality_names = modality_names self.modalities_model = modalities_model @@ -55,6 +100,10 @@ def __init__( class ParallelMultiModalModelStrategy(SingleDeviceStrategy): + """Model parallelism strategy for multi-modal models. + + WARNING: STILL UNDER DEVELOPMENT. + """ strategy_name = "parallel_multi_modal_model" def __init__( @@ -67,6 +116,7 @@ def __init__( super().__init__("cuda:0", accelerator, checkpoint_io, precision_plugin) def model_to_device(self) -> None: + """TODO: Docstring.""" model = self.model.model self.modalites_names = model.modalities_models.keys() num_modalities = len(self.modalities_names) @@ -88,6 +138,7 @@ def model_to_device(self) -> None: def batch_to_device( self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0 ) -> Any: + """TODO: Docstring.""" x, target = batch for modality_name in self.modalities_models: diff --git a/malpolon/models/standard_prediction_systems.py b/malpolon/models/standard_prediction_systems.py index dca873c4..664f0c57 100644 --- a/malpolon/models/standard_prediction_systems.py +++ b/malpolon/models/standard_prediction_systems.py @@ -1,8 +1,7 @@ """This module provides classes wrapping pytorchlightning training modules. -Author: Titouan Lorieul +Author: Titouan Lorieul Theo Larcher - """ from __future__ import annotations @@ -82,7 +81,7 @@ def _step( x, y = batch y_hat = self(x) - loss = self.loss(y_hat, self._cast_type_to_loss(y)) # Shape mismatch for binary: need to 'y = y.unsqueeze(1)' (or use .reshape(2)) to cast from [2] to [2,1] and cast y to float with .float() + loss = self.loss(y_hat, self._cast_type_to_loss(y)) # Shape mismatch for binary: need to 'y = y.unsqueeze(1)' (or use .reshape(2)) to cast from [2] to [2,1] and cast y to float with .float() self.log(f"{split}_loss", loss, **log_kwargs) for metric_name, metric_func in self.metrics.items(): @@ -285,7 +284,6 @@ def __init__( if True performs preprocessing operations on the hyperparameters, by default True """ - if hparams_preprocess: task = task.split('classification_')[1] metrics = check_metric(metrics) diff --git a/malpolon/models/utils.py b/malpolon/models/utils.py index 4182837e..29134cbd 100644 --- a/malpolon/models/utils.py +++ b/malpolon/models/utils.py @@ -1,8 +1,13 @@ -"""This file compiles useful functions related to models.""" +"""This file compiles useful functions related to models. + +Author: Theo Larcher + Titouan Lorieul +""" from __future__ import annotations import signal +import sys from pathlib import Path from typing import Mapping, Union @@ -25,28 +30,42 @@ def __init__(self, trainer): signal.signal(signal.SIGINT, self.signal_handler) def save_checkpoint(self): + """Save the latest checkpoint.""" print("Saving lastest checkpoint...") self.trainer.save_checkpoint(self.ckpt_dir_path) def signal_handler(self, sig, frame): + """Attempt to save the latest checkpoint in case of crash.""" print(f"Received signal {sig}. Performing cleanup...") self.save_checkpoint() - exit(0) + sys.exit(0) + + +def check_metric(metrics: OmegaConf) -> OmegaConf: + """Ensure user's model metrics are valid. -def check_metric(metrics: OmegaConf) -> bool: - """_summary_ + Users can either choose from a list of predefined metrics or + define their own custom metrics. This function binds the user's + metrics with their corresponding callable function from + torchmetrics, by reading the values in `metrics` which is a + dict-like structure returned by hydra when reading the config + file. + If the user chose predefined metrics, the function will + automatically bind the corresponding callable function from + torchmetrics. + If the user chose custom metrics, the function checks that they + also provided the callable function to compute the metric. Parameters ---------- - metric_name : str - _description_ - metric_type : str, optional - _description_, by default 'classification' + metrics: OmegaConf + user's input metrics, read from the config file via hydra, in + a dict-like structure Returns ------- - bool - _description_ + OmegaConf + user's metrics with their corresponding callable function """ try: metrics = OmegaConf.to_container(metrics) diff --git a/malpolon/plot/history.py b/malpolon/plot/history.py index 3ca08c18..949ea827 100644 --- a/malpolon/plot/history.py +++ b/malpolon/plot/history.py @@ -1,4 +1,10 @@ +"""Utilities used for plotting purposes. + +Author: Titouan Lorieul +""" + from __future__ import annotations + from typing import Optional import matplotlib.pyplot as plt @@ -6,6 +12,7 @@ def escape_tex(s: str) -> str: + """Escape special characters for LaTeX rendering.""" if not plt.rcParams["text.usetex"]: return s @@ -16,7 +23,7 @@ def escape_tex(s: str) -> str: def plot_metric(df_metrics: pd.DataFrame, metric: str, ax: plt.Axis) -> plt.Axis: - """Plot specific metric monitored during model training history + """Plot specific metric monitored during model training history. Parameters ---------- @@ -70,7 +77,7 @@ def plot_history( fig: Optional[plt.Figure] = None, axes: Optional[list[plt.Axis]] = None, ) -> tuple[plt.Figure, list[plt.Axis]]: - """Plot model training history + """Plot model training history. Parameters ---------- @@ -96,7 +103,7 @@ def plot_history( axes = fig.subplots(nrows=nrows, ncols=ncols) - empty_axes = axes.ravel()[-(len(axes) - len(base_metrics) + 1) :] + empty_axes = axes.ravel()[-(len(axes) - len(base_metrics) + 1):] for ax in empty_axes: ax.axis("off") @@ -109,6 +116,7 @@ def plot_history( if __name__ == "__main__": import argparse + import pandas as pd parser = argparse.ArgumentParser(description="plots the training curves") @@ -134,10 +142,10 @@ def plot_history( title = args.title[i] if i < len(args.title) else "" df = pd.read_csv(filename, index_col=args.index_column) - fig, axes = plot_history(df) - fig.canvas.manager.set_window_title(filename) + fig_hist, _ = plot_history(df) + fig_hist.canvas.manager.set_window_title(filename) if title: - fig.suptitle(title) + fig_hist.suptitle(title) plt.show() diff --git a/malpolon/plot/map.py b/malpolon/plot/map.py index 314ef7b4..a42e5427 100644 --- a/malpolon/plot/map.py +++ b/malpolon/plot/map.py @@ -1,5 +1,12 @@ +"""Utilities for plotting maps. + +Author: Titouan Lorieul + +""" + from __future__ import annotations -from typing import Optional, TYPE_CHECKING + +from typing import TYPE_CHECKING, Optional import matplotlib.pyplot as plt @@ -13,7 +20,7 @@ def plot_map( extent: Optional[npt.ArrayLike] = None, ax: Optional[plt.Axes] = None, ) -> plt.Axes: - """Plots a map to show the observations on + """Plot a map on which to show the observations. Parameters ---------- @@ -36,8 +43,11 @@ def plot_map( elif region is None and extent is None: raise ValueError("Either region or extent must be set") - import cartopy.crs as ccrs - import cartopy.feature as cfeature + # Import outside toplevel to ease package management, especially + # when working on a computing cluster because cartopy requires + # binaries to be installed. + import cartopy.crs as ccrs # pylint: disable=C0415 + import cartopy.feature as cfeature # pylint: disable=C0415 if ax is None: ax = plt.axes(projection=ccrs.PlateCarree()) diff --git a/malpolon/tests/test_environmental_raster.py b/malpolon/tests/test_environmental_raster.py index 9719c981..9cfbf1a6 100644 --- a/malpolon/tests/test_environmental_raster.py +++ b/malpolon/tests/test_environmental_raster.py @@ -1,3 +1,8 @@ +"""This script tests the environmental raster module. + +Author: Titouan Lorieul +""" + from pathlib import Path import numpy as np diff --git a/malpolon/tests/test_geolifeclef2022_dataset.py b/malpolon/tests/test_geolifeclef2022_dataset.py index 054025da..3f045a02 100644 --- a/malpolon/tests/test_geolifeclef2022_dataset.py +++ b/malpolon/tests/test_geolifeclef2022_dataset.py @@ -1,11 +1,16 @@ +"""This script tests the GeoLifeCLEF2022 dataset module. + +Author: Titouan Lorieul +""" + from pathlib import Path import numpy as np import pytest +from malpolon.data.datasets.geolifeclef2022 import ( + GeoLifeCLEF2022Dataset, load_patch, visualize_observation_patch) from malpolon.data.environmental_raster import PatchExtractor -from malpolon.data.datasets.geolifeclef2022 import load_patch, GeoLifeCLEF2022Dataset, visualize_observation_patch - DATA_PATH = Path("malpolon/tests/data/glc22") diff --git a/malpolon/tests/test_models.py b/malpolon/tests/test_models.py index d9b9cab7..d621b36a 100644 --- a/malpolon/tests/test_models.py +++ b/malpolon/tests/test_models.py @@ -1,3 +1,9 @@ +"""This script tests the models module. + +Author: Titouan Lorieul + Theo Larcher +""" + import timm import torch from torchvision import models diff --git a/malpolon/tests/test_standard_prediction_systems.py b/malpolon/tests/test_standard_prediction_systems.py index 33269770..e4e5cb71 100644 --- a/malpolon/tests/test_standard_prediction_systems.py +++ b/malpolon/tests/test_standard_prediction_systems.py @@ -1,3 +1,8 @@ +"""This script tests the standard prediction systems module. + +Author: Theo Larcher +""" + import numpy as np import timm diff --git a/malpolon/tests/test_torchgeo_datasets.py b/malpolon/tests/test_torchgeo_datasets.py index 3256fc8b..b0aa362f 100644 --- a/malpolon/tests/test_torchgeo_datasets.py +++ b/malpolon/tests/test_torchgeo_datasets.py @@ -1,3 +1,8 @@ +"""This script tests the torchgeo datasets module. + +Author: Theo Larcher +""" + from pathlib import Path import numpy as np diff --git a/setup.py b/setup.py index d0ff5888..fa467ccd 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ version="1.0.0", description="Malpolon v1.0.0", author="Theo Larcher, Titouan Lorieul", - author_email="theo.larcher@inria.fr, titouan.lorieul@inria.fr", + author_email="theo.larcher@inria.fr, titouan.lorieul@gmail.com", url="https://github.com/plantnet/malpolon", classifiers=[ "Development Status :: 3 - Alpha", From d034c300e92e9b89bfb74ad05ab9e03fafe27707 Mon Sep 17 00:00:00 2001 From: aerodynamic-sauce-pan Date: Mon, 26 Feb 2024 18:02:00 +0100 Subject: [PATCH 09/12] Fixed bug in malpolon.data.logging.str_object which crashed training loop. --- malpolon/logging.py | 2 +- setup.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/malpolon/logging.py b/malpolon/logging.py index 40215a9b..61dcdd75 100644 --- a/malpolon/logging.py +++ b/malpolon/logging.py @@ -45,7 +45,7 @@ class name and attributes (both name and values) filtered_attributes.append((key, val)) formatted_attributes = ", ".join( - map(lambda x: f"{*x,}={filtered_attributes}") + map(lambda x: f"{x[0]}={x[1]}", filtered_attributes) ) return f"{class_name}(\n {formatted_attributes}\n)".format(class_name, formatted_attributes) diff --git a/setup.py b/setup.py index fa467ccd..f9df413a 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,8 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup setup(name="malpolon", - version="1.0.0", - description="Malpolon v1.0.0", + version="1.0.2", + description="Malpolon v1.0.2", author="Theo Larcher, Titouan Lorieul", author_email="theo.larcher@inria.fr, titouan.lorieul@gmail.com", url="https://github.com/plantnet/malpolon", From 316b5949180f95a728f7ddbe2e54606be0b5cf8c Mon Sep 17 00:00:00 2001 From: aerodynamic-sauce-pan Date: Tue, 27 Feb 2024 10:49:13 +0100 Subject: [PATCH 10/12] Updated root README following Malpolon's first PyPi package distribution. --- README.md | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 77e2f600..84838fde 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ To install malpolon, you will first need to install **Python ≥ 3.10**, and sev
Click here to expand instructions -### 0. Requirements +### Requirements Before proceeding, please make sure the following packages are installed on your system: @@ -67,7 +67,18 @@ Before proceeding, please make sure the following packages are installed on your The following instructions show installation commands for Python 3.10, but can be adapted for any of the compatible Python versions metionned above by simply changing the version number. -### 1. Clone the repository +### Install from `PyPi` +The backend side of malpolon is distributed as a package on `PyPi`. To install it, simply run the following command: + +```script +pip install malpolon +``` + +However, versions available on PyPi are non-experimental and possibly behind the repository's `main` and `dev` branches. To know which version you want download, please refer to the *tags* section of the repository and match it with PyPi. +Furthermore, the PyPi package does not include the examples and the documentation. If you want to install the full repository, follow the next steps. + +### Install from `GitHub` +#### 1. Clone the repository Clone the Malpolon repository using `git` in the directory of your choice: ```script @@ -76,7 +87,7 @@ git clone https://github.com/plantnet/malpolon.git --- -### 2. Create your virtual environment +#### 2. Create your virtual environment - **Via `virtualenv`** @@ -106,7 +117,7 @@ conda activate malpolon_3.10 --- -### 3. Install Malpolon as a python package +#### 3. Install Malpolon as a python package The malpolon repository can also be installed in your virtual environment as a package. This allows you to import `malpolon` anywhere in your scripts without having to worry about file paths. It can be installed via `pip` using: From 229ed7e0d5605b69ac8508564f7860448cf6b521 Mon Sep 17 00:00:00 2001 From: aerodynamic-sauce-pan Date: Tue, 27 Feb 2024 11:10:23 +0100 Subject: [PATCH 11/12] Updated root readme with pypi shield button and hard coded emojis --- README.md | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 84838fde..580c8ecb 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,11 @@

+ Python version + Python version GitHub issues GitHub pull requests GitHub contributors - GitHub forks - GitHub stars - GitHub watchers License

@@ -39,7 +40,7 @@ Here is a list of the currently available scenarios: - Custom dataset : I have my own dataset consisting of pre-extracted image patches and/or rasters and I want to train a model on it. - [**Inference**](examples/inference/) : I have an observations file (.csv) and I want to predict the presence of species on a given area using a model I trained previously and a selected dataset or a shapefile I would provide. -## :wrench: Installation +## 🔧 Installation To install malpolon, you will first need to install **Python ≥ 3.10**, and several python packages. To do so, it is best practice to create a virtual environment containing all these packages locally. @@ -151,7 +152,7 @@ git checkout dev
-## :page_facing_up: Documentation +## 📄 Documentation An online code documentation is available via GitHub pages at [this link](https://plantnet.github.io/malpolon/). This documentation is updated each time new content is pushed to the `main` branch. @@ -171,7 +172,7 @@ make -C docs html The result can be found in `docs/_build/html`. -## :train2: Roadmap +## 🚆 Roadmap This roadmap outlines the planned features and milestones for the project. Please note that the roadmap is subject to change and may be updated as the project progress. From af2d4e739dd7f1a439eb06766811aa38de8a42ec Mon Sep 17 00:00:00 2001 From: Theo Larcher <42494948+tlarcher@users.noreply.github.com> Date: Tue, 27 Feb 2024 11:13:48 +0100 Subject: [PATCH 12/12] Update README.md --- README.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 580c8ecb..0696a809 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,6 @@

- Python version - Python version + Python version + Python version GitHub issues GitHub pull requests GitHub contributors