From 38313bcbfb075039f5382e7a38fd67a2ae142e2b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 9 May 2023 16:04:10 +0000 Subject: [PATCH 01/22] Update torch requirement from ~=1.13.1 to ~=2.0.1 --- updated-dependencies: - dependency-name: torch dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6440af7b..06abcdf0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ authors = [ requires-python = ">=3.9" dependencies = [ - "torch ~= 1.13.1", + "torch ~= 2.0.1", "torchvision ~= 0.14.1", "pytorch-lightning ~= 1.6.5", "torchmetrics == 0.6.0", From 4faa3180ee93abf313147b2168dc1cc93c3a653f Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Fri, 2 Jun 2023 11:57:56 -0700 Subject: [PATCH 02/22] Update to a newer torchvision that matches torch 2.0.1. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 06abcdf0..86580df1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ requires-python = ">=3.9" dependencies = [ "torch ~= 2.0.1", - "torchvision ~= 0.14.1", + "torchvision ~= 0.15.2", "pytorch-lightning ~= 1.6.5", "torchmetrics == 0.6.0", From 9df5f52c59d288fb9492c659905ecdf143f77d23 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Fri, 2 Jun 2023 12:02:02 -0700 Subject: [PATCH 03/22] Upgrade pytorch-lightning and torchmetrics. --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 86580df1..8be2788b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,8 +13,8 @@ requires-python = ">=3.9" dependencies = [ "torch ~= 2.0.1", "torchvision ~= 0.15.2", - "pytorch-lightning ~= 1.6.5", - "torchmetrics == 0.6.0", + "pytorch-lightning ~= 2.0.2", + "torchmetrics == 0.11.4", # --------- hydra --------- # "hydra-core ~= 1.2.0", From 098daf88d899bebe05adacd62a996bfc976fca3e Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Mon, 5 Jun 2023 16:35:29 -0700 Subject: [PATCH 04/22] Update pre-commit. --- .pre-commit-config.yaml | 54 +++++++++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 16 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 31a624f2..3e0bfeb5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ default_language_version: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.4.0 hooks: # list of supported hooks: https://pre-commit.com/hooks.html - id: trailing-whitespace @@ -19,7 +19,7 @@ repos: # python code formatting - repo: https://github.com/psf/black - rev: 22.6.0 + rev: 23.1.0 hooks: - id: black args: [--line-length, "99"] @@ -33,55 +33,56 @@ repos: # python upgrading syntax to newer version - repo: https://github.com/asottile/pyupgrade - rev: v2.32.1 + rev: v3.3.1 hooks: - id: pyupgrade args: [--py38-plus] # python docstring formatting - repo: https://github.com/myint/docformatter - rev: v1.4 + rev: v1.5.1 hooks: - id: docformatter args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99] # python check (PEP8), programming errors and code complexity - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 + rev: 6.0.0 hooks: - id: flake8 # ignore E203 because black is used for formatting. args: [ - "--ignore", - "E203,E501,F401,F403,F841,W504", + "--extend-ignore", + "E203,E402,E501,F401,F403,F841", "--exclude", "logs/*,data/*", ] # python security linter - repo: https://github.com/PyCQA/bandit - rev: "1.7.1" + rev: "1.7.5" hooks: - id: bandit args: ["-s", "B101"] # yaml formatting - repo: https://github.com/pre-commit/mirrors-prettier - rev: v2.7.1 + rev: v3.0.0-alpha.6 hooks: - id: prettier types: [yaml] + exclude: "environment.yaml" - # jupyter notebook cell output clearing - - repo: https://github.com/kynan/nbstripout - rev: 0.5.0 + # shell scripts linter + - repo: https://github.com/shellcheck-py/shellcheck-py + rev: v0.9.0.2 hooks: - - id: nbstripout + - id: shellcheck # md formatting - repo: https://github.com/executablebooks/mdformat - rev: 0.7.14 + rev: 0.7.16 hooks: - id: mdformat args: ["--number"] @@ -94,9 +95,30 @@ repos: # word spelling linter - repo: https://github.com/codespell-project/codespell - rev: v2.1.0 + rev: v2.2.4 hooks: - id: codespell args: - - --skip=logs/**,data/** + - --skip=logs/**,data/**,*.ipynb - --ignore-words-list=abc,def,gard + + # jupyter notebook cell output clearing + - repo: https://github.com/kynan/nbstripout + rev: 0.6.1 + hooks: + - id: nbstripout + + # jupyter notebook linting + - repo: https://github.com/nbQA-dev/nbQA + rev: 1.6.3 + hooks: + - id: nbqa-black + args: ["--line-length=99"] + - id: nbqa-isort + args: ["--profile=black"] + - id: nbqa-flake8 + args: + [ + "--extend-ignore=E203,E402,E501,F401,F841", + "--exclude=logs/*,data/*", + ] From 05bbf0d79f1ca5a8150e9511179c6280e787e5c1 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Mon, 5 Jun 2023 16:44:01 -0700 Subject: [PATCH 05/22] datamodule->data --- .../{datamodule => data}/armory_carla_over_objdet.yaml | 0 .../armory_carla_over_objdet_perturbable_mask.yaml | 0 mart/configs/{datamodule => data}/cifar10.yaml | 0 mart/configs/{datamodule => data}/coco.yaml | 8 ++++---- .../{datamodule => data}/coco_perturbable_mask.yaml | 0 mart/configs/{datamodule => data}/default.yaml | 2 +- .../{datamodule => data}/dummy_classification.yaml | 0 mart/configs/{datamodule => data}/imagenet.yaml | 0 mart/{datamodules => data}/__init__.py | 0 mart/{datamodules => data}/coco.py | 0 mart/{datamodules => data}/modular.py | 0 tests/test_configs.py | 6 +++--- 12 files changed, 8 insertions(+), 8 deletions(-) rename mart/configs/{datamodule => data}/armory_carla_over_objdet.yaml (100%) rename mart/configs/{datamodule => data}/armory_carla_over_objdet_perturbable_mask.yaml (100%) rename mart/configs/{datamodule => data}/cifar10.yaml (100%) rename mart/configs/{datamodule => data}/coco.yaml (91%) rename mart/configs/{datamodule => data}/coco_perturbable_mask.yaml (100%) rename mart/configs/{datamodule => data}/default.yaml (77%) rename mart/configs/{datamodule => data}/dummy_classification.yaml (100%) rename mart/configs/{datamodule => data}/imagenet.yaml (100%) rename mart/{datamodules => data}/__init__.py (100%) rename mart/{datamodules => data}/coco.py (100%) rename mart/{datamodules => data}/modular.py (100%) diff --git a/mart/configs/datamodule/armory_carla_over_objdet.yaml b/mart/configs/data/armory_carla_over_objdet.yaml similarity index 100% rename from mart/configs/datamodule/armory_carla_over_objdet.yaml rename to mart/configs/data/armory_carla_over_objdet.yaml diff --git a/mart/configs/datamodule/armory_carla_over_objdet_perturbable_mask.yaml b/mart/configs/data/armory_carla_over_objdet_perturbable_mask.yaml similarity index 100% rename from mart/configs/datamodule/armory_carla_over_objdet_perturbable_mask.yaml rename to mart/configs/data/armory_carla_over_objdet_perturbable_mask.yaml diff --git a/mart/configs/datamodule/cifar10.yaml b/mart/configs/data/cifar10.yaml similarity index 100% rename from mart/configs/datamodule/cifar10.yaml rename to mart/configs/data/cifar10.yaml diff --git a/mart/configs/datamodule/coco.yaml b/mart/configs/data/coco.yaml similarity index 91% rename from mart/configs/datamodule/coco.yaml rename to mart/configs/data/coco.yaml index a4ec3403..978daaaf 100644 --- a/mart/configs/datamodule/coco.yaml +++ b/mart/configs/data/coco.yaml @@ -2,7 +2,7 @@ defaults: - default.yaml train_dataset: - _target_: mart.datamodules.coco.CocoDetection + _target_: mart.data.coco.CocoDetection root: ${paths.data_dir}/coco/train2017 annFile: ${paths.data_dir}/coco/annotations/instances_train2017.json transforms: @@ -24,7 +24,7 @@ train_dataset: quant_max: 255 val_dataset: - _target_: mart.datamodules.coco.CocoDetection + _target_: mart.data.coco.CocoDetection root: ${paths.data_dir}/coco/val2017 annFile: ${paths.data_dir}/coco/annotations/instances_val2017.json transforms: @@ -44,7 +44,7 @@ val_dataset: quant_max: 255 test_dataset: - _target_: mart.datamodules.coco.CocoDetection + _target_: mart.data.coco.CocoDetection root: ${paths.data_dir}/coco/val2017 annFile: ${paths.data_dir}/coco/annotations/instances_val2017.json transforms: @@ -66,4 +66,4 @@ test_dataset: num_workers: 2 collate_fn: _target_: hydra.utils.get_method - path: mart.datamodules.coco.collate_fn + path: mart.data.coco.collate_fn diff --git a/mart/configs/datamodule/coco_perturbable_mask.yaml b/mart/configs/data/coco_perturbable_mask.yaml similarity index 100% rename from mart/configs/datamodule/coco_perturbable_mask.yaml rename to mart/configs/data/coco_perturbable_mask.yaml diff --git a/mart/configs/datamodule/default.yaml b/mart/configs/data/default.yaml similarity index 77% rename from mart/configs/datamodule/default.yaml rename to mart/configs/data/default.yaml index d3da6d4c..a82df951 100644 --- a/mart/configs/datamodule/default.yaml +++ b/mart/configs/data/default.yaml @@ -1,4 +1,4 @@ -_target_: mart.datamodules.LitDataModule +_target_: mart.data.LitDataModule # _convert_: all train_dataset: ??? diff --git a/mart/configs/datamodule/dummy_classification.yaml b/mart/configs/data/dummy_classification.yaml similarity index 100% rename from mart/configs/datamodule/dummy_classification.yaml rename to mart/configs/data/dummy_classification.yaml diff --git a/mart/configs/datamodule/imagenet.yaml b/mart/configs/data/imagenet.yaml similarity index 100% rename from mart/configs/datamodule/imagenet.yaml rename to mart/configs/data/imagenet.yaml diff --git a/mart/datamodules/__init__.py b/mart/data/__init__.py similarity index 100% rename from mart/datamodules/__init__.py rename to mart/data/__init__.py diff --git a/mart/datamodules/coco.py b/mart/data/coco.py similarity index 100% rename from mart/datamodules/coco.py rename to mart/data/coco.py diff --git a/mart/datamodules/modular.py b/mart/data/modular.py similarity index 100% rename from mart/datamodules/modular.py rename to mart/data/modular.py diff --git a/tests/test_configs.py b/tests/test_configs.py index bcd9a4da..9817ea3a 100644 --- a/tests/test_configs.py +++ b/tests/test_configs.py @@ -13,13 +13,13 @@ @patch("torchvision.datasets.imagenet.ImageNet.__init__") -@patch("mart.datamodules.coco.CocoDetection_.__init__") +@patch("mart.data.coco.CocoDetection_.__init__") @patch("torchvision.datasets.CIFAR10.__init__") def test_experiment_config( mock_cifar10: Mock, mock_coco: Mock, mock_imagenet: Mock, cfg_experiment: DictConfig ): assert cfg_experiment - assert cfg_experiment.datamodule + assert cfg_experiment.data assert cfg_experiment.model assert cfg_experiment.trainer @@ -30,6 +30,6 @@ def test_experiment_config( HydraConfig().set_config(cfg_experiment) - hydra.utils.instantiate(cfg_experiment.datamodule) + hydra.utils.instantiate(cfg_experiment.data) hydra.utils.instantiate(cfg_experiment.model) hydra.utils.instantiate(cfg_experiment.trainer) From 395692d9b420ca4bd5947217a550525b5caa15e9 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Mon, 5 Jun 2023 16:48:15 -0700 Subject: [PATCH 06/22] Add data.num_classes for the Accuracy metric of classification. --- mart/configs/data/cifar10.yaml | 3 ++ mart/configs/data/imagenet.yaml | 3 ++ mart/configs/metric/accuracy.yaml | 3 +- mart/data/modular.py | 2 ++ tests/test_experiments.py | 53 +++++++++++++++++++------------ 5 files changed, 42 insertions(+), 22 deletions(-) diff --git a/mart/configs/data/cifar10.yaml b/mart/configs/data/cifar10.yaml index ec2e795d..7f8b5b80 100644 --- a/mart/configs/data/cifar10.yaml +++ b/mart/configs/data/cifar10.yaml @@ -55,3 +55,6 @@ test_dataset: ${.val_dataset} num_workers: 4 collate_fn: null + +# The accuracy metric may require this value. +num_classes: 10 diff --git a/mart/configs/data/imagenet.yaml b/mart/configs/data/imagenet.yaml index d8253cbc..f0967743 100644 --- a/mart/configs/data/imagenet.yaml +++ b/mart/configs/data/imagenet.yaml @@ -46,3 +46,6 @@ val_dataset: quant_max: 255 test_dataset: ${.val_dataset} + +# The accuracy metric may require this value. +num_classes: 1000 diff --git a/mart/configs/metric/accuracy.yaml b/mart/configs/metric/accuracy.yaml index d1e3709f..a184aa34 100644 --- a/mart/configs/metric/accuracy.yaml +++ b/mart/configs/metric/accuracy.yaml @@ -5,7 +5,8 @@ training_metrics: _convert_: partial # metrics must be a dict metrics: acc: - _target_: torchmetrics.Accuracy + _target_: torchmetrics.classification.MulticlassAccuracy + num_classes: ${data.num_classes} compute_on_step: false validation_metrics: ${.training_metrics} diff --git a/mart/data/modular.py b/mart/data/modular.py index df99fc3e..65cbb888 100644 --- a/mart/data/modular.py +++ b/mart/data/modular.py @@ -29,6 +29,8 @@ def __init__( ims_per_batch=1, world_size=1, pin_memory=False, + # Classification metrics may require the value in config. + num_classes=None, ): super().__init__() diff --git a/tests/test_experiments.py b/tests/test_experiments.py index d128c1df..148c37f9 100644 --- a/tests/test_experiments.py +++ b/tests/test_experiments.py @@ -61,9 +61,9 @@ def classification_cfg() -> Dict: "++trainer.fast_dev_run=3", ], "datamodel": [ - "datamodule=dummy_classification", - "datamodule.ims_per_batch=2", - "datamodule.num_workers=0", + "data=dummy_classification", + "data.ims_per_batch=2", + "data.num_workers=0", ], } yield cfg @@ -84,7 +84,7 @@ def coco_cfg(tmp_path) -> Dict: ], "datamodel": [ "++paths.data_dir=" + str(tmp_path), - "datamodule.num_workers=0", + "data.num_workers=0", ], } yield cfg @@ -104,7 +104,7 @@ def carla_cfg(tmp_path) -> Dict: ], "datamodel": [ "++paths.data_dir=" + str(tmp_path), - "datamodule.num_workers=0", + "data.num_workers=0", ], } yield cfg @@ -123,8 +123,10 @@ def test_cifar10_cnn_adv_experiment(classification_cfg, tmp_path): "hydra.sweep.dir=" + str(tmp_path), "model.modules.input_adv_test.max_iters=10", "optimized_metric=training_metrics/acc", - "++datamodule.train_dataset.image_size=[3,32,32]", - "++datamodule.train_dataset.num_classes=10", + "++data.train_dataset.image_size=[3,32,32]", + "++data.train_dataset.num_classes=10", + # The accuracy metric may require data.num_classes. + "+data.num_classes=${data.train_dataset.num_classes}", ] + overrides run_sh_command(command) @@ -139,8 +141,10 @@ def test_cifar10_cnn_experiment(classification_cfg, tmp_path): "experiment=CIFAR10_CNN", "hydra.sweep.dir=" + str(tmp_path), "optimized_metric=training_metrics/acc", - "++datamodule.train_dataset.image_size=[3,32,32]", - "++datamodule.train_dataset.num_classes=10", + "++data.train_dataset.image_size=[3,32,32]", + "++data.train_dataset.num_classes=10", + # The accuracy metric may require data.num_classes. + "+data.num_classes=${data.train_dataset.num_classes}", ] + overrides run_sh_command(command) @@ -154,8 +158,10 @@ def test_cifar10_cnn_autoattack_experiment(classification_cfg, tmp_path): "-m", "experiment=CIFAR10_CNN", "hydra.sweep.dir=" + str(tmp_path), - "++datamodule.train_dataset.image_size=[3,32,32]", - "++datamodule.train_dataset.num_classes=10", + "++data.train_dataset.image_size=[3,32,32]", + "++data.train_dataset.num_classes=10", + # The accuracy metric may require data.num_classes. + "+data.num_classes=${data.train_dataset.num_classes}", "fit=false", "+attack@model.modules.input_adv_test=classification_autoattack", '+model.modules.input_adv_test.adversary.partial.device="cpu"', @@ -175,8 +181,10 @@ def test_cifar10_robust_bench_experiment(classification_cfg, tmp_path): "hydra.sweep.dir=" + str(tmp_path), "+attack@model.modules.input_adv_test=classification_eps8_pgd10_step1", "optimized_metric=training_metrics/acc", - "++datamodule.train_dataset.image_size=[3,32,32]", - "++datamodule.train_dataset.num_classes=10", + "++data.train_dataset.image_size=[3,32,32]", + "++data.train_dataset.num_classes=10", + # The accuracy metric may require data.num_classes. + "+data.num_classes=${data.train_dataset.num_classes}", ] + overrides run_sh_command(command) @@ -193,8 +201,10 @@ def test_imagenet_timm_experiment(classification_cfg, tmp_path): "hydra.sweep.dir=" + str(tmp_path), "++trainer.precision=32", "optimized_metric=training_metrics/acc", - "++datamodule.train_dataset.image_size=[3,469,387]", - "++datamodule.train_dataset.num_classes=200", + "++data.train_dataset.image_size=[3,469,387]", + "++data.train_dataset.num_classes=1000", + # The accuracy metric may require data.num_classes. + "+data.num_classes=${data.train_dataset.num_classes}", ] + overrides run_sh_command(command) @@ -282,12 +292,13 @@ def test_resume(tmpdir): "\n".join( [ "- experiment=CIFAR10_RobustBench", - "- datamodule=dummy_classification", - "- datamodule.ims_per_batch=2", - "- datamodule.num_workers=0", - "- datamodule.train_dataset.size=2", - "- datamodule.train_dataset.image_size=[3,32,32]", - "- datamodule.train_dataset.num_classes=10", + "- data=dummy_classification", + "- data.ims_per_batch=2", + "- data.num_workers=0", + "- data.train_dataset.size=2", + "- data.train_dataset.image_size=[3,32,32]", + "- data.train_dataset.num_classes=10", + "- +data.num_classes=10", "- fit=false", # Don't train or test the model, because the checkpoint is invalid. "- test=false", "- optimized_metric=null", # No metric to retrieve. From 4c7c04bf15b62b05e19c7a904d8b60231892f659 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Mon, 5 Jun 2023 16:50:49 -0700 Subject: [PATCH 07/22] datamodule->data extra. --- mart/__init__.py | 2 +- mart/__main__.py | 3 +-- .../ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml | 2 +- mart/configs/experiment/CIFAR10_CNN.yaml | 4 ++-- mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml | 4 ++-- mart/configs/experiment/COCO_TorchvisionFasterRCNN_Adv.yaml | 2 +- mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml | 4 ++-- mart/configs/experiment/ImageNet_Timm.yaml | 4 ++-- mart/configs/hparams_search/mnist_optuna.yaml | 2 +- mart/transforms/torchvision_ref.py | 2 +- 10 files changed, 14 insertions(+), 15 deletions(-) diff --git a/mart/__init__.py b/mart/__init__.py index 85181105..7ee4c290 100644 --- a/mart/__init__.py +++ b/mart/__init__.py @@ -1,7 +1,7 @@ import importlib from mart import attack as attack -from mart import datamodules as datamodules +from mart import data as data from mart import models as models from mart import nn as nn from mart import optim as optim diff --git a/mart/__main__.py b/mart/__main__.py index 42d71627..f8d19998 100644 --- a/mart/__main__.py +++ b/mart/__main__.py @@ -31,8 +31,7 @@ @hydra.main(version_base="1.2", config_path=config_path, config_name="lightning.yaml") def main(cfg: DictConfig) -> float: - - if cfg.resume is None and ("datamodule" not in cfg or "model" not in cfg): + if cfg.resume is None and ("data" not in cfg or "model" not in cfg): log.fatal("") log.fatal("Please specify an experiment to run, e.g.") log.fatal( diff --git a/mart/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml b/mart/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml index 8d512f7d..9c601e04 100644 --- a/mart/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml +++ b/mart/configs/experiment/ArmoryCarlaOverObjDet_TorchvisionFasterRCNN.yaml @@ -3,7 +3,7 @@ defaults: - COCO_TorchvisionFasterRCNN - override /model/modules@model.modules.preprocessor: tuple_tensorizer_normalizer - - override /datamodule: armory_carla_over_objdet_perturbable_mask + - override /data: armory_carla_over_objdet_perturbable_mask task_name: "ArmoryCarlaOverObjDet_TorchvisionFasterRCNN" tags: ["regular_training"] diff --git a/mart/configs/experiment/CIFAR10_CNN.yaml b/mart/configs/experiment/CIFAR10_CNN.yaml index 6885d90b..ea8e905d 100644 --- a/mart/configs/experiment/CIFAR10_CNN.yaml +++ b/mart/configs/experiment/CIFAR10_CNN.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - - override /datamodule: cifar10 + - override /data: cifar10 - override /model: classifier_cifar10_cnn - override /metric: accuracy - override /optimization: super_convergence @@ -22,7 +22,7 @@ trainer: max_steps: 5850 precision: 32 -datamodule: +data: ims_per_batch: 128 world_size: 1 num_workers: 8 diff --git a/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml b/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml index a1494d63..962bc5f7 100644 --- a/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml +++ b/mart/configs/experiment/COCO_TorchvisionFasterRCNN.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - - override /datamodule: coco + - override /data: coco - override /model: torchvision_faster_rcnn - override /metric: average_precision - override /optimization: super_convergence @@ -23,7 +23,7 @@ trainer: # FIXME: "nms_kernel" not implemented for 'BFloat16', torch.ops.torchvision.nms(). precision: 32 -datamodule: +data: ims_per_batch: 2 world_size: 1 diff --git a/mart/configs/experiment/COCO_TorchvisionFasterRCNN_Adv.yaml b/mart/configs/experiment/COCO_TorchvisionFasterRCNN_Adv.yaml index 398394bf..ee54403a 100644 --- a/mart/configs/experiment/COCO_TorchvisionFasterRCNN_Adv.yaml +++ b/mart/configs/experiment/COCO_TorchvisionFasterRCNN_Adv.yaml @@ -3,7 +3,7 @@ defaults: - COCO_TorchvisionFasterRCNN - /attack@model.modules.input_adv_test: object_detection_mask_adversary - - override /datamodule: coco_perturbable_mask + - override /data: coco_perturbable_mask task_name: "COCO_TorchvisionFasterRCNN_Adv" tags: ["adv"] diff --git a/mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml b/mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml index 39ab8fbc..f2f278fa 100644 --- a/mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml +++ b/mart/configs/experiment/COCO_TorchvisionRetinaNet.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - - override /datamodule: coco + - override /data: coco - override /model: torchvision_retinanet - override /metric: average_precision - override /optimization: super_convergence @@ -22,7 +22,7 @@ trainer: max_steps: 351798 precision: 16 -datamodule: +data: ims_per_batch: 2 world_size: 1 diff --git a/mart/configs/experiment/ImageNet_Timm.yaml b/mart/configs/experiment/ImageNet_Timm.yaml index 9173d5a5..d1022df3 100644 --- a/mart/configs/experiment/ImageNet_Timm.yaml +++ b/mart/configs/experiment/ImageNet_Timm.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - - override /datamodule: imagenet + - override /data: imagenet - override /model: classifier_timm - override /metric: accuracy - override /optimization: super_convergence @@ -20,7 +20,7 @@ trainer: max_steps: 140625 precision: 16 -datamodule: +data: ims_per_batch: 128 world_size: 1 num_workers: 8 diff --git a/mart/configs/hparams_search/mnist_optuna.yaml b/mart/configs/hparams_search/mnist_optuna.yaml index f328d07a..1391183e 100644 --- a/mart/configs/hparams_search/mnist_optuna.yaml +++ b/mart/configs/hparams_search/mnist_optuna.yaml @@ -46,7 +46,7 @@ hydra: # define hyperparameter search space params: model.optimizer.lr: interval(0.0001, 0.1) - datamodule.batch_size: choice(32, 64, 128, 256) + data.batch_size: choice(32, 64, 128, 256) model.net.lin1_size: choice(64, 128, 256) model.net.lin2_size: choice(64, 128, 256) model.net.lin3_size: choice(32, 64, 128, 256) diff --git a/mart/transforms/torchvision_ref.py b/mart/transforms/torchvision_ref.py index 0fa5eb17..1d96c785 100644 --- a/mart/transforms/torchvision_ref.py +++ b/mart/transforms/torchvision_ref.py @@ -35,7 +35,7 @@ def convert_coco_poly_to_mask(segmentations, height, width): # Source: https://github.com/pytorch/vision/blob/dc07ac2add8285e16a716564867d0b4b953f6735/references/detection/coco_utils.py#L47 -# Adapted to mart.datamodules.coco.CocoDetection by adding the "file_name" field. +# Adapted to mart.data.coco.CocoDetection by adding the "file_name" field. class ConvertCocoPolysToMask: def __call__(self, image, target): w, h = F.get_image_size(image) From a0ba7ba59752829e7c6c73da2a0bbae9b1b1bb7c Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Mon, 5 Jun 2023 16:55:26 -0700 Subject: [PATCH 08/22] pytorch_lightning->lightning.pytorch --- mart/attack/perturber.py | 2 +- mart/callbacks/eval_mode.py | 2 +- mart/callbacks/no_grad_mode.py | 2 +- mart/callbacks/visualizer.py | 2 +- mart/configs/callbacks/early_stopping.yaml | 6 ++---- mart/configs/callbacks/lr_monitor.yaml | 2 +- mart/configs/callbacks/model_checkpoint.yaml | 6 ++---- mart/configs/callbacks/model_summary.yaml | 6 ++---- mart/configs/callbacks/rich_progress_bar.yaml | 6 ++---- mart/configs/debug/profiler.yaml | 6 +++--- mart/configs/logger/comet.yaml | 2 +- mart/configs/logger/csv.yaml | 2 +- mart/configs/logger/mlflow.yaml | 2 +- mart/configs/logger/neptune.yaml | 2 +- mart/configs/logger/tensorboard.yaml | 2 +- mart/tasks/lightning.py | 12 ++++++------ mart/utils/pylogger.py | 2 +- mart/utils/silent.py | 5 ++++- tests/helpers/package_available.py | 4 ++-- tests/test_adversary.py | 5 +++-- tests/test_perturber.py | 2 +- 21 files changed, 38 insertions(+), 42 deletions(-) diff --git a/mart/attack/perturber.py b/mart/attack/perturber.py index 29df3059..75c565ad 100644 --- a/mart/attack/perturber.py +++ b/mart/attack/perturber.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Iterable import torch -from pytorch_lightning.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.exceptions import MisconfigurationException from .projector import Projector diff --git a/mart/callbacks/eval_mode.py b/mart/callbacks/eval_mode.py index be3b6397..44597929 100644 --- a/mart/callbacks/eval_mode.py +++ b/mart/callbacks/eval_mode.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause # -from pytorch_lightning.callbacks import Callback +from lightning.pytorch.callbacks import Callback __all__ = ["AttackInEvalMode"] diff --git a/mart/callbacks/no_grad_mode.py b/mart/callbacks/no_grad_mode.py index cfb90ead..5b3b1446 100644 --- a/mart/callbacks/no_grad_mode.py +++ b/mart/callbacks/no_grad_mode.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause # -from pytorch_lightning.callbacks import Callback +from lightning.pytorch.callbacks import Callback __all__ = ["ModelParamsNoGrad"] diff --git a/mart/callbacks/visualizer.py b/mart/callbacks/visualizer.py index 3354321e..39409143 100644 --- a/mart/callbacks/visualizer.py +++ b/mart/callbacks/visualizer.py @@ -6,7 +6,7 @@ import os -from pytorch_lightning.callbacks import Callback +from lightning.pytorch.callbacks import Callback from torchvision.transforms import ToPILImage __all__ = ["PerturbedImageVisualizer"] diff --git a/mart/configs/callbacks/early_stopping.yaml b/mart/configs/callbacks/early_stopping.yaml index 20ed2671..c826c8d5 100644 --- a/mart/configs/callbacks/early_stopping.yaml +++ b/mart/configs/callbacks/early_stopping.yaml @@ -1,9 +1,7 @@ -# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.EarlyStopping.html +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html -# Monitor a metric and stop training when it stops improving. -# Look at the above link for more detailed information. early_stopping: - _target_: pytorch_lightning.callbacks.EarlyStopping + _target_: lightning.pytorch.callbacks.EarlyStopping monitor: ??? # quantity to be monitored, must be specified !!! min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement patience: 3 # number of checks with no improvement after which training will be stopped diff --git a/mart/configs/callbacks/lr_monitor.yaml b/mart/configs/callbacks/lr_monitor.yaml index a23b28ee..f4bad6f8 100644 --- a/mart/configs/callbacks/lr_monitor.yaml +++ b/mart/configs/callbacks/lr_monitor.yaml @@ -1,3 +1,3 @@ lr_monitor: - _target_: pytorch_lightning.callbacks.LearningRateMonitor + _target_: lightning.pytorch.callbacks.LearningRateMonitor logging_interval: "step" diff --git a/mart/configs/callbacks/model_checkpoint.yaml b/mart/configs/callbacks/model_checkpoint.yaml index 3d4503b5..40581994 100644 --- a/mart/configs/callbacks/model_checkpoint.yaml +++ b/mart/configs/callbacks/model_checkpoint.yaml @@ -1,9 +1,7 @@ -# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.ModelCheckpoint.html +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html -# Save the model periodically by monitoring a quantity. -# Look at the above link for more detailed information. model_checkpoint: - _target_: pytorch_lightning.callbacks.ModelCheckpoint + _target_: lightning.pytorch.callbacks.ModelCheckpoint dirpath: "${paths.output_dir}/checkpoints/" # directory to save the model file filename: "epoch_{epoch:03d}" # checkpoint filename monitor: ??? # name of the logged metric which determines when model is improving diff --git a/mart/configs/callbacks/model_summary.yaml b/mart/configs/callbacks/model_summary.yaml index 04da98d3..b75981d8 100644 --- a/mart/configs/callbacks/model_summary.yaml +++ b/mart/configs/callbacks/model_summary.yaml @@ -1,7 +1,5 @@ -# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichModelSummary.html +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html -# Generates a summary of all layers in a LightningModule with rich text formatting. -# Look at the above link for more detailed information. model_summary: - _target_: pytorch_lightning.callbacks.RichModelSummary + _target_: lightning.pytorch.callbacks.RichModelSummary max_depth: 1 # the maximum depth of layer nesting that the summary will include diff --git a/mart/configs/callbacks/rich_progress_bar.yaml b/mart/configs/callbacks/rich_progress_bar.yaml index b6be5b45..de6f1ccb 100644 --- a/mart/configs/callbacks/rich_progress_bar.yaml +++ b/mart/configs/callbacks/rich_progress_bar.yaml @@ -1,6 +1,4 @@ -# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichProgressBar.html +# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html -# Create a progress bar with rich text formatting. -# Look at the above link for more detailed information. rich_progress_bar: - _target_: pytorch_lightning.callbacks.RichProgressBar + _target_: lightning.pytorch.callbacks.RichProgressBar diff --git a/mart/configs/debug/profiler.yaml b/mart/configs/debug/profiler.yaml index 1e8ab021..8e79d296 100644 --- a/mart/configs/debug/profiler.yaml +++ b/mart/configs/debug/profiler.yaml @@ -8,8 +8,8 @@ defaults: trainer: max_epochs: 1 profiler: - _target_: pytorch_lightning.profiler.SimpleProfiler - # _target_: pytorch_lightning.profiler.AdvancedProfiler - # _target_: pytorch_lightning.profiler.PyTorchProfiler + _target_: lightning.pytorch.profiler.SimpleProfiler + # _target_: lightning.pytorch.profiler.AdvancedProfiler + # _target_: lightning.pytorch.profiler.PyTorchProfiler dirpath: ${paths.output_dir} filename: profiler_log diff --git a/mart/configs/logger/comet.yaml b/mart/configs/logger/comet.yaml index 423f41f6..e0789274 100644 --- a/mart/configs/logger/comet.yaml +++ b/mart/configs/logger/comet.yaml @@ -1,7 +1,7 @@ # https://www.comet.ml comet: - _target_: pytorch_lightning.loggers.comet.CometLogger + _target_: lightning.pytorch.loggers.comet.CometLogger api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable save_dir: "${paths.output_dir}" project_name: "lightning-hydra-template" diff --git a/mart/configs/logger/csv.yaml b/mart/configs/logger/csv.yaml index 844ec671..fa028e9c 100644 --- a/mart/configs/logger/csv.yaml +++ b/mart/configs/logger/csv.yaml @@ -1,7 +1,7 @@ # csv logger built in lightning csv: - _target_: pytorch_lightning.loggers.csv_logs.CSVLogger + _target_: lightning.pytorch.loggers.csv_logs.CSVLogger save_dir: "${paths.output_dir}" name: "csv/" prefix: "" diff --git a/mart/configs/logger/mlflow.yaml b/mart/configs/logger/mlflow.yaml index 3b441a90..f8fb7e68 100644 --- a/mart/configs/logger/mlflow.yaml +++ b/mart/configs/logger/mlflow.yaml @@ -1,7 +1,7 @@ # https://mlflow.org mlflow: - _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger + _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger # experiment_name: "" # run_name: "" tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI diff --git a/mart/configs/logger/neptune.yaml b/mart/configs/logger/neptune.yaml index 5df1e342..8233c140 100644 --- a/mart/configs/logger/neptune.yaml +++ b/mart/configs/logger/neptune.yaml @@ -1,7 +1,7 @@ # https://neptune.ai neptune: - _target_: pytorch_lightning.loggers.neptune.NeptuneLogger + _target_: lightning.pytorch.loggers.neptune.NeptuneLogger api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable project: username/lightning-hydra-template # name: "" diff --git a/mart/configs/logger/tensorboard.yaml b/mart/configs/logger/tensorboard.yaml index 29067c90..2bd31f6d 100644 --- a/mart/configs/logger/tensorboard.yaml +++ b/mart/configs/logger/tensorboard.yaml @@ -1,7 +1,7 @@ # https://www.tensorflow.org/tensorboard/ tensorboard: - _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger + _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger save_dir: "${paths.output_dir}/tensorboard/" name: null log_graph: False diff --git a/mart/tasks/lightning.py b/mart/tasks/lightning.py index b02c32f1..30a462c4 100644 --- a/mart/tasks/lightning.py +++ b/mart/tasks/lightning.py @@ -1,15 +1,15 @@ from typing import Any, Dict, List, Tuple import hydra -from omegaconf import DictConfig -from pytorch_lightning import ( +from lightning.pytorch import ( Callback, LightningDataModule, LightningModule, Trainer, seed_everything, ) -from pytorch_lightning.loggers import LightningLoggerBase +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig from mart import utils @@ -32,8 +32,8 @@ def lightning(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: seed_everything(cfg.seed, workers=True) # Init lightning datamodule - log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>") - datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule) + log.info(f"Instantiating datamodule <{cfg.data._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) # Init lightning model log.info(f"Instantiating model <{cfg.model._target_}>") model: LightningModule = hydra.utils.instantiate(cfg.model) @@ -42,7 +42,7 @@ def lightning(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) # Init lightning loggers - logger: List[LightningLoggerBase] = utils.instantiate_loggers(cfg.get("logger")) + logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) # Init lightning trainer log.info(f"Instantiating trainer <{cfg.trainer._target_}>") diff --git a/mart/utils/pylogger.py b/mart/utils/pylogger.py index eb84d233..86a723b9 100644 --- a/mart/utils/pylogger.py +++ b/mart/utils/pylogger.py @@ -1,6 +1,6 @@ import logging -from pytorch_lightning.utilities import rank_zero_only +from lightning.pytorch.utilities import rank_zero_only __all__ = ["get_pylogger"] diff --git a/mart/utils/silent.py b/mart/utils/silent.py index b9cbd1c3..bd240b0b 100644 --- a/mart/utils/silent.py +++ b/mart/utils/silent.py @@ -13,7 +13,10 @@ class silent(ContextDecorator): """Suppress logging.""" - DEFAULT_NAMES = ["pytorch_lightning.utilities.rank_zero", "pytorch_lightning.accelerators.gpu"] + DEFAULT_NAMES = [ + "lightning.pytorch.utilities.rank_zero", + "lightning.pytorch.accelerators.cuda", + ] def __init__(self, names=None): if names is None: diff --git a/tests/helpers/package_available.py b/tests/helpers/package_available.py index 15630f19..614778fe 100644 --- a/tests/helpers/package_available.py +++ b/tests/helpers/package_available.py @@ -1,7 +1,7 @@ import platform import pkg_resources -from pytorch_lightning.utilities.xla_device import XLADeviceUtils +from lightning.fabric.accelerators import TPUAccelerator def _package_available(package_name: str) -> bool: @@ -12,7 +12,7 @@ def _package_available(package_name: str) -> bool: return False -_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() +_TPU_AVAILABLE = TPUAccelerator.is_available() _IS_WINDOWS = platform.system() == "Windows" diff --git a/tests/test_adversary.py b/tests/test_adversary.py index 0d0777ae..13874ff7 100644 --- a/tests/test_adversary.py +++ b/tests/test_adversary.py @@ -9,7 +9,7 @@ import pytest import torch -from pytorch_lightning.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.exceptions import MisconfigurationException from torch.optim import SGD import mart @@ -326,7 +326,8 @@ def test_configure_gradient_clipping(): # We need to mock a trainer since LightningModule does some checks adversary.trainer = Mock(gradient_clip_val=1.0, gradient_clip_algorithm="norm") - adversary.configure_gradient_clipping(optimizer, 0) + # Keep the gradient_clip_val the same in Trainer and configure_gradient_clipping() + adversary.configure_gradient_clipping(optimizer, 1.0) # Once for each parameter in the optimizer assert gradient_modifier.call_count == 2 diff --git a/tests/test_perturber.py b/tests/test_perturber.py index bb6c204b..af627359 100644 --- a/tests/test_perturber.py +++ b/tests/test_perturber.py @@ -9,7 +9,7 @@ import pytest import torch -from pytorch_lightning.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.exceptions import MisconfigurationException from torch.optim import SGD import mart From ff114a1b20196428dce65e8c3004dc62e0f4eeb9 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Mon, 5 Jun 2023 16:55:49 -0700 Subject: [PATCH 09/22] Hydra 1.2->1.3 --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index bdae295c..a3ded7d0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,7 +31,7 @@ # Loads the configuration file from a given experiment def get_cfg(experiment): - with initialize(version_base="1.2", config_path="../mart/configs"): + with initialize(version_base="1.3", config_path="../mart/configs"): params = "experiment=" + experiment cfg = compose(config_name="lightning.yaml", return_hydra_config=True, overrides=[params]) return cfg From b5bdf31d79d1de4bff1ea3e09ec4734a80cac6a8 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Mon, 5 Jun 2023 16:57:09 -0700 Subject: [PATCH 10/22] Split utils. --- mart/utils/__init__.py | 2 + mart/utils/instantiators.py | 52 +++++++++++ mart/utils/logging_utils.py | 52 +++++++++++ mart/utils/utils.py | 175 +++++++++--------------------------- 4 files changed, 146 insertions(+), 135 deletions(-) create mode 100644 mart/utils/instantiators.py create mode 100644 mart/utils/logging_utils.py diff --git a/mart/utils/__init__.py b/mart/utils/__init__.py index 50e71b3d..9b647b64 100644 --- a/mart/utils/__init__.py +++ b/mart/utils/__init__.py @@ -1,5 +1,7 @@ from .adapters import * from .export import * +from .instantiators import * +from .logging_utils import * from .monkey_patch import * from .pylogger import * from .rich_utils import * diff --git a/mart/utils/instantiators.py b/mart/utils/instantiators.py new file mode 100644 index 00000000..f175e954 --- /dev/null +++ b/mart/utils/instantiators.py @@ -0,0 +1,52 @@ +from typing import List + +import hydra +from lightning.pytorch import Callback +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +from . import pylogger + +__all__ = ["instantiate_callbacks", "instantiate_loggers"] + +log = pylogger.get_pylogger(__name__) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config.""" + + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config.""" + + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/mart/utils/logging_utils.py b/mart/utils/logging_utils.py new file mode 100644 index 00000000..841eebe4 --- /dev/null +++ b/mart/utils/logging_utils.py @@ -0,0 +1,52 @@ +from lightning.pytorch.utilities import rank_zero_only + +from . import pylogger + +__all__ = ["log_hyperparameters"] + +log = pylogger.get_pylogger(__name__) + + +@rank_zero_only +def log_hyperparameters(object_dict: dict) -> None: + """Controls which config parts are saved by lightning loggers. + + Additionally saves: + - Number of model parameters + """ + + hparams = {} + + cfg = object_dict["cfg"] + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) diff --git a/mart/utils/utils.py b/mart/utils/utils.py index 80609331..f49da98e 100644 --- a/mart/utils/utils.py +++ b/mart/utils/utils.py @@ -1,74 +1,26 @@ import os -import time import warnings from glob import glob from importlib.util import find_spec from pathlib import Path -from typing import Any, Callable, Dict, List, Tuple +from typing import Callable, Tuple import hydra from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Callback -from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.utilities.model_summary import summarize -from mart.utils import pylogger, rich_utils +from . import pylogger, rich_utils __all__ = [ - "close_loggers", "extras", "get_metric_value", "get_resume_checkpoint", - "instantiate_callbacks", - "instantiate_loggers", - "log_hyperparameters", - "save_file", "task_wrapper", ] log = pylogger.get_pylogger(__name__) -def task_wrapper(task_func: Callable) -> Callable: - """Optional decorator that wraps the task function in extra utilities. - - Makes multirun more resistant to failure. - - Utilities: - - Calling the `utils.extras()` before the task is started - - Calling the `utils.close_loggers()` after the task is finished - - Logging the exception if occurs - - Logging the task total execution time - - Logging the output dir - """ - - def wrap(cfg: DictConfig): - - # apply extra utilities - extras(cfg) - - # execute the task - try: - start_time = time.time() - metric_dict, object_dict = task_func(cfg=cfg) - except Exception as ex: - log.exception("") # save exception to `.log` file - raise ex - finally: - path = Path(cfg.paths.output_dir, "exec_time.log") - content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)" - save_file(path, content) # save task execution time (even if exception occurs) - close_loggers() # close loggers (even if exception occurs so multirun won't fail) - - log.info(f"Output dir: {cfg.paths.output_dir}") - - return metric_dict, object_dict - - return wrap - - def extras(cfg: DictConfig) -> None: """Applies optional utilities before the task is started. @@ -99,91 +51,57 @@ def extras(cfg: DictConfig) -> None: rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) -@rank_zero_only -def save_file(path, content) -> None: - """Save file in rank zero mode (only on one process in multi-GPU setup).""" - with open(path, "w+") as file: - file.write(content) - - -def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: - """Instantiates callbacks from config.""" - callbacks: List[Callback] = [] - - if not callbacks_cfg: - log.warning("Callbacks config is empty.") - return callbacks - - if not isinstance(callbacks_cfg, DictConfig): - raise TypeError("Callbacks config must be a DictConfig!") - - for _, cb_conf in callbacks_cfg.items(): - if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: - log.info(f"Instantiating callback <{cb_conf._target_}>") - callbacks.append(hydra.utils.instantiate(cb_conf)) - - return callbacks - - -def instantiate_loggers(logger_cfg: DictConfig) -> List[LightningLoggerBase]: - """Instantiates loggers from config.""" - logger: List[LightningLoggerBase] = [] - - if not logger_cfg: - log.warning("Logger config is empty.") - return logger - - if not isinstance(logger_cfg, DictConfig): - raise TypeError("Logger config must be a DictConfig!") - - for _, lg_conf in logger_cfg.items(): - if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: - log.info(f"Instantiating logger <{lg_conf._target_}>") - logger.append(hydra.utils.instantiate(lg_conf)) +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that controls the failure behavior when executing the task function. - return logger + This wrapper can be used to: + - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) + - save the exception to a `.log` file + - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) + - etc. (adjust depending on your needs) + Example: + ``` + @utils.task_wrapper + def train(cfg: DictConfig) -> Tuple[dict, dict]: -@rank_zero_only -def log_hyperparameters(object_dict: Dict[str, Any]) -> None: - """Controls which config parts are saved by lightning loggers. + ... - Additionally saves: - - Number of model parameters + return metric_dict, object_dict + ``` """ - hparams = {} - - cfg = object_dict["cfg"] - model = object_dict["model"] - trainer = object_dict["trainer"] - - if not trainer.logger: - log.warning("Logger not found! Skipping hyperparameter logging...") - return + def wrap(cfg: DictConfig): + # execute the task + try: + metric_dict, object_dict = task_func(cfg=cfg) - hparams["model"] = cfg["model"] + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") - # save number of model parameters - summary = summarize(model) + # some hyperparameter combinations might be invalid or cause out-of-memory errors + # so when using hparam search plugins like Optuna, you might want to disable + # raising the below exception to avoid multirun failure + raise ex - hparams["model/params/total"] = summary.total_parameters - hparams["model/params/trainable"] = summary.trainable_parameters - hparams["model/params/non_trainable"] = summary.total_parameters - summary.trainable_parameters + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.output_dir}") - hparams["datamodule"] = cfg["datamodule"] - hparams["trainer"] = cfg["trainer"] + # always close wandb run (even if exception occurs so multirun won't fail) + if find_spec("wandb"): # check if wandb is installed + import wandb - hparams["callbacks"] = cfg.get("callbacks") - hparams["extras"] = cfg.get("extras") + if wandb.run: + log.info("Closing wandb!") + wandb.finish() - hparams["task_name"] = cfg.get("task_name") - hparams["tags"] = cfg.get("tags") - hparams["ckpt_path"] = cfg.get("ckpt_path") - hparams["seed"] = cfg.get("seed") + return metric_dict, object_dict - # send hparams to all loggers - trainer.logger.log_hyperparams(hparams) + return wrap def get_metric_value(metric_dict: dict, metric_name: str) -> float: @@ -206,19 +124,6 @@ def get_metric_value(metric_dict: dict, metric_name: str) -> float: return metric_value -def close_loggers() -> None: - """Makes sure all loggers closed properly (prevents logging failure during multirun).""" - - log.info("Closing loggers...") - - if find_spec("wandb"): # if wandb is installed - import wandb - - if wandb.run: - log.info("Closing wandb!") - wandb.finish() - - def get_resume_checkpoint(config: DictConfig) -> Tuple[DictConfig]: """Resume a task from an existing checkpoint along with the config.""" From 31145b7b7368d4d942485820f651128c3dbefb71 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Mon, 5 Jun 2023 16:59:48 -0700 Subject: [PATCH 11/22] Update to PL's new API. --- mart/attack/adversary.py | 13 +++++++------ mart/callbacks/progress_bar.py | 17 ++++------------- mart/configs/trainer/default.yaml | 8 +++++++- mart/models/modular.py | 23 +++++++++-------------- 4 files changed, 27 insertions(+), 34 deletions(-) diff --git a/mart/attack/adversary.py b/mart/attack/adversary.py index 065a1a65..1d9d35f8 100644 --- a/mart/attack/adversary.py +++ b/mart/attack/adversary.py @@ -10,8 +10,8 @@ from itertools import cycle from typing import TYPE_CHECKING, Any, Callable -import pytorch_lightning as pl import torch +from lightning import pytorch as pl from mart.utils import silent @@ -126,15 +126,16 @@ def training_step(self, batch, batch_idx): if len(gain.shape) > 0: gain = gain.sum() + # Display gain on progress bar. + self.log("gain", gain, prog_bar=True) + return gain def configure_gradient_clipping( - self, optimizer, optimizer_idx, gradient_clip_val=None, gradient_clip_algorithm=None + self, optimizer, gradient_clip_val=None, gradient_clip_algorithm=None ): # Configuring gradient clipping in pl.Trainer is still useful, so use it. - super().configure_gradient_clipping( - optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm - ) + super().configure_gradient_clipping(optimizer, gradient_clip_val, gradient_clip_algorithm) if self.gradient_modifier: for group in optimizer.param_groups: @@ -184,7 +185,7 @@ def attacker(self): elif self.device.type == "cpu": accelerator = "cpu" - devices = None + devices = 1 else: raise NotImplementedError diff --git a/mart/callbacks/progress_bar.py b/mart/callbacks/progress_bar.py index f33811d7..4fc99d64 100644 --- a/mart/callbacks/progress_bar.py +++ b/mart/callbacks/progress_bar.py @@ -6,9 +6,9 @@ from typing import Any -import pytorch_lightning as pl -from pytorch_lightning.callbacks import TQDMProgressBar -from pytorch_lightning.utilities.rank_zero import rank_zero_only +from lightning import pytorch as pl +from lightning.pytorch.callbacks import TQDMProgressBar +from lightning.pytorch.utilities.rank_zero import rank_zero_only __all__ = ["ProgressBar"] @@ -44,15 +44,6 @@ def init_train_tqdm(self): def on_train_epoch_start(self, trainer: pl.Trainer, *_: Any) -> None: super().on_train_epoch_start(trainer) - # So that it does not display negative rate. - self.main_progress_bar.initial = 0 # So that it does not display Epoch n. rank_id = rank_zero_only.rank - self.main_progress_bar.set_description(f"Attack@rank{rank_id}") - - def get_metrics(self, *args, **kwargs): - """Rename metrics on progress bar status.""" - metrics = super().get_metrics(*args, **kwargs) - for old_name, new_name in self.rename_metrics.items(): - metrics[new_name] = metrics.pop(old_name) - return metrics + self.train_progress_bar.set_description(f"Attack@rank{rank_id}") diff --git a/mart/configs/trainer/default.yaml b/mart/configs/trainer/default.yaml index b3eeb7a1..ccff1885 100644 --- a/mart/configs/trainer/default.yaml +++ b/mart/configs/trainer/default.yaml @@ -1,4 +1,4 @@ -_target_: pytorch_lightning.Trainer +_target_: lightning.pytorch.trainer.Trainer default_root_dir: ${paths.output_dir} @@ -11,6 +11,12 @@ devices: 1 # mixed precision for extra speed-up # precision: 16 +# perform a validation loop every N training epochs +check_val_every_n_epoch: 1 + # set True to to ensure deterministic results # makes training slower but gives more reproducibility than just setting seeds deterministic: False + +# Disable PyTorch inference mode in val/test/predict, because we may run adversarial backprop. +inference_mode: False diff --git a/mart/models/modular.py b/mart/models/modular.py index 45add818..cfb955a7 100644 --- a/mart/models/modular.py +++ b/mart/models/modular.py @@ -9,7 +9,7 @@ logger = logging.getLogger(__name__) import torch # noqa: E402 -from pytorch_lightning import LightningModule # noqa: E402 +from lightning.pytorch import LightningModule # noqa: E402 from ..nn import SequentialDict # noqa: E402 from ..optim import OptimizerFactory # noqa: E402 @@ -108,12 +108,11 @@ def training_step(self, batch, batch_idx): output = self(input=input, target=target, model=self.model, step="training") for name in self.training_step_log: - self.log(f"training/{name}", output[name]) + # Display loss on prog_bar. + self.log(f"training/{name}", output[name], prog_bar=True) assert "loss" in output - return output - def training_step_end(self, output): if self.training_metrics is not None: # Some models only return loss in the training mode. if "preds" not in output or "target" not in output: @@ -124,7 +123,7 @@ def training_step_end(self, output): loss = output.pop("loss") return loss - def training_epoch_end(self, outputs): + def on_train_epoch_end(self): if self.training_metrics is not None: # Some models only return loss in the training mode. metrics = self.training_metrics.compute() @@ -144,15 +143,12 @@ def validation_step(self, batch, batch_idx): for name in self.validation_step_log: self.log(f"validation/{name}", output[name]) - return output - - def validation_step_end(self, output): self.validation_metrics(output["preds"], output["target"]) # I don't know why this is required to prevent CUDA memory leak in validaiton and test. (Not required in training.) output.clear() - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): metrics = self.validation_metrics.compute() metrics = self.flatten_metrics(metrics) self.validation_metrics.reset() @@ -170,15 +166,12 @@ def test_step(self, batch, batch_idx): for name in self.test_step_log: self.log(f"test/{name}", output[name]) - return output - - def test_step_end(self, output): self.test_metrics(output["preds"], output["target"]) # I don't know why this is required to prevent CUDA memory leak in validaiton and test. (Not required in training.) output.clear() - def test_epoch_end(self, outputs): + def on_test_epoch_end(self): metrics = self.test_metrics.compute() metrics = self.flatten_metrics(metrics) self.test_metrics.reset() @@ -224,4 +217,6 @@ def enumerate_metric(metric, name): enumerate_metric(metrics, prefix) - self.log_dict(metrics_dict, prog_bar=prog_bar) + # PyTorch Lightning: It is recommended to use `self.log('validation_metrics/acc', ..., sync_dist=True)` + # when logging on epoch level in distributed setting to accumulate the metric across devices. + self.log_dict(metrics_dict, prog_bar=prog_bar, sync_dist=True) From 03d4c4716e779e1b4cee874e8fc879852917162e Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Mon, 5 Jun 2023 17:00:11 -0700 Subject: [PATCH 12/22] Add aim logger. --- .gitignore | 8 +++++--- mart/configs/logger/aim.yaml | 28 ++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) create mode 100644 mart/configs/logger/aim.yaml diff --git a/.gitignore b/.gitignore index 117a6931..04a06484 100644 --- a/.gitignore +++ b/.gitignore @@ -146,7 +146,9 @@ dmypy.json # Lightning-Hydra-Template configs/local/default.yaml -data/ -logs/ +/data/ +/logs/ .env -.autoenv + +# Aim logging +.aim diff --git a/mart/configs/logger/aim.yaml b/mart/configs/logger/aim.yaml new file mode 100644 index 00000000..8f9f6ada --- /dev/null +++ b/mart/configs/logger/aim.yaml @@ -0,0 +1,28 @@ +# https://aimstack.io/ + +# example usage in lightning module: +# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py + +# open the Aim UI with the following command (run in the folder containing the `.aim` folder): +# `aim up` + +aim: + _target_: aim.pytorch_lightning.AimLogger + repo: ${paths.root_dir} # .aim folder will be created here + # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html# + + # aim allows to group runs under experiment name + experiment: null # any string, set to "default" if not specified + + train_metric_prefix: "train/" + val_metric_prefix: "val/" + test_metric_prefix: "test/" + + # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.) + system_tracking_interval: 10 # set to null to disable system metrics tracking + + # enable/disable logging of system params such as installed packages, git info, env vars, etc. + log_system_params: true + + # enable/disable tracking console logs (default value is true) + capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550 From a29a317f7942ac5db9e39825aede3d4c248a5a89 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Mon, 5 Jun 2023 17:00:28 -0700 Subject: [PATCH 13/22] Add cpu trainer. --- mart/configs/trainer/cpu.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 mart/configs/trainer/cpu.yaml diff --git a/mart/configs/trainer/cpu.yaml b/mart/configs/trainer/cpu.yaml new file mode 100644 index 00000000..640f71d5 --- /dev/null +++ b/mart/configs/trainer/cpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default.yaml + +accelerator: cpu +devices: 1 From 27cbe96ea97846c9122ff46488789ff59a39ca81 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Mon, 5 Jun 2023 17:01:40 -0700 Subject: [PATCH 14/22] Fix paths of lightning.pytorch and mart.data. --- mart/configs/debug/default.yaml | 2 +- mart/configs/lightning.yaml | 2 +- mart/configs/logger/wandb.yaml | 2 +- mart/utils/rich_utils.py | 12 ++---------- 4 files changed, 5 insertions(+), 13 deletions(-) diff --git a/mart/configs/debug/default.yaml b/mart/configs/debug/default.yaml index e30b19e2..1886902b 100644 --- a/mart/configs/debug/default.yaml +++ b/mart/configs/debug/default.yaml @@ -30,6 +30,6 @@ trainer: devices: 1 # debuggers don't like multiprocessing detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor -datamodule: +data: num_workers: 0 # debuggers don't like multiprocessing pin_memory: False # disable gpu memory pin diff --git a/mart/configs/lightning.yaml b/mart/configs/lightning.yaml index 714250f3..88a762fb 100644 --- a/mart/configs/lightning.yaml +++ b/mart/configs/lightning.yaml @@ -3,7 +3,7 @@ # specify here default training configuration defaults: - _self_ - - datamodule: null + - data: null - model: null - metric: null - optimization: null diff --git a/mart/configs/logger/wandb.yaml b/mart/configs/logger/wandb.yaml index d6d20e0b..ece16588 100644 --- a/mart/configs/logger/wandb.yaml +++ b/mart/configs/logger/wandb.yaml @@ -1,7 +1,7 @@ # https://wandb.ai wandb: - _target_: pytorch_lightning.loggers.wandb.WandbLogger + _target_: lightning.pytorch.loggers.wandb.WandbLogger # name: "" # name of the run (normally generated by wandb) save_dir: "${paths.output_dir}" offline: False diff --git a/mart/utils/rich_utils.py b/mart/utils/rich_utils.py index db7e75a5..b5c00130 100644 --- a/mart/utils/rich_utils.py +++ b/mart/utils/rich_utils.py @@ -5,8 +5,8 @@ import rich.syntax import rich.tree from hydra.core.hydra_config import HydraConfig +from lightning.pytorch.utilities import rank_zero_only from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning.utilities import rank_zero_only from rich.prompt import Prompt from mart.utils import pylogger @@ -20,7 +20,7 @@ def print_config_tree( cfg: DictConfig, print_order: Sequence[str] = ( - "datamodule", + "data", "model", "callbacks", "logger", @@ -97,11 +97,3 @@ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: if save_to_file: with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: rich.print(cfg.tags, file=file) - - -if __name__ == "__main__": - from hydra import compose, initialize - - with initialize(version_base="1.2", config_path="../../configs"): - cfg = compose(config_name="train.yaml", return_hydra_config=False, overrides=[]) - print_config_tree(cfg, resolve=False, save_to_file=False) From 5597103edcef743770fc90e0c716c32617ee4afb Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Mon, 5 Jun 2023 17:02:05 -0700 Subject: [PATCH 15/22] Update comment. --- mart/configs/paths/default.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mart/configs/paths/default.yaml b/mart/configs/paths/default.yaml index b00b6df3..ec81db2d 100644 --- a/mart/configs/paths/default.yaml +++ b/mart/configs/paths/default.yaml @@ -1,6 +1,6 @@ # path to root directory # this requires PROJECT_ROOT environment variable to exist -# PROJECT_ROOT is inferred and set by pyrootutils package in `train.py` and `eval.py` +# you can replace it with "." if you want the root to be the current working directory root_dir: ${oc.env:PROJECT_ROOT} # path to data directory From 5eec5e32239258ac05531d0f85591fa8c4eaa848 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Mon, 5 Jun 2023 17:02:45 -0700 Subject: [PATCH 16/22] Dependency: pytorch_lightning -> lightning --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8be2788b..0361c746 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,8 @@ requires-python = ">=3.9" dependencies = [ "torch ~= 2.0.1", "torchvision ~= 0.15.2", - "pytorch-lightning ~= 2.0.2", + "lightning ~= 2.0.2", + # TODO: Manually downgrad to torchmetrics == 0.6.0. "torchmetrics == 0.11.4", # --------- hydra --------- # From 1c83e34306c7d65b287bfc2272f0224fcba15c90 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Mon, 5 Jun 2023 17:03:21 -0700 Subject: [PATCH 17/22] Update to new API of mAP in torchmetrics. --- mart/configs/metric/average_precision.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mart/configs/metric/average_precision.yaml b/mart/configs/metric/average_precision.yaml index d41f9743..02025a54 100644 --- a/mart/configs/metric/average_precision.yaml +++ b/mart/configs/metric/average_precision.yaml @@ -1,11 +1,11 @@ # @package model training_metrics: - _target_: torchmetrics.detection.MAP + _target_: torchmetrics.detection.mean_ap.MeanAveragePrecision compute_on_step: false validation_metrics: - _target_: torchmetrics.detection.MAP + _target_: torchmetrics.detection.mean_ap.MeanAveragePrecision compute_on_step: false test_metrics: @@ -13,7 +13,7 @@ test_metrics: _convert_: partial metrics: map: - _target_: torchmetrics.detection.MAP + _target_: torchmetrics.detection.mean_ap.MeanAveragePrecision compute_on_step: false json: _target_: mart.utils.export.CocoPredictionJSON From cbd2fcf776f81450f27ff06e24498537c1f066ef Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Mon, 5 Jun 2023 17:14:45 -0700 Subject: [PATCH 18/22] Fix path of pytorch_lightning -> lightning.pytorch. --- mart/data/modular.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mart/data/modular.py b/mart/data/modular.py index 65cbb888..fa2dc24d 100644 --- a/mart/data/modular.py +++ b/mart/data/modular.py @@ -8,8 +8,8 @@ logger = logging.getLogger(__name__) -import pytorch_lightning as pl # noqa: E402 from hydra.utils import instantiate # noqa: E402 +from lightning import pytorch as pl # noqa: E402 from torch.utils.data import DataLoader, Dataset, Sampler # noqa: E402 __all__ = ["LitDataModule"] From 1b3cda37fff9f30a783888a198afb2180dda447c Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Mon, 10 Jul 2023 10:09:14 -0700 Subject: [PATCH 19/22] Pin pydantic==1.10.11 --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 5a2d0421..5ed8161b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "torch ~= 2.0.1", "torchvision ~= 0.15.2", "lightning ~= 2.0.2", + "pydantic == 1.10.11", # https://github.com/Lightning-AI/lightning/pull/18022/files "torchmetrics == 1.0.0", "numpy == 1.23.5", # https://github.com/pytorch/pytorch/issues/91516 From ba7aab3b97d17d385707cd9cbe69776123bcdfd4 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Mon, 10 Jul 2023 14:39:54 -0700 Subject: [PATCH 20/22] Hide some folders in examples. --- examples/carla_overhead_object_detection/.gitignore | 4 ++++ examples/robust_bench/.gitignore | 4 ++++ 2 files changed, 8 insertions(+) create mode 100644 examples/carla_overhead_object_detection/.gitignore create mode 100644 examples/robust_bench/.gitignore diff --git a/examples/carla_overhead_object_detection/.gitignore b/examples/carla_overhead_object_detection/.gitignore new file mode 100644 index 00000000..8ba906dd --- /dev/null +++ b/examples/carla_overhead_object_detection/.gitignore @@ -0,0 +1,4 @@ +# Default folder for downloading dataset +data + +logs diff --git a/examples/robust_bench/.gitignore b/examples/robust_bench/.gitignore new file mode 100644 index 00000000..8ba906dd --- /dev/null +++ b/examples/robust_bench/.gitignore @@ -0,0 +1,4 @@ +# Default folder for downloading dataset +data + +logs From eb8e9c6293b0e525421a4bebbd6c54b8dcb9d9a0 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Mon, 10 Jul 2023 14:49:35 -0700 Subject: [PATCH 21/22] Fix test_resume. --- tests/test_experiments.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/test_experiments.py b/tests/test_experiments.py index e69682ba..87e918e4 100644 --- a/tests/test_experiments.py +++ b/tests/test_experiments.py @@ -195,12 +195,13 @@ def test_resume(tmpdir): "\n".join( [ "- experiment=CIFAR10_CNN", - "- datamodule=dummy_classification", - "- datamodule.ims_per_batch=2", - "- datamodule.num_workers=0", - "- datamodule.train_dataset.size=2", - "- datamodule.train_dataset.image_size=[3,32,32]", - "- datamodule.train_dataset.num_classes=10", + "- data=dummy_classification", + "- data.ims_per_batch=2", + "- data.num_workers=0", + "- data.train_dataset.size=2", + "- data.train_dataset.image_size=[3,32,32]", + "- data.train_dataset.num_classes=10", + "- +data.num_classes=10", "- fit=false", # Don't train or test the model, because the checkpoint is invalid. "- test=false", "- optimized_metric=null", # No metric to retrieve. From 5b5a91fdd1a9fedbd982ad41858abe509bc4566c Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Mon, 10 Jul 2023 14:50:27 -0700 Subject: [PATCH 22/22] Fix torchmetrics configs. --- mart/configs/metric/accuracy.yaml | 4 ++-- mart/configs/metric/average_precision.yaml | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/mart/configs/metric/accuracy.yaml b/mart/configs/metric/accuracy.yaml index a184aa34..9c4a79f4 100644 --- a/mart/configs/metric/accuracy.yaml +++ b/mart/configs/metric/accuracy.yaml @@ -5,9 +5,9 @@ training_metrics: _convert_: partial # metrics must be a dict metrics: acc: - _target_: torchmetrics.classification.MulticlassAccuracy + _target_: torchmetrics.Accuracy + task: multiclass num_classes: ${data.num_classes} - compute_on_step: false validation_metrics: ${.training_metrics} diff --git a/mart/configs/metric/average_precision.yaml b/mart/configs/metric/average_precision.yaml index 02025a54..dc1d05f0 100644 --- a/mart/configs/metric/average_precision.yaml +++ b/mart/configs/metric/average_precision.yaml @@ -2,11 +2,9 @@ training_metrics: _target_: torchmetrics.detection.mean_ap.MeanAveragePrecision - compute_on_step: false validation_metrics: _target_: torchmetrics.detection.mean_ap.MeanAveragePrecision - compute_on_step: false test_metrics: _target_: torchmetrics.collections.MetricCollection @@ -14,7 +12,6 @@ test_metrics: metrics: map: _target_: torchmetrics.detection.mean_ap.MeanAveragePrecision - compute_on_step: false json: _target_: mart.utils.export.CocoPredictionJSON prediction_file_name: ${paths.output_dir}/test_prediction.json