diff --git a/.gitignore b/.gitignore index 117a6931..04a06484 100644 --- a/.gitignore +++ b/.gitignore @@ -146,7 +146,9 @@ dmypy.json # Lightning-Hydra-Template configs/local/default.yaml -data/ -logs/ +/data/ +/logs/ .env -.autoenv + +# Aim logging +.aim diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 31a624f2..3e0bfeb5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ default_language_version: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.4.0 hooks: # list of supported hooks: https://pre-commit.com/hooks.html - id: trailing-whitespace @@ -19,7 +19,7 @@ repos: # python code formatting - repo: https://github.com/psf/black - rev: 22.6.0 + rev: 23.1.0 hooks: - id: black args: [--line-length, "99"] @@ -33,55 +33,56 @@ repos: # python upgrading syntax to newer version - repo: https://github.com/asottile/pyupgrade - rev: v2.32.1 + rev: v3.3.1 hooks: - id: pyupgrade args: [--py38-plus] # python docstring formatting - repo: https://github.com/myint/docformatter - rev: v1.4 + rev: v1.5.1 hooks: - id: docformatter args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99] # python check (PEP8), programming errors and code complexity - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 + rev: 6.0.0 hooks: - id: flake8 # ignore E203 because black is used for formatting. args: [ - "--ignore", - "E203,E501,F401,F403,F841,W504", + "--extend-ignore", + "E203,E402,E501,F401,F403,F841", "--exclude", "logs/*,data/*", ] # python security linter - repo: https://github.com/PyCQA/bandit - rev: "1.7.1" + rev: "1.7.5" hooks: - id: bandit args: ["-s", "B101"] # yaml formatting - repo: https://github.com/pre-commit/mirrors-prettier - rev: v2.7.1 + rev: v3.0.0-alpha.6 hooks: - id: prettier types: [yaml] + exclude: "environment.yaml" - # jupyter notebook cell output clearing - - repo: https://github.com/kynan/nbstripout - rev: 0.5.0 + # shell scripts linter + - repo: https://github.com/shellcheck-py/shellcheck-py + rev: v0.9.0.2 hooks: - - id: nbstripout + - id: shellcheck # md formatting - repo: https://github.com/executablebooks/mdformat - rev: 0.7.14 + rev: 0.7.16 hooks: - id: mdformat args: ["--number"] @@ -94,9 +95,30 @@ repos: # word spelling linter - repo: https://github.com/codespell-project/codespell - rev: v2.1.0 + rev: v2.2.4 hooks: - id: codespell args: - - --skip=logs/**,data/** + - --skip=logs/**,data/**,*.ipynb - --ignore-words-list=abc,def,gard + + # jupyter notebook cell output clearing + - repo: https://github.com/kynan/nbstripout + rev: 0.6.1 + hooks: + - id: nbstripout + + # jupyter notebook linting + - repo: https://github.com/nbQA-dev/nbQA + rev: 1.6.3 + hooks: + - id: nbqa-black + args: ["--line-length=99"] + - id: nbqa-isort + args: ["--profile=black"] + - id: nbqa-flake8 + args: + [ + "--extend-ignore=E203,E402,E501,F401,F841", + "--exclude=logs/*,data/*", + ] diff --git a/examples/carla_overhead_object_detection/.gitignore b/examples/carla_overhead_object_detection/.gitignore new file mode 100644 index 00000000..8ba906dd --- /dev/null +++ b/examples/carla_overhead_object_detection/.gitignore @@ -0,0 +1,4 @@ +# Default folder for downloading dataset +data + +logs diff --git a/examples/carla_overhead_object_detection/configs/datamodule/armory_carla_over_objdet.yaml b/examples/carla_overhead_object_detection/configs/data/armory_carla_over_objdet.yaml similarity index 100% rename from examples/carla_overhead_object_detection/configs/datamodule/armory_carla_over_objdet.yaml rename to examples/carla_overhead_object_detection/configs/data/armory_carla_over_objdet.yaml diff --git a/examples/carla_overhead_object_detection/configs/datamodule/armory_carla_over_objdet_perturbable_mask.yaml b/examples/carla_overhead_object_detection/configs/data/armory_carla_over_objdet_perturbable_mask.yaml similarity index 100% rename from examples/carla_overhead_object_detection/configs/datamodule/armory_carla_over_objdet_perturbable_mask.yaml rename to examples/carla_overhead_object_detection/configs/data/armory_carla_over_objdet_perturbable_mask.yaml diff --git a/examples/carla_overhead_object_detection/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml b/examples/carla_overhead_object_detection/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml index 991a87e7..2042b972 100644 --- a/examples/carla_overhead_object_detection/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml +++ b/examples/carla_overhead_object_detection/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml @@ -3,7 +3,7 @@ defaults: - COCO_TorchvisionFasterRCNN - override /model/modules@model.modules.preprocessor: tuple_tensorizer_normalizer - - override /datamodule: armory_carla_over_objdet_perturbable_mask + - override /data: armory_carla_over_objdet_perturbable_mask task_name: "ArmoryCarlaOverObjDet_TorchvisionFasterRCNN" tags: ["regular_training"] diff --git a/examples/robust_bench/.gitignore b/examples/robust_bench/.gitignore new file mode 100644 index 00000000..8ba906dd --- /dev/null +++ b/examples/robust_bench/.gitignore @@ -0,0 +1,4 @@ +# Default folder for downloading dataset +data + +logs diff --git a/mart/__init__.py b/mart/__init__.py index 85181105..7ee4c290 100644 --- a/mart/__init__.py +++ b/mart/__init__.py @@ -1,7 +1,7 @@ import importlib from mart import attack as attack -from mart import datamodules as datamodules +from mart import data as data from mart import models as models from mart import nn as nn from mart import optim as optim diff --git a/mart/__main__.py b/mart/__main__.py index 42d71627..f8d19998 100644 --- a/mart/__main__.py +++ b/mart/__main__.py @@ -31,8 +31,7 @@ @hydra.main(version_base="1.2", config_path=config_path, config_name="lightning.yaml") def main(cfg: DictConfig) -> float: - - if cfg.resume is None and ("datamodule" not in cfg or "model" not in cfg): + if cfg.resume is None and ("data" not in cfg or "model" not in cfg): log.fatal("") log.fatal("Please specify an experiment to run, e.g.") log.fatal( diff --git a/mart/attack/adversary.py b/mart/attack/adversary.py index 8c5513d2..be7af943 100644 --- a/mart/attack/adversary.py +++ b/mart/attack/adversary.py @@ -10,8 +10,8 @@ from itertools import cycle from typing import TYPE_CHECKING, Any, Callable -import pytorch_lightning as pl import torch +from lightning import pytorch as pl from mart.utils import silent @@ -140,12 +140,10 @@ def training_step(self, batch, batch_idx): return gain def configure_gradient_clipping( - self, optimizer, optimizer_idx, gradient_clip_val=None, gradient_clip_algorithm=None + self, optimizer, gradient_clip_val=None, gradient_clip_algorithm=None ): # Configuring gradient clipping in pl.Trainer is still useful, so use it. - super().configure_gradient_clipping( - optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm - ) + super().configure_gradient_clipping(optimizer, gradient_clip_val, gradient_clip_algorithm) if self.gradient_modifier: for group in optimizer.param_groups: @@ -195,7 +193,7 @@ def attacker(self): elif self.device.type == "cpu": accelerator = "cpu" - devices = None + devices = 1 else: raise NotImplementedError diff --git a/mart/attack/perturber.py b/mart/attack/perturber.py index 29df3059..75c565ad 100644 --- a/mart/attack/perturber.py +++ b/mart/attack/perturber.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Iterable import torch -from pytorch_lightning.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.exceptions import MisconfigurationException from .projector import Projector diff --git a/mart/callbacks/eval_mode.py b/mart/callbacks/eval_mode.py index be3b6397..44597929 100644 --- a/mart/callbacks/eval_mode.py +++ b/mart/callbacks/eval_mode.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause # -from pytorch_lightning.callbacks import Callback +from lightning.pytorch.callbacks import Callback __all__ = ["AttackInEvalMode"] diff --git a/mart/callbacks/no_grad_mode.py b/mart/callbacks/no_grad_mode.py index cfb90ead..5b3b1446 100644 --- a/mart/callbacks/no_grad_mode.py +++ b/mart/callbacks/no_grad_mode.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause # -from pytorch_lightning.callbacks import Callback +from lightning.pytorch.callbacks import Callback __all__ = ["ModelParamsNoGrad"] diff --git a/mart/callbacks/progress_bar.py b/mart/callbacks/progress_bar.py index 5ee131b9..627dfb9f 100644 --- a/mart/callbacks/progress_bar.py +++ b/mart/callbacks/progress_bar.py @@ -6,9 +6,9 @@ from typing import Any -import pytorch_lightning as pl -from pytorch_lightning.callbacks import TQDMProgressBar -from pytorch_lightning.utilities.rank_zero import rank_zero_only +from lightning import pytorch as pl +from lightning.pytorch.callbacks import TQDMProgressBar +from lightning.pytorch.utilities.rank_zero import rank_zero_only __all__ = ["ProgressBar"] @@ -41,8 +41,6 @@ def init_train_tqdm(self): def on_train_epoch_start(self, trainer: pl.Trainer, *_: Any) -> None: super().on_train_epoch_start(trainer) - # So that it does not display negative rate. - self.main_progress_bar.initial = 0 # So that it does not display Epoch n. rank_id = rank_zero_only.rank self.main_progress_bar.set_description(f"Attack@rank{rank_id}") diff --git a/mart/callbacks/visualizer.py b/mart/callbacks/visualizer.py index 3354321e..39409143 100644 --- a/mart/callbacks/visualizer.py +++ b/mart/callbacks/visualizer.py @@ -6,7 +6,7 @@ import os -from pytorch_lightning.callbacks import Callback +from lightning.pytorch.callbacks import Callback from torchvision.transforms import ToPILImage __all__ = ["PerturbedImageVisualizer"] diff --git a/mart/configs/callbacks/early_stopping.yaml b/mart/configs/callbacks/early_stopping.yaml index 20ed2671..c826c8d5 100644 --- a/mart/configs/callbacks/early_stopping.yaml +++ b/mart/configs/callbacks/early_stopping.yaml @@ -1,9 +1,7 @@ -# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.EarlyStopping.html +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html -# Monitor a metric and stop training when it stops improving. -# Look at the above link for more detailed information. early_stopping: - _target_: pytorch_lightning.callbacks.EarlyStopping + _target_: lightning.pytorch.callbacks.EarlyStopping monitor: ??? # quantity to be monitored, must be specified !!! min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement patience: 3 # number of checks with no improvement after which training will be stopped diff --git a/mart/configs/callbacks/lr_monitor.yaml b/mart/configs/callbacks/lr_monitor.yaml index a23b28ee..f4bad6f8 100644 --- a/mart/configs/callbacks/lr_monitor.yaml +++ b/mart/configs/callbacks/lr_monitor.yaml @@ -1,3 +1,3 @@ lr_monitor: - _target_: pytorch_lightning.callbacks.LearningRateMonitor + _target_: lightning.pytorch.callbacks.LearningRateMonitor logging_interval: "step" diff --git a/mart/configs/callbacks/model_checkpoint.yaml b/mart/configs/callbacks/model_checkpoint.yaml index 3d4503b5..40581994 100644 --- a/mart/configs/callbacks/model_checkpoint.yaml +++ b/mart/configs/callbacks/model_checkpoint.yaml @@ -1,9 +1,7 @@ -# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.ModelCheckpoint.html +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html -# Save the model periodically by monitoring a quantity. -# Look at the above link for more detailed information. model_checkpoint: - _target_: pytorch_lightning.callbacks.ModelCheckpoint + _target_: lightning.pytorch.callbacks.ModelCheckpoint dirpath: "${paths.output_dir}/checkpoints/" # directory to save the model file filename: "epoch_{epoch:03d}" # checkpoint filename monitor: ??? # name of the logged metric which determines when model is improving diff --git a/mart/configs/callbacks/model_summary.yaml b/mart/configs/callbacks/model_summary.yaml index 04da98d3..b75981d8 100644 --- a/mart/configs/callbacks/model_summary.yaml +++ b/mart/configs/callbacks/model_summary.yaml @@ -1,7 +1,5 @@ -# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichModelSummary.html +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html -# Generates a summary of all layers in a LightningModule with rich text formatting. -# Look at the above link for more detailed information. model_summary: - _target_: pytorch_lightning.callbacks.RichModelSummary + _target_: lightning.pytorch.callbacks.RichModelSummary max_depth: 1 # the maximum depth of layer nesting that the summary will include diff --git a/mart/configs/callbacks/rich_progress_bar.yaml b/mart/configs/callbacks/rich_progress_bar.yaml index b6be5b45..de6f1ccb 100644 --- a/mart/configs/callbacks/rich_progress_bar.yaml +++ b/mart/configs/callbacks/rich_progress_bar.yaml @@ -1,6 +1,4 @@ -# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichProgressBar.html +# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html -# Create a progress bar with rich text formatting. -# Look at the above link for more detailed information. rich_progress_bar: - _target_: pytorch_lightning.callbacks.RichProgressBar + _target_: lightning.pytorch.callbacks.RichProgressBar diff --git a/mart/configs/datamodule/cifar10.yaml b/mart/configs/data/cifar10.yaml similarity index 96% rename from mart/configs/datamodule/cifar10.yaml rename to mart/configs/data/cifar10.yaml index ec2e795d..7f8b5b80 100644 --- a/mart/configs/datamodule/cifar10.yaml +++ b/mart/configs/data/cifar10.yaml @@ -55,3 +55,6 @@ test_dataset: ${.val_dataset} num_workers: 4 collate_fn: null + +# The accuracy metric may require this value. +num_classes: 10 diff --git a/mart/configs/datamodule/coco.yaml b/mart/configs/data/coco.yaml similarity index 91% rename from mart/configs/datamodule/coco.yaml rename to mart/configs/data/coco.yaml index a4ec3403..978daaaf 100644 --- a/mart/configs/datamodule/coco.yaml +++ b/mart/configs/data/coco.yaml @@ -2,7 +2,7 @@ defaults: - default.yaml train_dataset: - _target_: mart.datamodules.coco.CocoDetection + _target_: mart.data.coco.CocoDetection root: ${paths.data_dir}/coco/train2017 annFile: ${paths.data_dir}/coco/annotations/instances_train2017.json transforms: @@ -24,7 +24,7 @@ train_dataset: quant_max: 255 val_dataset: - _target_: mart.datamodules.coco.CocoDetection + _target_: mart.data.coco.CocoDetection root: ${paths.data_dir}/coco/val2017 annFile: ${paths.data_dir}/coco/annotations/instances_val2017.json transforms: @@ -44,7 +44,7 @@ val_dataset: quant_max: 255 test_dataset: - _target_: mart.datamodules.coco.CocoDetection + _target_: mart.data.coco.CocoDetection root: ${paths.data_dir}/coco/val2017 annFile: ${paths.data_dir}/coco/annotations/instances_val2017.json transforms: @@ -66,4 +66,4 @@ test_dataset: num_workers: 2 collate_fn: _target_: hydra.utils.get_method - path: mart.datamodules.coco.collate_fn + path: mart.data.coco.collate_fn diff --git a/mart/configs/datamodule/coco_perturbable_mask.yaml b/mart/configs/data/coco_perturbable_mask.yaml similarity index 100% rename from mart/configs/datamodule/coco_perturbable_mask.yaml rename to mart/configs/data/coco_perturbable_mask.yaml diff --git a/mart/configs/datamodule/default.yaml b/mart/configs/data/default.yaml similarity index 77% rename from mart/configs/datamodule/default.yaml rename to mart/configs/data/default.yaml index d3da6d4c..a82df951 100644 --- a/mart/configs/datamodule/default.yaml +++ b/mart/configs/data/default.yaml @@ -1,4 +1,4 @@ -_target_: mart.datamodules.LitDataModule +_target_: mart.data.LitDataModule # _convert_: all train_dataset: ??? diff --git a/mart/configs/datamodule/dummy_classification.yaml b/mart/configs/data/dummy_classification.yaml similarity index 100% rename from mart/configs/datamodule/dummy_classification.yaml rename to mart/configs/data/dummy_classification.yaml diff --git a/mart/configs/datamodule/imagenet.yaml b/mart/configs/data/imagenet.yaml similarity index 95% rename from mart/configs/datamodule/imagenet.yaml rename to mart/configs/data/imagenet.yaml index d8253cbc..f0967743 100644 --- a/mart/configs/datamodule/imagenet.yaml +++ b/mart/configs/data/imagenet.yaml @@ -46,3 +46,6 @@ val_dataset: quant_max: 255 test_dataset: ${.val_dataset} + +# The accuracy metric may require this value. +num_classes: 1000 diff --git a/mart/configs/debug/default.yaml b/mart/configs/debug/default.yaml index e30b19e2..1886902b 100644 --- a/mart/configs/debug/default.yaml +++ b/mart/configs/debug/default.yaml @@ -30,6 +30,6 @@ trainer: devices: 1 # debuggers don't like multiprocessing detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor -datamodule: +data: num_workers: 0 # debuggers don't like multiprocessing pin_memory: False # disable gpu memory pin diff --git a/mart/configs/debug/profiler.yaml b/mart/configs/debug/profiler.yaml index 1e8ab021..8e79d296 100644 --- a/mart/configs/debug/profiler.yaml +++ b/mart/configs/debug/profiler.yaml @@ -8,8 +8,8 @@ defaults: trainer: max_epochs: 1 profiler: - _target_: pytorch_lightning.profiler.SimpleProfiler - # _target_: pytorch_lightning.profiler.AdvancedProfiler - # _target_: pytorch_lightning.profiler.PyTorchProfiler + _target_: lightning.pytorch.profiler.SimpleProfiler + # _target_: lightning.pytorch.profiler.AdvancedProfiler + # _target_: lightning.pytorch.profiler.PyTorchProfiler dirpath: ${paths.output_dir} filename: profiler_log diff --git a/mart/configs/experiment/CIFAR10_CNN.yaml b/mart/configs/experiment/CIFAR10_CNN.yaml index cfb6eafa..7a823ede 100644 --- a/mart/configs/experiment/CIFAR10_CNN.yaml +++ b/mart/configs/experiment/CIFAR10_CNN.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - - override /datamodule: cifar10 + - override /data: cifar10 - override /model: classifier_cifar10_cnn - override /metric: accuracy - override /optimization: super_convergence @@ -22,7 +22,7 @@ trainer: max_steps: 5850 precision: 32 -datamodule: +data: ims_per_batch: 128 world_size: 1 num_workers: 8 diff --git a/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml b/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml index 9259de2c..53497fcd 100644 --- a/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml +++ b/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - - override /datamodule: coco + - override /data: coco - override /model: torchvision_faster_rcnn - override /metric: average_precision - override /optimization: super_convergence @@ -23,7 +23,7 @@ trainer: # FIXME: "nms_kernel" not implemented for 'BFloat16', torch.ops.torchvision.nms(). precision: 32 -datamodule: +data: ims_per_batch: 2 world_size: 1 diff --git a/mart/configs/experiment/COCO_TorchvisionFasterRCNN_Adv.yaml b/mart/configs/experiment/COCO_TorchvisionFasterRCNN_Adv.yaml index 398394bf..ee54403a 100644 --- a/mart/configs/experiment/COCO_TorchvisionFasterRCNN_Adv.yaml +++ b/mart/configs/experiment/COCO_TorchvisionFasterRCNN_Adv.yaml @@ -3,7 +3,7 @@ defaults: - COCO_TorchvisionFasterRCNN - /attack@model.modules.input_adv_test: object_detection_mask_adversary - - override /datamodule: coco_perturbable_mask + - override /data: coco_perturbable_mask task_name: "COCO_TorchvisionFasterRCNN_Adv" tags: ["adv"] diff --git a/mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml b/mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml index dbd4541f..f8232ea0 100644 --- a/mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml +++ b/mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - - override /datamodule: coco + - override /data: coco - override /model: torchvision_retinanet - override /metric: average_precision - override /optimization: super_convergence @@ -22,7 +22,7 @@ trainer: max_steps: 351798 precision: 16 -datamodule: +data: ims_per_batch: 2 world_size: 1 diff --git a/mart/configs/experiment/ImageNet_Timm.yaml b/mart/configs/experiment/ImageNet_Timm.yaml index 4c86bb45..c7b0a56c 100644 --- a/mart/configs/experiment/ImageNet_Timm.yaml +++ b/mart/configs/experiment/ImageNet_Timm.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - - override /datamodule: imagenet + - override /data: imagenet - override /model: classifier_timm - override /metric: accuracy - override /optimization: super_convergence @@ -20,7 +20,7 @@ trainer: max_steps: 140625 precision: 16 -datamodule: +data: ims_per_batch: 128 world_size: 1 num_workers: 8 diff --git a/mart/configs/hparams_search/mnist_optuna.yaml b/mart/configs/hparams_search/mnist_optuna.yaml index f328d07a..1391183e 100644 --- a/mart/configs/hparams_search/mnist_optuna.yaml +++ b/mart/configs/hparams_search/mnist_optuna.yaml @@ -46,7 +46,7 @@ hydra: # define hyperparameter search space params: model.optimizer.lr: interval(0.0001, 0.1) - datamodule.batch_size: choice(32, 64, 128, 256) + data.batch_size: choice(32, 64, 128, 256) model.net.lin1_size: choice(64, 128, 256) model.net.lin2_size: choice(64, 128, 256) model.net.lin3_size: choice(32, 64, 128, 256) diff --git a/mart/configs/lightning.yaml b/mart/configs/lightning.yaml index 714250f3..88a762fb 100644 --- a/mart/configs/lightning.yaml +++ b/mart/configs/lightning.yaml @@ -3,7 +3,7 @@ # specify here default training configuration defaults: - _self_ - - datamodule: null + - data: null - model: null - metric: null - optimization: null diff --git a/mart/configs/logger/aim.yaml b/mart/configs/logger/aim.yaml new file mode 100644 index 00000000..8f9f6ada --- /dev/null +++ b/mart/configs/logger/aim.yaml @@ -0,0 +1,28 @@ +# https://aimstack.io/ + +# example usage in lightning module: +# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py + +# open the Aim UI with the following command (run in the folder containing the `.aim` folder): +# `aim up` + +aim: + _target_: aim.pytorch_lightning.AimLogger + repo: ${paths.root_dir} # .aim folder will be created here + # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html# + + # aim allows to group runs under experiment name + experiment: null # any string, set to "default" if not specified + + train_metric_prefix: "train/" + val_metric_prefix: "val/" + test_metric_prefix: "test/" + + # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.) + system_tracking_interval: 10 # set to null to disable system metrics tracking + + # enable/disable logging of system params such as installed packages, git info, env vars, etc. + log_system_params: true + + # enable/disable tracking console logs (default value is true) + capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550 diff --git a/mart/configs/logger/comet.yaml b/mart/configs/logger/comet.yaml index 423f41f6..e0789274 100644 --- a/mart/configs/logger/comet.yaml +++ b/mart/configs/logger/comet.yaml @@ -1,7 +1,7 @@ # https://www.comet.ml comet: - _target_: pytorch_lightning.loggers.comet.CometLogger + _target_: lightning.pytorch.loggers.comet.CometLogger api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable save_dir: "${paths.output_dir}" project_name: "lightning-hydra-template" diff --git a/mart/configs/logger/csv.yaml b/mart/configs/logger/csv.yaml index 844ec671..fa028e9c 100644 --- a/mart/configs/logger/csv.yaml +++ b/mart/configs/logger/csv.yaml @@ -1,7 +1,7 @@ # csv logger built in lightning csv: - _target_: pytorch_lightning.loggers.csv_logs.CSVLogger + _target_: lightning.pytorch.loggers.csv_logs.CSVLogger save_dir: "${paths.output_dir}" name: "csv/" prefix: "" diff --git a/mart/configs/logger/mlflow.yaml b/mart/configs/logger/mlflow.yaml index 3b441a90..f8fb7e68 100644 --- a/mart/configs/logger/mlflow.yaml +++ b/mart/configs/logger/mlflow.yaml @@ -1,7 +1,7 @@ # https://mlflow.org mlflow: - _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger + _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger # experiment_name: "" # run_name: "" tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI diff --git a/mart/configs/logger/neptune.yaml b/mart/configs/logger/neptune.yaml index 5df1e342..8233c140 100644 --- a/mart/configs/logger/neptune.yaml +++ b/mart/configs/logger/neptune.yaml @@ -1,7 +1,7 @@ # https://neptune.ai neptune: - _target_: pytorch_lightning.loggers.neptune.NeptuneLogger + _target_: lightning.pytorch.loggers.neptune.NeptuneLogger api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable project: username/lightning-hydra-template # name: "" diff --git a/mart/configs/logger/tensorboard.yaml b/mart/configs/logger/tensorboard.yaml index 29067c90..2bd31f6d 100644 --- a/mart/configs/logger/tensorboard.yaml +++ b/mart/configs/logger/tensorboard.yaml @@ -1,7 +1,7 @@ # https://www.tensorflow.org/tensorboard/ tensorboard: - _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger + _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger save_dir: "${paths.output_dir}/tensorboard/" name: null log_graph: False diff --git a/mart/configs/logger/wandb.yaml b/mart/configs/logger/wandb.yaml index d6d20e0b..ece16588 100644 --- a/mart/configs/logger/wandb.yaml +++ b/mart/configs/logger/wandb.yaml @@ -1,7 +1,7 @@ # https://wandb.ai wandb: - _target_: pytorch_lightning.loggers.wandb.WandbLogger + _target_: lightning.pytorch.loggers.wandb.WandbLogger # name: "" # name of the run (normally generated by wandb) save_dir: "${paths.output_dir}" offline: False diff --git a/mart/configs/metric/accuracy.yaml b/mart/configs/metric/accuracy.yaml index d1e3709f..9c4a79f4 100644 --- a/mart/configs/metric/accuracy.yaml +++ b/mart/configs/metric/accuracy.yaml @@ -6,7 +6,8 @@ training_metrics: metrics: acc: _target_: torchmetrics.Accuracy - compute_on_step: false + task: multiclass + num_classes: ${data.num_classes} validation_metrics: ${.training_metrics} diff --git a/mart/configs/metric/average_precision.yaml b/mart/configs/metric/average_precision.yaml index d41f9743..dc1d05f0 100644 --- a/mart/configs/metric/average_precision.yaml +++ b/mart/configs/metric/average_precision.yaml @@ -1,20 +1,17 @@ # @package model training_metrics: - _target_: torchmetrics.detection.MAP - compute_on_step: false + _target_: torchmetrics.detection.mean_ap.MeanAveragePrecision validation_metrics: - _target_: torchmetrics.detection.MAP - compute_on_step: false + _target_: torchmetrics.detection.mean_ap.MeanAveragePrecision test_metrics: _target_: torchmetrics.collections.MetricCollection _convert_: partial metrics: map: - _target_: torchmetrics.detection.MAP - compute_on_step: false + _target_: torchmetrics.detection.mean_ap.MeanAveragePrecision json: _target_: mart.utils.export.CocoPredictionJSON prediction_file_name: ${paths.output_dir}/test_prediction.json diff --git a/mart/configs/paths/default.yaml b/mart/configs/paths/default.yaml index b00b6df3..ec81db2d 100644 --- a/mart/configs/paths/default.yaml +++ b/mart/configs/paths/default.yaml @@ -1,6 +1,6 @@ # path to root directory # this requires PROJECT_ROOT environment variable to exist -# PROJECT_ROOT is inferred and set by pyrootutils package in `train.py` and `eval.py` +# you can replace it with "." if you want the root to be the current working directory root_dir: ${oc.env:PROJECT_ROOT} # path to data directory diff --git a/mart/configs/trainer/cpu.yaml b/mart/configs/trainer/cpu.yaml new file mode 100644 index 00000000..640f71d5 --- /dev/null +++ b/mart/configs/trainer/cpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default.yaml + +accelerator: cpu +devices: 1 diff --git a/mart/configs/trainer/default.yaml b/mart/configs/trainer/default.yaml index b3eeb7a1..ccff1885 100644 --- a/mart/configs/trainer/default.yaml +++ b/mart/configs/trainer/default.yaml @@ -1,4 +1,4 @@ -_target_: pytorch_lightning.Trainer +_target_: lightning.pytorch.trainer.Trainer default_root_dir: ${paths.output_dir} @@ -11,6 +11,12 @@ devices: 1 # mixed precision for extra speed-up # precision: 16 +# perform a validation loop every N training epochs +check_val_every_n_epoch: 1 + # set True to to ensure deterministic results # makes training slower but gives more reproducibility than just setting seeds deterministic: False + +# Disable PyTorch inference mode in val/test/predict, because we may run adversarial backprop. +inference_mode: False diff --git a/mart/datamodules/__init__.py b/mart/data/__init__.py similarity index 100% rename from mart/datamodules/__init__.py rename to mart/data/__init__.py diff --git a/mart/datamodules/coco.py b/mart/data/coco.py similarity index 100% rename from mart/datamodules/coco.py rename to mart/data/coco.py diff --git a/mart/datamodules/modular.py b/mart/data/modular.py similarity index 96% rename from mart/datamodules/modular.py rename to mart/data/modular.py index df99fc3e..fa2dc24d 100644 --- a/mart/datamodules/modular.py +++ b/mart/data/modular.py @@ -8,8 +8,8 @@ logger = logging.getLogger(__name__) -import pytorch_lightning as pl # noqa: E402 from hydra.utils import instantiate # noqa: E402 +from lightning import pytorch as pl # noqa: E402 from torch.utils.data import DataLoader, Dataset, Sampler # noqa: E402 __all__ = ["LitDataModule"] @@ -29,6 +29,8 @@ def __init__( ims_per_batch=1, world_size=1, pin_memory=False, + # Classification metrics may require the value in config. + num_classes=None, ): super().__init__() diff --git a/mart/models/modular.py b/mart/models/modular.py index 70df9b39..075e4f75 100644 --- a/mart/models/modular.py +++ b/mart/models/modular.py @@ -8,7 +8,7 @@ from operator import attrgetter import torch -from pytorch_lightning import LightningModule +from lightning.pytorch import LightningModule from ..nn import SequentialDict from ..optim import OptimizerFactory @@ -147,7 +147,7 @@ def training_step(self, batch, batch_idx): return output[self.output_loss_key] - def training_epoch_end(self, outputs): + def on_train_epoch_end(self): if self.training_metrics is not None: # Some models only return loss in the training mode. metrics = self.training_metrics.compute() @@ -171,7 +171,7 @@ def validation_step(self, batch, batch_idx): return None - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): metrics = self.validation_metrics.compute() metrics = self.flatten_metrics(metrics) self.validation_metrics.reset() @@ -193,7 +193,7 @@ def test_step(self, batch, batch_idx): return None - def test_epoch_end(self, outputs): + def on_test_epoch_end(self): metrics = self.test_metrics.compute() metrics = self.flatten_metrics(metrics) self.test_metrics.reset() @@ -239,4 +239,6 @@ def enumerate_metric(metric, name): enumerate_metric(metrics, prefix) - self.log_dict(metrics_dict, prog_bar=prog_bar) + # PyTorch Lightning: It is recommended to use `self.log('validation_metrics/acc', ..., sync_dist=True)` + # when logging on epoch level in distributed setting to accumulate the metric across devices. + self.log_dict(metrics_dict, prog_bar=prog_bar, sync_dist=True) diff --git a/mart/tasks/lightning.py b/mart/tasks/lightning.py index b02c32f1..30a462c4 100644 --- a/mart/tasks/lightning.py +++ b/mart/tasks/lightning.py @@ -1,15 +1,15 @@ from typing import Any, Dict, List, Tuple import hydra -from omegaconf import DictConfig -from pytorch_lightning import ( +from lightning.pytorch import ( Callback, LightningDataModule, LightningModule, Trainer, seed_everything, ) -from pytorch_lightning.loggers import LightningLoggerBase +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig from mart import utils @@ -32,8 +32,8 @@ def lightning(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: seed_everything(cfg.seed, workers=True) # Init lightning datamodule - log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>") - datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule) + log.info(f"Instantiating datamodule <{cfg.data._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) # Init lightning model log.info(f"Instantiating model <{cfg.model._target_}>") model: LightningModule = hydra.utils.instantiate(cfg.model) @@ -42,7 +42,7 @@ def lightning(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) # Init lightning loggers - logger: List[LightningLoggerBase] = utils.instantiate_loggers(cfg.get("logger")) + logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) # Init lightning trainer log.info(f"Instantiating trainer <{cfg.trainer._target_}>") diff --git a/mart/transforms/torchvision_ref.py b/mart/transforms/torchvision_ref.py index 0fa5eb17..1d96c785 100644 --- a/mart/transforms/torchvision_ref.py +++ b/mart/transforms/torchvision_ref.py @@ -35,7 +35,7 @@ def convert_coco_poly_to_mask(segmentations, height, width): # Source: https://github.com/pytorch/vision/blob/dc07ac2add8285e16a716564867d0b4b953f6735/references/detection/coco_utils.py#L47 -# Adapted to mart.datamodules.coco.CocoDetection by adding the "file_name" field. +# Adapted to mart.data.coco.CocoDetection by adding the "file_name" field. class ConvertCocoPolysToMask: def __call__(self, image, target): w, h = F.get_image_size(image) diff --git a/mart/utils/__init__.py b/mart/utils/__init__.py index 50e71b3d..9b647b64 100644 --- a/mart/utils/__init__.py +++ b/mart/utils/__init__.py @@ -1,5 +1,7 @@ from .adapters import * from .export import * +from .instantiators import * +from .logging_utils import * from .monkey_patch import * from .pylogger import * from .rich_utils import * diff --git a/mart/utils/instantiators.py b/mart/utils/instantiators.py new file mode 100644 index 00000000..f175e954 --- /dev/null +++ b/mart/utils/instantiators.py @@ -0,0 +1,52 @@ +from typing import List + +import hydra +from lightning.pytorch import Callback +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +from . import pylogger + +__all__ = ["instantiate_callbacks", "instantiate_loggers"] + +log = pylogger.get_pylogger(__name__) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config.""" + + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config.""" + + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/mart/utils/logging_utils.py b/mart/utils/logging_utils.py new file mode 100644 index 00000000..841eebe4 --- /dev/null +++ b/mart/utils/logging_utils.py @@ -0,0 +1,52 @@ +from lightning.pytorch.utilities import rank_zero_only + +from . import pylogger + +__all__ = ["log_hyperparameters"] + +log = pylogger.get_pylogger(__name__) + + +@rank_zero_only +def log_hyperparameters(object_dict: dict) -> None: + """Controls which config parts are saved by lightning loggers. + + Additionally saves: + - Number of model parameters + """ + + hparams = {} + + cfg = object_dict["cfg"] + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) diff --git a/mart/utils/pylogger.py b/mart/utils/pylogger.py index eb84d233..86a723b9 100644 --- a/mart/utils/pylogger.py +++ b/mart/utils/pylogger.py @@ -1,6 +1,6 @@ import logging -from pytorch_lightning.utilities import rank_zero_only +from lightning.pytorch.utilities import rank_zero_only __all__ = ["get_pylogger"] diff --git a/mart/utils/rich_utils.py b/mart/utils/rich_utils.py index db7e75a5..b5c00130 100644 --- a/mart/utils/rich_utils.py +++ b/mart/utils/rich_utils.py @@ -5,8 +5,8 @@ import rich.syntax import rich.tree from hydra.core.hydra_config import HydraConfig +from lightning.pytorch.utilities import rank_zero_only from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning.utilities import rank_zero_only from rich.prompt import Prompt from mart.utils import pylogger @@ -20,7 +20,7 @@ def print_config_tree( cfg: DictConfig, print_order: Sequence[str] = ( - "datamodule", + "data", "model", "callbacks", "logger", @@ -97,11 +97,3 @@ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: if save_to_file: with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: rich.print(cfg.tags, file=file) - - -if __name__ == "__main__": - from hydra import compose, initialize - - with initialize(version_base="1.2", config_path="../../configs"): - cfg = compose(config_name="train.yaml", return_hydra_config=False, overrides=[]) - print_config_tree(cfg, resolve=False, save_to_file=False) diff --git a/mart/utils/silent.py b/mart/utils/silent.py index b9cbd1c3..bd240b0b 100644 --- a/mart/utils/silent.py +++ b/mart/utils/silent.py @@ -13,7 +13,10 @@ class silent(ContextDecorator): """Suppress logging.""" - DEFAULT_NAMES = ["pytorch_lightning.utilities.rank_zero", "pytorch_lightning.accelerators.gpu"] + DEFAULT_NAMES = [ + "lightning.pytorch.utilities.rank_zero", + "lightning.pytorch.accelerators.cuda", + ] def __init__(self, names=None): if names is None: diff --git a/mart/utils/utils.py b/mart/utils/utils.py index 777005ae..ea400b40 100644 --- a/mart/utils/utils.py +++ b/mart/utils/utils.py @@ -1,5 +1,4 @@ import os -import time import warnings from glob import glob from importlib.util import find_spec @@ -9,22 +8,13 @@ import hydra from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Callback -from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.utilities.model_summary import summarize -from mart.utils import pylogger, rich_utils +from . import pylogger, rich_utils __all__ = [ - "close_loggers", "extras", "get_metric_value", "get_resume_checkpoint", - "instantiate_callbacks", - "instantiate_loggers", - "log_hyperparameters", - "save_file", "task_wrapper", "flatten_dict", ] @@ -32,44 +22,6 @@ log = pylogger.get_pylogger(__name__) -def task_wrapper(task_func: Callable) -> Callable: - """Optional decorator that wraps the task function in extra utilities. - - Makes multirun more resistant to failure. - - Utilities: - - Calling the `utils.extras()` before the task is started - - Calling the `utils.close_loggers()` after the task is finished - - Logging the exception if occurs - - Logging the task total execution time - - Logging the output dir - """ - - def wrap(cfg: DictConfig): - - # apply extra utilities - extras(cfg) - - # execute the task - try: - start_time = time.time() - metric_dict, object_dict = task_func(cfg=cfg) - except Exception as ex: - log.exception("") # save exception to `.log` file - raise ex - finally: - path = Path(cfg.paths.output_dir, "exec_time.log") - content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)" - save_file(path, content) # save task execution time (even if exception occurs) - close_loggers() # close loggers (even if exception occurs so multirun won't fail) - - log.info(f"Output dir: {cfg.paths.output_dir}") - - return metric_dict, object_dict - - return wrap - - def extras(cfg: DictConfig) -> None: """Applies optional utilities before the task is started. @@ -100,91 +52,57 @@ def extras(cfg: DictConfig) -> None: rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) -@rank_zero_only -def save_file(path, content) -> None: - """Save file in rank zero mode (only on one process in multi-GPU setup).""" - with open(path, "w+") as file: - file.write(content) - - -def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: - """Instantiates callbacks from config.""" - callbacks: List[Callback] = [] - - if not callbacks_cfg: - log.warning("Callbacks config is empty.") - return callbacks - - if not isinstance(callbacks_cfg, DictConfig): - raise TypeError("Callbacks config must be a DictConfig!") - - for _, cb_conf in callbacks_cfg.items(): - if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: - log.info(f"Instantiating callback <{cb_conf._target_}>") - callbacks.append(hydra.utils.instantiate(cb_conf)) - - return callbacks - - -def instantiate_loggers(logger_cfg: DictConfig) -> List[LightningLoggerBase]: - """Instantiates loggers from config.""" - logger: List[LightningLoggerBase] = [] - - if not logger_cfg: - log.warning("Logger config is empty.") - return logger - - if not isinstance(logger_cfg, DictConfig): - raise TypeError("Logger config must be a DictConfig!") - - for _, lg_conf in logger_cfg.items(): - if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: - log.info(f"Instantiating logger <{lg_conf._target_}>") - logger.append(hydra.utils.instantiate(lg_conf)) +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that controls the failure behavior when executing the task function. - return logger + This wrapper can be used to: + - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) + - save the exception to a `.log` file + - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) + - etc. (adjust depending on your needs) + Example: + ``` + @utils.task_wrapper + def train(cfg: DictConfig) -> Tuple[dict, dict]: -@rank_zero_only -def log_hyperparameters(object_dict: Dict[str, Any]) -> None: - """Controls which config parts are saved by lightning loggers. + ... - Additionally saves: - - Number of model parameters + return metric_dict, object_dict + ``` """ - hparams = {} - - cfg = object_dict["cfg"] - model = object_dict["model"] - trainer = object_dict["trainer"] - - if not trainer.logger: - log.warning("Logger not found! Skipping hyperparameter logging...") - return + def wrap(cfg: DictConfig): + # execute the task + try: + metric_dict, object_dict = task_func(cfg=cfg) - hparams["model"] = cfg["model"] + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") - # save number of model parameters - summary = summarize(model) + # some hyperparameter combinations might be invalid or cause out-of-memory errors + # so when using hparam search plugins like Optuna, you might want to disable + # raising the below exception to avoid multirun failure + raise ex - hparams["model/params/total"] = summary.total_parameters - hparams["model/params/trainable"] = summary.trainable_parameters - hparams["model/params/non_trainable"] = summary.total_parameters - summary.trainable_parameters + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.output_dir}") - hparams["datamodule"] = cfg["datamodule"] - hparams["trainer"] = cfg["trainer"] + # always close wandb run (even if exception occurs so multirun won't fail) + if find_spec("wandb"): # check if wandb is installed + import wandb - hparams["callbacks"] = cfg.get("callbacks") - hparams["extras"] = cfg.get("extras") + if wandb.run: + log.info("Closing wandb!") + wandb.finish() - hparams["task_name"] = cfg.get("task_name") - hparams["tags"] = cfg.get("tags") - hparams["ckpt_path"] = cfg.get("ckpt_path") - hparams["seed"] = cfg.get("seed") + return metric_dict, object_dict - # send hparams to all loggers - trainer.logger.log_hyperparams(hparams) + return wrap def get_metric_value(metric_dict: dict, metric_name: str) -> float: @@ -207,19 +125,6 @@ def get_metric_value(metric_dict: dict, metric_name: str) -> float: return metric_value -def close_loggers() -> None: - """Makes sure all loggers closed properly (prevents logging failure during multirun).""" - - log.info("Closing loggers...") - - if find_spec("wandb"): # if wandb is installed - import wandb - - if wandb.run: - log.info("Closing wandb!") - wandb.finish() - - def get_resume_checkpoint(config: DictConfig) -> Tuple[DictConfig]: """Resume a task from an existing checkpoint along with the config.""" diff --git a/pyproject.toml b/pyproject.toml index 941807b0..5ed8161b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,10 +11,11 @@ authors = [ requires-python = ">=3.9" dependencies = [ - "torch ~= 1.13.1", - "torchvision ~= 0.14.1", - "pytorch-lightning ~= 1.6.5", - "torchmetrics == 0.6.0", + "torch ~= 2.0.1", + "torchvision ~= 0.15.2", + "lightning ~= 2.0.2", + "pydantic == 1.10.11", # https://github.com/Lightning-AI/lightning/pull/18022/files + "torchmetrics == 1.0.0", "numpy == 1.23.5", # https://github.com/pytorch/pytorch/issues/91516 # --------- hydra --------- # diff --git a/tests/conftest.py b/tests/conftest.py index 220f8734..22c42b22 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,7 +29,7 @@ # Loads the configuration file from a given experiment def get_cfg(experiment): - with initialize(version_base="1.2", config_path="../mart/configs"): + with initialize(version_base="1.3", config_path="../mart/configs"): params = "experiment=" + experiment cfg = compose(config_name="lightning.yaml", return_hydra_config=True, overrides=[params]) return cfg diff --git a/tests/helpers/package_available.py b/tests/helpers/package_available.py index 15630f19..614778fe 100644 --- a/tests/helpers/package_available.py +++ b/tests/helpers/package_available.py @@ -1,7 +1,7 @@ import platform import pkg_resources -from pytorch_lightning.utilities.xla_device import XLADeviceUtils +from lightning.fabric.accelerators import TPUAccelerator def _package_available(package_name: str) -> bool: @@ -12,7 +12,7 @@ def _package_available(package_name: str) -> bool: return False -_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() +_TPU_AVAILABLE = TPUAccelerator.is_available() _IS_WINDOWS = platform.system() == "Windows" diff --git a/tests/test_adversary.py b/tests/test_adversary.py index bf9ff9c0..4302631b 100644 --- a/tests/test_adversary.py +++ b/tests/test_adversary.py @@ -9,7 +9,7 @@ import pytest import torch -from pytorch_lightning.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.exceptions import MisconfigurationException from torch.optim import SGD import mart @@ -355,7 +355,8 @@ def test_configure_gradient_clipping(): # We need to mock a trainer since LightningModule does some checks adversary.trainer = Mock(gradient_clip_val=1.0, gradient_clip_algorithm="norm") - adversary.configure_gradient_clipping(optimizer, 0) + # Keep the gradient_clip_val the same in Trainer and configure_gradient_clipping() + adversary.configure_gradient_clipping(optimizer, 1.0) # Once for each parameter in the optimizer assert gradient_modifier.call_count == 2 diff --git a/tests/test_configs.py b/tests/test_configs.py index bcd9a4da..9817ea3a 100644 --- a/tests/test_configs.py +++ b/tests/test_configs.py @@ -13,13 +13,13 @@ @patch("torchvision.datasets.imagenet.ImageNet.__init__") -@patch("mart.datamodules.coco.CocoDetection_.__init__") +@patch("mart.data.coco.CocoDetection_.__init__") @patch("torchvision.datasets.CIFAR10.__init__") def test_experiment_config( mock_cifar10: Mock, mock_coco: Mock, mock_imagenet: Mock, cfg_experiment: DictConfig ): assert cfg_experiment - assert cfg_experiment.datamodule + assert cfg_experiment.data assert cfg_experiment.model assert cfg_experiment.trainer @@ -30,6 +30,6 @@ def test_experiment_config( HydraConfig().set_config(cfg_experiment) - hydra.utils.instantiate(cfg_experiment.datamodule) + hydra.utils.instantiate(cfg_experiment.data) hydra.utils.instantiate(cfg_experiment.model) hydra.utils.instantiate(cfg_experiment.trainer) diff --git a/tests/test_experiments.py b/tests/test_experiments.py index 59a058c8..87e918e4 100644 --- a/tests/test_experiments.py +++ b/tests/test_experiments.py @@ -40,9 +40,9 @@ def classification_cfg() -> Dict: "++trainer.fast_dev_run=3", ], "datamodel": [ - "datamodule=dummy_classification", - "datamodule.ims_per_batch=2", - "datamodule.num_workers=0", + "data=dummy_classification", + "data.ims_per_batch=2", + "data.num_workers=0", ], } yield cfg @@ -63,7 +63,7 @@ def coco_cfg(tmp_path) -> Dict: ], "datamodel": [ "++paths.data_dir=" + str(tmp_path), - "datamodule.num_workers=0", + "data.num_workers=0", ], } yield cfg @@ -82,8 +82,10 @@ def test_cifar10_cnn_adv_experiment(classification_cfg, tmp_path): "hydra.sweep.dir=" + str(tmp_path), "model.modules.input_adv_test.max_iters=10", "optimized_metric=training_metrics/acc", - "++datamodule.train_dataset.image_size=[3,32,32]", - "++datamodule.train_dataset.num_classes=10", + "++data.train_dataset.image_size=[3,32,32]", + "++data.train_dataset.num_classes=10", + # The accuracy metric may require data.num_classes. + "+data.num_classes=${data.train_dataset.num_classes}", ] + overrides run_sh_command(command) @@ -98,8 +100,10 @@ def test_cifar10_cnn_experiment(classification_cfg, tmp_path): "experiment=CIFAR10_CNN", "hydra.sweep.dir=" + str(tmp_path), "optimized_metric=training_metrics/acc", - "++datamodule.train_dataset.image_size=[3,32,32]", - "++datamodule.train_dataset.num_classes=10", + "++data.train_dataset.image_size=[3,32,32]", + "++data.train_dataset.num_classes=10", + # The accuracy metric may require data.num_classes. + "+data.num_classes=${data.train_dataset.num_classes}", ] + overrides run_sh_command(command) @@ -116,8 +120,10 @@ def test_imagenet_timm_experiment(classification_cfg, tmp_path): "hydra.sweep.dir=" + str(tmp_path), "++trainer.precision=32", "optimized_metric=training_metrics/acc", - "++datamodule.train_dataset.image_size=[3,469,387]", - "++datamodule.train_dataset.num_classes=200", + "++data.train_dataset.image_size=[3,469,387]", + "++data.train_dataset.num_classes=1000", + # The accuracy metric may require data.num_classes. + "+data.num_classes=${data.train_dataset.num_classes}", ] + overrides run_sh_command(command) @@ -189,12 +195,13 @@ def test_resume(tmpdir): "\n".join( [ "- experiment=CIFAR10_CNN", - "- datamodule=dummy_classification", - "- datamodule.ims_per_batch=2", - "- datamodule.num_workers=0", - "- datamodule.train_dataset.size=2", - "- datamodule.train_dataset.image_size=[3,32,32]", - "- datamodule.train_dataset.num_classes=10", + "- data=dummy_classification", + "- data.ims_per_batch=2", + "- data.num_workers=0", + "- data.train_dataset.size=2", + "- data.train_dataset.image_size=[3,32,32]", + "- data.train_dataset.num_classes=10", + "- +data.num_classes=10", "- fit=false", # Don't train or test the model, because the checkpoint is invalid. "- test=false", "- optimized_metric=null", # No metric to retrieve. diff --git a/tests/test_perturber.py b/tests/test_perturber.py index bb6c204b..af627359 100644 --- a/tests/test_perturber.py +++ b/tests/test_perturber.py @@ -9,7 +9,7 @@ import pytest import torch -from pytorch_lightning.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.exceptions import MisconfigurationException from torch.optim import SGD import mart