diff --git a/README.md b/README.md index e5343e9d..00d0bc2a 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,9 @@ _TorchUncertainty_ is a package designed to help you leverage [uncertainty quant :construction: _TorchUncertainty_ is in early development :construction: - expect changes, but reach out and contribute if you are interested in the project! **Please raise an issue if you have any bugs or difficulties and join the [discord server](https://discord.gg/HMCawt5MJu).** -Our webpage and documentation is available here: [torch-uncertainty.github.io](https://torch-uncertainty.github.io). +:books: Our webpage and documentation is available here: [torch-uncertainty.github.io](https://torch-uncertainty.github.io). :books: + +TorchUncertainty contains the *official implementations* of multiple papers from *major machine-learning and computer vision conferences* and was/will be featured in tutorials at **WACV 2024** and **ECCV 2024**. --- @@ -47,7 +49,14 @@ We make a quickstart available at [torch-uncertainty.github.io/quickstart](https ## :books: Implemented methods -TorchUncertainty currently supports **Classification**, **probabilistic** and pointwise **Regression** and **Segmentation**. +TorchUncertainty currently supports **classification**, **probabilistic** and pointwise **regression**, **segmentation** and **pixelwise regression** (such as monocular depth estimation). It includes the official codes of the following papers: + +- *A Symmetry-Aware Exploration of Bayesian Neural Network Posteriors* - [ICLR 2024](https://arxiv.org/abs/2310.08287) +- *LP-BNN: Encoding the latent posterior of Bayesian Neural Networks for uncertainty quantification* - [IEEE TPAMI](https://arxiv.org/abs/2012.02818) +- *Packed-Ensembles for Efficient Uncertainty Estimation* - [ICLR 2023](https://arxiv.org/abs/2210.09184) - [Tutorial](https://torch-uncertainty.github.io/auto_tutorials/tutorial_pe_cifar10.html) +- *MUAD: Multiple Uncertainties for Autonomous Driving, a benchmark for multiple uncertainty types and tasks* - [BMVC 2022](https://arxiv.org/abs/2203.01437) + +We also provide the following methods: ### Baselines @@ -86,18 +95,3 @@ Our documentation contains the following tutorials: - [Deep Evidential Regression on a Toy Example](https://torch-uncertainty.github.io/auto_tutorials/tutorial_der_cubic.html) - [Training a LeNet with Monte-Carlo Dropout](https://torch-uncertainty.github.io/auto_tutorials/tutorial_mc_dropout.html) - [Training a LeNet with Deep Evidential Classification](https://torch-uncertainty.github.io/auto_tutorials/tutorial_evidential_classification.html) - -## Other References - -This package also contains the official implementation of Packed-Ensembles. - -If you find the corresponding models interesting, please consider citing our [paper](https://arxiv.org/abs/2210.09184): - -```text -@inproceedings{laurent2023packed, - title={Packed-Ensembles for Efficient Uncertainty Estimation}, - author={Laurent, Olivier and Lafage, Adrien and Tartaglione, Enzo and Daniel, Geoffrey and Martinez, Jean-Marc and Bursuc, Andrei and Franchi, Gianni}, - booktitle={ICLR}, - year={2023} -} -``` diff --git a/auto_tutorials_source/tutorial_corruptions.py b/auto_tutorials_source/tutorial_corruptions.py index e71b1223..d20e4f19 100644 --- a/auto_tutorials_source/tutorial_corruptions.py +++ b/auto_tutorials_source/tutorial_corruptions.py @@ -12,11 +12,9 @@ torchvision and matplotlib. """ -import torch from torchvision.datasets import CIFAR10 from torchvision.transforms import Compose, ToTensor, Resize -from torchvision.utils import make_grid import matplotlib.pyplot as plt ds = CIFAR10("./data", train=False, download=True) diff --git a/auto_tutorials_source/tutorial_mc_batch_norm.py b/auto_tutorials_source/tutorial_mc_batch_norm.py index 9d08e7dc..12781e2b 100644 --- a/auto_tutorials_source/tutorial_mc_batch_norm.py +++ b/auto_tutorials_source/tutorial_mc_batch_norm.py @@ -72,9 +72,11 @@ # %% # 4. Gathering Everything and Training the Model # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# You can also save the results in a variable by saving the output of +# `trainer.test`. trainer.fit(model=routine, datamodule=datamodule) -trainer.test(model=routine, datamodule=datamodule) +trainer.test(model=routine, datamodule=datamodule); # %% # 5. Wrapping the Model in a MCBatchNorm @@ -88,10 +90,10 @@ # to highlight the effect of stochasticity on the predictions. routine.model = MCBatchNorm( - routine.model, num_estimators=8, convert=True, mc_batch_size=4 + routine.model, num_estimators=8, convert=True, mc_batch_size=16 ) routine.model.fit(datamodule.train) -routine.eval() +routine.eval(); # %% # 6. Testing the Model @@ -118,17 +120,17 @@ def imshow(img): dataiter = iter(datamodule.val_dataloader()) images, labels = next(dataiter) -# print images -imshow(torchvision.utils.make_grid(images[:4, ...])) -print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4))) - routine.eval() logits = routine(images).reshape(8, 128, 10) probs = torch.nn.functional.softmax(logits, dim=-1) +most_uncertain = sorted(probs.var(0).sum(-1).topk(4).indices) +# print images +imshow(torchvision.utils.make_grid(images[most_uncertain, ...])) +print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4))) -for j in sorted(probs.var(0).sum(-1).topk(4).indices): +for j in most_uncertain: values, predicted = torch.max(probs[:, j], 1) print( f"Predicted digits for the image {j}: ", diff --git a/auto_tutorials_source/tutorial_scaler.py b/auto_tutorials_source/tutorial_scaler.py index 2d927b10..e9a31969 100644 --- a/auto_tutorials_source/tutorial_scaler.py +++ b/auto_tutorials_source/tutorial_scaler.py @@ -17,19 +17,17 @@ In this tutorial, we will need: -- torch for its objects -- the "calibration error" metric to compute the ECE and evaluate the top-label calibration +- TorchUncertainty's Calibration Error metric to compute to evaluate the top-label calibration with ECE and plot the reliability diagrams - the CIFAR-100 datamodule to handle the data - a ResNet 18 as starting model - the temperature scaler to improve the top-label calibration -- a utility to download hf models easily -- the calibration plot to visualize the calibration. +- a utility function to download HF models easily -If you use the classification routine, the plots will be automatically available in the tensorboard logs. +If you use the classification routine, the plots will be automatically available in the tensorboard logs if you use the `log_plots` flag. """ from torch_uncertainty.datamodules import CIFAR100DataModule -from torch_uncertainty.metrics import CE +from torch_uncertainty.metrics import CalibrationError from torch_uncertainty.models.resnet import resnet18 from torch_uncertainty.post_processing import TemperatureScaler from torch_uncertainty.utils import load_hf @@ -88,7 +86,7 @@ test_dataloader = DataLoader(test_dataset, batch_size=32) # Initialize the ECE -ece = CE(task="multiclass", num_classes=100) +ece = CalibrationError(task="multiclass", num_classes=100) # Iterate on the calibration dataloader for sample, target in test_dataloader: diff --git a/docs/source/api.rst b/docs/source/api.rst index 24abb1ef..1a415798 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -6,7 +6,7 @@ API Reference Routines -------- -The routine are the main building blocks of the library. They define the framework +The routine are the main building blocks of the library. They define the framework in which the models are trained and evaluated. They allow for easy computation of different metrics crucial for uncertainty estimation in different contexts, namely classification, regression and segmentation. @@ -42,10 +42,20 @@ Segmentation SegmentationRoutine +Pixelwise Regression +^^^^^^^^^^^^^^^^^^^^ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + PixelRegressionRoutine + Baselines --------- -TorchUncertainty provide lightning-based models that can be easily trained and evaluated. +TorchUncertainty provide lightning-based models that can be easily trained and evaluated. These models inherit from the routines and are specifically designed to benchmark different methods in similar settings, here with constant architectures. @@ -85,8 +95,19 @@ Segmentation :nosignatures: :template: class.rst + DeepLabBaseline SegFormerBaseline +Monocular Depth Estimation +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: class.rst + + BTSBaseline + Layers ------ @@ -122,6 +143,8 @@ Bayesian layers BayesConv1d BayesConv2d BayesConv3d + LPBNNLinear + LPBNNConv2d Models ------ @@ -158,9 +181,12 @@ Metrics :template: class.rst AUSE + AURC + AdaptiveCalibrationError BrierScore CategoricalNLL - CE + CalibrationError + CovAt5Risk, Disagreement DistributionNLL Entropy @@ -169,6 +195,7 @@ Metrics MeanGTRelativeAbsoluteError MeanGTRelativeSquaredError MutualInformation + RiskAt80Cov, SILog ThresholdAccuracy diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst index fc2dc687..11781df0 100644 --- a/docs/source/contributing.rst +++ b/docs/source/contributing.rst @@ -20,8 +20,7 @@ The scope of TorchUncertainty TorchUncertainty can host any method - if possible linked to a paper - and roughly contained in the following fields: -* Uncertainty quantification in general, including Bayesian deep learning, - Monte Carlo dropout, ensemble methods, etc. +* Uncertainty quantification in general, including Bayesian deep learning, Monte Carlo dropout, ensemble methods, etc. * Out-of-distribution detection methods * Applications (e.g. object detection, segmentation, etc.) @@ -54,7 +53,7 @@ group: Then navigate to ``./docs`` and build the documentation with: .. parsed-literal:: - + make html Optionally, specify ``html-noplot`` instead of ``html`` to avoid running the tutorials. @@ -73,7 +72,7 @@ PR. This will avoid multiplying the number featureless commits. To do this, run, at the root of the folder: .. parsed-literal:: - + python3 -m pytest tests Try to include an emoji at the start of each commit message following the suggestions @@ -118,4 +117,4 @@ License If you feel that the current license is an obstacle to your contribution, let us know, and we may reconsider. However, the models’ weights hosted on Hugging -Face are likely to stay Apache 2.0. +Face are likely to remain Apache 2.0. diff --git a/docs/source/index.rst b/docs/source/index.rst index 63b2ec8c..b0af32c5 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -11,10 +11,10 @@ Welcome to Torch Uncertainty Welcome to the documentation of TorchUncertainty. -This website contains the documentation for +This website contains the documentation for `installing `_ -and `contributing `_ to TorchUncertainty, -details on the `API `_, and a +and `contributing `_ to TorchUncertainty, +details on the `API `_, and a `comprehensive list of the references `_ of the models and metrics implemented. @@ -29,12 +29,36 @@ Installation To install TorchUncertainty with contribution in mind, check the `contribution page `_. +----- + +Official Implementations +^^^^^^^^^^^^^^^^^^^^^^^^ + +TorchUncertainty also houses multiple official implementations of papers from major conferences & journals. + +**A Symmetry-Aware Exploration of Bayesian Neural Network Posteriors** + +* Authors: *Olivier Laurent, Emanuel Aldea, and Gianni Franchi* +* Paper: `ICLR 2024 `_. + +**Encoding the latent posterior of Bayesian Neural Networks for uncertainty quantification** + +* Authors: *Gianni Franchi, Andrei Bursuc, Emanuel Aldea, Severine Dubuisson, and Isabelle Bloch* +* Paper: `IEEE TPAMI `_. + +**Packed-Ensembles for Efficient Uncertainty Estimation** + +* Authors: *Olivier Laurent, Adrien Lafage, Enzo Tartaglione, Geoffrey Daniel, Jean-Marc Martinez, Andrei Bursuc, and Gianni Franchi* +* Paper: `ICLR 2023 `_. + +**MUAD: Multiple Uncertainties for Autonomous Driving, a benchmark for multiple uncertainty types and tasks** + +* Authors: *Gianni Franchi, Xuanlong Yu, Andrei Bursuc, Angel Tena, Rémi Kazmierczak, Séverine Dubuisson, Emanuel Aldea, David Filliat* +* Paper: `BMVC 2022 `_. + Packed-Ensembles ^^^^^^^^^^^^^^^^ -Finally, TorchUncertainty also includes the official PyTorch implementation for -the following paper: - **Packed-Ensembles for Efficient Uncertainty Estimation** * Authors: *Olivier Laurent, Adrien Lafage, Enzo Tartaglione, Geoffrey Daniel, Jean-Marc Martinez, Andrei Bursuc, and Gianni Franchi* diff --git a/docs/source/installation.rst b/docs/source/installation.rst index a05ae32e..e1e3348f 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -14,7 +14,7 @@ From PyPI --------- Check that you have PyTorch (cpu or gpu) installed on your system. Then, install -the package via pip: +the package via pip: .. parsed-literal:: @@ -24,7 +24,7 @@ To update the package, run: .. parsed-literal:: - pip install -U torch-uncertainty + pip install -U torch-uncertainty From source ----------- diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index bd8efd44..68fc7fd0 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -11,9 +11,9 @@ These routines make it very easy to: - compute and monitor uncertainty metrics: calibration, out-of-distribution detection, proper scores, grouping loss, etc. - leverage calibration methods automatically during evaluation -Yet, we take account that their will be as many different uses of TorchUncertainty as there are of users. +Yet, we take account that their will be as many different uses of TorchUncertainty as there are of users. This page provides ideas on how to benefit from TorchUncertainty at all levels: from ready-to-train lightning-based models to using only specific -PyTorch layers. +PyTorch layers. .. figure:: _static/images/structure_torch_uncertainty.jpg :alt: TorchUncertainty structure @@ -26,39 +26,39 @@ PyTorch layers. Training with TorchUncertainty's Uncertainty-aware Routines ----------------------------------------------------------- -TorchUncertainty provides a set of Ligthning training and evaluation routines that wrap PyTorch models. Let's have a look at the +TorchUncertainty provides a set of Ligthning training and evaluation routines that wrap PyTorch models. Let's have a look at the `Classification routine `_ -and its parameters. +and its parameters. .. code:: python - from lightning.pytorch import LightningModule - - class ClassificationRoutine(LightningModule): - def __init__( - self, - model: nn.Module, - num_classes: int, - loss: nn.Module, - num_estimators: int = 1, - format_batch_fn: nn.Module | None = None, - optim_recipe: dict | Optimizer | None = None, - # ... - eval_ood: bool = False, - eval_grouping_loss: bool = False, - ood_criterion: Literal[ - "msp", "logit", "energy", "entropy", "mi", "vr" - ] = "msp", - log_plots: bool = False, - save_in_csv: bool = False, - calibration_set: Literal["val", "test"] | None = None, - ) -> None: - ... + from lightning.pytorch import LightningModule + + class ClassificationRoutine(LightningModule): + def __init__( + self, + model: nn.Module, + num_classes: int, + loss: nn.Module, + num_estimators: int = 1, + format_batch_fn: nn.Module | None = None, + optim_recipe: dict | Optimizer | None = None, + # ... + eval_ood: bool = False, + eval_grouping_loss: bool = False, + ood_criterion: Literal[ + "msp", "logit", "energy", "entropy", "mi", "vr" + ] = "msp", + log_plots: bool = False, + save_in_csv: bool = False, + calibration_set: Literal["val", "test"] | None = None, + ) -> None: + ... Building your First Routine ^^^^^^^^^^^^^^^^^^^^^^^^^^^ -This routine is a wrapper of any custom or TorchUncertainty classification model. To use it, +This routine is a wrapper of any custom or TorchUncertainty classification model. To use it, just build your model and pass it to the routine as argument along with an optimization recipe and the loss as well as the number of classes that we use for torch metrics. @@ -80,13 +80,13 @@ Training with the Routine To train with this routine, you will first need to create a lightning Trainer and have either a lightning datamodule or PyTorch dataloaders. When benchmarking models, we advise to use lightning datamodules that will automatically handle -train/val/test splits, out-of-distribution detection and dataset shift. For this example, let us use TorchUncertainty's -CIFAR10 datamodule. Please keep in mind that you could use your own datamodule or dataloaders. +train/val/test splits, out-of-distribution detection and dataset shift. For this example, let us use TorchUncertainty's +CIFAR10 datamodule. .. code:: python from torch_uncertainty.datamodules import CIFAR10DataModule - from pytorch_lightning import Trainer + from lightning.pytorch import Trainer dm = CIFAR10DataModule(root="data", batch_size=32) trainer = Trainer(gpus=1, max_epochs=100) @@ -94,14 +94,15 @@ CIFAR10 datamodule. Please keep in mind that you could use your own datamodule o trainer.test(routine, dm) Here it is, you have trained your first model with TorchUncertainty! As a result, you will get access to various metrics -measuring the ability of your model to handle uncertainty. +measuring the ability of your model to handle uncertainty. You can get other examples of training with lightning Trainers +looking at the `Tutorials `_. More metrics ^^^^^^^^^^^^ With TorchUncertainty datamodules, you can easily test models on out-of-distribution datasets, by setting the ``eval_ood`` parameter to ``True``. You can also evaluate the grouping loss by setting ``eval_grouping_loss`` to ``True``. -Finally, you can calibrate your model using the ``calibration_set`` parameter. In this case, you will get +Finally, you can calibrate your model using the ``calibration_set`` parameter. In this case, you will get metrics for but the uncalibrated and calibrated models: the metrics corresponding to the temperature scaled model will begin with ``ts_``. diff --git a/docs/source/references.rst b/docs/source/references.rst index 674b8b7b..bd4467c9 100644 --- a/docs/source/references.rst +++ b/docs/source/references.rst @@ -63,6 +63,17 @@ For Deep Ensembles, consider citing: * Paper: `NeurIPS 2017 `__. +Monte-Carlo Dropout +^^^^^^^^^^^^^^^^^^^ + +For Monte-Carlo Dropout, consider citing: + +**Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning** + +* Authors: *Yarin Gal and Zoubin Ghahramani* +* Paper: `ICML 2016 `__. + + BatchEnsemble ^^^^^^^^^^^^^ @@ -104,15 +115,17 @@ For Packed-Ensembles, consider citing: * Authors: *Olivier Laurent, Adrien Lafage, Enzo Tartaglione, Geoffrey Daniel, Jean-Marc Martinez, Andrei Bursuc, and Gianni Franchi* * Paper: `ICLR 2023 `__. -Monte-Carlo Dropout -^^^^^^^^^^^^^^^^^^^ -For Monte-Carlo Dropout, consider citing: +LPBNN +^^^^^ -**Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning** +For LPBNN, consider citing: + +**Encoding the latent posterior of Bayesian Neural Networks for uncertainty quantification** + +* Authors: *Gianni Franchi, Andrei Bursuc, Emanuel Aldea, Severine Dubuisson, Isabelle Bloch* +* Paper: `IEEE TPAMI 2024 `__. -* Authors: *Yarin Gal and Zoubin Ghahramani* -* Paper: `ICML 2016 `__. Data Augmentation Methods ------------------------- @@ -145,7 +158,7 @@ For MixupIO, consider citing: **On the Pitfall of Mixup for Uncertainty Calibration** * Authors: *Deng-Bao Wang, Lanqing Li, Peilin Zhao, Pheng-Ann Heng, and Min-Ling Zhang* -* Paper: `CVPR 2023 ` +* Paper: `CVPR 2023 __` Warping Mixup ^^^^^^^^^^^^^ @@ -195,6 +208,26 @@ For the expected calibration error, consider citing: * Authors: *Mahdi Pakdaman Naeini, Gregory F. Cooper, and Milos Hauskrecht* * Paper: `AAAI 2015 `__. +Adaptive Calibration Error +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For the adaptive calibration error, consider citing: + +**Measuring Calibration in Deep Learning** + +* Authors: Jeremy Nixon, Mike Dusenberry, Ghassen Jerfel, Timothy Nguyen, Jeremiah Liu, Linchuan Zhang, Dustin Tran +* Paper: `CVPRW 2019 `__. + +Area Under the Risk-Coverage curve +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For the area under the risk-coverage curve, consider citing: + +**Selective classification for deep neural networks** + +* Authors: Yonatan Geifman, Ran El-Yaniv +* Paper: `NeurIPS 2017 `__. + Grouping Loss ^^^^^^^^^^^^^ @@ -214,7 +247,7 @@ The following datasets are used/implemented. MNIST ^^^^^ -**Gradient-based learning applied to document recognition** +**Gradient-based learning applied to document recognition** * Authors: *Yann LeCun, Leon Bottou, Yoshua Bengio, and Patrick Haffner* * Paper: `Proceedings of the IEEE 1998 `__. @@ -328,7 +361,7 @@ MUAD **MUAD: Multiple Uncertainties for Autonomous Driving Dataset** * Authors: Gianni Franchi, Xuanlong Yu, Andrei Bursuc, et al.* -* Paper: `BMVC 2022 ` +* Paper: `BMVC 2022 __` Architectures ------------- diff --git a/experiments/classification/cifar10/configs/resnet.yaml b/experiments/classification/cifar10/configs/resnet.yaml index aa053391..feb656c8 100644 --- a/experiments/classification/cifar10/configs/resnet.yaml +++ b/experiments/classification/cifar10/configs/resnet.yaml @@ -13,7 +13,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/batched.yaml b/experiments/classification/cifar10/configs/resnet18/batched.yaml index e71130f9..69f1fea2 100644 --- a/experiments/classification/cifar10/configs/resnet18/batched.yaml +++ b/experiments/classification/cifar10/configs/resnet18/batched.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/masked.yaml b/experiments/classification/cifar10/configs/resnet18/masked.yaml index 202ba0c4..a989dc2d 100644 --- a/experiments/classification/cifar10/configs/resnet18/masked.yaml +++ b/experiments/classification/cifar10/configs/resnet18/masked.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/mimo.yaml b/experiments/classification/cifar10/configs/resnet18/mimo.yaml index e45988db..187ec011 100644 --- a/experiments/classification/cifar10/configs/resnet18/mimo.yaml +++ b/experiments/classification/cifar10/configs/resnet18/mimo.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/packed.yaml b/experiments/classification/cifar10/configs/resnet18/packed.yaml index 79bd47f3..3e1e1dbe 100644 --- a/experiments/classification/cifar10/configs/resnet18/packed.yaml +++ b/experiments/classification/cifar10/configs/resnet18/packed.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet18/standard.yaml b/experiments/classification/cifar10/configs/resnet18/standard.yaml index b5406a28..2eb2586b 100644 --- a/experiments/classification/cifar10/configs/resnet18/standard.yaml +++ b/experiments/classification/cifar10/configs/resnet18/standard.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/batched.yaml b/experiments/classification/cifar10/configs/resnet50/batched.yaml index 7133cc5f..fc0cfeae 100644 --- a/experiments/classification/cifar10/configs/resnet50/batched.yaml +++ b/experiments/classification/cifar10/configs/resnet50/batched.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/masked.yaml b/experiments/classification/cifar10/configs/resnet50/masked.yaml index 00eaf9c3..41ea41a3 100644 --- a/experiments/classification/cifar10/configs/resnet50/masked.yaml +++ b/experiments/classification/cifar10/configs/resnet50/masked.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/mimo.yaml b/experiments/classification/cifar10/configs/resnet50/mimo.yaml index d7d23ccd..766b7371 100644 --- a/experiments/classification/cifar10/configs/resnet50/mimo.yaml +++ b/experiments/classification/cifar10/configs/resnet50/mimo.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/packed.yaml b/experiments/classification/cifar10/configs/resnet50/packed.yaml index 2ecc4e6a..9ffd0a90 100644 --- a/experiments/classification/cifar10/configs/resnet50/packed.yaml +++ b/experiments/classification/cifar10/configs/resnet50/packed.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/resnet50/standard.yaml b/experiments/classification/cifar10/configs/resnet50/standard.yaml index 1797df73..39b076e1 100644 --- a/experiments/classification/cifar10/configs/resnet50/standard.yaml +++ b/experiments/classification/cifar10/configs/resnet50/standard.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10.yaml b/experiments/classification/cifar10/configs/wideresnet28x10.yaml index fb1bea00..3cb97464 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10.yaml @@ -14,7 +14,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -22,7 +22,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml index f4010902..6ad00b9a 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/batched.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml index ae31197b..3fecaf27 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/masked.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml index 31a09775..b71c670f 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/mimo.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml index a46c6fac..cd45736c 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/packed.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml index c5cd566f..65616694 100644 --- a/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml +++ b/experiments/classification/cifar10/configs/wideresnet28x10/standard.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: @@ -32,6 +32,7 @@ model: loss: CrossEntropyLoss version: std style: cifar + dropout_rate: 0.3 data: root: ./data batch_size: 128 diff --git a/experiments/classification/cifar100/configs/resnet.yaml b/experiments/classification/cifar100/configs/resnet.yaml index d72a2c2b..f61f467b 100644 --- a/experiments/classification/cifar100/configs/resnet.yaml +++ b/experiments/classification/cifar100/configs/resnet.yaml @@ -13,7 +13,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -21,7 +21,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet18/batched.yaml b/experiments/classification/cifar100/configs/resnet18/batched.yaml index 61393563..ce2057dd 100644 --- a/experiments/classification/cifar100/configs/resnet18/batched.yaml +++ b/experiments/classification/cifar100/configs/resnet18/batched.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet18/masked.yaml b/experiments/classification/cifar100/configs/resnet18/masked.yaml index 31f6e2a8..36048d65 100644 --- a/experiments/classification/cifar100/configs/resnet18/masked.yaml +++ b/experiments/classification/cifar100/configs/resnet18/masked.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet18/mimo.yaml b/experiments/classification/cifar100/configs/resnet18/mimo.yaml index 7a3aec17..ddd474c9 100644 --- a/experiments/classification/cifar100/configs/resnet18/mimo.yaml +++ b/experiments/classification/cifar100/configs/resnet18/mimo.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet18/packed.yaml b/experiments/classification/cifar100/configs/resnet18/packed.yaml index 4e14cce9..6cf74dc5 100644 --- a/experiments/classification/cifar100/configs/resnet18/packed.yaml +++ b/experiments/classification/cifar100/configs/resnet18/packed.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet18/standard.yaml b/experiments/classification/cifar100/configs/resnet18/standard.yaml index f8e9b821..15fb4eae 100644 --- a/experiments/classification/cifar100/configs/resnet18/standard.yaml +++ b/experiments/classification/cifar100/configs/resnet18/standard.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet50/batched.yaml b/experiments/classification/cifar100/configs/resnet50/batched.yaml index 69259b96..1884c845 100644 --- a/experiments/classification/cifar100/configs/resnet50/batched.yaml +++ b/experiments/classification/cifar100/configs/resnet50/batched.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet50/masked.yaml b/experiments/classification/cifar100/configs/resnet50/masked.yaml index a1707666..a58f4453 100644 --- a/experiments/classification/cifar100/configs/resnet50/masked.yaml +++ b/experiments/classification/cifar100/configs/resnet50/masked.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet50/mimo.yaml b/experiments/classification/cifar100/configs/resnet50/mimo.yaml index 987a632d..9acb534a 100644 --- a/experiments/classification/cifar100/configs/resnet50/mimo.yaml +++ b/experiments/classification/cifar100/configs/resnet50/mimo.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet50/packed.yaml b/experiments/classification/cifar100/configs/resnet50/packed.yaml index 954caf11..0e1f9185 100644 --- a/experiments/classification/cifar100/configs/resnet50/packed.yaml +++ b/experiments/classification/cifar100/configs/resnet50/packed.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: diff --git a/experiments/classification/cifar100/configs/resnet50/standard.yaml b/experiments/classification/cifar100/configs/resnet50/standard.yaml index 575b6e6f..a1f10fab 100644 --- a/experiments/classification/cifar100/configs/resnet50/standard.yaml +++ b/experiments/classification/cifar100/configs/resnet50/standard.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc mode: max save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,14 +23,14 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: cls_val/Acc + monitor: val/cls/Acc patience: 1000 check_finite: true model: num_classes: 100 in_channels: 3 loss: CrossEntropyLoss - version: standard + version: std arch: 50 style: cifar data: diff --git a/experiments/classification/cifar100/configs/wideresnet28x10/standard.yaml b/experiments/classification/cifar100/configs/wideresnet28x10/standard.yaml new file mode 100644 index 00000000..44ccba6d --- /dev/null +++ b/experiments/classification/cifar100/configs/wideresnet28x10/standard.yaml @@ -0,0 +1,50 @@ +# lightning.pytorch==2.1.3 +seed_everything: false +eval_after_fit: true +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/wideresnet28x10 + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/cls/Acc + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val/cls/Acc + patience: 1000 + check_finite: true +model: + num_classes: 100 + in_channels: 3 + loss: CrossEntropyLoss + version: std + style: cifar + dropout_rate: 0.3 +data: + root: ./data + batch_size: 128 + auto_augment: rand-m9-n2-mstd1 +optimizer: + lr: 0.1 + momentum: 0.9 + weight_decay: 5e-4 + nesterov: true +lr_scheduler: + milestones: + - 60 + - 120 + - 160 + gamma: 0.2 diff --git a/experiments/depth/kitti/bts.py b/experiments/depth/kitti/bts.py new file mode 100644 index 00000000..456784e3 --- /dev/null +++ b/experiments/depth/kitti/bts.py @@ -0,0 +1,28 @@ +import torch +from lightning.pytorch.cli import LightningArgumentParser + +from torch_uncertainty.baselines.depth import BTSBaseline +from torch_uncertainty.datamodules.depth import KITTIDataModule +from torch_uncertainty.utils import TULightningCLI +from torch_uncertainty.utils.learning_rate import PolyLR + + +class BTSCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.AdamW) + parser.add_lr_scheduler_args(PolyLR) + + +def cli_main() -> BTSCLI: + return BTSCLI(BTSBaseline, KITTIDataModule) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/depth/kitti/configs/bts.yaml b/experiments/depth/kitti/configs/bts.yaml new file mode 100644 index 00000000..89de3232 --- /dev/null +++ b/experiments/depth/kitti/configs/bts.yaml @@ -0,0 +1,48 @@ +# lightning.pytorch==2.2.0 +eval_after_fit: true +seed_everything: false +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 50 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/bts + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/SILog + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step +model: + loss: + class_path: torch_uncertainty.metrics.SILog + init_args: + sqrt: true + version: std + arch: 50 + max_depth: 80.0 + num_estimators: 1 + pretrained_backbone: true +data: + root: ./data + batch_size: 4 + crop_size: + - 352 + - 704 + inference_size: + - 352 + - 1216 + num_workers: 4 +optimizer: + lr: 1e-4 +lr_scheduler: + power: 0.9 + total_iters: 50 diff --git a/experiments/depth/nyu/bts.py b/experiments/depth/nyu/bts.py new file mode 100644 index 00000000..20cc0330 --- /dev/null +++ b/experiments/depth/nyu/bts.py @@ -0,0 +1,28 @@ +import torch +from lightning.pytorch.cli import LightningArgumentParser + +from torch_uncertainty.baselines.depth import BTSBaseline +from torch_uncertainty.datamodules.depth import NYUv2DataModule +from torch_uncertainty.utils import TULightningCLI +from torch_uncertainty.utils.learning_rate import PolyLR + + +class BTSCLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.AdamW) + parser.add_lr_scheduler_args(PolyLR) + + +def cli_main() -> BTSCLI: + return BTSCLI(BTSBaseline, NYUv2DataModule) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/depth/nyu/configs/bts.yaml b/experiments/depth/nyu/configs/bts.yaml new file mode 100644 index 00000000..8a9d0957 --- /dev/null +++ b/experiments/depth/nyu/configs/bts.yaml @@ -0,0 +1,52 @@ +# lightning.pytorch==2.2.0 +eval_after_fit: true +seed_everything: false +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 100 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/bts + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/SILog + mode: min + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step +model: + loss: + class_path: torch_uncertainty.metrics.SILog + init_args: + sqrt: true + version: std + arch: 50 + max_depth: 10.0 + num_estimators: 1 + pretrained_backbone: true +data: + root: ./data + batch_size: 8 + crop_size: + - 416 + - 544 + inference_size: + - 480 + - 640 + num_workers: 8 + max_depth: 10.0 + min_depth: 1e-3 +optimizer: + lr: 1e-4 + weight_decay: 1e-2 + eps: 1e-3 +lr_scheduler: + power: 0.9 + total_iters: 100 diff --git a/experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml b/experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml index 2e9b056d..9d6e17ae 100644 --- a/experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml +++ b/experiments/regression/uci_datasets/configs/gaussian_mlp_kin8nm.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: reg_val/NLL + monitor: val/NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: reg_val/NLL + monitor: val/NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml b/experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml index d95e09a1..c906150c 100644 --- a/experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml +++ b/experiments/regression/uci_datasets/configs/laplace_mlp_kin8nm.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: reg_val/NLL + monitor: val/NLL mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: reg_val/NLL + monitor: val/NLL patience: 1000 check_finite: true model: diff --git a/experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml b/experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml index b6ce9fad..ca09ac4a 100644 --- a/experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml +++ b/experiments/regression/uci_datasets/configs/pw_mlp_kin8nm.yaml @@ -15,7 +15,7 @@ trainer: callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: - monitor: reg_val/MSE + monitor: val/MSE mode: min save_last: true - class_path: lightning.pytorch.callbacks.LearningRateMonitor @@ -23,7 +23,7 @@ trainer: logging_interval: step - class_path: lightning.pytorch.callbacks.EarlyStopping init_args: - monitor: reg_val/MSE + monitor: val/MSE patience: 1000 check_finite: true model: diff --git a/experiments/segmentation/cityscapes/configs/deeplab.yaml b/experiments/segmentation/cityscapes/configs/deeplab.yaml new file mode 100644 index 00000000..babefa1c --- /dev/null +++ b/experiments/segmentation/cityscapes/configs/deeplab.yaml @@ -0,0 +1,46 @@ +# lightning.pytorch==2.2.0 +eval_after_fit: true +seed_everything: false +trainer: + accelerator: gpu + devices: 1 + precision: 16-mixed + max_epochs: 200 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/deeplab + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/seg/mIoU + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step +model: + num_classes: 19 + loss: CrossEntropyLoss + version: std + arch: 50 + style: v3+ + output_stride: 16 + separable: false + num_estimators: 1 +data: + root: ./data/Cityscapes + batch_size: 8 + crop_size: 768 + inference_size: + - 1024 + - 2048 + num_workers: 8 +optimizer: + lr: 1e-2 + weight_decay: 1e-4 + momentum: 0.9 +lr_scheduler: + total_iters: 200 diff --git a/experiments/segmentation/cityscapes/configs/segformer.yaml b/experiments/segmentation/cityscapes/configs/segformer.yaml index b2abf11e..0ae0c212 100644 --- a/experiments/segmentation/cityscapes/configs/segformer.yaml +++ b/experiments/segmentation/cityscapes/configs/segformer.yaml @@ -5,6 +5,21 @@ trainer: accelerator: gpu devices: 1 max_steps: 160000 + logger: + class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: logs/segformer + name: standard + default_hp_metric: false + callbacks: + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val/seg/mIoU + mode: max + save_last: true + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step model: num_classes: 19 loss: CrossEntropyLoss @@ -12,13 +27,13 @@ model: arch: 0 num_estimators: 1 data: - root: ./data + root: ./data/Cityscapes batch_size: 8 crop_size: 1024 inference_size: - 1024 - 2048 - num_workers: 30 + num_workers: 8 optimizer: lr: 6e-5 lr_scheduler: diff --git a/experiments/segmentation/cityscapes/deeplab.py b/experiments/segmentation/cityscapes/deeplab.py new file mode 100644 index 00000000..ce064b05 --- /dev/null +++ b/experiments/segmentation/cityscapes/deeplab.py @@ -0,0 +1,28 @@ +import torch +from lightning.pytorch.cli import LightningArgumentParser + +from torch_uncertainty.baselines.segmentation import DeepLabBaseline +from torch_uncertainty.datamodules.segmentation import CityscapesDataModule +from torch_uncertainty.utils import TULightningCLI +from torch_uncertainty.utils.learning_rate import PolyLR + + +class DeepLabV3CLI(TULightningCLI): + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + parser.add_optimizer_args(torch.optim.SGD) + parser.add_lr_scheduler_args(PolyLR) + + +def cli_main() -> DeepLabV3CLI: + return DeepLabV3CLI(DeepLabBaseline, CityscapesDataModule) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("medium") + cli = cli_main() + if ( + (not cli.trainer.fast_dev_run) + and cli.subcommand == "fit" + and cli._get(cli.config, "eval_after_fit") + ): + cli.trainer.test(datamodule=cli.datamodule, ckpt_path="best") diff --git a/experiments/segmentation/readme.md b/experiments/segmentation/readme.md index e8ef0698..99c55bd9 100644 --- a/experiments/segmentation/readme.md +++ b/experiments/segmentation/readme.md @@ -1 +1,3 @@ # Segmentation Benchmarks + +Note: Optimize the number of `data.workers` to your computer to gain speed and avoid pauses. diff --git a/pyproject.toml b/pyproject.toml index 4d003ce5..0b11a230 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,11 +41,15 @@ dependencies = [ "huggingface-hub", "scikit-learn", "matplotlib", + "opencv-python", "glest==0.0.1a0", ] [project.optional-dependencies] +image = ["scikit-image", "h5py",] +tabular = ["pandas"] dev = [ + "torch_uncertainty[image]", "ruff==0.3.4", "pytest-cov", "pre-commit", @@ -59,8 +63,6 @@ docs = [ "sphinx-design", "sphinx-codeautolink", ] -image = ["scikit-image", "opencv-python"] -tabular = ["pandas"] all = ["torch_uncertainty[dev,docs,image,tabular]"] [project.urls] @@ -114,6 +116,7 @@ lint.ignore = [ "D205", "D206", "ISC001", + "N818", "N812", "RUF012", "S101", diff --git a/tests/_dummies/__init__.py b/tests/_dummies/__init__.py index ac5d4d0d..d942a5ae 100644 --- a/tests/_dummies/__init__.py +++ b/tests/_dummies/__init__.py @@ -1,16 +1,19 @@ # ruff: noqa: F401 from .baseline import ( DummyClassificationBaseline, + DummyDepthBaseline, DummyRegressionBaseline, DummySegmentationBaseline, ) from .datamodule import ( DummyClassificationDataModule, + DummyDepthDataModule, DummyRegressionDataModule, DummySegmentationDataModule, ) from .dataset import ( DummyClassificationDataset, + DummyDepthDataset, DummyRegressionDataset, DummySegmentationDataset, ) diff --git a/tests/_dummies/baseline.py b/tests/_dummies/baseline.py index b650f180..c43b444c 100644 --- a/tests/_dummies/baseline.py +++ b/tests/_dummies/baseline.py @@ -1,6 +1,5 @@ import copy -from pytorch_lightning import LightningModule from torch import nn from torch_uncertainty.layers.distributions import ( @@ -11,6 +10,7 @@ from torch_uncertainty.models.deep_ensembles import deep_ensembles from torch_uncertainty.routines import ( ClassificationRoutine, + PixelRegressionRoutine, RegressionRoutine, SegmentationRoutine, ) @@ -41,7 +41,7 @@ def __new__( kernel_tau_std: float = 0.5, mixup_alpha: float = 0, cutmix_alpha: float = 0, - ) -> LightningModule: + ) -> ClassificationRoutine: model = dummy_model( in_channels=in_channels, num_classes=num_classes, @@ -102,7 +102,7 @@ def __new__( baseline_type: str = "single", optim_recipe=None, dist_type: str = "normal", - ) -> LightningModule: + ) -> RegressionRoutine: if probabilistic: if dist_type == "normal": last_layer = NormalLayer(output_dim) @@ -157,7 +157,9 @@ def __new__( loss: type[nn.Module], baseline_type: str = "single", optim_recipe=None, - ) -> LightningModule: + metric_subsampling_rate: float = 1, + log_plots: bool = False, + ) -> SegmentationRoutine: model = dummy_segmentation_model( in_channels=in_channels, num_classes=num_classes, @@ -172,6 +174,8 @@ def __new__( format_batch_fn=None, num_estimators=1, optim_recipe=optim_recipe(model), + metric_subsampling_rate=metric_subsampling_rate, + log_plots=log_plots, ) # baseline_type == "ensemble": @@ -186,4 +190,50 @@ def __new__( format_batch_fn=RepeatTarget(2), num_estimators=2, optim_recipe=optim_recipe(model), + metric_subsampling_rate=metric_subsampling_rate, + log_plots=log_plots, + ) + + +class DummyDepthBaseline: + def __new__( + cls, + in_channels: int, + output_dim: int, + image_size: int, + loss: type[nn.Module], + baseline_type: str = "single", + optim_recipe=None, + ) -> PixelRegressionRoutine: + model = dummy_segmentation_model( + num_classes=output_dim, + in_channels=in_channels, + image_size=image_size, + ) + + if baseline_type == "single": + return PixelRegressionRoutine( + output_dim=output_dim, + probabilistic=False, + model=model, + loss=loss, + format_batch_fn=None, + num_estimators=1, + optim_recipe=optim_recipe(model), + ) + + # baseline_type == "ensemble": + model = deep_ensembles( + [model, copy.deepcopy(model)], + task="pixel_regression", + probabilistic=False, + ) + return PixelRegressionRoutine( + output_dim=output_dim, + probabilistic=False, + model=model, + loss=loss, + format_batch_fn=RepeatTarget(2), + num_estimators=2, + optim_recipe=optim_recipe(model), ) diff --git a/tests/_dummies/datamodule.py b/tests/_dummies/datamodule.py index 9cf0ab77..51c769dd 100644 --- a/tests/_dummies/datamodule.py +++ b/tests/_dummies/datamodule.py @@ -11,6 +11,7 @@ from .dataset import ( DummyClassificationDataset, + DummyDepthDataset, DummyRegressionDataset, DummySegmentationDataset, ) @@ -246,3 +247,92 @@ def _get_train_data(self) -> ArrayLike: def _get_train_targets(self) -> ArrayLike: return np.array(self.train.targets) + + +class DummyDepthDataModule(AbstractDataModule): + num_channels = 3 + training_task = "pixel_regression" + + def __init__( + self, + root: str | Path, + batch_size: int, + output_dim: int = 2, + num_workers: int = 1, + image_size: int = 4, + pin_memory: bool = True, + persistent_workers: bool = True, + num_images: int = 2, + ) -> None: + super().__init__( + root=root, + val_split=None, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + + self.output_dim = output_dim + self.num_channels = 3 + self.num_images = num_images + self.image_size = image_size + + self.dataset = DummyDepthDataset + + self.train_transform = T.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.float32, + "others": None, + }, + scale=True, + ) + self.test_transform = T.ToDtype( + dtype={ + tv_tensors.Image: torch.float32, + tv_tensors.Mask: torch.float32, + "others": None, + }, + scale=True, + ) + + def prepare_data(self) -> None: + pass + + def setup(self, stage: str | None = None) -> None: + if stage == "fit" or stage is None: + self.train = self.dataset( + self.root, + num_channels=self.num_channels, + output_dim=self.output_dim, + image_size=self.image_size, + transforms=self.train_transform, + num_images=self.num_images, + ) + self.val = self.dataset( + self.root, + num_channels=self.num_channels, + output_dim=self.output_dim, + image_size=self.image_size, + transforms=self.test_transform, + num_images=self.num_images, + ) + elif stage == "test": + self.test = self.dataset( + self.root, + num_channels=self.num_channels, + output_dim=self.output_dim, + image_size=self.image_size, + transforms=self.test_transform, + num_images=self.num_images, + ) + + def test_dataloader(self) -> DataLoader | list[DataLoader]: + return [self._data_loader(self.test)] + + def _get_train_data(self) -> ArrayLike: + return self.train.data + + def _get_train_targets(self) -> ArrayLike: + return np.array(self.train.targets) diff --git a/tests/_dummies/dataset.py b/tests/_dummies/dataset.py index 3e5e4024..1ab0c66b 100644 --- a/tests/_dummies/dataset.py +++ b/tests/_dummies/dataset.py @@ -220,9 +220,10 @@ def __init__( root: Path, split: str = "train", transforms: Callable[..., Any] | None = None, - num_channels: int = 3, + input_channels: int = 3, image_size: int = 4, num_images: int = 2, + output_dim: int = 1, **args, ) -> None: super().__init__() @@ -234,12 +235,15 @@ def __init__( self.data: Any = [] self.targets = [] - if num_channels == 1: + if input_channels == 1: img_shape = (num_images, image_size, image_size) else: - img_shape = (num_images, num_channels, image_size, image_size) + img_shape = (num_images, input_channels, image_size, image_size) - smnt_shape = (num_images, 1, image_size, image_size) + if output_dim == 1: + smnt_shape = (num_images, image_size, image_size) + else: + smnt_shape = (num_images, output_dim, image_size, image_size) self.data = np.random.randint( low=0, diff --git a/tests/_dummies/test_dummy_dataset.py b/tests/_dummies/test_dummy_dataset.py deleted file mode 100644 index 80a5f939..00000000 --- a/tests/_dummies/test_dummy_dataset.py +++ /dev/null @@ -1,38 +0,0 @@ -from torchvision.transforms import ToTensor - -from .dataset import DummyClassificationDataset, DummyRegressionDataset -from .transform import DummyTransform - - -class TestDummyClassificationDataset: - """Testing the Dummy dataset class.""" - - def test_dataset(self): - dataset = DummyClassificationDataset( - "./.data", transform=ToTensor(), target_transform=DummyTransform() - ) - _ = len(dataset) - _, _ = dataset[0] - - def test_dataset_notransform(self): - dataset = DummyClassificationDataset("./.data") - _ = len(dataset) - _, _ = dataset[0] - - -class TestDummyRegressionDataset: - """Testing the Dummy dataset class.""" - - def test_dataset(self): - dataset = DummyRegressionDataset( - "./.data", - transform=DummyTransform(), - target_transform=DummyTransform(), - ) - _ = len(dataset) - _, _ = dataset[0] - - def test_dataset_notransform(self): - dataset = DummyRegressionDataset("./.data") - _ = len(dataset) - _, _ = dataset[0] diff --git a/tests/baselines/test_standard.py b/tests/baselines/test_standard.py index 77cb948f..6db1c03b 100644 --- a/tests/baselines/test_standard.py +++ b/tests/baselines/test_standard.py @@ -9,7 +9,10 @@ WideResNetBaseline, ) from torch_uncertainty.baselines.regression import MLPBaseline -from torch_uncertainty.baselines.segmentation import SegFormerBaseline +from torch_uncertainty.baselines.segmentation import ( + DeepLabBaseline, + SegFormerBaseline, +) class TestStandardBaseline: @@ -151,3 +154,32 @@ def test_errors(self): version="test", arch=0, ) + + +class TestStandardDeepLabBaseline: + """Testing the DeepLab baseline class.""" + + def test_standard(self): + net = DeepLabBaseline( + num_classes=10, + loss=nn.CrossEntropyLoss(), + version="std", + style="v3", + output_stride=16, + arch=50, + separable=True, + ).eval() + summary(net) + _ = net(torch.rand(1, 3, 32, 32)) + + def test_errors(self): + with pytest.raises(ValueError): + DeepLabBaseline( + num_classes=10, + loss=nn.CrossEntropyLoss(), + version="test", + style="v3", + output_stride=16, + arch=50, + separable=True, + ) diff --git a/tests/datamodules/depth_estimation/test_muad.py b/tests/datamodules/depth_estimation/test_muad.py deleted file mode 100644 index cce2088a..00000000 --- a/tests/datamodules/depth_estimation/test_muad.py +++ /dev/null @@ -1,37 +0,0 @@ -import pytest - -from tests._dummies.dataset import DummyDepthDataset -from torch_uncertainty.datamodules.depth_estimation import MUADDataModule -from torch_uncertainty.datasets import MUAD - - -class TestMUADDataModule: - """Testing the MUADDataModule datamodule.""" - - def test_camvid_main(self): - dm = MUADDataModule(root="./data/", batch_size=128) - - assert dm.dataset == MUAD - - dm.dataset = DummyDepthDataset - - dm.prepare_data() - dm.setup() - - with pytest.raises(ValueError): - dm.setup("xxx") - - # test abstract methods - dm.get_train_set() - dm.get_val_set() - dm.get_test_set() - - dm.train_dataloader() - dm.val_dataloader() - dm.test_dataloader() - - dm.val_split = 0.1 - dm.prepare_data() - dm.setup() - dm.train_dataloader() - dm.val_dataloader() diff --git a/tests/datamodules/test_depth.py b/tests/datamodules/test_depth.py new file mode 100644 index 00000000..bee19975 --- /dev/null +++ b/tests/datamodules/test_depth.py @@ -0,0 +1,79 @@ +import pytest + +from tests._dummies.dataset import DummyDepthDataset +from torch_uncertainty.datamodules.depth import ( + KITTIDataModule, + MUADDataModule, + NYUv2DataModule, +) +from torch_uncertainty.datasets import MUAD, KITTIDepth, NYUv2 + + +class TestMUADDataModule: + """Testing the MUADDataModule datamodule.""" + + def test_muad_main(self): + dm = MUADDataModule( + root="./data/", min_depth=0, max_depth=100, batch_size=128 + ) + + assert dm.dataset == MUAD + + dm.dataset = DummyDepthDataset + + dm.prepare_data() + dm.setup() + + with pytest.raises(ValueError): + dm.setup("xxx") + + # test abstract methods + dm.get_train_set() + dm.get_val_set() + dm.get_test_set() + + dm.train_dataloader() + dm.val_dataloader() + dm.test_dataloader() + + dm.val_split = 0.1 + dm.prepare_data() + dm.setup() + dm.train_dataloader() + dm.val_dataloader() + + +class TestNYUDataModule: + """Testing the MUADDataModule datamodule.""" + + def test_nyu_main(self): + dm = NYUv2DataModule(root="./data/", max_depth=100, batch_size=128) + + assert dm.dataset == NYUv2 + + dm.dataset = DummyDepthDataset + + dm.prepare_data() + dm.setup() + + with pytest.raises(ValueError): + dm.setup("xxx") + + # test abstract methods + dm.get_train_set() + dm.get_val_set() + dm.get_test_set() + + dm.train_dataloader() + dm.val_dataloader() + dm.test_dataloader() + + dm.val_split = 0.1 + dm.prepare_data() + dm.setup() + dm.train_dataloader() + dm.val_dataloader() + + def test_kitti_main(self): + dm = KITTIDataModule(root="./data/", max_depth=100, batch_size=128) + assert dm.dataset == KITTIDepth diff --git a/tests/datasets/test_kitti.py b/tests/datasets/test_kitti.py new file mode 100644 index 00000000..9afd4f4f --- /dev/null +++ b/tests/datasets/test_kitti.py @@ -0,0 +1,11 @@ +import pytest + +from torch_uncertainty.datasets import KITTIDepth + + +class TestKITTIDepth: + """Testing the KITTIDepth dataset class.""" + + def test_nodataset(self): + with pytest.raises(FileNotFoundError): + _ = KITTIDepth("./.data", split="train") diff --git a/tests/layers/test_bayesian.py b/tests/layers/test_bayesian.py index d5e52128..420da049 100644 --- a/tests/layers/test_bayesian.py +++ b/tests/layers/test_bayesian.py @@ -6,6 +6,8 @@ BayesConv2d, BayesConv3d, BayesLinear, + LPBNNConv2d, + LPBNNLinear, ) from torch_uncertainty.layers.bayesian.sampler import TrainableDistribution @@ -157,3 +159,60 @@ def test_log_posterior(self) -> None: sampler = TrainableDistribution(torch.ones(1), torch.ones(1)) with pytest.raises(ValueError): sampler.log_posterior() + + +class TestLPBNNLinear: + """Testing the LPBNNLinear layer class.""" + + def test_linear(self, feat_input_odd: torch.Tensor) -> None: + layer = LPBNNLinear(10, 2, num_estimators=4) + print(layer) + out = layer(feat_input_odd.repeat(4, 1)) + assert out.shape == torch.Size([5 * 4, 2]) + + layer = LPBNNLinear(10, 2, num_estimators=4, bias=False) + layer = layer.eval() + out = layer(feat_input_odd.repeat(4, 1)) + assert out.shape == torch.Size([5 * 4, 2]) + + def test_linear_even(self, feat_input_even: torch.Tensor) -> None: + layer = LPBNNLinear(10, 2, num_estimators=4) + out = layer(feat_input_even.repeat(4, 1)) + assert out.shape == torch.Size([8 * 4, 2]) + + out = layer(feat_input_even) + + +class TestLPBNNConv2d: + """Testing the LPBNNConv2d layer class.""" + + def test_conv2(self, img_input_odd: torch.Tensor) -> None: + layer = LPBNNConv2d(10, 2, kernel_size=1, num_estimators=4) + print(layer) + out = layer(img_input_odd.repeat(4, 1, 1, 1)) + assert out.shape == torch.Size([5 * 4, 2, 3, 3]) + + layer = LPBNNConv2d( + 10, 2, kernel_size=1, num_estimators=4, bias=False, gamma=False + ) + layer = layer.eval() + out = layer(img_input_odd.repeat(4, 1, 1, 1)) + assert out.shape == torch.Size([5 * 4, 2, 3, 3]) + + def test_conv2_even(self, img_input_even: torch.Tensor) -> None: + layer = LPBNNConv2d( + 10, 2, kernel_size=1, num_estimators=4, padding_mode="reflect" + ) + print(layer) + out = layer(img_input_even.repeat(4, 1, 1, 1)) + assert out.shape == torch.Size([8 * 4, 2, 3, 3]) + + out = layer(img_input_even) + + def test_errors(self) -> None: + with pytest.raises(ValueError, match="std_factor must be"): + LPBNNConv2d(10, 2, kernel_size=1, num_estimators=1, std_factor=-1) + with pytest.raises(ValueError, match="num_estimators must be"): + LPBNNConv2d(10, 2, kernel_size=1, num_estimators=-1) + with pytest.raises(ValueError, match="hidden_size must be"): + LPBNNConv2d(10, 2, kernel_size=1, num_estimators=1, hidden_size=-1) diff --git a/tests/metrics/classification/test_calibration.py b/tests/metrics/classification/test_calibration.py index fb3c2035..ee8ab224 100644 --- a/tests/metrics/classification/test_calibration.py +++ b/tests/metrics/classification/test_calibration.py @@ -2,14 +2,14 @@ import pytest import torch -from torch_uncertainty.metrics import CE +from torch_uncertainty.metrics import AdaptiveCalibrationError, CalibrationError -class TestCE: - """Testing the CE metric class.""" +class TestCalibrationError: + """Testing the CalibrationError metric class.""" def test_plot_binary(self) -> None: - metric = CE(task="binary", n_bins=2, norm="l1") + metric = CalibrationError(task="binary", n_bins=2, norm="l1") metric.update( torch.as_tensor([0.25, 0.25, 0.55, 0.75, 0.75]), torch.as_tensor([0, 0, 1, 1, 1]), @@ -24,7 +24,9 @@ def test_plot_binary(self) -> None: def test_plot_multiclass( self, ) -> None: - metric = CE(task="multiclass", n_bins=3, norm="l1", num_classes=3) + metric = CalibrationError( + task="multiclass", n_bins=3, norm="l1", num_classes=3 + ) metric.update( torch.as_tensor( [ @@ -44,8 +46,73 @@ def test_plot_multiclass( plt.close(fig) def test_errors(self) -> None: - with pytest.raises(ValueError): - _ = CE(task="geometric_mean") + with pytest.raises(TypeError, match="is expected to be `int`"): + CalibrationError(task="multiclass", num_classes=None) + + +class TestAdaptiveCalibrationError: + """Testing the AdaptiveCalibrationError metric class.""" + + def test_main(self) -> None: + ace = AdaptiveCalibrationError( + task="binary", num_bins=2, norm="l1", validate_args=True + ) + ace = AdaptiveCalibrationError( + task="binary", num_bins=2, norm="l1", validate_args=False + ) + ece = CalibrationError(task="binary", num_bins=2, norm="l1") + ace.update( + torch.as_tensor([0.35, 0.35, 0.75, 0.75]), + torch.as_tensor([0, 0, 1, 1]), + ) + ece.update( + torch.as_tensor([0.35, 0.35, 0.75, 0.75]), + torch.as_tensor([0, 0, 1, 1]), + ) + assert ace.compute().item() == ece.compute().item() + + ace.reset() + ace.update( + torch.as_tensor([0.3, 0.24, 0.25, 0.2, 0.8]), + torch.as_tensor([0, 0, 0, 1, 1]), + ) + assert ace.compute().item() == pytest.approx( + 3 / 5 * (1 - 1 / 3 * (0.7 + 0.76 + 0.75)) + 2 / 5 * (0.8 - 0.5) + ) - with pytest.raises(ValueError): - _ = CE(task="multiclass", num_classes=1.5) + ace = AdaptiveCalibrationError( + task="multiclass", + num_classes=2, + num_bins=2, + norm="l2", + validate_args=True, + ) + ace.update( + torch.as_tensor( + [[0.7, 0.3], [0.76, 0.24], [0.75, 0.25], [0.2, 0.8], [0.8, 0.2]] + ), + torch.as_tensor([0, 0, 0, 1, 1]), + ) + assert ace.compute().item() ** 2 == pytest.approx( + 3 / 5 * (1 - 1 / 3 * (0.7 + 0.76 + 0.75)) ** 2 + + 2 / 5 * (0.8 - 0.5) ** 2 + ) + + ace = AdaptiveCalibrationError( + task="multiclass", + num_classes=2, + num_bins=2, + norm="max", + validate_args=False, + ) + ace.update( + torch.as_tensor( + [[0.7, 0.3], [0.76, 0.24], [0.75, 0.25], [0.2, 0.8], [0.8, 0.2]] + ), + torch.as_tensor([0, 0, 0, 1, 1]), + ) + assert ace.compute().item() ** 2 == pytest.approx((0.8 - 0.5) ** 2) + + def test_errors(self) -> None: + with pytest.raises(TypeError, match="is expected to be `int`"): + AdaptiveCalibrationError(task="multiclass", num_classes=None) diff --git a/tests/metrics/classification/test_fpr95.py b/tests/metrics/classification/test_fpr95.py index e94e785c..99bb0dc3 100644 --- a/tests/metrics/classification/test_fpr95.py +++ b/tests/metrics/classification/test_fpr95.py @@ -1,7 +1,7 @@ import pytest import torch -from torch_uncertainty.metrics.classification.fpr95 import FPR95, FPRx +from torch_uncertainty.metrics.classification import FPR95, FPRx class TestFPR95: diff --git a/tests/metrics/classification/test_risk_coverage.py b/tests/metrics/classification/test_risk_coverage.py new file mode 100644 index 00000000..3506479f --- /dev/null +++ b/tests/metrics/classification/test_risk_coverage.py @@ -0,0 +1,129 @@ +import matplotlib.pyplot as plt +import pytest +import torch + +from torch_uncertainty.metrics.classification import ( + AURC, + CovAtxRisk, + RiskAtxCov, +) + + +class TestAURC: + """Testing the AURC metric class.""" + + def test_compute_binary(self) -> None: + probs = torch.as_tensor([0.1, 0.2, 0.3, 0.4, 0.2]) + targets = torch.as_tensor([1, 1, 1, 1, 1]) + metric = AURC() + assert metric(probs, targets).item() == pytest.approx(1) + targets = torch.as_tensor([0, 0, 0, 0, 0]) + metric = AURC() + assert metric(probs, targets).item() == pytest.approx(0) + targets = torch.as_tensor([0, 0, 1, 1, 0]) + metric = AURC() + value = (0 * 0.4 + 0.25 * 0.2 / 2 + 0.25 * 0.2 + 0.15 * 0.2 / 2) / 0.8 + assert metric(probs, targets).item() == pytest.approx(value) + + def test_compute_multiclass(self) -> None: + probs = torch.as_tensor( + [[0.1, 0.9], [0.2, 0.8], [0.3, 0.7], [0.4, 0.6], [0.2, 0.8]] + ) + targets = torch.as_tensor([1, 1, 1, 1, 1]).long() + metric = AURC() + assert metric(probs, targets).item() == pytest.approx(0) + targets = torch.as_tensor([0, 0, 0, 0, 0]) + metric = AURC() + assert metric(probs, targets).item() == pytest.approx(1) + targets = torch.as_tensor([1, 1, 0, 0, 1]) + metric = AURC() + value = (0 * 0.4 + 0.25 * 0.2 / 2 + 0.25 * 0.2 + 0.15 * 0.2 / 2) / 0.8 + assert metric(probs, targets).item() == pytest.approx(value) + + def test_plot(self) -> None: + scores = torch.as_tensor([0.2, 0.1, 0.5, 0.3, 0.4]) + values = torch.as_tensor([0.1, 0.2, 0.3, 0.4, 0.5]) + metric = AURC() + metric.update(scores, values) + fig, ax = metric.plot() + assert isinstance(fig, plt.Figure) + assert isinstance(ax, plt.Axes) + assert ax.get_xlabel() == "Coverage (%)" + assert ax.get_ylabel() == "Risk - Error Rate (%)" + plt.close(fig) + + metric = AURC() + metric.update(scores, values) + fig, ax = metric.plot(plot_value=False) + assert isinstance(fig, plt.Figure) + assert isinstance(ax, plt.Axes) + assert ax.get_xlabel() == "Coverage (%)" + assert ax.get_ylabel() == "Risk - Error Rate (%)" + plt.close(fig) + + +class TestCovAtxRisk: + """Testing the CovAtxRisk metric class.""" + + def test_compute_zero(self) -> None: + probs = torch.as_tensor( + [[0.9, 0.1], [0.8, 0.2], [0.7, 0.3], [0.6, 0.4], [0.8, 0.2]] + ) + targets = torch.as_tensor([1, 1, 1, 1, 1]) + metric = CovAtxRisk(risk_threshold=0.5) + # no cov for given risk + assert torch.isnan(metric(probs, targets)) + + probs = torch.as_tensor( + [0.1, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.48, 0.49] + ) + targets = torch.as_tensor([1, 0, 1, 1, 1, 0, 0, 0, 1]) + metric = CovAtxRisk(risk_threshold=0.55) + # multiple cov for given risk + assert metric(probs, targets) == pytest.approx(8 / 9) + + probs = torch.as_tensor([0.1, 0.2, 0.3, 0.4, 0.2]) + targets = torch.as_tensor([0, 0, 1, 1, 1]) + metric = CovAtxRisk(risk_threshold=0.5) + assert metric(probs, targets) == pytest.approx(4 / 5) + + targets = torch.as_tensor([0, 0, 1, 1, 0]) + metric = CovAtxRisk(risk_threshold=0.5) + assert metric(probs, targets) == 1 + + def test_errors(self): + with pytest.raises( + TypeError, match="Expected threshold to be of type float" + ): + CovAtxRisk(risk_threshold="0.5") + with pytest.raises( + ValueError, match="Threshold should be in the range" + ): + CovAtxRisk(risk_threshold=-0.5) + + +class TestRiskAtxCov: + """Testing the RiskAtxCov metric class.""" + + def test_compute_zero(self) -> None: + probs = torch.as_tensor( + [[0.9, 0.1], [0.8, 0.2], [0.7, 0.3], [0.6, 0.4], [0.8, 0.2]] + ) + targets = torch.as_tensor([1, 1, 1, 1, 1]) + metric = RiskAtxCov(cov_threshold=0.5) + assert metric(probs, targets) == 1 + + probs = torch.as_tensor([0.1, 0.2, 0.3, 0.4, 0.2]) + targets = torch.as_tensor([0, 0, 1, 1, 1]) + metric = RiskAtxCov(cov_threshold=0.5) + assert metric(probs, targets) == pytest.approx(1 / 3) + + probs = torch.as_tensor([0.1, 0.19, 0.3, 0.15, 0.4, 0.2]) + targets = torch.as_tensor([0, 0, 1, 0, 1, 1]) + metric = RiskAtxCov(cov_threshold=0.5) + assert metric(probs, targets) == 0 + + probs = torch.as_tensor([0.1, 0.2, 0.3, 0.15, 0.4, 0.2]) + targets = torch.as_tensor([0, 0, 1, 0, 1, 1]) + metric = RiskAtxCov(cov_threshold=0.55) + assert metric(probs, targets) == 1 / 4 diff --git a/tests/metrics/classification/test_sparsification.py b/tests/metrics/classification/test_sparsification.py index b5dcd9eb..e89df5fe 100644 --- a/tests/metrics/classification/test_sparsification.py +++ b/tests/metrics/classification/test_sparsification.py @@ -1,34 +1,23 @@ import matplotlib.pyplot as plt -import pytest import torch from torch_uncertainty.metrics import AUSE -@pytest.fixture -def uncertainty_scores() -> torch.Tensor: - return torch.as_tensor([0.2, 0.1, 0.5, 0.3, 0.4]) - - -@pytest.fixture -def error_values() -> torch.Tensor: - return torch.as_tensor([0.1, 0.2, 0.3, 0.4, 0.5]) - - class TestAUSE: """Testing the AUSE metric class.""" - def test_compute_zero(self, error_values: torch.Tensor) -> None: + def test_compute_zero(self) -> None: + values = torch.as_tensor([0.1, 0.2, 0.3, 0.4, 0.5]) metric = AUSE() - metric.update(error_values, error_values) - res = metric.compute() - assert res == 0 + metric.update(values, values) + assert metric.compute() == 0 - def test_plot( - self, uncertainty_scores: torch.Tensor, error_values: torch.Tensor - ) -> None: + def test_plot(self) -> None: + scores = torch.as_tensor([0.2, 0.1, 0.5, 0.3, 0.4]) + values = torch.as_tensor([0.1, 0.2, 0.3, 0.4, 0.5]) metric = AUSE() - metric.update(uncertainty_scores, error_values) + metric.update(scores, values) fig, ax = metric.plot() assert isinstance(fig, plt.Figure) assert isinstance(ax, plt.Axes) @@ -37,7 +26,7 @@ def test_plot( plt.close(fig) metric = AUSE() - metric.update(uncertainty_scores, error_values) + metric.update(scores, values) fig, ax = metric.plot(plot_oracle=False, plot_value=False) assert isinstance(fig, plt.Figure) assert isinstance(ax, plt.Axes) diff --git a/tests/metrics/regression/test_depth_estimation_metrics.py b/tests/metrics/regression/test_depth.py similarity index 85% rename from tests/metrics/regression/test_depth_estimation_metrics.py rename to tests/metrics/regression/test_depth.py index 0c1fbea0..4281ec5f 100644 --- a/tests/metrics/regression/test_depth_estimation_metrics.py +++ b/tests/metrics/regression/test_depth.py @@ -16,12 +16,12 @@ class TestLog10: def test_main(self): metric = Log10() - preds = torch.rand((10, 2)).double() - targets = torch.rand((10, 2)).double() + preds = torch.rand((10, 2)).double() + 0.01 + targets = torch.rand((10, 2)).double() + 0.01 metric.update(preds[:, 0], targets[:, 0]) metric.update(preds[:, 1], targets[:, 1]) assert torch.mean( - preds.log10().flatten() - targets.log10().flatten() + torch.abs(preds.log10().flatten() - targets.log10().flatten()) ) == pytest.approx(metric.compute()) @@ -70,6 +70,19 @@ def test_main(self): ** 2 ) == pytest.approx(metric.compute()) + metric = SILog(sqrt=True) + preds = torch.rand((10, 2)).double() + targets = torch.rand((10, 2)).double() + metric.update(preds[:, 0], targets[:, 0]) + metric.update(preds[:, 1], targets[:, 1]) + mean_log_dists = torch.mean( + targets.flatten().log() - preds.flatten().log() + ) + assert torch.mean( + (preds.flatten().log() - targets.flatten().log() + mean_log_dists) + ** 2 + ) ** 0.5 == pytest.approx(metric.compute()) + class TestThresholdAccuracy: """Testing the ThresholdAccuracy metric.""" diff --git a/tests/metrics/regression/test_inverse.py b/tests/metrics/regression/test_inverse.py new file mode 100644 index 00000000..7c56a2b9 --- /dev/null +++ b/tests/metrics/regression/test_inverse.py @@ -0,0 +1,46 @@ +import pytest +import torch + +from torch_uncertainty.metrics import ( + MeanAbsoluteErrorInverse, + MeanSquaredErrorInverse, +) + + +class TestMeanAbsoluteErrorInverse: + """Test the MeanAbsoluteErrorInverse metric.""" + + def test_main(self): + preds = torch.tensor([1, 1 / 2, 1 / 3]) + target = torch.tensor([1, 1 / 2, 1 / 3]) + metric = MeanAbsoluteErrorInverse(unit="m") + metric.update(preds, target) + assert metric.compute() == pytest.approx(0) + + metric.reset() + target = torch.tensor([1, 1, 1]) + metric.update(preds, target) + assert metric.compute() == pytest.approx(1) + + MeanAbsoluteErrorInverse(unit="mm") + MeanAbsoluteErrorInverse(unit="km") + + def test_error(self): + with pytest.raises(ValueError, match="unit must be one of 'mm'"): + MeanAbsoluteErrorInverse(unit="cm") + + +class TestMeanSquaredErrorInverse: + """Test the MeanSquaredErrorInverse metric.""" + + def test_main(self): + preds = torch.tensor([1, 1 / 2, 1 / 3]) + target = torch.tensor([1, 1 / 2, 1 / 3]) + metric = MeanSquaredErrorInverse(unit="m") + metric.update(preds, target) + assert metric.compute() == pytest.approx(0) + + metric.reset() + target = torch.tensor([1, 1, 1]) + metric.update(preds, target) + assert metric.compute() == pytest.approx(5 / 3) diff --git a/tests/models/test_deeplab.py b/tests/models/test_deeplab.py new file mode 100644 index 00000000..54729c27 --- /dev/null +++ b/tests/models/test_deeplab.py @@ -0,0 +1,29 @@ +import pytest +import torch + +from torch_uncertainty.models.segmentation.deeplab import ( + _DeepLabV3, + deep_lab_v3_resnet50, + deep_lab_v3_resnet101, +) + + +class TestDeeplab: + """Testing the Deeplab class.""" + + @torch.no_grad() + def test_main(self): + model = deep_lab_v3_resnet50(10, "v3", 16, True, False).eval() + model(torch.randn(1, 3, 32, 32)) + model = deep_lab_v3_resnet50(10, "v3", 16, False, False).eval() + model = deep_lab_v3_resnet101(10, "v3+", 8, True, False).eval() + model(torch.randn(1, 3, 32, 32)) + model = deep_lab_v3_resnet101(10, "v3+", 8, False, False).eval() + + def test_errors(self): + with pytest.raises(ValueError, match="Unknown backbone:"): + _DeepLabV3(10, "other", "v3", 16, True, False) + with pytest.raises(ValueError, match="output_stride: "): + deep_lab_v3_resnet50(10, "v3", 15, True, False) + with pytest.raises(ValueError, match="Unknown style: "): + deep_lab_v3_resnet50(10, "v2", 16, True, False) diff --git a/tests/models/test_resnets.py b/tests/models/test_resnets.py index 561e2394..44c2cd3c 100644 --- a/tests/models/test_resnets.py +++ b/tests/models/test_resnets.py @@ -1,76 +1,43 @@ import pytest import torch -from torch_uncertainty.models.resnet.batched import ( - batched_resnet20, - batched_resnet34, - batched_resnet101, - batched_resnet152, -) -from torch_uncertainty.models.resnet.masked import ( - masked_resnet20, - masked_resnet34, - masked_resnet101, -) -from torch_uncertainty.models.resnet.mimo import ( - mimo_resnet20, - mimo_resnet34, - mimo_resnet101, - mimo_resnet152, -) -from torch_uncertainty.models.resnet.packed import ( - packed_resnet20, - packed_resnet34, - packed_resnet101, - packed_resnet152, -) -from torch_uncertainty.models.resnet.std import ( - resnet20, - resnet34, - resnet50, - resnet101, - resnet152, +from torch_uncertainty.models.resnet import ( + batched_resnet, + lpbnn_resnet, + masked_resnet, + mimo_resnet, + packed_resnet, + resnet, ) -class TestStdResnet: - """Testing the ResNet std class.""" +class TestResnet: + """Testing the ResNet classes.""" def test_main(self): - resnet20(1, 10, conv_bias=True, style="cifar") - resnet34(1, 10, conv_bias=False, style="cifar") - resnet101(1, 10, style="cifar") - resnet152(1, 10) - - model = resnet50(1, 10, style="cifar") + resnet(1, 10, arch=18, conv_bias=True, style="cifar") + model = resnet(1, 10, arch=50, style="cifar") with torch.no_grad(): model(torch.randn(1, 1, 32, 32)) model.feats_forward(torch.randn(1, 1, 32, 32)) def test_mc_dropout(self): - resnet20(1, 10, conv_bias=True, style="cifar") - resnet34(1, 10, conv_bias=False, style="cifar") - resnet101(1, 10, style="cifar") - resnet152(1, 10) - - model = resnet50(1, 10, style="cifar") + resnet(1, 10, arch=20, conv_bias=False, style="cifar") + model = resnet(1, 10, arch=50).eval() with torch.no_grad(): model(torch.randn(1, 1, 32, 32)) def test_error(self): with pytest.raises(ValueError): - resnet20(1, 10, style="test") + resnet(1, 10, arch=20, style="test") class TestPackedResnet: """Testing the ResNet packed class.""" def test_main(self): - packed_resnet20(1, 10, 2, 2, 1, conv_bias=True) - packed_resnet34(1, 10, 2, 2, 1, conv_bias=False) - packed_resnet101(1, 10, 2, 2, 1) - model = packed_resnet152(1, 10, 2, 2, 1) - + model = packed_resnet(1, 10, 20, 2, 2, 1) + model = packed_resnet(1, 10, 152, 2, 2, 1) assert model.check_config( {"alpha": 2, "gamma": 1, "groups": 1, "num_estimators": 2} ) @@ -80,47 +47,63 @@ def test_main(self): def test_error(self): with pytest.raises(ValueError): - packed_resnet20(1, 10, 2, 2, 1, style="test") + packed_resnet(1, 10, 20, 2, 2, 1, style="test") class TestMaskedResnet: """Testing the ResNet masked class.""" def test_main(self): - masked_resnet20(1, 10, 2, 2, conv_bias=True) - masked_resnet34(1, 10, 2, 2, conv_bias=False) - masked_resnet101(1, 10, 2, 2) + model = masked_resnet(1, 10, 20, 2, 2) + with torch.no_grad(): + model(torch.randn(1, 1, 32, 32)) def test_error(self): with pytest.raises(ValueError): - masked_resnet20(1, 10, 2, 2, style="test") + masked_resnet(1, 10, 20, 2, 2, style="test") class TestBatchedResnet: """Testing the ResNet batched class.""" def test_main(self): - batched_resnet20(1, 10, 2, conv_bias=True) - batched_resnet34(1, 10, 2, conv_bias=False) - batched_resnet101(1, 10, 2) - batched_resnet152(1, 10, 2) + model = batched_resnet(1, 10, 20, 2, conv_bias=True) + with torch.no_grad(): + model(torch.randn(1, 1, 32, 32)) + + def test_error(self): + with pytest.raises(ValueError): + batched_resnet(1, 10, 20, 2, style="test") + + +class TestLPBNNResnet: + """Testing the ResNet LPBNN class.""" + + def test_main(self): + model = lpbnn_resnet(1, 10, 20, 2, conv_bias=True) + with torch.no_grad(): + model(torch.randn(1, 1, 32, 32)) + model = lpbnn_resnet(1, 10, 50, 2, conv_bias=False, style="cifar") + with torch.no_grad(): + model(torch.randn(1, 1, 32, 32)) def test_error(self): with pytest.raises(ValueError): - batched_resnet20(1, 10, 2, style="test") + lpbnn_resnet(1, 10, 20, 2, style="test") + with pytest.raises( + ValueError, match="Unknown ResNet architecture. Got" + ): + lpbnn_resnet(1, 10, 42, 2, style="test") class TestMIMOResnet: """Testing the ResNet MIMO class.""" def test_main(self): - model = mimo_resnet20(1, 10, 2, style="cifar", conv_bias=True) - model = mimo_resnet34(1, 10, 2, style="cifar", conv_bias=False) + model = mimo_resnet(1, 10, 34, 2, style="cifar", conv_bias=False) model.train() model(torch.rand((2, 1, 28, 28))) - mimo_resnet101(1, 10, 2) - mimo_resnet152(1, 10, 2) def test_error(self): with pytest.raises(ValueError): - mimo_resnet101(1, 10, 2, style="test") + mimo_resnet(1, 10, 101, 2, style="test") diff --git a/tests/models/test_segformer.py b/tests/models/test_segformer.py index f69439d8..a9fbe0f2 100644 --- a/tests/models/test_segformer.py +++ b/tests/models/test_segformer.py @@ -1,25 +1,19 @@ import torch from torch_uncertainty.models.segmentation.segformer import ( - seg_former_b0, - seg_former_b1, - seg_former_b2, - seg_former_b3, - seg_former_b4, - seg_former_b5, + seg_former, ) class TestSegformer: """Testing the Segformer class.""" + @torch.no_grad() def test_main(self): - seg_former_b1(10) - seg_former_b2(10) - seg_former_b3(10) - seg_former_b4(10) - seg_former_b5(10) - - model = seg_former_b0(10) - with torch.no_grad(): - model(torch.randn(1, 3, 32, 32)) + model = seg_former(10, 0) + seg_former(10, 1) + seg_former(10, 2) + seg_former(10, 3) + seg_former(10, 4) + seg_former(10, 5) + model(torch.randn(1, 3, 32, 32)) diff --git a/tests/models/test_vggs.py b/tests/models/test_vggs.py index f517d8c7..e281d2d9 100644 --- a/tests/models/test_vggs.py +++ b/tests/models/test_vggs.py @@ -1,24 +1,32 @@ -from torch_uncertainty.models.vgg.packed import packed_vgg11 -from torch_uncertainty.models.vgg.std import vgg11 +import pytest +from torch_uncertainty.models.vgg.packed import packed_vgg +from torch_uncertainty.models.vgg.std import vgg -class TestStdVGG: + +class TestVGGs: """Testing the VGG std class.""" def test_main(self): - vgg11(1, 10, style="cifar") - - def test_mc_dropout(self): - vgg11( - in_channels=1, + vgg(in_channels=1, num_classes=10, arch=11, style="cifar") + packed_vgg( + in_channels=2, num_classes=10, - style="cifar", - num_estimators=3, + arch=11, + alpha=2, + num_estimators=2, + gamma=1, ) - -class TestPackedVGG: - """Testing the VGG packed class.""" - - def test_main(self): - packed_vgg11(2, 10, 2, 2, 1) + def test_errors(self): + with pytest.raises(ValueError, match="Unknown VGG arch"): + vgg(in_channels=1, num_classes=10, arch=12, style="cifar") + with pytest.raises(ValueError, match="Unknown VGG arch"): + packed_vgg( + in_channels=2, + num_classes=10, + arch=12, + alpha=2, + num_estimators=2, + gamma=1, + ) diff --git a/tests/routines/test_classification.py b/tests/routines/test_classification.py index 9b22d898..22f6cae6 100644 --- a/tests/routines/test_classification.py +++ b/tests/routines/test_classification.py @@ -370,6 +370,16 @@ def test_classification_failures(self): num_classes=10, model=nn.Module(), loss=None, cutmix_alpha=-1 ) + with pytest.raises( + ValueError, match="num_calibration_bins must be at least 2, got" + ): + ClassificationRoutine( + model=nn.Identity(), + num_classes=2, + loss=nn.CrossEntropyLoss(), + num_calibration_bins=0, + ) + with pytest.raises(ValueError): ClassificationRoutine( num_classes=10, diff --git a/tests/routines/test_depth.py b/tests/routines/test_depth.py new file mode 100644 index 00000000..e404ca80 --- /dev/null +++ b/tests/routines/test_depth.py @@ -0,0 +1,75 @@ +from pathlib import Path + +import pytest +from torch import nn + +from tests._dummies import ( + DummyDepthBaseline, + DummyDepthDataModule, +) +from torch_uncertainty.optim_recipes import optim_cifar10_resnet18 +from torch_uncertainty.routines import PixelRegressionRoutine +from torch_uncertainty.utils import TUTrainer + + +class TestDepth: + def test_one_estimator_two_classes(self): + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + + root = Path(__file__).parent.absolute().parents[0] / "data" + dm = DummyDepthDataModule(root=root, batch_size=4, output_dim=2) + + model = DummyDepthBaseline( + in_channels=dm.num_channels, + output_dim=dm.output_dim, + image_size=dm.image_size, + loss=nn.MSELoss(), + baseline_type="single", + optim_recipe=optim_cifar10_resnet18, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_two_estimators_one_class(self): + trainer = TUTrainer(accelerator="cpu", fast_dev_run=True) + + root = Path(__file__).parent.absolute().parents[0] / "data" + dm = DummyDepthDataModule(root=root, batch_size=4, output_dim=1) + + model = DummyDepthBaseline( + in_channels=dm.num_channels, + output_dim=dm.output_dim, + image_size=dm.image_size, + loss=nn.MSELoss(), + baseline_type="ensemble", + optim_recipe=optim_cifar10_resnet18, + ) + + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + model(dm.get_test_set()[0][0]) + + def test_depth_errors(self): + with pytest.raises( + ValueError, match="num_estimators must be positive, got" + ): + PixelRegressionRoutine( + model=nn.Identity(), + output_dim=2, + loss=nn.MSELoss(), + num_estimators=0, + probabilistic=False, + ) + + with pytest.raises(ValueError, match="output_dim must be positive"): + PixelRegressionRoutine( + model=nn.Identity(), + output_dim=0, + loss=nn.MSELoss(), + num_estimators=1, + probabilistic=False, + ) diff --git a/tests/routines/test_segmentation.py b/tests/routines/test_segmentation.py index 7e03b673..cb2e41a3 100644 --- a/tests/routines/test_segmentation.py +++ b/tests/routines/test_segmentation.py @@ -26,6 +26,7 @@ def test_one_estimator_two_classes(self): loss=nn.CrossEntropyLoss(), baseline_type="single", optim_recipe=optim_cifar10_resnet18, + log_plots=True, ) trainer.fit(model, dm) @@ -53,15 +54,40 @@ def test_two_estimators_two_classes(self): trainer.test(model, dm) model(dm.get_test_set()[0][0]) - def test_segmentation_failures(self): - with pytest.raises(ValueError): + def test_segmentation_errors(self): + with pytest.raises( + ValueError, match="num_estimators must be positive, got" + ): SegmentationRoutine( model=nn.Identity(), num_classes=2, loss=nn.CrossEntropyLoss(), num_estimators=0, ) - with pytest.raises(ValueError): + + with pytest.raises( + ValueError, match="num_classes must be at least 2, got" + ): SegmentationRoutine( model=nn.Identity(), num_classes=1, loss=nn.CrossEntropyLoss() ) + + with pytest.raises( + ValueError, match="metric_subsampling_rate must be in" + ): + SegmentationRoutine( + model=nn.Identity(), + num_classes=2, + loss=nn.CrossEntropyLoss(), + metric_subsampling_rate=-1, + ) + + with pytest.raises( + ValueError, match="num_calibration_bins must be at least 2, got" + ): + SegmentationRoutine( + model=nn.Identity(), + num_classes=2, + loss=nn.CrossEntropyLoss(), + num_calibration_bins=0, + ) diff --git a/tests/test_cli.py b/tests/test_cli.py index edce26d4..8683a523 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -27,7 +27,7 @@ def test_cli_init(self): "--data.batch_size", "4", "--trainer.callbacks+=ModelCheckpoint", - "--trainer.callbacks.monitor=cls_val/acc", + "--trainer.callbacks.monitor=val/cls/Acc", "--trainer.callbacks.mode=max", ] cli = TULightningCLI(ResNetBaseline, CIFAR10DataModule, run=False) diff --git a/tests/test_optim_recipes.py b/tests/test_optim_recipes.py index 48fcd06f..b71ac43f 100644 --- a/tests/test_optim_recipes.py +++ b/tests/test_optim_recipes.py @@ -1,9 +1,7 @@ # ruff: noqa: F401 import pytest +import torch -from torch_uncertainty.models.resnet import resnet18, resnet34, resnet50 -from torch_uncertainty.models.vgg import vgg16 -from torch_uncertainty.models.wideresnet import wideresnet28x10 from torch_uncertainty.optim_recipes import ( get_procedure, ) @@ -11,65 +9,30 @@ class TestOptProcedures: def test_optim_cifar10(self): - procedure = get_procedure("resnet18", "cifar10", "standard") - model = resnet18(in_channels=3, num_classes=10) - procedure(model) - - procedure = get_procedure("resnet34", "cifar10", "masked") - model = resnet34(in_channels=3, num_classes=100) - procedure(model) - - procedure = get_procedure("resnet50", "cifar10", "packed") - model = resnet50(in_channels=3, num_classes=10) - procedure(model) - - procedure = get_procedure("wideresnet28x10", "cifar10", "batched") - model = wideresnet28x10(in_channels=3, num_classes=10) - procedure(model) - - procedure = get_procedure("vgg16", "cifar10", "standard") - model = vgg16(in_channels=3, num_classes=10) - procedure(model) + model = torch.nn.Linear(1, 1) + get_procedure("resnet18", "cifar10", "standard")(model) + get_procedure("resnet34", "cifar10", "masked")(model) + get_procedure("resnet50", "cifar10", "packed")(model) + get_procedure("wideresnet28x10", "cifar10", "batched")(model) + get_procedure("vgg16", "cifar10", "standard")(model) def test_optim_cifar100(self): - procedure = get_procedure("resnet18", "cifar100", "masked") - model = resnet18(in_channels=3, num_classes=100) - procedure(model) - - procedure = get_procedure("resnet34", "cifar100", "masked") - model = resnet34(in_channels=3, num_classes=100) - procedure(model) - - procedure = get_procedure("resnet50", "cifar100") - model = resnet50(in_channels=3, num_classes=100) - procedure(model) - - procedure = get_procedure("wideresnet28x10", "cifar100") - model = wideresnet28x10(in_channels=3, num_classes=100) - procedure(model) - - procedure = get_procedure("vgg16", "cifar100", "standard") - model = vgg16(in_channels=3, num_classes=100) - procedure(model) + model = torch.nn.Linear(1, 1) + get_procedure("resnet18", "cifar100", "masked")(model) + get_procedure("resnet34", "cifar100", "masked")(model) + get_procedure("resnet50", "cifar100")(model) + get_procedure("wideresnet28x10", "cifar100")(model) + get_procedure("vgg16", "cifar100", "standard")(model) def test_optim_tinyimagenet(self): - procedure = get_procedure("resnet34", "tiny-imagenet", "standard") - model = resnet34(in_channels=3, num_classes=1000) - procedure(model) - - procedure = get_procedure("resnet50", "tiny-imagenet", "standard") - model = resnet50(in_channels=3, num_classes=1000) - procedure(model) + model = torch.nn.Linear(1, 1) + get_procedure("resnet34", "tiny-imagenet", "standard")(model) + get_procedure("resnet50", "tiny-imagenet", "standard")(model) def test_optim_imagenet_resnet50(self): - procedure = get_procedure("resnet50", "imagenet", "standard", "A3") - model = resnet50(in_channels=3, num_classes=1000) - procedure(model, effective_batch_size=64) - procedure(model) - - procedure = get_procedure("resnet50", "imagenet", "standard") - model = resnet50(in_channels=3, num_classes=1000) - procedure(model) + model = torch.nn.Linear(1, 1) + get_procedure("resnet50", "imagenet", "standard", "A3")(model) + get_procedure("resnet50", "imagenet", "standard")(model) def test_optim_unknown(self): with pytest.raises(NotImplementedError): diff --git a/tests/test_utils.py b/tests/test_utils.py index 0ce5c482..69c17b1d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,10 @@ import pytest import torch -from huggingface_hub.utils._errors import RepositoryNotFoundError +from huggingface_hub.utils._errors import ( + HfHubHTTPError, + RepositoryNotFoundError, +) from torch.distributions import Laplace, Normal from torch_uncertainty.utils import ( @@ -37,10 +40,10 @@ def test_hub_exists(self): hub.load_hf("test", version=2) def test_hub_notexists(self): - with pytest.raises(RepositoryNotFoundError): + with pytest.raises((RepositoryNotFoundError, HfHubHTTPError)): hub.load_hf("tests") - with pytest.raises(ValueError): + with pytest.raises((ValueError, HfHubHTTPError)): hub.load_hf("test", version=42) diff --git a/torch_uncertainty/baselines/classification/resnet.py b/torch_uncertainty/baselines/classification/resnet.py index 989bb9d8..ff051b48 100644 --- a/torch_uncertainty/baselines/classification/resnet.py +++ b/torch_uncertainty/baselines/classification/resnet.py @@ -4,36 +4,12 @@ from torch_uncertainty.models import mc_dropout from torch_uncertainty.models.resnet import ( - batched_resnet18, - batched_resnet20, - batched_resnet34, - batched_resnet50, - batched_resnet101, - batched_resnet152, - masked_resnet18, - masked_resnet20, - masked_resnet34, - masked_resnet50, - masked_resnet101, - masked_resnet152, - mimo_resnet18, - mimo_resnet20, - mimo_resnet34, - mimo_resnet50, - mimo_resnet101, - mimo_resnet152, - packed_resnet18, - packed_resnet20, - packed_resnet34, - packed_resnet50, - packed_resnet101, - packed_resnet152, - resnet18, - resnet20, - resnet34, - resnet50, - resnet101, - resnet152, + batched_resnet, + lpbnn_resnet, + masked_resnet, + mimo_resnet, + packed_resnet, + resnet, ) from torch_uncertainty.routines.classification import ClassificationRoutine from torch_uncertainty.transforms import MIMOBatchFormat, RepeatTarget @@ -41,56 +17,15 @@ class ResNetBaseline(ClassificationRoutine): single = ["std"] - ensemble = ["packed", "batched", "masked", "mc-dropout", "mimo"] + ensemble = ["packed", "batched", "lpbnn", "masked", "mc-dropout", "mimo"] versions = { - "std": [ - resnet18, - resnet20, - resnet34, - resnet50, - resnet101, - resnet152, - ], - "packed": [ - packed_resnet18, - packed_resnet20, - packed_resnet34, - packed_resnet50, - packed_resnet101, - packed_resnet152, - ], - "batched": [ - batched_resnet18, - batched_resnet20, - batched_resnet34, - batched_resnet50, - batched_resnet101, - batched_resnet152, - ], - "masked": [ - masked_resnet18, - masked_resnet20, - masked_resnet34, - masked_resnet50, - masked_resnet101, - masked_resnet152, - ], - "mimo": [ - mimo_resnet18, - mimo_resnet20, - mimo_resnet34, - mimo_resnet50, - mimo_resnet101, - mimo_resnet152, - ], - "mc-dropout": [ - resnet18, - resnet20, - resnet34, - resnet50, - resnet101, - resnet152, - ], + "std": resnet, + "packed": packed_resnet, + "batched": batched_resnet, + "lpbnn": lpbnn_resnet, + "masked": masked_resnet, + "mimo": mimo_resnet, + "mc-dropout": resnet, } archs = [18, 20, 34, 50, 101, 152] @@ -104,6 +39,7 @@ def __init__( "mc-dropout", "packed", "batched", + "lpbnn", "masked", "mimo", ], @@ -133,6 +69,7 @@ def __init__( calibration_set: Literal["val", "test"] | None = None, eval_ood: bool = False, eval_grouping_loss: bool = False, + num_calibration_bins: int = 15, pretrained: bool = False, ) -> None: r"""ResNet backbone baseline for classification providing support for @@ -219,6 +156,8 @@ def __init__( OOD detection or not. Defaults to ``False``. eval_grouping_loss (bool, optional): Indicates whether to evaluate the grouping loss or not. Defaults to ``False``. + num_calibration_bins (int, optional): Number of calibration bins. + Defaults to ``15``. pretrained (bool, optional): Indicates whether to use the pretrained weights or not. Only used if :attr:`version` is ``"packed"``. Defaults to ``False``. @@ -231,6 +170,7 @@ def __init__( LightningModule: ResNet baseline ready for training and evaluation. """ params = { + "arch": arch, "conv_bias": False, "dropout_rate": dropout_rate, "groups": groups, @@ -274,7 +214,7 @@ def __init__( if version == "mc-dropout": # std ResNets don't have `num_estimators` del params["num_estimators"] - model = self.versions[version][self.archs.index(arch)](**params) + model = self.versions[version](**params) if version == "mc-dropout": model = mc_dropout( model=model, @@ -301,5 +241,6 @@ def __init__( log_plots=log_plots, save_in_csv=save_in_csv, calibration_set=calibration_set, + num_calibration_bins=num_calibration_bins, ) self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/classification/vgg.py b/torch_uncertainty/baselines/classification/vgg.py index 9c429ea1..fc4f5256 100644 --- a/torch_uncertainty/baselines/classification/vgg.py +++ b/torch_uncertainty/baselines/classification/vgg.py @@ -4,14 +4,8 @@ from torch_uncertainty.models import mc_dropout from torch_uncertainty.models.vgg import ( - packed_vgg11, - packed_vgg13, - packed_vgg16, - packed_vgg19, - vgg11, - vgg13, - vgg16, - vgg19, + packed_vgg, + vgg, ) from torch_uncertainty.routines.classification import ClassificationRoutine from torch_uncertainty.transforms import RepeatTarget @@ -21,14 +15,9 @@ class VGGBaseline(ClassificationRoutine): single = ["std"] ensemble = ["mc-dropout", "packed"] versions = { - "std": [vgg11, vgg13, vgg16, vgg19], - "mc-dropout": [vgg11, vgg13, vgg16, vgg19], - "packed": [ - packed_vgg11, - packed_vgg13, - packed_vgg16, - packed_vgg19, - ], + "std": vgg, + "mc-dropout": vgg, + "packed": packed_vgg, } archs = [11, 13, 16, 19] @@ -139,6 +128,7 @@ def __init__( "num_classes": num_classes, "style": style, "groups": groups, + "arch": arch, } if version not in self.versions: @@ -174,7 +164,7 @@ def __init__( if version == "mc-dropout": # std VGGs don't have `num_estimators` del params["num_estimators"] - model = self.versions[version][self.archs.index(arch)](**params) + model = self.versions[version](**params) if version == "mc-dropout": model = mc_dropout( model=model, diff --git a/torch_uncertainty/baselines/depth/__init__.py b/torch_uncertainty/baselines/depth/__init__.py new file mode 100644 index 00000000..6643eab0 --- /dev/null +++ b/torch_uncertainty/baselines/depth/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa: F401 +from .bts import BTSBaseline diff --git a/torch_uncertainty/baselines/depth/bts.py b/torch_uncertainty/baselines/depth/bts.py new file mode 100644 index 00000000..2f05e18b --- /dev/null +++ b/torch_uncertainty/baselines/depth/bts.py @@ -0,0 +1,48 @@ +from typing import Literal + +from torch import nn + +from torch_uncertainty.models.depth.bts import bts_resnet50, bts_resnet101 +from torch_uncertainty.routines import PixelRegressionRoutine + + +class BTSBaseline(PixelRegressionRoutine): + single = ["std"] + versions = { + "std": [ + bts_resnet50, + bts_resnet101, + ] + } + archs = [50, 101] + + def __init__( + self, + loss: nn.Module, + version: Literal["std"], + arch: int, + max_depth: float, + num_estimators: int = 1, + pretrained_backbone: bool = True, + ) -> None: + params = { + "dist_layer": nn.Identity, + "max_depth": max_depth, + "pretrained_backbone": pretrained_backbone, + } + + format_batch_fn = nn.Identity() + + if version not in self.versions: + raise ValueError(f"Unknown version {version}") + + model = self.versions[version][self.archs.index(arch)](**params) + super().__init__( + output_dim=1, + probabilistic=False, + model=model, + loss=loss, + num_estimators=num_estimators, + format_batch_fn=format_batch_fn, + ) + self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/segmentation/__init__.py b/torch_uncertainty/baselines/segmentation/__init__.py index fe2488e4..3dbaae4a 100644 --- a/torch_uncertainty/baselines/segmentation/__init__.py +++ b/torch_uncertainty/baselines/segmentation/__init__.py @@ -1,2 +1,3 @@ # ruff: noqa: F401 +from .deeplab import DeepLabBaseline from .segformer import SegFormerBaseline diff --git a/torch_uncertainty/baselines/segmentation/deeplab.py b/torch_uncertainty/baselines/segmentation/deeplab.py new file mode 100644 index 00000000..01575f1f --- /dev/null +++ b/torch_uncertainty/baselines/segmentation/deeplab.py @@ -0,0 +1,61 @@ +from typing import Literal + +from torch import nn + +from torch_uncertainty.models.segmentation.deeplab import ( + deep_lab_v3_resnet50, + deep_lab_v3_resnet101, +) +from torch_uncertainty.routines.segmentation import SegmentationRoutine + + +class DeepLabBaseline(SegmentationRoutine): + single = ["std"] + versions = { + "std": [ + deep_lab_v3_resnet50, + deep_lab_v3_resnet101, + ] + } + archs = [50, 101] + + def __init__( + self, + num_classes: int, + loss: nn.Module, + version: Literal["std"], + arch: int, + style: Literal["v3", "v3+"], + output_stride: int, + separable: bool, + num_estimators: int = 1, + metric_subsampling_rate: float = 1e-2, + log_plots: bool = False, + num_calibration_bins: int = 15, + pretrained_backbone: bool = True, + ) -> None: + params = { + "num_classes": num_classes, + "style": style, + "output_stride": output_stride, + "separable": separable, + "pretrained_backbone": pretrained_backbone, + } + + format_batch_fn = nn.Identity() + + if version not in self.versions: + raise ValueError(f"Unknown version {version}") + + model = self.versions[version][self.archs.index(arch)](**params) + super().__init__( + num_classes=num_classes, + model=model, + loss=loss, + num_estimators=num_estimators, + format_batch_fn=format_batch_fn, + metric_subsampling_rate=metric_subsampling_rate, + log_plots=log_plots, + num_calibration_bins=num_calibration_bins, + ) + self.save_hyperparameters(ignore=["loss"]) diff --git a/torch_uncertainty/baselines/segmentation/segformer.py b/torch_uncertainty/baselines/segmentation/segformer.py index 1e8185a1..97d98a3b 100644 --- a/torch_uncertainty/baselines/segmentation/segformer.py +++ b/torch_uncertainty/baselines/segmentation/segformer.py @@ -3,12 +3,7 @@ from torch import nn from torch_uncertainty.models.segmentation.segformer import ( - seg_former_b0, - seg_former_b1, - seg_former_b2, - seg_former_b3, - seg_former_b4, - seg_former_b5, + seg_former, ) from torch_uncertainty.routines.segmentation import SegmentationRoutine @@ -16,14 +11,7 @@ class SegFormerBaseline(SegmentationRoutine): single = ["std"] versions = { - "std": [ - seg_former_b0, - seg_former_b1, - seg_former_b2, - seg_former_b3, - seg_former_b4, - seg_former_b5, - ] + "std": seg_former, } archs = [0, 1, 2, 3, 4, 5] @@ -61,6 +49,7 @@ def __init__( """ params = { "num_classes": num_classes, + "arch": arch, } format_batch_fn = nn.Identity() @@ -68,7 +57,7 @@ def __init__( if version not in self.versions: raise ValueError(f"Unknown version {version}") - model = self.versions[version][self.archs.index(arch)](**params) + model = self.versions[version](**params) super().__init__( num_classes=num_classes, diff --git a/torch_uncertainty/datamodules/classification/imagenet.py b/torch_uncertainty/datamodules/classification/imagenet.py index 8f89a23b..d215a79f 100644 --- a/torch_uncertainty/datamodules/classification/imagenet.py +++ b/torch_uncertainty/datamodules/classification/imagenet.py @@ -17,7 +17,10 @@ ImageNetR, OpenImageO, ) -from torch_uncertainty.utils.misc import create_train_val_split +from torch_uncertainty.utils import ( + create_train_val_split, + interpolation_modes_from_str, +) class ImageNetDataModule(AbstractDataModule): @@ -45,6 +48,7 @@ def __init__( test_alt: str | None = None, procedure: str | None = None, train_size: int = 224, + interpolation: str = "bilinear", rand_augment_opt: str | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -65,6 +69,8 @@ def __init__( test_alt (str): Which test set to use. Defaults to ``None``. procedure (str): Which procedure to use. Defaults to ``None``. train_size (int): Size of training images. Defaults to ``224``. + interpolation (str): Interpolation method for the Resize Crops. + Defaults to ``"bilinear"``. rand_augment_opt (str): Which RandAugment to use. Defaults to ``None``. num_workers (int): Number of workers to use for data loading. Defaults to ``1``. @@ -89,6 +95,7 @@ def __init__( self.val_split = val_split self.ood_ds = ood_ds self.test_alt = test_alt + self.interpolation = interpolation_modes_from_str(interpolation) if test_alt is None: self.dataset = ImageNet @@ -137,7 +144,9 @@ def __init__( self.train_transform = T.Compose( [ - T.RandomResizedCrop(train_size), + T.RandomResizedCrop( + train_size, interpolation=self.interpolation + ), T.RandomHorizontalFlip(), main_transform, T.ToTensor(), @@ -147,7 +156,7 @@ def __init__( self.test_transform = T.Compose( [ - T.Resize(256), + T.Resize(256, interpolation=self.interpolation), T.CenterCrop(224), T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), diff --git a/torch_uncertainty/datamodules/classification/tiny_imagenet.py b/torch_uncertainty/datamodules/classification/tiny_imagenet.py index 5430264d..25c62f31 100644 --- a/torch_uncertainty/datamodules/classification/tiny_imagenet.py +++ b/torch_uncertainty/datamodules/classification/tiny_imagenet.py @@ -11,7 +11,10 @@ from torch_uncertainty.datamodules.abstract import AbstractDataModule from torch_uncertainty.datasets.classification import ImageNetO, TinyImageNet -from torch_uncertainty.utils import create_train_val_split +from torch_uncertainty.utils import ( + create_train_val_split, + interpolation_modes_from_str, +) class TinyImageNetDataModule(AbstractDataModule): @@ -26,6 +29,7 @@ def __init__( eval_ood: bool = False, val_split: float | None = None, ood_ds: str = "svhn", + interpolation: str = "bilinear", rand_augment_opt: str | None = None, num_workers: int = 1, pin_memory: bool = True, @@ -43,6 +47,7 @@ def __init__( self.eval_ood = eval_ood self.ood_ds = ood_ds + self.interpolation = interpolation_modes_from_str(interpolation) self.dataset = TinyImageNet @@ -74,7 +79,7 @@ def __init__( self.test_transform = T.Compose( [ - T.Resize(64), + T.Resize(64, interpolation=self.interpolation), T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ] diff --git a/torch_uncertainty/datamodules/depth/__init__.py b/torch_uncertainty/datamodules/depth/__init__.py new file mode 100644 index 00000000..39d81e56 --- /dev/null +++ b/torch_uncertainty/datamodules/depth/__init__.py @@ -0,0 +1,4 @@ +# ruff: noqa: F401 +from .kitti import KITTIDataModule +from .muad import MUADDataModule +from .nyu import NYUv2DataModule diff --git a/torch_uncertainty/datamodules/depth_estimation/muad.py b/torch_uncertainty/datamodules/depth/base.py similarity index 74% rename from torch_uncertainty/datamodules/depth_estimation/muad.py rename to torch_uncertainty/datamodules/depth/base.py index a751926e..47e8cf73 100644 --- a/torch_uncertainty/datamodules/depth_estimation/muad.py +++ b/torch_uncertainty/datamodules/depth/base.py @@ -4,51 +4,54 @@ from torch.nn.common_types import _size_2_t from torch.nn.modules.utils import _pair from torchvision import tv_tensors +from torchvision.datasets import VisionDataset from torchvision.transforms import v2 from torch_uncertainty.datamodules.abstract import AbstractDataModule -from torch_uncertainty.datasets import MUAD from torch_uncertainty.transforms import RandomRescale from torch_uncertainty.utils.misc import create_train_val_split -class MUADDataModule(AbstractDataModule): +class DepthDataModule(AbstractDataModule): def __init__( self, + dataset: type[VisionDataset], root: str | Path, batch_size: int, - crop_size: _size_2_t = 1024, - inference_size: _size_2_t = (1024, 2048), + min_depth: float, + max_depth: float, + crop_size: _size_2_t, + inference_size: _size_2_t, val_split: float | None = None, num_workers: int = 1, pin_memory: bool = True, persistent_workers: bool = True, ) -> None: - r"""Segmentation DataModule for the MUAD dataset. + r"""Base depth datamodule. Args: + dataset (type[VisionDataset]): Dataset class to use. root (str or Path): Root directory of the datasets. batch_size (int): Number of samples per batch. + min_depth (float, optional): Minimum depth value for evaluation. + max_depth (float, optional): Maximum depth value for training and + evaluation. crop_size (sequence or int, optional): Desired input image and - segmentation mask sizes during training. If :attr:`crop_size` is an + depth mask sizes during training. If :attr:`crop_size` is an int instead of sequence like :math:`(H, W)`, a square crop :math:`(\text{size},\text{size})` is made. If provided a sequence of length :math:`1`, it will be interpreted as - :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. + :math:`(\text{size[0]},\text{size[1]})`. inference_size (sequence or int, optional): Desired input image and - segmentation mask sizes during inference. If size is an int, + depth mask sizes during inference. If size is an int, smaller edge of the images will be matched to this number, i.e., :math:`\text{height}>\text{width}`, then image will be rescaled to :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. - Defaults to ``(1024,2048)``. val_split (float or None, optional): Share of training samples to use - for validation. Defaults to ``None``. - num_workers (int, optional): Number of dataloaders to use. Defaults to - ``1``. - pin_memory (bool, optional): Whether to pin memory. Defaults to - ``True``. + for validation. + num_workers (int, optional): Number of dataloaders to use. + pin_memory (bool, optional): Whether to pin memory. persistent_workers (bool, optional): Whether to use persistent workers. - Defaults to ``True``. """ super().__init__( root=root, @@ -59,17 +62,19 @@ def __init__( persistent_workers=persistent_workers, ) - self.dataset = MUAD + self.dataset = dataset + self.min_depth = min_depth + self.max_depth = max_depth self.crop_size = _pair(crop_size) self.inference_size = _pair(inference_size) self.train_transform = v2.Compose( [ - RandomRescale(min_scale=0.5, max_scale=2.0, antialias=True), + RandomRescale(min_scale=0.5, max_scale=2.0), v2.RandomCrop( size=self.crop_size, pad_if_needed=True, - fill={tv_tensors.Image: 0, tv_tensors.Mask: -float("inf")}, + fill={tv_tensors.Image: 0, tv_tensors.Mask: float("nan")}, ), v2.RandomHorizontalFlip(), v2.ToDtype( @@ -86,7 +91,7 @@ def __init__( ) self.test_transform = v2.Compose( [ - v2.Resize(size=self.inference_size, antialias=True), + v2.Resize(size=self.inference_size), v2.ToDtype( dtype={ tv_tensors.Image: torch.float32, @@ -102,18 +107,21 @@ def __init__( def prepare_data(self) -> None: # coverage: ignore self.dataset( - root=self.root, split="train", target_type="depth", download=True + root=self.root, + split="train", + max_depth=self.max_depth, + download=True, ) self.dataset( - root=self.root, split="val", target_type="depth", download=True + root=self.root, split="val", max_depth=self.max_depth, download=True ) def setup(self, stage: str | None = None) -> None: if stage == "fit" or stage is None: full = self.dataset( root=self.root, + max_depth=self.max_depth, split="train", - target_type="depth", transforms=self.train_transform, ) @@ -123,20 +131,23 @@ def setup(self, stage: str | None = None) -> None: self.val_split, self.test_transform, ) + self.val.min_depth = self.min_depth else: self.train = full self.val = self.dataset( root=self.root, + min_depth=self.min_depth, + max_depth=self.max_depth, split="val", - target_type="depth", transforms=self.test_transform, ) if stage == "test" or stage is None: self.test = self.dataset( root=self.root, + min_depth=self.min_depth, + max_depth=self.max_depth, split="val", - target_type="depth", transforms=self.test_transform, ) diff --git a/torch_uncertainty/datamodules/depth/kitti.py b/torch_uncertainty/datamodules/depth/kitti.py new file mode 100644 index 00000000..55f30296 --- /dev/null +++ b/torch_uncertainty/datamodules/depth/kitti.py @@ -0,0 +1,66 @@ +from pathlib import Path + +from torch.nn.common_types import _size_2_t + +from torch_uncertainty.datasets import KITTIDepth + +from .base import DepthDataModule + + +class KITTIDataModule(DepthDataModule): + def __init__( + self, + root: str | Path, + batch_size: int, + min_depth: float = 1e-3, + max_depth: float = 80.0, + crop_size: _size_2_t = (352, 704), + inference_size: _size_2_t = (375, 1242), + val_split: float | None = None, + num_workers: int = 1, + pin_memory: bool = True, + persistent_workers: bool = True, + ) -> None: + r"""Depth DataModule for the KITTI-Depth dataset. + + Args: + root (str or Path): Root directory of the datasets. + batch_size (int): Number of samples per batch. + min_depth (float, optional): Minimum depth value for evaluation. + Defaults to ``1e-3``. + max_depth (float, optional): Maximum depth value for training and + evaluation. Defaults to ``80.0``. + crop_size (sequence or int, optional): Desired input image and + depth mask sizes during training. If :attr:`crop_size` is an + int instead of sequence like :math:`(H, W)`, a square crop + :math:`(\text{size},\text{size})` is made. If provided a sequence + of length :math:`1`, it will be interpreted as + :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``(375, 1242)``. + inference_size (sequence or int, optional): Desired input image and + depth mask sizes during inference. If size is an int, + smaller edge of the images will be matched to this number, i.e., + :math:`\text{height}>\text{width}`, then image will be rescaled to + :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. + Defaults to ``(375, 1242)``. + val_split (float or None, optional): Share of training samples to use + for validation. Defaults to ``None``. + num_workers (int, optional): Number of dataloaders to use. Defaults to + ``1``. + pin_memory (bool, optional): Whether to pin memory. Defaults to + ``True``. + persistent_workers (bool, optional): Whether to use persistent workers. + Defaults to ``True``. + """ + super().__init__( + dataset=KITTIDepth, + root=root, + batch_size=batch_size, + min_depth=min_depth, + max_depth=max_depth, + crop_size=crop_size, + inference_size=inference_size, + val_split=val_split, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) diff --git a/torch_uncertainty/datamodules/depth/muad.py b/torch_uncertainty/datamodules/depth/muad.py new file mode 100644 index 00000000..5ca8643b --- /dev/null +++ b/torch_uncertainty/datamodules/depth/muad.py @@ -0,0 +1,124 @@ +from pathlib import Path + +from torch.nn.common_types import _size_2_t + +from torch_uncertainty.datasets import MUAD +from torch_uncertainty.utils.misc import create_train_val_split + +from .base import DepthDataModule + + +class MUADDataModule(DepthDataModule): + def __init__( + self, + root: str | Path, + batch_size: int, + min_depth: float, + max_depth: float, + crop_size: _size_2_t = 1024, + inference_size: _size_2_t = (1024, 2048), + val_split: float | None = None, + num_workers: int = 1, + pin_memory: bool = True, + persistent_workers: bool = True, + ) -> None: + r"""Depth DataModule for the MUAD dataset. + + Args: + root (str or Path): Root directory of the datasets. + batch_size (int): Number of samples per batch. + min_depth (float, optional): Minimum depth value for evaluation + max_depth (float, optional): Maximum depth value for training and + evaluation. + crop_size (sequence or int, optional): Desired input image and + depth mask sizes during training. If :attr:`crop_size` is an + int instead of sequence like :math:`(H, W)`, a square crop + :math:`(\text{size},\text{size})` is made. If provided a sequence + of length :math:`1`, it will be interpreted as + :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``1024``. + inference_size (sequence or int, optional): Desired input image and + depth mask sizes during inference. If size is an int, + smaller edge of the images will be matched to this number, i.e., + :math:`\text{height}>\text{width}`, then image will be rescaled to + :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. + Defaults to ``(1024,2048)``. + val_split (float or None, optional): Share of training samples to use + for validation. Defaults to ``None``. + num_workers (int, optional): Number of dataloaders to use. Defaults to + ``1``. + pin_memory (bool, optional): Whether to pin memory. Defaults to + ``True``. + persistent_workers (bool, optional): Whether to use persistent workers. + Defaults to ``True``. + """ + super().__init__( + dataset=MUAD, + root=root, + batch_size=batch_size, + min_depth=min_depth, + max_depth=max_depth, + crop_size=crop_size, + inference_size=inference_size, + val_split=val_split, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + + def prepare_data(self) -> None: # coverage: ignore + self.dataset( + root=self.root, + split="train", + max_depth=self.max_depth, + target_type="depth", + download=True, + ) + self.dataset( + root=self.root, + split="val", + min_depth=self.min_depth, + max_depth=self.max_depth, + target_type="depth", + download=True, + ) + + def setup(self, stage: str | None = None) -> None: + if stage == "fit" or stage is None: + full = self.dataset( + root=self.root, + split="train", + max_depth=self.max_depth, + target_type="depth", + transforms=self.train_transform, + ) + + if self.val_split is not None: + self.train, self.val = create_train_val_split( + full, + self.val_split, + self.test_transform, + ) + self.val.min_depth = self.min_depth + else: + self.train = full + self.val = self.dataset( + root=self.root, + split="val", + min_depth=self.min_depth, + max_depth=self.max_depth, + target_type="depth", + transforms=self.test_transform, + ) + + if stage == "test" or stage is None: + self.test = self.dataset( + root=self.root, + split="val", + min_depth=self.min_depth, + max_depth=self.max_depth, + target_type="depth", + transforms=self.test_transform, + ) + + if stage not in ["fit", "test", None]: + raise ValueError(f"Stage {stage} is not supported.") diff --git a/torch_uncertainty/datamodules/depth/nyu.py b/torch_uncertainty/datamodules/depth/nyu.py new file mode 100644 index 00000000..c421c044 --- /dev/null +++ b/torch_uncertainty/datamodules/depth/nyu.py @@ -0,0 +1,66 @@ +from pathlib import Path + +from torch.nn.common_types import _size_2_t + +from torch_uncertainty.datasets import NYUv2 + +from .base import DepthDataModule + + +class NYUv2DataModule(DepthDataModule): + def __init__( + self, + root: str | Path, + batch_size: int, + min_depth: float = 1e-3, + max_depth: float = 10.0, + crop_size: _size_2_t = (416, 544), + inference_size: _size_2_t = (480, 640), + val_split: float | None = None, + num_workers: int = 1, + pin_memory: bool = True, + persistent_workers: bool = True, + ) -> None: + r"""Depth DataModule for the NYUv2 dataset. + + Args: + root (str or Path): Root directory of the datasets. + batch_size (int): Number of samples per batch. + min_depth (float, optional): Minimum depth value for evaluation. + Defaults to ``1e-3``. + max_depth (float, optional): Maximum depth value for training and + evaluation. Defaults to ``10.0``. + crop_size (sequence or int, optional): Desired input image and + depth mask sizes during training. If :attr:`crop_size` is an + int instead of sequence like :math:`(H, W)`, a square crop + :math:`(\text{size},\text{size})` is made. If provided a sequence + of length :math:`1`, it will be interpreted as + :math:`(\text{size[0]},\text{size[1]})`. Defaults to ``(416, 544)``. + inference_size (sequence or int, optional): Desired input image and + depth mask sizes during inference. If size is an int, + smaller edge of the images will be matched to this number, i.e., + :math:`\text{height}>\text{width}`, then image will be rescaled to + :math:`(\text{size}\times\text{height}/\text{width},\text{size})`. + Defaults to ``(480, 640)``. + val_split (float or None, optional): Share of training samples to use + for validation. Defaults to ``None``. + num_workers (int, optional): Number of dataloaders to use. Defaults to + ``1``. + pin_memory (bool, optional): Whether to pin memory. Defaults to + ``True``. + persistent_workers (bool, optional): Whether to use persistent workers. + Defaults to ``True``. + """ + super().__init__( + dataset=NYUv2, + root=root, + batch_size=batch_size, + min_depth=min_depth, + max_depth=max_depth, + crop_size=crop_size, + inference_size=inference_size, + val_split=val_split, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) diff --git a/torch_uncertainty/datamodules/depth_estimation/__init__.py b/torch_uncertainty/datamodules/depth_estimation/__init__.py deleted file mode 100644 index dc94a8cb..00000000 --- a/torch_uncertainty/datamodules/depth_estimation/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# ruff: noqa: F401 -from .muad import MUADDataModule diff --git a/torch_uncertainty/datasets/__init__.py b/torch_uncertainty/datasets/__init__.py index 732334a0..d6a02df0 100644 --- a/torch_uncertainty/datasets/__init__.py +++ b/torch_uncertainty/datasets/__init__.py @@ -1,4 +1,6 @@ # ruff: noqa: F401 from .aggregated_dataset import AggregatedDataset from .frost import FrostImages +from .kitti import KITTIDepth from .muad import MUAD +from .nyu import NYUv2 diff --git a/torch_uncertainty/datasets/kitti.py b/torch_uncertainty/datasets/kitti.py new file mode 100644 index 00000000..f2b2a35f --- /dev/null +++ b/torch_uncertainty/datasets/kitti.py @@ -0,0 +1,265 @@ +import json +import shutil +from collections.abc import Callable +from pathlib import Path +from typing import Literal + +from PIL import Image +from torchvision import tv_tensors +from torchvision.datasets import VisionDataset +from torchvision.datasets.utils import ( + download_and_extract_archive, + download_url, +) +from torchvision.transforms import functional as F +from tqdm import tqdm + + +class KITTIDepth(VisionDataset): + root: Path + depth_url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/data_depth_annotated.zip" + depth_md5 = "7d1ce32633dc2f43d9d1656a1f875e47" + raw_url = "https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/" + raw_filenames_url = "https://raw.githubusercontent.com/torch-uncertainty/dataset-metadata/main/download/kitti/raw_filenames.json" + raw_filenames_md5 = "e5b7fad5ecd059488ef6c02dc9e444c1" + _num_samples = { + "train": 42949, + "val": 3426, + "test": ..., + } + + def __init__( + self, + root: str | Path, + split: Literal["train", "val"], + min_depth: float = 0.0, + max_depth: float = 80.0, + transforms: Callable | None = None, + download: bool = False, + remove_unused: bool = False, + ) -> None: + print( + "KITTIDepth is copyrighted by the Karlsruhe Institute of Technology " + "(KIT) and the Toyota Technological Institute at Chicago (TTIC). " + "By using KITTIDepth, you agree to the terms and conditions of the " + "Creative Commons Attribution-NonCommercial-ShareAlike 3.0 License. " + "This means that you must attribute the work in the manner specified " + "by the authors, you may not use this work for commercial purposes " + "and if you alter, transform, or build upon this work, you may " + "distribute the resulting work only under the same license." + ) + + super().__init__( + root=Path(root) / "KITTIDepth", + transforms=transforms, + ) + self.min_depth = min_depth + self.max_depth = max_depth + + if split not in ["train", "val"]: + raise ValueError( + f"split must be one of ['train', 'val']. Got {split}." + ) + + self.split = split + + if not self.check_split_integrity("leftDepth"): + if download: + self._download_depth() + else: + raise FileNotFoundError( + f"KITTI {split} split not found or incomplete. Set download=True to download it." + ) + + if not self.check_split_integrity("leftImg8bit"): + if download: + self._download_raw(remove_unused) + else: + raise FileNotFoundError( + f"KITTI {split} split not found or incomplete. Set download=True to download it." + ) + + self._make_dataset() + + def check_split_integrity(self, folder: str) -> bool: + split_path = self.root / self.split + return ( + split_path.is_dir() + and len(list((split_path / folder).glob("*.png"))) + == self._num_samples[self.split] + ) + + def __getitem__( + self, index: int + ) -> tuple[tv_tensors.Image, tv_tensors.Mask]: + """Get the sample at the given index. + + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is a depth map. + """ + image = tv_tensors.Image(Image.open(self.samples[index]).convert("RGB")) + target = tv_tensors.Mask( + F.pil_to_tensor(Image.open(self.targets[index])).squeeze(0) / 256.0 + ) + target[(target <= self.min_depth) | (target > self.max_depth)] = float( + "nan" + ) + + if self.transforms is not None: + image, target = self.transforms(image, target) + + return image, target + + def __len__(self) -> int: + """The number of samples in the dataset.""" + return self._num_samples[self.split] + + def _make_dataset(self) -> None: + self.samples = sorted( + (self.root / self.split / "leftImg8bit").glob("*.png") + ) + self.targets = sorted( + (self.root / self.split / "leftDepth").glob("*.png") + ) + + def _download_depth(self) -> None: + """Download and extract the depth annotation dataset.""" + if not (self.root / "tmp").exists(): + download_and_extract_archive( + self.depth_url, + download_root=self.root, + extract_root=self.root / "tmp", + md5=self.depth_md5, + ) + + print("Re-structuring the depth annotations...") + + if (self.root / "train" / "leftDepth").exists(): + shutil.rmtree(self.root / "train" / "leftDepth") + + (self.root / "train" / "leftDepth").mkdir(parents=True, exist_ok=False) + + depth_files = list((self.root).glob("**/tmp/train/**/image_02/*.png")) + print("Train files:") + for file in tqdm(depth_files): + exp_code = file.parents[3].name.split("_") + filecode = "_".join( + [exp_code[0], exp_code[1], exp_code[2], exp_code[4], file.name] + ) + shutil.copy(file, self.root / "train" / "leftDepth" / filecode) + + if (self.root / "val" / "leftDepth").exists(): + shutil.rmtree(self.root / "val" / "leftDepth") + + (self.root / "val" / "leftDepth").mkdir(parents=True, exist_ok=False) + + depth_files = list((self.root).glob("**/tmp/val/**/image_02/*.png")) + print("Validation files:") + for file in tqdm(depth_files): + exp_code = file.parents[3].name.split("_") + filecode = "_".join( + [exp_code[0], exp_code[1], exp_code[2], exp_code[4], file.name] + ) + shutil.copy(file, self.root / "val" / "leftDepth" / filecode) + + shutil.rmtree(self.root / "tmp") + + def _download_raw(self, remove_unused: bool) -> None: + """Download and extract the raw dataset.""" + download_url( + self.raw_filenames_url, + self.root, + "raw_filenames.json", + self.raw_filenames_md5, + ) + with (self.root / "raw_filenames.json").open() as file: + raw_filenames = json.load(file) + + for filename in tqdm(raw_filenames): + print(self.raw_url + filename) + download_and_extract_archive( + self.raw_url + filename, + download_root=self.root, + extract_root=self.root / "raw", + md5=None, + ) + + print("Re-structuring the raw data...") + + samples_to_keep = list( + (self.root / "train" / "leftDepth").glob("*.png") + ) + + if (self.root / "train" / "leftImg8bit").exists(): + shutil.rmtree(self.root / "train" / "leftImg8bit") + + (self.root / "train" / "leftImg8bit").mkdir( + parents=True, exist_ok=False + ) + + print("Train files:") + for sample in tqdm(samples_to_keep): + filecode = sample.name.split("_") + first_level = "_".join([filecode[0], filecode[1], filecode[2]]) + second_level = "_".join( + [ + filecode[0], + filecode[1], + filecode[2], + "drive", + filecode[3], + "sync", + ] + ) + raw_path = ( + self.root + / "raw" + / first_level + / second_level + / "image_02" + / "data" + / filecode[4] + ) + shutil.copy( + raw_path, self.root / "train" / "leftImg8bit" / sample.name + ) + + samples_to_keep = list((self.root / "val" / "leftDepth").glob("*.png")) + + if (self.root / "val" / "leftImg8bit").exists(): + shutil.rmtree(self.root / "val" / "leftImg8bit") + + (self.root / "val" / "leftImg8bit").mkdir(parents=True, exist_ok=False) + + print("Validation files:") + for sample in tqdm(samples_to_keep): + filecode = sample.name.split("_") + first_level = "_".join([filecode[0], filecode[1], filecode[2]]) + second_level = "_".join( + [ + filecode[0], + filecode[1], + filecode[2], + "drive", + filecode[3], + "sync", + ] + ) + raw_path = ( + self.root + / "raw" + / first_level + / second_level + / "image_02" + / "data" + / filecode[4] + ) + shutil.copy( + raw_path, self.root / "val" / "leftImg8bit" / sample.name + ) + + if remove_unused: + shutil.rmtree(self.root / "raw") diff --git a/torch_uncertainty/datasets/muad.py b/torch_uncertainty/datasets/muad.py index ffe842e8..9cde371a 100644 --- a/torch_uncertainty/datasets/muad.py +++ b/torch_uncertainty/datasets/muad.py @@ -3,7 +3,7 @@ import shutil from collections.abc import Callable from pathlib import Path -from typing import Any, Literal +from typing import Literal import cv2 import numpy as np @@ -42,6 +42,8 @@ def __init__( self, root: str | Path, split: Literal["train", "val"], + min_depth: float | None = None, + max_depth: float | None = None, target_type: Literal["semantic", "depth"] = "semantic", transforms: Callable | None = None, download: bool = False, @@ -52,6 +54,10 @@ def __init__( root (str): Root directory of dataset where directory 'leftImg8bit' and 'leftLabel' or 'leftDepth' are located. split (str, optional): The image split to use, 'train' or 'val'. + min_depth (float, optional): The maximum depth value to use if + target_type is 'depth'. Defaults to None. + max_depth (float, optional): The maximum depth value to use if + target_type is 'depth'. Defaults to None. target_type (str, optional): The type of target to use, 'semantic' or 'depth'. transforms (callable, optional): A function/transform that takes in @@ -75,6 +81,8 @@ def __init__( root=Path(root) / "MUAD", transforms=transforms, ) + self.min_depth = min_depth + self.max_depth = max_depth if split not in ["train", "val"]: raise ValueError( @@ -108,7 +116,7 @@ def __init__( ): if download: self._download(split=f"{split}_depth") - # FIXME: Depth target for train are in a different folder + # Depth target for train are in a different folder # thus we move them to the correct folder if split == "train": shutil.move( @@ -169,7 +177,9 @@ def decode_target(self, target: Image.Image) -> np.ndarray: target[target == 255] = 19 return self.train_id_to_color[target] - def __getitem__(self, index: int) -> tuple[Any, Any]: + def __getitem__( + self, index: int + ) -> tuple[tv_tensors.Image, tv_tensors.Mask]: """Get the sample at the given index. Args: @@ -192,10 +202,13 @@ def __getitem__(self, index: int) -> tuple[Any, Any]: cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH, ) ) - # TODO: in the long tun it would be better to use a custom + # TODO: in the long run it would be better to use a custom # tv_tensor for depth maps (e.g. tv_tensors.DepthMap) target = np.asarray(target, np.float32) target = tv_tensors.Mask(400 * (1 - target)) # convert to meters + target[(target <= self.min_depth) | (target > self.max_depth)] = ( + float("nan") + ) if self.transforms is not None: image, target = self.transforms(image, target) @@ -222,7 +235,7 @@ def _make_dataset(self, path: Path) -> None: """ if "depth" in path.name: raise NotImplementedError( - "Depth regression mode is not implemented yet. Raise an issue " + "Depth mode is not implemented yet. Raise an issue " "if you need it." ) self.samples = sorted((path / "leftImg8bit/").glob("**/*")) diff --git a/torch_uncertainty/datasets/nyu.py b/torch_uncertainty/datasets/nyu.py new file mode 100644 index 00000000..90c5736e --- /dev/null +++ b/torch_uncertainty/datasets/nyu.py @@ -0,0 +1,141 @@ +from collections.abc import Callable +from pathlib import Path +from typing import Literal + +import cv2 +import h5py +import numpy as np +from PIL import Image +from torchvision import tv_tensors +from torchvision.datasets import VisionDataset +from torchvision.datasets.utils import ( + check_integrity, + download_and_extract_archive, + download_url, +) + + +class NYUv2(VisionDataset): + root: Path + rgb_urls = { + "train": "http://www.doc.ic.ac.uk/~ahanda/nyu_train_rgb.tgz", + "val": "http://www.doc.ic.ac.uk/~ahanda/nyu_test_rgb.tgz", + } + rgb_md5 = { + "train": "ad124bbde47e371359caa4642a8a4611", + "val": "f47f7c7c8a20d1210db7941c4f153b06", + } + depth_url = "http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat" + depth_md5 = "520609c519fba3ba5ac58c8fefcc3530" + + def __init__( + self, + root: Path | str, + split: Literal["train", "val"], + transforms: Callable | None = None, + min_depth: float = 0.0, + max_depth: float = 10.0, + download: bool = False, + ): + """NYUv2 depth dataset. + + Args: + root (Path | str): Root directory where dataset is stored. + split (Literal["train", "val"]): Dataset split. + transforms (Callable | None): Transform to apply to samples & targets. + Defaults to None. + min_depth (float): Minimum depth value. Defaults to 1e-3. + max_depth (float): Maximum depth value. Defaults to 10. + download (bool): Download dataset if not found. Defaults to False. + """ + super().__init__(Path(root) / "NYUv2", transforms=transforms) + self.min_depth = min_depth + self.max_depth = max_depth + + if split not in ["train", "val"]: + raise ValueError( + f"split must be one of ['train', 'val']. Got {split}." + ) + self.split = split + + if not self._check_integrity(): + if download: + self._download() + else: + raise FileNotFoundError( + f"NYUv2 {split} split not found or incomplete. Set download=True to download it." + ) + + # make dataset + path = self.root / self.split + self.samples = sorted((path / "rgb_img").glob("**/*")) + self.targets = sorted((path / "depth").glob("**/*")) + + def __getitem__(self, index: int): + """Return image and target at index. + + Args: + index (int): Index of the sample. + """ + image = tv_tensors.Image(Image.open(self.samples[index]).convert("RGB")) + target = Image.fromarray( + cv2.imread( + str(self.targets[index]), + cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH, + ) + ) + target = np.asarray(target, np.uint16) + target = tv_tensors.Mask(target / 1e4) # convert to meters + target[(target <= self.min_depth) | (target > self.max_depth)] = float( + "nan" + ) + if self.transforms is not None: + image, target = self.transforms(image, target) + return image, target + + def __len__(self): + """Return number of samples in dataset.""" + return len(self.samples) + + def _check_integrity(self) -> bool: + """Check if dataset is present and complete.""" + return ( + check_integrity( + self.root / f"nyu_{self.split}_rgb.tgz", + self.rgb_md5[self.split], + ) + and check_integrity(self.root / "depth.mat", self.depth_md5) + and (self.root / self.split / "rgb_img").exists() + and (self.root / self.split / "depth").exists() + ) + + def _download(self): + """Download and extract dataset.""" + download_and_extract_archive( + self.rgb_urls[self.split], + self.root, + extract_root=self.root / self.split / "rgb_img", + filename=f"nyu_{self.split}_rgb.tgz", + md5=self.rgb_md5[self.split], + ) + if not check_integrity(self.root / "depth.mat", self.depth_md5): + download_url( + NYUv2.depth_url, self.root, "depth.mat", self.depth_md5 + ) + self._create_depth_files() + + def _create_depth_files(self): + """Create depth images from the depth.mat file.""" + path = self.root / self.split + (path / "depth").mkdir() + samples = sorted((path / "rgb_img").glob("**/*")) + ids = [int(p.stem.split("_")[-1]) for p in samples] + file = h5py.File(self.root / "depth.mat", "r") + depths = file["depths"] + for i in range(len(depths)): + img_id = i + 1 + if img_id in ids: + img = (depths[i] * 1e4).astype(np.uint16).T + Image.fromarray(img).save( + path / "depth" / f"nyu_depth_{str(img_id).zfill(4)}.png" + ) diff --git a/torch_uncertainty/layers/batch_ensemble.py b/torch_uncertainty/layers/batch_ensemble.py index 6022f40b..ac641413 100644 --- a/torch_uncertainty/layers/batch_ensemble.py +++ b/torch_uncertainty/layers/batch_ensemble.py @@ -1,7 +1,7 @@ import math import torch -from torch import nn +from torch import Tensor, nn from torch.nn.common_types import _size_2_t from torch.nn.modules.utils import _pair @@ -10,11 +10,10 @@ class BatchLinear(nn.Module): __constants__ = ["in_features", "out_features", "num_estimators"] in_features: int out_features: int - n_estimator: int - weight: torch.Tensor - r_group: torch.Tensor - s_group: torch.Tensor - bias: torch.Tensor | None + num_estimators: int + r_group: Tensor + s_group: Tensor + bias: Tensor | None def __init__( self, @@ -27,17 +26,17 @@ def __init__( ) -> None: r"""BatchEnsemble-style Linear layer. - Applies a linear transformation using BatchEnsemble method to the incoming + Apply a linear transformation using BatchEnsemble method to the incoming data. .. math:: y=(x\circ \widehat{r_{group}})W^{T}\circ \widehat{s_{group}} + \widehat{b} Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - num_estimators (int): number of estimators in the ensemble referred as - :math:`M` here. + in_features (int): Number of input features.. + out_features (int): Number of output features. + num_estimators (int): number of estimators in the ensemble, referred as + :math:`M`. bias (bool, optional): if ``True``, adds a learnable bias to the output. Defaults to ``True``. device (Any, optional): device to use for the parameters and @@ -97,7 +96,10 @@ def __init__( self.num_estimators = num_estimators self.linear = nn.Linear( - in_features=in_features, out_features=out_features, bias=False + in_features=in_features, + out_features=out_features, + bias=False, + **factory_kwargs, ) self.r_group = nn.Parameter( @@ -124,7 +126,7 @@ def reset_parameters(self) -> None: bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(self.bias, -bound, bound) - def forward(self, inputs: torch.Tensor) -> torch.Tensor: + def forward(self, inputs: Tensor) -> Tensor: batch_size = inputs.size(0) examples_per_estimator = torch.tensor( batch_size // self.num_estimators, device=inputs.device @@ -143,7 +145,6 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: s_group = torch.cat( [s_group, s_group[:extra]], dim=0 ) # .unsqueeze(-1).unsqueeze(-1) - bias: torch.Tensor | None if self.bias is not None: bias = torch.repeat_interleave( self.bias, @@ -181,15 +182,15 @@ class BatchConv2d(nn.Module): in_channels: int out_channels: int kernel_size: tuple[int, ...] - n_estimator: int + num_estimators: int stride: tuple[int, ...] padding: str | tuple[int, ...] dilation: tuple[int, ...] groups: int - weight: torch.Tensor - r_group: torch.Tensor - s_group: torch.Tensor - bias: torch.Tensor | None + weight: Tensor + r_group: Tensor + s_group: Tensor + bias: Tensor | None def __init__( self, @@ -232,7 +233,7 @@ def __init__( `_. Args: - in_channels (int): number of channels in the input image. + in_channels (int): number of channels in the input images. out_channels (int): number of channels produced by the convolution. kernel_size (int or tuple): size of the convolving kernel. num_estimators (int): number of estimators in the ensemble referred as @@ -321,6 +322,7 @@ def __init__( dilation=dilation, groups=groups, bias=False, + **factory_kwargs, ) self.r_group = nn.Parameter( torch.empty((num_estimators, in_channels), **factory_kwargs) @@ -345,7 +347,7 @@ def reset_parameters(self) -> None: bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(self.bias, -bound, bound) - def forward(self, inputs: torch.Tensor) -> torch.Tensor: + def forward(self, inputs: Tensor) -> Tensor: batch_size = inputs.size(0) examples_per_estimator = batch_size // self.num_estimators extra = batch_size % self.num_estimators @@ -381,7 +383,6 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: ) s_group = torch.cat([s_group, s_group[:extra]], dim=0) # - bias: torch.Tensor | None if self.bias is not None: bias = ( torch.repeat_interleave( @@ -406,8 +407,10 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: ) def extra_repr(self) -> str: - s = ( - "{in_channels}, {out_channels}, kernel_size={kernel_size}" - ", num_estimators={num_estimators}, stride={stride}" + return ( + f"in_channels={self.in_channels}," + f" out_channels={self.out_channels}," + f" kernel_size={self.kernel_size}," + f" num_estimators={self.num_estimators}," + f" stride={self.stride}" ) - return s.format(**self.__dict__) diff --git a/torch_uncertainty/layers/bayesian/__init__.py b/torch_uncertainty/layers/bayesian/__init__.py index 15aa4644..e650638a 100644 --- a/torch_uncertainty/layers/bayesian/__init__.py +++ b/torch_uncertainty/layers/bayesian/__init__.py @@ -1,5 +1,13 @@ # ruff: noqa: F401 from .bayes_conv import BayesConv1d, BayesConv2d, BayesConv3d from .bayes_linear import BayesLinear +from .lpbnn import LPBNNConv2d, LPBNNLinear -bayesian_modules = (BayesConv1d, BayesConv2d, BayesConv3d, BayesLinear) +bayesian_modules = ( + BayesConv1d, + BayesConv2d, + BayesConv3d, + BayesLinear, + LPBNNLinear, + LPBNNConv2d, +) diff --git a/torch_uncertainty/layers/bayesian/lpbnn.py b/torch_uncertainty/layers/bayesian/lpbnn.py new file mode 100644 index 00000000..b2585749 --- /dev/null +++ b/torch_uncertainty/layers/bayesian/lpbnn.py @@ -0,0 +1,352 @@ +"""These layers are still work in progress.""" + +import math + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.common_types import _size_2_t + + +def check_lpbnn_parameters_consistency( + hidden_size: int, std_factor: float, num_estimators: int +) -> None: + if hidden_size < 1: + raise ValueError( + f"hidden_size must be greater than 0. Got {hidden_size}." + ) + if std_factor < 0: + raise ValueError( + f"std_factor must be greater than 0. Got {std_factor}." + ) + if num_estimators < 1: + raise ValueError( + f"num_estimators must be greater than 0. Got {num_estimators}." + ) + + +def _sample(mu: Tensor, logvar: Tensor, std_factor: float) -> Tensor: + eps = torch.randn_like(mu) + return eps * std_factor * torch.exp(logvar * 0.5) + mu + + +class LPBNNLinear(nn.Module): + __constants__ = [ + "in_features", + "out_features", + "num_estimators", + "hidden_size", + ] + in_features: int + out_features: int + num_estimators: int + bias: torch.Tensor | None + + def __init__( + self, + in_features: int, + out_features: int, + num_estimators: int, + hidden_size: int = 32, + std_factor: float = 1e-2, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + """LPBNN-style linear layer. + + Args: + in_features (int): Number of input features. + out_features (int): Number of output features. + num_estimators (int): Number of models to sample from. + hidden_size (int): Size of the hidden layer. Defaults to 32. + std_factor (float): Factor to multiply the standard deviation of the + latent noise. Defaults to 1e-2. + bias (bool): If ``True``, adds a learnable bias to the output. + Defaults to ``True``. + device (torch.device): Device on which the layer is stored. + Defaults to ``None``. + dtype (torch.dtype): Data type of the layer. Defaults to ``None``. + + Reference: + `Encoding the latent posterior of Bayesian Neural Networks for + uncertainty quantification `_. + """ + check_lpbnn_parameters_consistency( + hidden_size, std_factor, num_estimators + ) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.hidden_size = hidden_size + self.std_factor = std_factor + self.num_estimators = num_estimators + + # for the KL Loss + self.lprior = 0 + + self.linear = nn.Linear( + in_features, out_features, bias=False, **factory_kwargs + ) + self.alpha = nn.Parameter( + torch.empty((num_estimators, in_features), **factory_kwargs), + requires_grad=False, + ) + self.gamma = nn.Parameter( + torch.empty((num_estimators, out_features), **factory_kwargs) + ) + self.encoder = nn.Linear( + in_features, self.hidden_size, **factory_kwargs + ) + self.latent_mean = nn.Linear( + self.hidden_size, self.hidden_size, **factory_kwargs + ) + self.latent_logvar = nn.Linear( + self.hidden_size, self.hidden_size, **factory_kwargs + ) + self.decoder = nn.Linear( + self.hidden_size, in_features, **factory_kwargs + ) + self.latent_loss = torch.zeros(1, **factory_kwargs) + if bias: + self.bias = nn.Parameter( + torch.empty((num_estimators, out_features), **factory_kwargs) + ) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + nn.init.normal_(self.alpha, mean=1.0, std=0.1) + nn.init.normal_(self.gamma, mean=1.0, std=0.1) + self.linear.reset_parameters() + self.encoder.reset_parameters() + self.decoder.reset_parameters() + self.latent_mean.reset_parameters() + self.latent_logvar.reset_parameters() + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out( + self.linear.weight + ) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, x: Tensor) -> Tensor: + # Draw a sample from the dist generated by the noise self.alpha + latent = F.relu(self.encoder(self.alpha)) + latent_mean, latent_logvar = ( + self.latent_mean(latent), + self.latent_logvar(latent), + ) + z_latent = _sample(latent_mean, latent_logvar, self.std_factor) + + # one sample per "model" with as many features as x + alpha_sample = self.decoder(z_latent) + + # Compute the latent loss + if self.training: + mse = F.mse_loss(alpha_sample, self.alpha) + kld = -0.5 * torch.sum( + 1 + latent_logvar - latent_mean**2 - torch.exp(latent_logvar) + ) + # For the KL Loss + self.lvposterior = mse + kld + + # Compute the output + num_examples_per_model = int(x.size(0) / self.num_estimators) + alpha = alpha_sample.repeat((num_examples_per_model, 1)) + gamma = self.gamma.repeat((num_examples_per_model, 1)) + out = self.linear(x * alpha) * gamma + + if self.bias is not None: + bias = self.bias.repeat((num_examples_per_model, 1)) + out += bias + return out + + def extra_repr(self) -> str: + return ( + f"in_features={self.in_features}, " + f"out_features={self.out_features}, " + f"num_estimators={self.num_estimators}, " + f"hidden_size={self.hidden_size}, bias={self.bias is not None}" + ) + + +class LPBNNConv2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_estimators: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: str | _size_2_t = 0, + groups: int = 1, + hidden_size: int = 32, + std_factor: float = 1e-2, + gamma: bool = True, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ): + """LPBNN-style 2D convolutional layer. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + num_estimators (int): Number of models to sample from. + kernel_size (int or tuple): Size of the convolving kernel. + stride (int or tuple, optional): Stride of the convolution. Default: 1. + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0. + groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1. + hidden_size (int): Size of the hidden layer. Defaults to 32. + std_factor (float): Factor to multiply the standard deviation of the + latent noise. Defaults to 1e-2. + gamma (bool): If ``True``, adds a learnable gamma to the output. + Defaults to ``True``. + bias (bool): If ``True``, adds a learnable bias to the output. + Defaults to ``True``. + padding_mode (str): 'zeros', 'reflect', 'replicate' or 'circular'. Default: 'zeros'. + device (torch.device): Device on which the layer is stored. + Defaults to ``None``. + dtype (torch.dtype): Data type of the layer. Defaults to ``None``. + + Reference: + `Encoding the latent posterior of Bayesian Neural Networks for + uncertainty quantification `_. + """ + check_lpbnn_parameters_consistency( + hidden_size, std_factor, num_estimators + ) + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_size = hidden_size + self.std_factor = std_factor + self.num_estimators = num_estimators + + # for the KL Loss + self.lprior = 0 + + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=False, + padding_mode=padding_mode, + ) + self.alpha = nn.Parameter( + torch.empty(num_estimators, in_channels, **factory_kwargs), + requires_grad=False, + ) + + self.encoder = nn.Linear( + in_channels, self.hidden_size, **factory_kwargs + ) + self.decoder = nn.Linear( + self.hidden_size, in_channels, **factory_kwargs + ) + self.latent_mean = nn.Linear( + self.hidden_size, self.hidden_size, **factory_kwargs + ) + self.latent_logvar = nn.Linear( + self.hidden_size, self.hidden_size, **factory_kwargs + ) + + self.latent_loss = torch.zeros(1, **factory_kwargs) + if gamma: + self.gamma = nn.Parameter( + torch.empty((num_estimators, out_channels), **factory_kwargs) + ) + else: + self.register_parameter("gamma", None) + + if bias: + self.bias = nn.Parameter( + torch.empty((num_estimators, out_channels), **factory_kwargs) + ) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + nn.init.normal_(self.alpha, mean=1.0, std=0.1) + if self.gamma is not None: + nn.init.normal_(self.gamma, mean=1.0, std=0.1) + self.conv.reset_parameters() + self.encoder.reset_parameters() + self.decoder.reset_parameters() + self.latent_mean.reset_parameters() + self.latent_logvar.reset_parameters() + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.conv.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, x: Tensor) -> Tensor: + # Draw a sample from the dist generated by the latent noise self.alpha + latent = F.relu(self.encoder(self.alpha)) + latent_mean, latent_logvar = ( + self.latent_mean(latent), + self.latent_logvar(latent), + ) + z_latent = _sample(latent_mean, latent_logvar, self.std_factor) + + # one sample per "model" with as many features as x + alpha_sample = self.decoder(z_latent) + + # Compute the latent loss + if self.training: + mse = F.mse_loss(alpha_sample, self.alpha) + kld = -0.5 * torch.sum( + 1 + latent_logvar - latent_mean.pow(2) - latent_logvar.exp() + ) + # for the KL Loss + self.lvposterior = mse + kld + + num_examples_per_model = int(x.size(0) / self.num_estimators) + + # Compute the output + alpha = ( + alpha_sample.repeat((num_examples_per_model, 1)) + .unsqueeze(-1) + .unsqueeze(-1) + ) + if self.gamma is not None: + gamma = ( + self.gamma.repeat((num_examples_per_model, 1)) + .unsqueeze(-1) + .unsqueeze(-1) + ) + out = self.conv(x * alpha) * gamma + else: + out = self.conv(x * alpha) + + if self.bias is not None: + bias = ( + self.bias.repeat((num_examples_per_model, 1)) + .unsqueeze(-1) + .unsqueeze(-1) + ) + out += bias + + return out + + def extra_repr(self) -> str: + return ( + f"in_channels={self.in_channels}, " + f"out_channels={self.out_channels}, " + f"num_estimators={self.num_estimators}, " + f"hidden_size={self.hidden_size}, " + f"gamma={self.gamma is not None}, " + f"bias={self.bias is not None}" + ) diff --git a/torch_uncertainty/layers/bayesian/sampler.py b/torch_uncertainty/layers/bayesian/sampler.py index 6c7b7977..a512fad7 100644 --- a/torch_uncertainty/layers/bayesian/sampler.py +++ b/torch_uncertainty/layers/bayesian/sampler.py @@ -28,7 +28,7 @@ def sample(self) -> Tensor: def log_posterior(self, weight: Tensor | None = None) -> Tensor: if self.weight is None or self.sigma is None: raise ValueError( - "Sample the weights before asking for the log posterior." + "Sample the weights before querying the log posterior." ) if weight is None: # coverage: ignore diff --git a/torch_uncertainty/layers/distributions.py b/torch_uncertainty/layers/distributions.py index 108cc8c4..341cf5b9 100644 --- a/torch_uncertainty/layers/distributions.py +++ b/torch_uncertainty/layers/distributions.py @@ -66,13 +66,13 @@ def forward(self, x: Tensor) -> Laplace: r"""Forward pass of the Laplace distribution layer. Args: - x (Tensor): A tensor of shape (:attr:`dim` :math:`\times`2). + x (Tensor): A tensor of shape (..., :attr:`dim` :math:`\times`2). Returns: Laplace: The output Laplace distribution. """ - loc = x[:, : self.dim] - scale = F.softplus(x[:, self.dim :]) + self.eps + loc = x[..., : self.dim] + scale = F.softplus(x[..., self.dim :]) + self.eps return Laplace(loc, scale) @@ -100,8 +100,8 @@ def forward(self, x: Tensor) -> NormalInverseGamma: Returns: NormalInverseGamma: The output NormalInverseGamma distribution. """ - loc = x[:, : self.dim] - lmbda = F.softplus(x[:, self.dim : 2 * self.dim]) + self.eps - alpha = 1 + F.softplus(x[:, 2 * self.dim : 3 * self.dim]) + self.eps - beta = F.softplus(x[:, 3 * self.dim :]) + self.eps + loc = x[..., : self.dim] + lmbda = F.softplus(x[..., self.dim : 2 * self.dim]) + self.eps + alpha = 1 + F.softplus(x[..., 2 * self.dim : 3 * self.dim]) + self.eps + beta = F.softplus(x[..., 3 * self.dim :]) + self.eps return NormalInverseGamma(loc, lmbda, alpha, beta) diff --git a/torch_uncertainty/layers/mc_batch_norm.py b/torch_uncertainty/layers/mc_batch_norm.py index 9a68e633..916dc6f8 100644 --- a/torch_uncertainty/layers/mc_batch_norm.py +++ b/torch_uncertainty/layers/mc_batch_norm.py @@ -83,7 +83,7 @@ def reset_mc_statistics(self) -> None: class MCBatchNorm1d(_MCBatchNorm): - """Applies Monte Carlo Batch Normalization over a 2D or 3D input. + """Monte Carlo Batch Normalization over a 2D or 3D (batched) input. Args: num_features (int): Number of features. @@ -96,7 +96,7 @@ class MCBatchNorm1d(_MCBatchNorm): device (optional): Device. Defaults to None. dtype (optional): Data type. Defaults to None. - Note: + Warning: This layer should not be used out of the corresponding wrapper. Check MCBatchNorm in torch_uncertainty/post_processing/. """ @@ -109,7 +109,7 @@ def _check_input_dim(self, inputs) -> None: class MCBatchNorm2d(_MCBatchNorm): - """Applies Monte Carlo Batch Normalization over a 4D input. + """Monte Carlo Batch Normalization over a 3D or 4D (batched) input. Args: num_features (int): Number of features. @@ -122,7 +122,7 @@ class MCBatchNorm2d(_MCBatchNorm): device (optional): Device. Defaults to None. dtype (optional): Data type. Defaults to None. - Note: + Warning: This layer should not be used out of the corresponding wrapper. Check MCBatchNorm in torch_uncertainty/post_processing/. """ @@ -135,7 +135,7 @@ def _check_input_dim(self, inputs) -> None: class MCBatchNorm3d(_MCBatchNorm): - """Applies Monte Carlo Batch Normalization over a 5D input. + """Monte Carlo Batch Normalization over a 4D or 5D (batched) input. Args: num_features (int): Number of features. @@ -148,7 +148,7 @@ class MCBatchNorm3d(_MCBatchNorm): device (optional): Device. Defaults to None. dtype (optional): Data type. Defaults to None. - Note: + Warning: This layer should not be used out of the corresponding wrapper. Check MCBatchNorm in torch_uncertainty/post_processing/. """ diff --git a/torch_uncertainty/layers/packed.py b/torch_uncertainty/layers/packed.py index 336c5576..6f742b17 100644 --- a/torch_uncertainty/layers/packed.py +++ b/torch_uncertainty/layers/packed.py @@ -6,14 +6,14 @@ def check_packed_parameters_consistency( - alpha: float, num_estimators: int, gamma: int + alpha: float, gamma: int, num_estimators: int ) -> None: """Check the consistency of the parameters of the Packed-Ensembles layers. Args: alpha (float): The width multiplier of the layer. - num_estimators (int): The number of estimators in the ensemble. gamma (int): The number of groups in the ensemble. + num_estimators (int): The number of estimators in the ensemble. """ if alpha is None: raise ValueError("You must specify the value of the arg. `alpha`") @@ -21,6 +21,13 @@ def check_packed_parameters_consistency( if alpha <= 0: raise ValueError(f"Attribute `alpha` should be > 0, not {alpha}") + if not isinstance(gamma, int): + raise TypeError( + f"Attribute `gamma` should be an int, not {type(gamma)}" + ) + if gamma <= 0: + raise ValueError(f"Attribute `gamma` should be >= 1, not {gamma}") + if num_estimators is None: raise ValueError( "You must specify the value of the arg. `num_estimators`" @@ -36,13 +43,6 @@ def check_packed_parameters_consistency( f"{num_estimators}" ) - if not isinstance(gamma, int): - raise TypeError( - f"Attribute `gamma` should be an int, not {type(gamma)}" - ) - if gamma <= 0: - raise ValueError(f"Attribute `gamma` should be >= 1, not {gamma}") - class PackedLinear(nn.Module): def __init__( @@ -103,11 +103,10 @@ def __init__( 1). The (often) necessary rearrange operation is executed by default. """ + check_packed_parameters_consistency(alpha, gamma, num_estimators) factory_kwargs = {"device": device, "dtype": dtype} super().__init__() - check_packed_parameters_consistency(alpha, num_estimators, gamma) - self.first = first self.num_estimators = num_estimators self.rearrange = rearrange @@ -237,11 +236,10 @@ def __init__( :attr:`groups`. However, the number of input and output channels will be changed to comply with this constraint. """ + check_packed_parameters_consistency(alpha, gamma, num_estimators) factory_kwargs = {"device": device, "dtype": dtype} super().__init__() - check_packed_parameters_consistency(alpha, num_estimators, gamma) - self.num_estimators = num_estimators # Define the number of channels of the underlying convolution @@ -366,11 +364,10 @@ def __init__( :attr:`groups`. However, the number of input and output channels will be changed to comply with this constraint. """ + check_packed_parameters_consistency(alpha, gamma, num_estimators) factory_kwargs = {"device": device, "dtype": dtype} super().__init__() - check_packed_parameters_consistency(alpha, num_estimators, gamma) - self.num_estimators = num_estimators # Define the number of channels of the underlying convolution @@ -497,8 +494,7 @@ def __init__( """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() - - check_packed_parameters_consistency(alpha, num_estimators, gamma) + check_packed_parameters_consistency(alpha, gamma, num_estimators) self.num_estimators = num_estimators diff --git a/torch_uncertainty/losses.py b/torch_uncertainty/losses.py index 55aeb91a..b0d6e1b8 100644 --- a/torch_uncertainty/losses.py +++ b/torch_uncertainty/losses.py @@ -54,13 +54,15 @@ def forward(self) -> Tensor: def _kl_div(self) -> Tensor: """Gathers pre-computed KL-Divergences from :attr:`model`.""" kl_divergence = torch.zeros(1) + count = 0 for module in self.model.modules(): if isinstance(module, bayesian_modules): kl_divergence = kl_divergence.to( device=module.lvposterior.device ) kl_divergence += module.lvposterior - module.lprior - return kl_divergence + count += 1 + return kl_divergence / count class ELBOLoss(nn.Module): @@ -112,7 +114,7 @@ def forward(self, inputs: Tensor, targets: Tensor) -> Tensor: aggregated_elbo += self.kl_weight * self._kl_div() return aggregated_elbo / self.num_samples - def set_model(self, model: nn.Module) -> None: + def set_model(self, model: nn.Module | None) -> None: self.model = model if model is not None: self._kl_div = KLDiv(model) diff --git a/torch_uncertainty/metrics/__init__.py b/torch_uncertainty/metrics/__init__.py index 207d0c9b..ee1a63b9 100644 --- a/torch_uncertainty/metrics/__init__.py +++ b/torch_uncertainty/metrics/__init__.py @@ -1,22 +1,28 @@ # ruff: noqa: F401 from .classification import ( + AURC, AUSE, - CE, FPR95, + AdaptiveCalibrationError, BrierScore, + CalibrationError, CategoricalNLL, + CovAt5Risk, Disagreement, Entropy, GroupingLoss, MeanIntersectionOverUnion, MutualInformation, + RiskAt80Cov, VariationRatio, ) from .regression import ( DistributionNLL, Log10, + MeanAbsoluteErrorInverse, MeanGTRelativeAbsoluteError, MeanGTRelativeSquaredError, + MeanSquaredErrorInverse, MeanSquaredLogError, SILog, ThresholdAccuracy, diff --git a/torch_uncertainty/metrics/classification/__init__.py b/torch_uncertainty/metrics/classification/__init__.py index df6078c9..de375588 100644 --- a/torch_uncertainty/metrics/classification/__init__.py +++ b/torch_uncertainty/metrics/classification/__init__.py @@ -1,12 +1,14 @@ # ruff: noqa: F401 +from .adaptive_calibration_error import AdaptiveCalibrationError from .brier_score import BrierScore -from .calibration import CE +from .calibration_error import CalibrationError +from .categorical_nll import CategoricalNLL from .disagreement import Disagreement from .entropy import Entropy -from .fpr95 import FPR95 +from .fpr95 import FPR95, FPRx from .grouping_loss import GroupingLoss from .mean_iou import MeanIntersectionOverUnion from .mutual_information import MutualInformation -from .nll import CategoricalNLL +from .risk_coverage import AURC, CovAt5Risk, CovAtxRisk, RiskAt80Cov, RiskAtxCov from .sparsification import AUSE from .variation_ratio import VariationRatio diff --git a/torch_uncertainty/metrics/classification/adaptive_calibration_error.py b/torch_uncertainty/metrics/classification/adaptive_calibration_error.py new file mode 100644 index 00000000..8c5de1b1 --- /dev/null +++ b/torch_uncertainty/metrics/classification/adaptive_calibration_error.py @@ -0,0 +1,226 @@ +from typing import Any, Literal + +import torch +from torch import Tensor +from torch.nn.utils.rnn import pad_sequence +from torchmetrics.classification.calibration_error import ( + _binary_calibration_error_arg_validation, + _multiclass_calibration_error_arg_validation, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel + + +def _equal_binning_bucketize( + confidences: Tensor, accuracies: Tensor, num_bins: int +) -> tuple[Tensor, Tensor, Tensor]: + """Compute bins for the adaptive calibration error. + + Args: + confidences: The confidence (i.e. predicted prob) of the top1 + prediction. + accuracies: 1.0 if the top-1 prediction was correct, 0.0 otherwise. + num_bins: Number of bins to use when computing adaptive calibration + error. + + Returns: + tuple with binned accuracy, binned confidence and binned probabilities + """ + confidences, indices = torch.sort(confidences) + accuracies = accuracies[indices] + acc_bin, conf_bin = ( + accuracies.tensor_split(num_bins), + confidences.tensor_split(num_bins), + ) + count_bin = torch.as_tensor( + [len(cb) for cb in conf_bin], + dtype=confidences.dtype, + device=confidences.device, + ) + return ( + pad_sequence(acc_bin, batch_first=True).sum(1) / count_bin, + pad_sequence(conf_bin, batch_first=True).sum(1) / count_bin, + torch.as_tensor(count_bin) / len(confidences), + ) + + +def _ace_compute( + confidences: Tensor, + accuracies: Tensor, + num_bins: int, + norm: Literal["l1", "l2", "max"] = "l1", + debias: bool = False, +) -> Tensor: + """Compute the adaptive calibration error given the provided number of bins + and norm. + + Args: + confidences: The confidence (i.e. predicted prob) of the top1 + prediction. + accuracies: 1.0 if the top-1 prediction was correct, 0.0 otherwise. + num_bins: Number of bins to use when computing adaptive calibration + error. + norm: Norm function to use when computing calibration error. Defaults + to "l1". + debias: Apply debiasing to L2 norm computation as in + `Verified Uncertainty Calibration`_. Defaults to False. + + Returns: + Tensor: Adaptive Calibration error scalar. + """ + with torch.no_grad(): + acc_bin, conf_bin, prop_bin = _equal_binning_bucketize( + confidences, accuracies, num_bins + ) + + if norm == "l1": + return torch.sum(torch.abs(acc_bin - conf_bin) * prop_bin) + if norm == "max": + ace = torch.max(torch.abs(acc_bin - conf_bin)) + if norm == "l2": + ace = torch.sum(torch.pow(acc_bin - conf_bin, 2) * prop_bin) + if debias: # coverage: ignore + debias_bins = (acc_bin * (acc_bin - 1) * prop_bin) / ( + prop_bin * accuracies.size()[0] - 1 + ) + ace += torch.sum( + torch.nan_to_num(debias_bins) + ) # replace nans with zeros if nothing appeared in a bin + return torch.sqrt(ace) if ace > 0 else torch.tensor(0) + return ace + + +class BinaryAdaptiveCalibrationError(Metric): + r"""`Adaptive Top-label Calibration Error` for binary tasks.""" + + is_differentiable: bool = False + higher_is_better: bool = False + full_state_update: bool = False + + confidences: list[Tensor] + accuracies: list[Tensor] + + def __init__( + self, + n_bins: int = 10, + norm: Literal["l1", "l2", "max"] = "l1", + ignore_index: int | None = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if ignore_index is not None: # coverage: ignore + raise ValueError( + "ignore_index is not supported for multiclass tasks." + ) + + if validate_args: + _binary_calibration_error_arg_validation(n_bins, norm, ignore_index) + self.n_bins = n_bins + self.norm = norm + + self.add_state("confidences", [], dist_reduce_fx="cat") + self.add_state("accuracies", [], dist_reduce_fx="cat") + + def update(self, probs: Tensor, targets: Tensor) -> None: + """Update metric states with predictions and targets.""" + confidences, preds = torch.max(probs, 1 - probs), torch.round(probs) + accuracies = preds == targets + self.confidences.append(confidences.float()) + self.accuracies.append(accuracies.float()) + + def compute(self) -> Tensor: + """Compute metric.""" + confidences = dim_zero_cat(self.confidences) + accuracies = dim_zero_cat(self.accuracies) + return _ace_compute( + confidences, accuracies, self.n_bins, norm=self.norm + ) + + +class MulticlassAdaptiveCalibrationError(Metric): + r"""`Adaptive Top-label Calibration Error` for multiclass tasks.""" + + is_differentiable: bool = False + higher_is_better: bool = False + full_state_update: bool = False + + confidences: list[Tensor] + accuracies: list[Tensor] + + def __init__( + self, + num_classes: int, + n_bins: int = 10, + norm: Literal["l1", "l2", "max"] = "l1", + ignore_index: int | None = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if ignore_index is not None: # coverage: ignore + raise ValueError( + "ignore_index is not supported for multiclass tasks." + ) + + if validate_args: + _multiclass_calibration_error_arg_validation( + num_classes, n_bins, norm, ignore_index + ) + self.n_bins = n_bins + self.norm = norm + + self.add_state("confidences", [], dist_reduce_fx="cat") + self.add_state("accuracies", [], dist_reduce_fx="cat") + + def update(self, probs: Tensor, targets: Tensor) -> None: + """Update metric states with predictions and targets.""" + confidences, preds = torch.max(probs, 1) + accuracies = preds == targets + self.confidences.append(confidences.float()) + self.accuracies.append(accuracies.float()) + + def compute(self) -> Tensor: + """Compute metric.""" + confidences = dim_zero_cat(self.confidences) + accuracies = dim_zero_cat(self.accuracies) + return _ace_compute( + confidences, accuracies, self.n_bins, norm=self.norm + ) + + +class AdaptiveCalibrationError: + """`Adaptive Top-label Calibration Error`. + + Reference: + Nixon et al. Measuring calibration in deep learning. In CVPRW, 2019. + """ + + def __new__( + cls, + task: Literal["binary", "multiclass"], + num_bins: int = 10, + norm: Literal["l1", "l2", "max"] = "l1", + num_classes: int | None = None, + ignore_index: int | None = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + task = ClassificationTaskNoMultilabel.from_str(task) + kwargs.update( + { + "n_bins": num_bins, + "norm": norm, + "ignore_index": ignore_index, + "validate_args": validate_args, + } + ) + if task == ClassificationTaskNoMultilabel.BINARY: + return BinaryAdaptiveCalibrationError(**kwargs) + # task is ClassificationTaskNoMultilabel.MULTICLASS + if not isinstance(num_classes, int): + raise TypeError( + f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`" + ) + return MulticlassAdaptiveCalibrationError(num_classes, **kwargs) diff --git a/torch_uncertainty/metrics/classification/calibration.py b/torch_uncertainty/metrics/classification/calibration.py deleted file mode 100644 index c32787f4..00000000 --- a/torch_uncertainty/metrics/classification/calibration.py +++ /dev/null @@ -1,172 +0,0 @@ -from typing import Any, Literal - -import matplotlib.pyplot as plt -import torch -from torchmetrics.classification.calibration_error import ( - BinaryCalibrationError, - MulticlassCalibrationError, -) -from torchmetrics.metric import Metric -from torchmetrics.utilities.data import dim_zero_cat -from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel -from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE - - -class BinaryCE(BinaryCalibrationError): # noqa: N818 - def plot(self, ax: _AX_TYPE | None = None) -> _PLOT_OUT_TYPE: - fig, ax = plt.subplots() if ax is None else (None, ax) - - conf = dim_zero_cat(self.confidences) - acc = dim_zero_cat(self.accuracies) - - bin_width = 1 / self.n_bins - - bin_ids = torch.round( - torch.clamp(conf * self.n_bins, 1e-5, self.n_bins - 1 - 1e-5) - ) - val, inverse, counts = bin_ids.unique( - return_inverse=True, return_counts=True - ) - val_oh = torch.nn.functional.one_hot( - val.long(), num_classes=self.n_bins - ) - - # add 1e-6 to avoid division NaNs - values = ( - val_oh.T.float() - @ torch.sum( - acc.unsqueeze(1) * torch.nn.functional.one_hot(inverse).float(), - 0, - ) - / (val_oh.T @ counts + 1e-6).float() - ) - counts_all = (val_oh.T @ counts).float() - total = torch.sum(counts) - - plt.rc("axes", axisbelow=True) - ax.hist( - x=[bin_width * i * 100 for i in range(self.n_bins)], - weights=values * 100, - bins=[bin_width * i * 100 for i in range(self.n_bins + 1)], - alpha=0.7, - linewidth=1, - edgecolor="#0d559f", - color="#1f77b4", - ) - for i, count in enumerate(counts_all): - ax.text( - 3.0 + 9.9 * i, - 1, - f"{int(count/total*100)}%", - fontsize=8, - ) - - ax.plot([0, 100], [0, 100], "--", color="#0d559f") - plt.grid(True, linestyle="--", alpha=0.7, zorder=0) - ax.set_xlabel("Top-class Confidence (%)", fontsize=16) - ax.set_ylabel("Success Rate (%)", fontsize=16) - ax.set_xlim(0, 100) - ax.set_ylim(0, 100) - ax.set_aspect("equal", "box") - fig.tight_layout() - return fig, ax - - -class MulticlassCE(MulticlassCalibrationError): # noqa: N818 - def plot(self, ax: _AX_TYPE | None = None) -> _PLOT_OUT_TYPE: - fig, ax = plt.subplots() if ax is None else (None, ax) - - conf = dim_zero_cat(self.confidences).cpu() - acc = dim_zero_cat(self.accuracies).cpu() - - bin_width = 1 / self.n_bins - - bin_ids = torch.round( - torch.clamp(conf * self.n_bins, 1e-5, self.n_bins - 1 - 1e-5) - ) - val, inverse, counts = bin_ids.unique( - return_inverse=True, return_counts=True - ) - val_oh = torch.nn.functional.one_hot( - val.long(), num_classes=self.n_bins - ) - - # add 1e-6 to avoid division NaNs - values = ( - val_oh.T.float() - @ torch.sum( - acc.unsqueeze(1) * torch.nn.functional.one_hot(inverse).float(), - 0, - ) - / (val_oh.T.float() @ counts.float() + 1e-6) - ) - counts_all = val_oh.T.float() @ counts.float() - total = torch.sum(counts) - - plt.rc("axes", axisbelow=True) - ax.hist( - x=[bin_width * i * 100 for i in range(self.n_bins)], - weights=values * 100, - bins=[bin_width * i * 100 for i in range(self.n_bins + 1)], - alpha=0.7, - linewidth=1, - edgecolor="#0d559f", - color="#1f77b4", - ) - for i, count in enumerate(counts_all): - ax.text( - 3.0 + 9.9 * i, - 1, - f"{int(count/total*100)}%", - fontsize=8, - ) - - ax.plot([0, 100], [0, 100], "--", color="#0d559f") - plt.grid(True, linestyle="--", alpha=0.7, zorder=0) - ax.set_xlabel("Top-class Confidence (%)", fontsize=16) - ax.set_ylabel("Success Rate (%)", fontsize=16) - ax.set_xlim(0, 100) - ax.set_ylim(0, 100) - ax.set_aspect("equal", "box") - fig.tight_layout() - return fig, ax - - -class CE: - r"""`Top-label Calibration Error `_. - - See - `CalibrationError `_ - for details. Our version of the metric is a wrapper around the original - metric providing a plotting functionality. - """ - - def __new__( # type: ignore[misc] - cls, - task: Literal["binary", "multiclass"], - n_bins: int = 10, - norm: Literal["l1", "l2", "max"] = "l1", - num_classes: int | None = None, - ignore_index: int | None = None, - validate_args: bool = True, - **kwargs: Any, - ) -> Metric: - """Initialize task metric.""" - task = ClassificationTaskNoMultilabel.from_str(task) - kwargs.update( - { - "n_bins": n_bins, - "norm": norm, - "ignore_index": ignore_index, - "validate_args": validate_args, - } - ) - if task == ClassificationTaskNoMultilabel.BINARY: - return BinaryCE(**kwargs) - if task == ClassificationTaskNoMultilabel.MULTICLASS: - if not isinstance(num_classes, int): - raise ValueError( - f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`" - ) - return MulticlassCE(num_classes, **kwargs) - raise ValueError(f"Not handled value: {task}") # coverage: ignore diff --git a/torch_uncertainty/metrics/classification/calibration_error.py b/torch_uncertainty/metrics/classification/calibration_error.py new file mode 100644 index 00000000..7577b740 --- /dev/null +++ b/torch_uncertainty/metrics/classification/calibration_error.py @@ -0,0 +1,123 @@ +from typing import Any, Literal + +import matplotlib.pyplot as plt +import torch +from torchmetrics.classification.calibration_error import ( + BinaryCalibrationError, + MulticlassCalibrationError, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +from .adaptive_calibration_error import AdaptiveCalibrationError + + +def _ce_plot(self, ax: _AX_TYPE | None = None) -> _PLOT_OUT_TYPE: + fig, ax = plt.subplots(figsize=(6, 6)) if ax is None else (None, ax) + + conf = dim_zero_cat(self.confidences) + acc = dim_zero_cat(self.accuracies) + bin_width = 1 / self.n_bins + + bin_ids = torch.round( + torch.clamp(conf * self.n_bins, 1e-5, self.n_bins - 1 - 1e-5) + ) + val, inverse, counts = bin_ids.unique( + return_inverse=True, return_counts=True + ) + counts = counts.float() + val_oh = torch.nn.functional.one_hot( + val.long(), num_classes=self.n_bins + ).float() + + # add 1e-6 to avoid division NaNs + values = ( + val_oh.T + @ torch.sum( + acc.unsqueeze(1) * torch.nn.functional.one_hot(inverse).float(), + 0, + ) + / (val_oh.T @ counts + 1e-6) + ) + + plt.rc("axes", axisbelow=True) + ax.hist( + x=[bin_width * i * 100 for i in range(self.n_bins)], + weights=values.cpu() * 100, + bins=[bin_width * i * 100 for i in range(self.n_bins + 1)], + alpha=0.7, + linewidth=1, + edgecolor="#0d559f", + color="#1f77b4", + ) + + ax.plot([0, 100], [0, 100], "--", color="#0d559f") + plt.grid(True, linestyle="--", alpha=0.7, zorder=0) + ax.set_xlabel("Top-class Confidence (%)", fontsize=16) + ax.set_ylabel("Success Rate (%)", fontsize=16) + ax.set_xlim(0, 100) + ax.set_ylim(0, 100) + ax.set_aspect("equal", "box") + fig.tight_layout() + return fig, ax + + +# overwrite the plot method of the original metrics +BinaryCalibrationError.plot = _ce_plot +MulticlassCalibrationError.plot = _ce_plot + + +class CalibrationError: + r"""`Top-label Calibration Error`_. + + See + `CalibrationError `_ + for details. Our version of the metric is a wrapper around the original + metric providing a plotting functionality. + + Reference: + Naeini et al. "Obtaining well calibrated probabilities using Bayesian + binning." In AAAI, 2015. + """ + + def __new__( # type: ignore[misc] + cls, + task: Literal["binary", "multiclass"], + adaptive: bool = False, + num_bins: int = 10, + norm: Literal["l1", "l2", "max"] = "l1", + num_classes: int | None = None, + ignore_index: int | None = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + """Initialize task metric.""" + if adaptive: + return AdaptiveCalibrationError( + task=task, + num_bins=num_bins, + norm=norm, + num_classes=num_classes, + ignore_index=ignore_index, + validate_args=validate_args, + **kwargs, + ) + task = ClassificationTaskNoMultilabel.from_str(task) + kwargs.update( + { + "n_bins": num_bins, + "norm": norm, + "ignore_index": ignore_index, + "validate_args": validate_args, + } + ) + if task == ClassificationTaskNoMultilabel.BINARY: + return BinaryCalibrationError(**kwargs) + # task is ClassificationTaskNoMultilabel.MULTICLASS + if not isinstance(num_classes, int): + raise TypeError( + f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`" + ) + return MulticlassCalibrationError(num_classes, **kwargs) diff --git a/torch_uncertainty/metrics/classification/nll.py b/torch_uncertainty/metrics/classification/categorical_nll.py similarity index 100% rename from torch_uncertainty/metrics/classification/nll.py rename to torch_uncertainty/metrics/classification/categorical_nll.py diff --git a/torch_uncertainty/metrics/classification/mean_iou.py b/torch_uncertainty/metrics/classification/mean_iou.py index 95c5b8a0..54dd5a0b 100644 --- a/torch_uncertainty/metrics/classification/mean_iou.py +++ b/torch_uncertainty/metrics/classification/mean_iou.py @@ -1,3 +1,5 @@ +from typing import Literal + from torch import Tensor from torchmetrics.classification.stat_scores import MulticlassStatScores from torchmetrics.utilities.compute import _safe_divide @@ -10,6 +12,25 @@ class MeanIntersectionOverUnion(MulticlassStatScores): higher_is_better: bool = True full_state_update: bool = False + def __init__( + self, + num_classes: int, + top_k: int = 1, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: int | None = None, + validate_args: bool = True, + **kwargs, + ) -> None: + super().__init__( + num_classes, + top_k, + "macro", + multidim_average, + ignore_index, + validate_args, + **kwargs, + ) + def compute(self) -> Tensor: """Compute the Means Intersection over Union (MIoU) based on saved inputs.""" tp, fp, _, fn = self._final_state() diff --git a/torch_uncertainty/metrics/classification/risk_coverage.py b/torch_uncertainty/metrics/classification/risk_coverage.py new file mode 100644 index 00000000..f10fcdc5 --- /dev/null +++ b/torch_uncertainty/metrics/classification/risk_coverage.py @@ -0,0 +1,304 @@ +import math + +import matplotlib.pyplot as plt +import numpy as np +import torch +from sklearn.metrics import auc +from torch import Tensor +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.plot import _AX_TYPE + + +class AURC(Metric): + is_differentiable: bool = False + higher_is_better: bool = False + full_state_update: bool = False + + scores: list[Tensor] + errors: list[Tensor] + + def __init__(self, **kwargs) -> None: + r"""`Area Under the Risk-Coverage curve`_. + + The Area Under the Risk-Coverage curve (AURC) is the main metric for + Selective Classification (SC) performance assessment. It evaluates the + quality of uncertainty estimates by measuring the ability to + discriminate between correct and incorrect predictions based on their + rank (and not their values in contrast with calibration). + + As input to ``forward`` and ``update`` the metric accepts the following + input: + + - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape + ``(N, ...)`` containing probabilities for each observation. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape + ``(N, ...)`` containing ground-truth labels. + + As output to ``forward`` and ``compute`` the metric returns the + following output: + - ``aurc`` (:class:`~torch.Tensor`): A scalar tensor containing the + area under the risk-coverage curve + + Args: + kwargs: Additional keyword arguments. + + Reference: + Geifman & El-Yaniv. "Selective classification for deep neural + networks." In NeurIPS, 2017. + """ + super().__init__(**kwargs) + self.add_state("scores", default=[], dist_reduce_fx="cat") + self.add_state("errors", default=[], dist_reduce_fx="cat") + + def update(self, probs: Tensor, targets: Tensor) -> None: + """Store the scores and their associated errors for later computation. + + Args: + probs (Tensor): The predicted probabilities of shape :math:`(N, C)`. + targets (Tensor): The ground truth labels of shape :math:`(N,)`. + """ + if probs.ndim == 1: + probs = torch.stack([1 - probs, probs], dim=-1) + self.scores.append(probs.max(-1).values) + self.errors.append((probs.argmax(-1) != targets) * 1.0) + + def partial_compute(self) -> Tensor: + """Compute the error and optimal error rates for the RC curve. + + Returns: + Tensor: The error rates and the optimal/oracle error + rates. + """ + scores = dim_zero_cat(self.scores) + errors = dim_zero_cat(self.errors) + return _aurc_rejection_rate_compute(scores, errors) + + def compute(self) -> Tensor: + """Compute the Area Under the Risk-Coverage curve (AURC). + + Normalize the AURC as if its support was between 0 and 1. This has an + impact on the AURC when the number of samples is small. + + Returns: + Tensor: The AURC. + """ + error_rates = self.partial_compute().cpu() + num_samples = error_rates.size(0) + x = torch.arange(1, num_samples + 1, device="cpu") / num_samples + return torch.tensor([auc(x, error_rates)], device=self.device) / ( + 1 - 1 / num_samples + ) + + def plot( + self, + ax: _AX_TYPE | None = None, + plot_value: bool = True, + name: str | None = None, + ) -> tuple[plt.Figure | None, plt.Axes]: + """Plot the risk-cov. curve corresponding to the inputs passed to + ``update``, and the oracle risk-cov. curve. + + Args: + ax (Axes | None, optional): An matplotlib axis object. If provided + will add plot to this axis. Defaults to None. + plot_value (bool, optional): Whether to print the AURC value on the + plot. Defaults to True. + name (str | None, optional): Name of the model. Defaults to None. + + Returns: + tuple[[Figure | None], Axes]: Figure object and Axes object + """ + fig, ax = plt.subplots(figsize=(6, 6)) if ax is None else (None, ax) + + # Computation of AUSEC + error_rates = self.partial_compute().cpu().flip(0) + num_samples = error_rates.size(0) + rejection_rates = (np.arange(num_samples) / num_samples) * 100 + + x = np.arange(num_samples) / num_samples + aurc = auc(x, error_rates) + + # reduce plot size + plot_xs = np.arange(0.01, 100 + 0.01, 0.01) + xs = np.arange(start=1, stop=num_samples + 1, step=1) / num_samples + rejection_rates = np.interp(plot_xs, xs, rejection_rates) + error_rates = np.interp(plot_xs, xs, error_rates) + + # plot + ax.plot( + 100 - rejection_rates, + error_rates * 100, + label="Model" if name is None else name, + ) + + if plot_value: + ax.text( + 0.02, + 0.95, + f"AUSEC={aurc:.3%}", + color="black", + ha="left", + va="bottom", + transform=ax.transAxes, + ) + plt.grid(True, linestyle="--", alpha=0.7, zorder=0) + ax.set_xlabel("Coverage (%)", fontsize=16) + ax.set_ylabel("Risk - Error Rate (%)", fontsize=16) + ax.set_xlim(0, 100) + ax.set_ylim(0, 100) + ax.set_aspect("equal", "box") + ax.legend(loc="upper right") + fig.tight_layout() + return fig, ax + + +def _aurc_rejection_rate_compute( + scores: Tensor, + errors: Tensor, +) -> Tensor: + """Compute the cumulative error rates for a given set of scores and errors. + + Args: + scores (Tensor): uncertainty scores of shape :math:`(B,)` + errors (Tensor): binary errors of shape :math:`(B,)` + """ + num_samples = scores.size(0) + errors = errors[scores.argsort(descending=True)] + return errors.cumsum(dim=-1) / torch.arange( + 1, num_samples + 1, dtype=scores.dtype, device=scores.device + ) + + +class CovAtxRisk(Metric): + is_differentiable: bool = False + higher_is_better: bool = False + full_state_update: bool = False + + scores: list[Tensor] + errors: list[Tensor] + + def __init__(self, risk_threshold: float, **kwargs) -> None: + r"""`Coverage at x Risk`_. + + If there are multiple coverage values corresponding to the given risk, + i.e., the risk(coverage) is not monotonic, the coverage at x risk is + the maximum coverage value corresponding to the given risk. If no + there is no coverage value corresponding to the given risk, return + float("nan"). + + Args: + risk_threshold (float): The risk threshold at which to compute the + coverage. + kwargs: Additional arguments to pass to the metric class. + """ + super().__init__(**kwargs) + self.add_state("scores", default=[], dist_reduce_fx="cat") + self.add_state("errors", default=[], dist_reduce_fx="cat") + _risk_coverage_checks(risk_threshold) + self.risk_threshold = risk_threshold + + def update(self, probs: Tensor, targets: Tensor) -> None: + """Store the scores and their associated errors for later computation. + + Args: + probs (Tensor): The predicted probabilities of shape :math:`(N, C)`. + targets (Tensor): The ground truth labels of shape :math:`(N,)`. + """ + if probs.ndim == 1: + probs = torch.stack([1 - probs, probs], dim=-1) + self.scores.append(probs.max(-1).values) + self.errors.append((probs.argmax(-1) != targets) * 1.0) + + def compute(self) -> Tensor: + """Compute the coverage at x Risk. + + Returns: + Tensor: The coverage at x risk. + """ + scores = dim_zero_cat(self.scores) + errors = dim_zero_cat(self.errors) + num_samples = scores.size(0) + error_rates = _aurc_rejection_rate_compute(scores, errors) + admissible_risks = (error_rates > self.risk_threshold) * 1 + max_cov_at_risk = admissible_risks.flip(0).argmin() + # check if max_cov_at_risk is really admissible, if not return nan + risk = admissible_risks[max_cov_at_risk] + if risk > self.risk_threshold: + return torch.tensor([float("nan")]) + return 1 - max_cov_at_risk / num_samples + + +class CovAt5Risk(CovAtxRisk): + def __init__(self, **kwargs) -> None: + r"""`Coverage at 5% Risk`_. + + If there are multiple coverage values corresponding to 5% risk, the + coverage at 5% risk is the maximum coverage value corresponding to 5% + risk. If no there is no coverage value corresponding to the given risk, + this metric returns float("nan"). + """ + super().__init__(risk_threshold=0.05, **kwargs) + + +class RiskAtxCov(Metric): + is_differentiable: bool = False + higher_is_better: bool = False + full_state_update: bool = False + + scores: list[Tensor] + errors: list[Tensor] + + def __init__(self, cov_threshold: float, **kwargs) -> None: + r"""`Risk at x Coverage`_. + + Args: + cov_threshold (float): The coverage threshold at which to compute + the risk. + kwargs: Additional arguments to pass to the metric class. + """ + super().__init__(**kwargs) + self.add_state("scores", default=[], dist_reduce_fx="cat") + self.add_state("errors", default=[], dist_reduce_fx="cat") + _risk_coverage_checks(cov_threshold) + self.cov_threshold = cov_threshold + + def update(self, probs: Tensor, targets: Tensor) -> None: + """Store the scores and their associated errors for later computation. + + Args: + probs (Tensor): The predicted probabilities of shape :math:`(N, C)`. + targets (Tensor): The ground truth labels of shape :math:`(N,)`. + """ + if probs.ndim == 1: + probs = torch.stack([1 - probs, probs], dim=-1) + self.scores.append(probs.max(-1).values) + self.errors.append((probs.argmax(-1) != targets) * 1.0) + + def compute(self) -> Tensor: + """Compute the risk at x coverage. + + Returns: + Tensor: The risk at x coverage. + """ + scores = dim_zero_cat(self.scores) + errors = dim_zero_cat(self.errors) + error_rates = _aurc_rejection_rate_compute(scores, errors) + return error_rates[math.ceil(scores.size(0) * self.cov_threshold) - 1] + + +class RiskAt80Cov(RiskAtxCov): + def __init__(self, **kwargs) -> None: + r"""`Risk at 80% Coverage`_.""" + super().__init__(cov_threshold=0.8, **kwargs) + + +def _risk_coverage_checks(threshold: float) -> None: + if not isinstance(threshold, float): + raise TypeError( + f"Expected threshold to be of type float, but got {type(threshold)}" + ) + if threshold < 0 or threshold > 1: + raise ValueError( + f"Threshold should be in the range [0, 1], but got {threshold}." + ) diff --git a/torch_uncertainty/metrics/classification/sparsification.py b/torch_uncertainty/metrics/classification/sparsification.py index c843f442..82fe41f8 100644 --- a/torch_uncertainty/metrics/classification/sparsification.py +++ b/torch_uncertainty/metrics/classification/sparsification.py @@ -20,7 +20,7 @@ class AUSE(Metric): errors: list[Tensor] def __init__(self, **kwargs) -> None: - """The Area Under the Sparsification Error curve (AUSE) metric to estimate + r"""The Area Under the Sparsification Error curve (AUSE) metric to estimate the quality of the uncertainty estimates, i.e., how much they coincide with the true errors. @@ -57,6 +57,13 @@ def update(self, scores: Tensor, errors: Tensor) -> None: self.scores.append(scores) self.errors.append(errors) + def partial_compute(self) -> tuple[Tensor, Tensor]: + scores = dim_zero_cat(self.scores) + errors = dim_zero_cat(self.errors) + error_rates = _ause_rejection_rate_compute(scores, errors) + optimal_error_rates = _ause_rejection_rate_compute(errors, errors) + return error_rates.cpu(), optimal_error_rates.cpu() + def compute(self) -> Tensor: """Compute the Area Under the Sparsification Error curve (AUSE) based on inputs passed to ``update``. @@ -64,16 +71,10 @@ def compute(self) -> Tensor: Returns: Tensor: The AUSE. """ - scores = dim_zero_cat(self.scores) - errors = dim_zero_cat(self.errors) - computed_error_rates = _rejection_rate_compute(scores, errors) - computed_optimal_error_rates = _rejection_rate_compute(errors, errors) - - x = np.arange(computed_error_rates.size(0)) / computed_error_rates.size( - 0 - ) - y = (computed_error_rates - computed_optimal_error_rates).numpy() - + error_rates, optimal_error_rates = self.partial_compute() + num_samples = error_rates.size(0) + x = np.arange(1, num_samples + 1) / num_samples + y = (error_rates - optimal_error_rates).numpy() return torch.tensor([auc(x, y)]) def plot( @@ -99,32 +100,24 @@ def plot( fig, ax = plt.subplots() if ax is None else (None, ax) # Computation of AUSEC - scores = dim_zero_cat(self.scores) - errors = dim_zero_cat(self.errors) - computed_error_rates = _rejection_rate_compute(scores, errors) - computed_optimal_error_rates = _rejection_rate_compute(errors, errors) - - x = np.arange(computed_error_rates.size(0)) / computed_error_rates.size( - 0 - ) - y = (computed_error_rates - computed_optimal_error_rates).numpy() + error_rates, optimal_error_rates = self.partial_compute() + num_samples = error_rates.size(0) + x = np.arange(num_samples) / num_samples + y = (error_rates - optimal_error_rates).numpy() ausec = auc(x, y) - rejection_rates = ( - np.arange(computed_error_rates.size(0)) - / computed_error_rates.size(0) - ) * 100 + rejection_rates = (np.arange(num_samples) / num_samples) * 100 ax.plot( rejection_rates, - computed_error_rates * 100, + error_rates * 100, label="Model", ) if plot_oracle: ax.plot( rejection_rates, - computed_optimal_error_rates * 100, + optimal_error_rates * 100, label="Oracle", ) @@ -148,7 +141,7 @@ def plot( return fig, ax -def _rejection_rate_compute( +def _ause_rejection_rate_compute( scores: Tensor, errors: Tensor, ) -> Tensor: @@ -166,5 +159,4 @@ def _rejection_rate_compute( error_rates = torch.zeros(num_samples + 1) error_rates[0] = errors.sum() error_rates[1:] = errors.cumsum(dim=-1).flip(0) - return error_rates / error_rates[0] diff --git a/torch_uncertainty/metrics/classification/variation_ratio.py b/torch_uncertainty/metrics/classification/variation_ratio.py index cdf05c89..a4e7609e 100644 --- a/torch_uncertainty/metrics/classification/variation_ratio.py +++ b/torch_uncertainty/metrics/classification/variation_ratio.py @@ -19,7 +19,6 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) - allowed_reduction = ("sum", "mean", "none", None) if reduction not in allowed_reduction: raise ValueError( diff --git a/torch_uncertainty/metrics/regression/__init__.py b/torch_uncertainty/metrics/regression/__init__.py index 50f26c74..262641f9 100644 --- a/torch_uncertainty/metrics/regression/__init__.py +++ b/torch_uncertainty/metrics/regression/__init__.py @@ -1,4 +1,5 @@ # ruff: noqa: F401 +from .inverse import MeanAbsoluteErrorInverse, MeanSquaredErrorInverse from .log10 import Log10 from .mse_log import MeanSquaredLogError from .nll import DistributionNLL diff --git a/torch_uncertainty/metrics/regression/inverse.py b/torch_uncertainty/metrics/regression/inverse.py new file mode 100644 index 00000000..d80a730c --- /dev/null +++ b/torch_uncertainty/metrics/regression/inverse.py @@ -0,0 +1,108 @@ +from typing import Literal + +from torch import Tensor +from torchmetrics import MeanAbsoluteError, MeanSquaredError + + +def _unit_to_factor(unit: Literal["mm", "m", "km"]) -> float: + """Convert a unit to a factor for scaling. + + Args: + unit: Unit for the computation of the metric. Must be one of 'mm', 'm', + 'km'. + """ + if unit == "km": + return 1e-3 + if unit == "m": + return 1.0 + if unit == "mm": + return 1e3 + raise ValueError(f"unit must be one of 'mm', 'm', 'km'. Got {unit}.") + + +class MeanSquaredErrorInverse(MeanSquaredError): + r"""Compute the `Mean Squared Error of the inverse predictions`_ (iMSE). + + .. math:: \text{iMSE} = \frac{1}{N}\sum_i^N(\frac{1}{y_i} - \frac{1}{\hat{y_i}})^2 + + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a + tensor of predictions. + Both are scaled by a factor of :attr:`unit_factor` depending on the + :attr:`unit` given. + + As input to ``forward`` and ``update`` the metric accepts the following + input: + + - ``preds`` (:class:`~Tensor`): Predictions from model + - ``target`` (:class:`~Tensor`): Ground truth values + + As output of ``forward`` and ``compute`` the metric returns the following + output: + + - ``mean_squared_error`` (:class:`~Tensor`): A tensor with the mean + squared error + + Args: + squared: If True returns MSE value, if False returns RMSE value. + num_outputs: Number of outputs in multioutput setting. + unit: Unit for the computation of the metric. Must be one of 'mm', 'm', + 'km'. Defauts to 'km'. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more + info. + """ + + def __init__( + self, + squared: bool = True, + num_outputs: int = 1, + unit: str = "km", + **kwargs, + ) -> None: + super().__init__(squared, num_outputs, **kwargs) + self.unit_factor = _unit_to_factor(unit) + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + super().update( + 1 / (preds * self.unit_factor), 1 / (target * self.unit_factor) + ) + + +class MeanAbsoluteErrorInverse(MeanAbsoluteError): + r"""`Compute the Mean Absolute Error of the inverse predictions`_ (iMAE). + + .. math:: \text{iMAE} = \frac{1}{N}\sum_i^N | \frac{1}{y_i} - \frac{1}{\hat{y_i}} | + + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a + tensor of predictions. + Both are scaled by a factor of :attr:`unit_factor` depending on the + :attr:`unit` given. + + As input to ``forward`` and ``update`` the metric accepts the following + input: + + - ``preds`` (:class:`~Tensor`): Predictions from model + - ``target`` (:class:`~Tensor`): Ground truth values + + As output of ``forward`` and ``compute`` the metric returns the following + output: + + - ``mean_absolute_inverse_error`` (:class:`~Tensor`): A tensor with the + mean absolute error over the state + + Args: + unit: Unit for the computation of the metric. Must be one of 'mm', 'm', + 'km'. Defauts to 'km'. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for + more info. + """ + + def __init__(self, unit: str = "km", **kwargs) -> None: + super().__init__(**kwargs) + self.unit_factor = _unit_to_factor(unit) + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + super().update( + 1 / (preds * self.unit_factor), 1 / (target * self.unit_factor) + ) diff --git a/torch_uncertainty/metrics/regression/log10.py b/torch_uncertainty/metrics/regression/log10.py index e93ce571..2885da79 100644 --- a/torch_uncertainty/metrics/regression/log10.py +++ b/torch_uncertainty/metrics/regression/log10.py @@ -1,16 +1,16 @@ import torch from torch import Tensor -from torchmetrics import Metric -from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics import MeanAbsoluteError -class Log10(Metric): +class Log10(MeanAbsoluteError): def __init__(self, **kwargs) -> None: r"""The Log10 metric. - .. math:: \text{Log10} = \frac{1}{N} \sum_{i=1}^{N} \log_{10}(y_i) - \log_{10}(\hat{y_i}) + .. math:: \text{Log10} = \frac{1}{N} \sum_{i=1}^{N} |\log_{10}(y_i) - \log_{10}(\hat{y_i})| - where :math:`N` is the batch size, :math:`y_i` is a tensor of target values and :math:`\hat{y_i}` is a tensor of prediction. + where :math:`N` is the batch size, :math:`y_i` is a tensor of target + values and :math:`\hat{y_i}` is a tensor of prediction. Inputs: - :attr:`preds`: :math:`(N)` @@ -28,10 +28,4 @@ def __init__(self, **kwargs) -> None: def update(self, pred: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" - self.values += torch.sum(pred.log10() - target.log10()) - self.total += target.size(0) - - def compute(self) -> Tensor: - """Compute the Log10 metric.""" - values = dim_zero_cat(self.values) - return values / self.total + return super().update(pred.log10(), target.log10()) diff --git a/torch_uncertainty/metrics/regression/mse_log.py b/torch_uncertainty/metrics/regression/mse_log.py index 6aca1208..c182c5cf 100644 --- a/torch_uncertainty/metrics/regression/mse_log.py +++ b/torch_uncertainty/metrics/regression/mse_log.py @@ -8,25 +8,30 @@ def __init__(self, squared: bool = True, **kwargs) -> None: .. math:: \text{MSELog} = \frac{1}{N}\sum_i^N (\log \hat{y_i} - \log y_i)^2 - where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. + where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a + tensor of predictions. - As input to ``forward`` and ``update`` the metric accepts the following input: + As input to ``forward`` and ``update`` the metric accepts the following + input: - ``preds`` (:class:`~torch.Tensor`): Predictions from model - ``target`` (:class:`~torch.Tensor`): Ground truth values - As output of ``forward`` and ``compute`` the metric returns the following output: + As output of ``forward`` and ``compute`` the metric returns the + following output: - ``mse_log`` (:class:`~torch.Tensor`): A tensor with the relative mean absolute error over the state Args: - squared: If True returns MSELog value, if False returns EMSELog value. + squared: If True returns MSELog value, if False returns EMSELog + value. kwargs: Additional keyword arguments, see `Advanced metric settings `_. Reference: - As in e.g. From big to small: Multi-scale local planar guidance for monocular depth estimation + As in e.g. From big to small: Multi-scale local planar guidance for + monocular depth estimation """ super().__init__(squared, **kwargs) diff --git a/torch_uncertainty/metrics/regression/relative_error.py b/torch_uncertainty/metrics/regression/relative_error.py index af428450..9362013a 100644 --- a/torch_uncertainty/metrics/regression/relative_error.py +++ b/torch_uncertainty/metrics/regression/relative_error.py @@ -5,28 +5,33 @@ class MeanGTRelativeAbsoluteError(MeanAbsoluteError): def __init__(self, **kwargs) -> None: - r"""Compute Mean Absolute Error relative to the Ground Truth (MAErel or ARE). + r"""Compute Mean Absolute Error relative to the Ground Truth (MAErel + or ARErel). .. math:: \text{MAErel} = \frac{1}{N}\sum_i^N \frac{| y_i - \hat{y_i} |}{y_i} - where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. + where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a + tensor of predictions. - As input to ``forward`` and ``update`` the metric accepts the following input: + As input to ``forward`` and ``update`` the metric accepts the following + input: - ``preds`` (:class:`~torch.Tensor`): Predictions from model - ``target`` (:class:`~torch.Tensor`): Ground truth values - As output of ``forward`` and ``compute`` the metric returns the following output: + As output of ``forward`` and ``compute`` the metric returns the + following output: - - ``rel_mean_absolute_error`` (:class:`~torch.Tensor`): A tensor with the - relative mean absolute error over the state + - ``rel_mean_absolute_error`` (:class:`~torch.Tensor`): A tensor with + the relative mean absolute error over the state Args: kwargs: Additional keyword arguments, see `Advanced metric settings `_. Reference: - As in e.g. From big to small: Multi-scale local planar guidance for monocular depth estimation + As in e.g. From big to small: Multi-scale local planar guidance for + monocular depth estimation """ super().__init__(**kwargs) @@ -43,25 +48,31 @@ def __init__( .. math:: \text{MSErel} = \frac{1}{N}\sum_i^N \frac{(y_i - \hat{y_i})^2}{y_i} - Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a + tensor of predictions. - As input to ``forward`` and ``update`` the metric accepts the following input: + As input to ``forward`` and ``update`` the metric accepts the following + input: - ``preds`` (:class:`~torch.Tensor`): Predictions from model - ``target`` (:class:`~torch.Tensor`): Ground truth values - As output of ``forward`` and ``compute`` the metric returns the following output: + As output of ``forward`` and ``compute`` the metric returns the + following output: - - ``rel_mean_squared_error`` (:class:`~torch.Tensor`): A tensor with the relative mean squared error + - ``rel_mean_squared_error`` (:class:`~torch.Tensor`): A tensor with + the relative mean squared error Args: - squared: If True returns MSErel value, if False returns RMSErel value. + squared: If True returns MSErel value, if False returns RMSErel + value. num_outputs: Number of outputs in multioutput setting kwargs: Additional keyword arguments, see `Advanced metric settings `_. Reference: - As in e.g. From big to small: Multi-scale local planar guidance for monocular depth estimation + As in e.g. From big to small: Multi-scale local planar guidance for + monocular depth estimation """ super().__init__(squared, num_outputs, **kwargs) diff --git a/torch_uncertainty/metrics/regression/silog.py b/torch_uncertainty/metrics/regression/silog.py index f8e71000..a0ac3152 100644 --- a/torch_uncertainty/metrics/regression/silog.py +++ b/torch_uncertainty/metrics/regression/silog.py @@ -7,40 +7,66 @@ class SILog(Metric): - def __init__(self, lmbda: float = 1, **kwargs: Any) -> None: + def __init__( + self, sqrt: bool = False, lmbda: float = 1.0, **kwargs: Any + ) -> None: r"""The Scale-Invariant Logarithmic Loss metric. - .. math:: \text{SILog} = \frac{1}{N} \sum_{i=1}^{N} \left(\log(y_i) - \log(\hat{y_i})\right)^2 - \left(\frac{1}{N} \sum_{i=1}^{N} \log(y_i) \right)^2 + .. math:: \text{SILog} = \frac{1}{N} \sum_{i=1}^{N} \left(\log(y_i) - \log(\hat{y_i})\right)^2 - \left(\frac{1}{N} \sum_{i=1}^{N} \log(y_i) \right)^2, - where :math:`N` is the batch size, :math:`y_i` is a tensor of target values and :math:`\hat{y_i}` is a tensor of prediction. + where :math:`N` is the batch size, :math:`y_i` is a tensor of target + values and :math:`\hat{y_i}` is a tensor of prediction. + Return the square root of SILog by setting :attr:`sqrt` to `True`. Inputs: - :attr:`pred`: :math:`(N)` - :attr:`target`: :math:`(N)` Args: - lmbda: The regularization parameter on the variance of error (default 1). + sqrt: If `True`, return the square root of the metric. Defaults to + False. + lmbda: The regularization parameter on the variance of error. + Defaults to 1.0. kwargs: Additional keyword arguments, see `Advanced metric settings `_. Reference: - Depth Map Prediction from a Single Image using a Multi-Scale Deep Network. + Depth Map Prediction from a Single Image using a Multi-Scale Deep + Network. David Eigen, Christian Puhrsch, Rob Fergus. NeurIPS 2014. - From Big to Small: Multi-Scale Local Planar Guidance for Monocular Depth Estimation. - Jin Han Lee, Myung-Kyu Han, Dong Wook Ko and Il Hong Suh. For the lambda parameter. + From Big to Small: Multi-Scale Local Planar Guidance for Monocular + Depth Estimation. + Jin Han Lee, Myung-Kyu Han, Dong Wook Ko and Il Hong Suh. (For + :attr:`lmbda`) """ super().__init__(**kwargs) + self.sqrt = sqrt self.lmbda = lmbda - self.add_state("log_dists", default=[], dist_reduce_fx="cat") + self.add_state( + "log_dists", + default=torch.tensor(0.0), + dist_reduce_fx="sum", + ) + self.add_state( + "sq_log_dists", + default=torch.tensor(0.0), + dist_reduce_fx="sum", + ) + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, pred: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" - self.log_dists.append(torch.flatten(pred.log() - target.log())) + self.log_dists += torch.sum(pred.log() - target.log()) + self.sq_log_dists += torch.sum((pred.log() - target.log()) ** 2) + self.total += target.size(0) def compute(self) -> Tensor: """Compute the Scale-Invariant Logarithmic Loss.""" log_dists = dim_zero_cat(self.log_dists) - num_samples = log_dists.size(0) - return torch.mean(log_dists**2) - self.lmbda * torch.sum( - log_dists - ) ** 2 / (num_samples * num_samples) + sq_log_dists = dim_zero_cat(self.sq_log_dists) + out = sq_log_dists / self.total - self.lmbda * log_dists**2 / ( + self.total * self.total + ) + if self.sqrt: + return torch.sqrt(out) + return out diff --git a/torch_uncertainty/models/deep_ensembles.py b/torch_uncertainty/models/deep_ensembles.py index 637e43d6..49640108 100644 --- a/torch_uncertainty/models/deep_ensembles.py +++ b/torch_uncertainty/models/deep_ensembles.py @@ -60,7 +60,7 @@ def deep_ensembles( models: list[nn.Module] | nn.Module, num_estimators: int | None = None, task: Literal[ - "classification", "regression", "segmentation" + "classification", "regression", "segmentation", "pixel_regression" ] = "classification", probabilistic: bool | None = None, reset_model_parameters: bool = False, @@ -70,7 +70,8 @@ def deep_ensembles( Args: models (list[nn.Module] | nn.Module): The model to be ensembled. num_estimators (int | None): The number of estimators in the ensemble. - task (Literal["classification", "regression"]): The model task. + task (Literal["classification", "regression", "segmentation", "pixel_regression"]): The model task. + Defaults to "classification". probabilistic (bool): Whether the regression model is probabilistic. reset_model_parameters (bool): Whether to reset the model parameters when :attr:models is a module or a list of length 1. @@ -125,7 +126,7 @@ def deep_ensembles( if task in ("classification", "segmentation"): return _DeepEnsembles(models=models) - if task == "regression": + if task in ("regression", "pixel_regression"): if probabilistic is None: raise ValueError( "probabilistic must be specified for regression models." diff --git a/tests/datamodules/depth_estimation/__init__.py b/torch_uncertainty/models/depth/__init__.py similarity index 100% rename from tests/datamodules/depth_estimation/__init__.py rename to torch_uncertainty/models/depth/__init__.py diff --git a/torch_uncertainty/models/depth/bts.py b/torch_uncertainty/models/depth/bts.py new file mode 100644 index 00000000..3284b69a --- /dev/null +++ b/torch_uncertainty/models/depth/bts.py @@ -0,0 +1,660 @@ +import math +from typing import Literal + +import torch +import torchvision.models as tv_models +from torch import Tensor, nn +from torch.distributions import Distribution +from torch.nn import functional as F +from torchvision.models.densenet import DenseNet121_Weights, DenseNet161_Weights +from torchvision.models.resnet import ( + ResNet50_Weights, + ResNet101_Weights, + ResNeXt50_32X4D_Weights, + ResNeXt101_32X8D_Weights, +) + +from torch_uncertainty.layers.distributions import LaplaceLayer, NormalLayer +from torch_uncertainty.models.utils import Backbone + +resnet_feat_out_channels = [64, 256, 512, 1024, 2048] +resnet_feat_names = ["relu", "layer1", "layer2", "layer3", "layer4"] +densenet_feat_names = [ + "relu0", + "pool0", + "transition1", + "transition2", + "norm5", +] + +bts_backbones = [ + "densenet121", + "densenet161", + "resnet50", + "resnet101", + "resnext50", + "resnext101", +] + + +class AtrousBlock2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dilation: int, + norm_first: bool = True, + norm_momentum: float = 0.1, + **factory_kwargs, + ): + """Atrous block with 1x1 and 3x3 convolutions. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + dilation (int): Dilation rate for the 3x3 convolution. + norm_first (bool): Whether to apply normalization before the 1x1 convolution. + Defaults to True. + norm_momentum (float): Momentum for the normalization layer. Defaults to 0.1. + factory_kwargs: Additional arguments for the PyTorch layers. + """ + super().__init__() + + self.norm_first = norm_first + if norm_first: + self.first_norm = nn.BatchNorm2d( + in_channels, momentum=norm_momentum, **factory_kwargs + ) + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels * 2, + bias=False, + kernel_size=1, + stride=1, + padding=0, + **factory_kwargs, + ) + self.norm = nn.BatchNorm2d( + out_channels * 2, momentum=norm_momentum, **factory_kwargs + ) + self.conv2 = nn.Conv2d( + in_channels=out_channels * 2, + out_channels=out_channels, + bias=False, + kernel_size=3, + stride=1, + padding=(dilation, dilation), + dilation=dilation, + **factory_kwargs, + ) + + def forward(self, x: Tensor) -> Tensor: + if self.norm_first: + x = self.first_norm(x) + out = F.relu(self.conv1(x)) + out = F.relu(self.norm(out)) + return self.conv2(out) + + +class UpConv2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + ratio: int = 2, + **factory_kwargs, + ): + """Upsampling convolution. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + ratio (int): Upsampling ratio. + factory_kwargs: Additional arguments for the convolution layer. + """ + super().__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + bias=False, + kernel_size=3, + stride=1, + padding=1, + **factory_kwargs, + ) + self.ratio = ratio + + def forward(self, x: Tensor) -> Tensor: + out = F.interpolate(x, scale_factor=self.ratio, mode="nearest") + return F.elu(self.conv(out)) + + +class Reduction1x1(nn.Module): + def __init__( + self, + num_in_filters: int, + num_out_filters: int, + max_depth: float, + is_final: bool = False, + **factory_kwargs, + ): + super().__init__() + self.max_depth = max_depth + self.is_final = is_final + self.reduc = torch.nn.Sequential() + + while num_out_filters >= 4: + if num_out_filters < 8: + if self.is_final: + self.reduc.add_module( + "final", + torch.nn.Sequential( + nn.Conv2d( + num_in_filters, + out_channels=1, + bias=False, + kernel_size=1, + stride=1, + padding=0, + **factory_kwargs, + ), + nn.Sigmoid(), + ), + ) + else: + self.reduc.add_module( + "plane_params", + torch.nn.Conv2d( + num_in_filters, + out_channels=3, + bias=False, + kernel_size=1, + stride=1, + padding=0, + **factory_kwargs, + ), + ) + break + + self.reduc.add_module( + f"inter_{num_in_filters}_{num_out_filters}", + torch.nn.Sequential( + nn.Conv2d( + in_channels=num_in_filters, + out_channels=num_out_filters, + bias=False, + kernel_size=1, + stride=1, + padding=0, + **factory_kwargs, + ), + nn.ELU(), + ), + ) + + num_in_filters = num_out_filters + num_out_filters = num_out_filters // 2 + + def forward(self, x: Tensor) -> Tensor: + x = self.reduc.forward(x) + if not self.is_final: + theta = F.sigmoid(x[:, 0, :, :]) * math.pi / 3 + phi = F.sigmoid(x[:, 1, :, :]) * math.pi * 2 + dist = F.sigmoid(x[:, 2, :, :]) * self.max_depth + x = torch.cat( + [ + torch.mul(torch.sin(theta), torch.cos(phi)).unsqueeze(1), + torch.mul(torch.sin(theta), torch.sin(phi)).unsqueeze(1), + torch.cos(theta).unsqueeze(1), + dist.unsqueeze(1), + ], + dim=1, + ) + return x + + +class LocalPlanarGuidance(nn.Module): + def __init__(self, up_ratio: int) -> None: + super().__init__() + self.register_buffer( + "u", torch.arange(up_ratio).reshape([1, 1, up_ratio]) + ) + self.register_buffer( + "v", torch.arange(up_ratio).reshape([1, up_ratio, 1]) + ) + self.up_ratio = up_ratio + + def forward(self, x: Tensor) -> Tensor: + x_expanded = torch.repeat_interleave( + torch.repeat_interleave(x, self.up_ratio, 2), self.up_ratio, 3 + ) + + u = self.u.repeat( + x.size(0), + x.size(2) * self.up_ratio, + x.size(3), + ) + u = (u - (self.up_ratio - 1) * 0.5) / self.up_ratio + + v = self.v.repeat( + x.size(0), + x.size(2), + x.size(3) * self.up_ratio, + ) + v = (v - (self.up_ratio - 1) * 0.5) / self.up_ratio + + return x_expanded[:, 3, :, :] / ( + x_expanded[:, 0, :, :] * u + + x_expanded[:, 1, :, :] * v + + x_expanded[:, 2, :, :] + ) + + +class BTSBackbone(Backbone): + def __init__(self, backbone_name: str, pretrained: bool) -> None: + """BTS backbone. + + Args: + backbone_name (str): Name of the backbone. + pretrained (bool): Use a pretrained backbone. + """ + if backbone_name == "densenet121": + model = tv_models.densenet121( + weights=DenseNet121_Weights.DEFAULT if pretrained else None + ).features + feat_names = densenet_feat_names + self.feat_out_channels = [64, 64, 128, 256, 1024] + elif backbone_name == "densenet161": + model = tv_models.densenet161( + weights=DenseNet161_Weights.DEFAULT if pretrained else None + ).features + feat_names = densenet_feat_names + self.feat_out_channels = [96, 96, 192, 384, 2208] + elif backbone_name == "resnet50": + model = tv_models.resnet50( + weights=ResNet50_Weights.IMAGENET1K_V2 if pretrained else None + ) + elif backbone_name == "resnet101": + model = tv_models.resnet101( + weights=ResNet101_Weights.IMAGENET1K_V2 if pretrained else None + ) + elif backbone_name == "resnext50": + model = tv_models.resnext50_32x4d( + weights=ResNeXt50_32X4D_Weights.IMAGENET1K_V2 + if pretrained + else None + ) + else: # backbone_name == "resnext101": + model = tv_models.resnext101_32x8d( + weights=ResNeXt101_32X8D_Weights.IMAGENET1K_V2 + if pretrained + else None + ) + if "res" in backbone_name: # remove classification heads from ResNets + feat_names = resnet_feat_names + self.feat_out_channels = resnet_feat_out_channels + model.avgpool = nn.Identity() + model.fc = nn.Identity() + super().__init__(model=model, feat_names=feat_names) + + +class BTSDecoder(nn.Module): + def __init__( + self, + max_depth: float, + feat_out_channels: list[int], + num_features: int, + dist_layer: type[nn.Module], + ): + super().__init__() + self.max_depth = max_depth + + self.upconv5 = UpConv2d( + in_channels=feat_out_channels[4], out_channels=num_features + ) + self.bn5 = nn.BatchNorm2d(num_features, momentum=0.01, affine=True) + + self.conv5 = nn.Conv2d( + num_features + feat_out_channels[3], + num_features, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) + self.upconv4 = UpConv2d( + in_channels=num_features, out_channels=num_features // 2 + ) + self.bn4 = nn.BatchNorm2d(num_features // 2, momentum=0.01, affine=True) + self.conv4 = nn.Conv2d( + num_features // 2 + feat_out_channels[2], + num_features // 2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) + self.bn4_2 = nn.BatchNorm2d( + num_features // 2, momentum=0.01, affine=True + ) + + self.daspp_3 = AtrousBlock2d( + num_features // 2, + num_features // 4, + 3, + norm_first=False, + norm_momentum=0.01, + ) + self.daspp_6 = AtrousBlock2d( + num_features // 2 + num_features // 4 + feat_out_channels[2], + num_features // 4, + 6, + norm_momentum=0.01, + ) + self.daspp_12 = AtrousBlock2d( + num_features + feat_out_channels[2], + num_features // 4, + 12, + norm_momentum=0.01, + ) + self.daspp_18 = AtrousBlock2d( + num_features + num_features // 4 + feat_out_channels[2], + num_features // 4, + 18, + norm_momentum=0.01, + ) + self.daspp_24 = AtrousBlock2d( + num_features + num_features // 2 + feat_out_channels[2], + num_features // 4, + 24, + norm_momentum=0.01, + ) + self.daspp_conv = torch.nn.Sequential( + nn.Conv2d( + num_features + num_features // 2 + num_features // 4, + num_features // 4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + nn.ELU(), + ) + self.reduc8x8 = Reduction1x1( + num_features // 4, num_features // 4, self.max_depth + ) + self.lpg8x8 = LocalPlanarGuidance(8) + + self.upconv3 = UpConv2d( + in_channels=num_features // 4, out_channels=num_features // 4 + ) + self.bn3 = nn.BatchNorm2d(num_features // 4, momentum=0.01, affine=True) + self.conv3 = nn.Conv2d( + num_features // 4 + feat_out_channels[1] + 1, + num_features // 4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) + self.reduc4x4 = Reduction1x1( + num_features // 4, num_features // 8, self.max_depth + ) + self.lpg4x4 = LocalPlanarGuidance(4) + + self.upconv2 = UpConv2d( + in_channels=num_features // 4, out_channels=num_features // 8 + ) + self.bn2 = nn.BatchNorm2d(num_features // 8, momentum=0.01, affine=True) + self.conv2 = nn.Conv2d( + num_features // 8 + feat_out_channels[0] + 1, + num_features // 8, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) + + self.reduc2x2 = Reduction1x1( + num_features // 8, num_features // 16, self.max_depth + ) + self.lpg2x2 = LocalPlanarGuidance(2) + + self.upconv1 = UpConv2d( + in_channels=num_features // 8, out_channels=num_features // 16 + ) + self.reduc1x1 = Reduction1x1( + num_features // 16, + num_features // 32, + self.max_depth, + is_final=True, + ) + self.conv1 = nn.Conv2d( + num_features // 16 + 4, num_features // 16, 3, 1, 1, bias=False + ) + self.output_channels = 1 + if dist_layer in (NormalLayer, LaplaceLayer): + self.output_channels = 2 + elif dist_layer != nn.Identity: + raise ValueError( + f"Unsupported distribution layer. Got {dist_layer}." + ) + self.depth = nn.Conv2d( + num_features // 16, self.output_channels, 3, 1, 1, bias=False + ) + self.dist_layer = dist_layer(dim=1) + + def feat_forward(self, features: list[Tensor]) -> Tensor: + dense_features = F.relu(features[4]) + upconv5 = self.bn5(self.upconv5(dense_features)) # H/16 + iconv5 = F.elu(self.conv5(torch.cat([upconv5, features[3]], dim=1))) + + upconv4 = self.bn4(self.upconv4(iconv5)) # H/8 + concat4 = torch.cat([upconv4, features[2]], dim=1) + iconv4 = self.bn4_2(F.elu(self.conv4(concat4))) + + daspp_3 = self.daspp_3(iconv4) + concat4_2 = torch.cat([concat4, daspp_3], dim=1) + daspp_6 = self.daspp_6(concat4_2) + concat4_3 = torch.cat([concat4_2, daspp_6], dim=1) + daspp_12 = self.daspp_12(concat4_3) + concat4_4 = torch.cat([concat4_3, daspp_12], dim=1) + daspp_18 = self.daspp_18(concat4_4) + daspp_24 = self.daspp_24(torch.cat([concat4_4, daspp_18], dim=1)) + concat4_daspp = torch.cat( + [iconv4, daspp_3, daspp_6, daspp_12, daspp_18, daspp_24], dim=1 + ) + daspp_feat = self.daspp_conv(concat4_daspp) + + reduc8x8 = self.reduc8x8(daspp_feat) + plane_normal_8x8 = reduc8x8[:, :3, :, :] + plane_normal_8x8 = F.normalize(plane_normal_8x8, p=2, dim=1) + plane_dist_8x8 = reduc8x8[:, 3, :, :] + plane_eq_8x8 = torch.cat( + [plane_normal_8x8, plane_dist_8x8.unsqueeze(1)], 1 + ) + depth_8x8 = self.lpg8x8(plane_eq_8x8) + depth_8x8_scaled = depth_8x8.unsqueeze(1) / self.max_depth + depth_8x8_scaled_ds = F.interpolate( + depth_8x8_scaled, scale_factor=0.25, mode="nearest" + ) + + upconv3 = self.bn3(self.upconv3(daspp_feat)) # H/4 + concat3 = torch.cat([upconv3, features[1], depth_8x8_scaled_ds], dim=1) + iconv3 = F.elu(self.conv3(concat3)) + + reduc4x4 = self.reduc4x4(iconv3) + plane_normal_4x4 = reduc4x4[:, :3, :, :] + plane_normal_4x4 = F.normalize(plane_normal_4x4, p=2, dim=1) + plane_dist_4x4 = reduc4x4[:, 3, :, :] + plane_eq_4x4 = torch.cat( + [plane_normal_4x4, plane_dist_4x4.unsqueeze(1)], 1 + ) + depth_4x4 = self.lpg4x4(plane_eq_4x4) + depth_4x4_scaled = depth_4x4.unsqueeze(1) / self.max_depth + depth_4x4_scaled_ds = F.interpolate( + depth_4x4_scaled, scale_factor=0.5, mode="nearest" + ) + + upconv2 = self.bn2(self.upconv2(iconv3)) # H/2 + iconv2 = F.elu( + self.conv2( + torch.cat([upconv2, features[0], depth_4x4_scaled_ds], dim=1) + ) + ) + + reduc2x2 = self.reduc2x2(iconv2) + plane_normal_2x2 = reduc2x2[:, :3, :, :] + plane_normal_2x2 = F.normalize(plane_normal_2x2, p=2, dim=1) + plane_dist_2x2 = reduc2x2[:, 3, :, :] + plane_eq_2x2 = torch.cat( + [plane_normal_2x2, plane_dist_2x2.unsqueeze(1)], 1 + ) + depth_2x2 = self.lpg2x2(plane_eq_2x2) + depth_2x2_scaled = depth_2x2.unsqueeze(1) / self.max_depth + + upconv1 = self.upconv1(iconv2) + reduc1x1 = self.reduc1x1(upconv1) + concat1 = torch.cat( + [ + upconv1, + reduc1x1, + depth_2x2_scaled, + depth_4x4_scaled, + depth_8x8_scaled, + ], + dim=1, + ) + return F.elu(self.conv1(concat1)) + + def forward(self, features: list[Tensor]) -> Tensor | Distribution: + """Forward pass. + + Args: + features (list[Tensor]): List of the features from the backbone. + + Note: + Depending of the :attr:`dist_layer` of the backbone, the output can + be a distribution or a single tensor. + """ + # TODO: handle focal + out = self.depth(self.feat_forward(features)) + if self.output_channels != 1: + loc = self.max_depth * F.sigmoid(out[:, 0, :, :]) + scale = self.max_depth * out[:, 1, :, :] + out = self.dist_layer(torch.stack([loc, scale], -1)) + else: + out = self.max_depth * F.sigmoid(out) + return out + + +class _BTS(nn.Module): + def __init__( + self, + backbone_name: Literal[ + "densenet121", + "densenet161", + "resnet50", + "resnet101", + "resnext50", + "resnext101", + ], + max_depth: float, + bts_size: int = 512, + dist_layer: type[nn.Module] = nn.Identity, + pretrained_backbone: bool = True, + ) -> None: + """BTS model. + + Args: + backbone_name (str): Name of the encoding backbone. + max_depth (float): Maximum predicted depth. + bts_size (int): BTS feature size. Defaults to 512. + dist_layer (nn.Module): Distribution layer for probabilistic depth + estimation. Defaults to nn.Identity. + pretrained_backbone (bool): Use a pretrained backbone. Defaults to True. + + Reference: + From Big to Small: Multi-Scale Local Planar Guidance for Monocular Depth Estimation. + Jin Han Lee, Myung-Kyu Han, Dong Wook Ko, Il Hong Suh. ArXiv. + """ + super().__init__() + self.max_depth = max_depth + + self.backbone = BTSBackbone(backbone_name, pretrained_backbone) + self.decoder = BTSDecoder( + max_depth, self.backbone.feat_out_channels, bts_size, dist_layer + ) + + # TODO: Handle focal + def forward(self, x: Tensor, focal: float | None = None) -> Tensor: + """Forward pass. + + Args: + x (Tensor): Input tensor. + focal (float): Focal length for API consistency. + """ + return self.decoder(self.backbone(x)) + + +def _bts( + backbone_name: str, + max_depth: float, + bts_size: int = 512, + dist_layer: type[nn.Module] = nn.Identity, + pretrained_backbone: bool = True, +) -> _BTS: + if backbone_name not in bts_backbones: + raise ValueError(f"Unsupported backbone. Got {backbone_name}.") + return _BTS( + backbone_name, max_depth, bts_size, dist_layer, pretrained_backbone + ) + + +def bts_resnet50( + max_depth: float, + bts_size: int = 512, + dist_layer: type[nn.Module] = nn.Identity, + pretrained_backbone: bool = True, +) -> _BTS: + """BTS model with ResNet-50 backbone. + + Args: + max_depth (float): Maximum predicted depth. + bts_size (int): BTS feature size. Defaults to 512. + dist_layer (nn.Module): Distribution layer for probabilistic depth + estimation. Defaults to nn.Identity. + pretrained_backbone (bool): Use a pretrained backbone. Defaults to True. + """ + return _bts( + "resnet50", + max_depth, + bts_size=bts_size, + dist_layer=dist_layer, + pretrained_backbone=pretrained_backbone, + ) + + +def bts_resnet101( + max_depth: float, + bts_size: int = 512, + dist_layer: type[nn.Module] = nn.Identity, + pretrained_backbone: bool = True, +) -> _BTS: + """BTS model with ResNet-101 backbone. + + Args: + max_depth (float): Maximum predicted depth. + bts_size (int): BTS feature size. Defaults to 512. + dist_layer (nn.Module): Distribution layer for probabilistic depth + estimation. Defaults to nn.Identity. + pretrained_backbone (bool): Use a pretrained backbone. Defaults to True. + """ + return _bts( + "resnet101", + max_depth, + bts_size=bts_size, + dist_layer=dist_layer, + pretrained_backbone=pretrained_backbone, + ) diff --git a/torch_uncertainty/models/lenet.py b/torch_uncertainty/models/lenet.py index fcb9663e..b18fa488 100644 --- a/torch_uncertainty/models/lenet.py +++ b/torch_uncertainty/models/lenet.py @@ -124,7 +124,7 @@ def lenet( last_layer_dropout: bool = False, ) -> _LeNet: return _lenet( - False, + stochastic=False, in_channels=in_channels, num_classes=num_classes, linear_layer=nn.Linear, diff --git a/torch_uncertainty/models/mc_dropout.py b/torch_uncertainty/models/mc_dropout.py index 355fe43e..24a545b3 100644 --- a/torch_uncertainty/models/mc_dropout.py +++ b/torch_uncertainty/models/mc_dropout.py @@ -5,13 +5,17 @@ class _MCDropout(nn.Module): def __init__( self, model: nn.Module, num_estimators: int, last_layer: bool ) -> None: - """MC Dropout wrapper for a model. + """MC Dropout wrapper for a model containing nn.Dropout modules. Args: model (nn.Module): model to wrap num_estimators (int): number of estimators to use last_layer (bool): whether to apply dropout to the last layer only. + Warning: + Apply dropout using modules and not functional for this wrapper to + work as intended. + Warning: The underlying models must have a non-zero :attr:`dropout_rate` attribute. diff --git a/torch_uncertainty/models/mlp.py b/torch_uncertainty/models/mlp.py index a822343d..1a50524f 100644 --- a/torch_uncertainty/models/mlp.py +++ b/torch_uncertainty/models/mlp.py @@ -19,7 +19,7 @@ def __init__( layer: type[nn.Module], activation: Callable, layer_args: dict, - final_layer: nn.Module, + final_layer: type[nn.Module], final_layer_args: dict, dropout_rate: float, ) -> None: @@ -97,7 +97,7 @@ def _mlp( layer_args: dict | None = None, layer: type[nn.Module] = nn.Linear, activation: Callable = F.relu, - final_layer: nn.Module = nn.Identity, + final_layer: type[nn.Module] = nn.Identity, final_layer_args: dict | None = None, dropout_rate: float = 0.0, ) -> _MLP | _StochasticMLP: @@ -125,7 +125,7 @@ def mlp( hidden_dims: list[int], layer: type[nn.Module] = nn.Linear, activation: Callable = F.relu, - final_layer: nn.Module = nn.Identity, + final_layer: type[nn.Module] = nn.Identity, final_layer_args: dict | None = None, dropout_rate: float = 0.0, ) -> _MLP: @@ -167,7 +167,7 @@ def packed_mlp( alpha: float = 2, gamma: float = 1, activation: Callable = F.relu, - final_layer: nn.Module = nn.Identity, + final_layer: type[nn.Module] = nn.Identity, final_layer_args: dict | None = None, dropout_rate: float = 0.0, ) -> _MLP: @@ -195,7 +195,7 @@ def bayesian_mlp( num_outputs: int, hidden_dims: list[int], activation: Callable = F.relu, - final_layer: nn.Module = nn.Identity, + final_layer: type[nn.Module] = nn.Identity, final_layer_args: dict | None = None, dropout_rate: float = 0.0, ) -> _StochasticMLP: diff --git a/torch_uncertainty/models/resnet/__init__.py b/torch_uncertainty/models/resnet/__init__.py index 883b98db..cdff770e 100644 --- a/torch_uncertainty/models/resnet/__init__.py +++ b/torch_uncertainty/models/resnet/__init__.py @@ -1,6 +1,7 @@ # ruff: noqa: F401, F403 -from .batched import * -from .masked import * -from .mimo import * -from .packed import * -from .std import * +from .batched import batched_resnet +from .lpbnn import lpbnn_resnet +from .masked import masked_resnet +from .mimo import mimo_resnet +from .packed import packed_resnet +from .std import resnet diff --git a/torch_uncertainty/models/resnet/batched.py b/torch_uncertainty/models/resnet/batched.py index 52b3fc1f..4b32d3f1 100644 --- a/torch_uncertainty/models/resnet/batched.py +++ b/torch_uncertainty/models/resnet/batched.py @@ -1,10 +1,3 @@ -"""_BatchedResNet in PyTorch. - -Reference: -[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun - Deep Residual Learning for Image Recognition. arXiv:1512.03385 -""" - from typing import Literal import torch.nn.functional as F @@ -12,13 +5,10 @@ from torch_uncertainty.layers import BatchConv2d, BatchLinear +from .utils import get_resnet_num_blocks + __all__ = [ - "batched_resnet18", - "batched_resnet20", - "batched_resnet34", - "batched_resnet50", - "batched_resnet101", - "batched_resnet152", + "batched_resnet", ] @@ -34,7 +24,7 @@ def __init__( conv_bias: bool, dropout_rate: float, groups: int, - normalization_layer: nn.Module, + normalization_layer: type[nn.Module], ) -> None: super().__init__() self.conv1 = BatchConv2d( @@ -48,7 +38,6 @@ def __init__( bias=conv_bias, ) self.bn1 = normalization_layer(planes) - self.dropout = nn.Dropout2d(p=dropout_rate) self.conv2 = BatchConv2d( planes, @@ -61,7 +50,6 @@ def __init__( bias=conv_bias, ) self.bn2 = normalization_layer(planes) - self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( @@ -95,7 +83,7 @@ def __init__( conv_bias: bool, dropout_rate: float, groups: int, - normalization_layer: nn.Module, + normalization_layer: type[nn.Module], ) -> None: super().__init__() self.conv1 = BatchConv2d( @@ -166,7 +154,7 @@ def __init__( width_multiplier: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", in_planes: int = 64, - normalization_layer: nn.Module = nn.BatchNorm2d, + normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> None: super().__init__() self.in_planes = in_planes * width_multiplier @@ -280,7 +268,7 @@ def _make_layer( conv_bias: bool, dropout_rate: float, groups: int, - normalization_layer: nn.Module, + normalization_layer: type[nn.Module], ) -> nn.Module: strides = [stride] + [1] * (num_blocks - 1) layers = [] @@ -313,250 +301,39 @@ def forward(self, x: Tensor) -> Tensor: return self.linear(out) -def batched_resnet18( - in_channels: int, - num_classes: int, - num_estimators: int, - conv_bias: bool = True, - dropout_rate: float = 0, - groups: int = 1, - style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, -) -> _BatchedResNet: - """BatchEnsemble of ResNet-18. - - Args: - in_channels (int): Number of input channels. - num_estimators (int): Number of estimators in the ensemble. - conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. - dropout_rate (float): Dropout rate. Defaults to 0. - groups (int): Number of groups within each estimator. - num_classes (int): Number of classes to predict. - style (bool, optional): Whether to use the ImageNet - structure. Defaults to ``True``. - normalization_layer (nn.Module, optional): Normalization layer. - - Returns: - _BatchedResNet: A BatchEnsemble-style ResNet-18. - """ - return _BatchedResNet( - _BasicBlock, - [2, 2, 2, 2], - in_channels=in_channels, - num_classes=num_classes, - num_estimators=num_estimators, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - groups=groups, - style=style, - in_planes=64, - normalization_layer=normalization_layer, - ) - - -def batched_resnet20( - in_channels: int, - num_classes: int, - num_estimators: int, - conv_bias: bool = True, - dropout_rate: float = 0, - groups: int = 1, - style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, -) -> _BatchedResNet: - """BatchEnsemble of ResNet-20. - - Args: - in_channels (int): Number of input channels. - num_estimators (int): Number of estimators in the ensemble. - conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. - dropout_rate (float): Dropout rate. Defaults to 0. - groups (int): Number of groups within each estimator. - num_classes (int): Number of classes to predict. - style (bool, optional): Whether to use the ImageNet - structure. Defaults to ``True``. - normalization_layer (nn.Module, optional): Normalization layer. - - Returns: - _BatchedResNet: A BatchEnsemble-style ResNet-20. - """ - return _BatchedResNet( - _BasicBlock, - [3, 3, 3], - in_channels=in_channels, - num_classes=num_classes, - num_estimators=num_estimators, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - groups=groups, - style=style, - in_planes=16, - normalization_layer=normalization_layer, - ) - - -def batched_resnet34( - in_channels: int, - num_classes: int, - num_estimators: int, - conv_bias: bool = True, - dropout_rate: float = 0, - groups: int = 1, - style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, -) -> _BatchedResNet: - """BatchEnsemble of ResNet-34. - - Args: - in_channels (int): Number of input channels. - num_estimators (int): Number of estimators in the ensemble. - conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. - dropout_rate (float): Dropout rate. Defaults to 0. - groups (int): Number of groups within each estimator. - num_classes (int): Number of classes to predict. - style (bool, optional): Whether to use the ImageNet - structure. Defaults to ``True``. - normalization_layer (nn.Module, optional): Normalization layer. - - Returns: - _BatchedResNet: A BatchEnsemble-style ResNet-34. - """ - return _BatchedResNet( - _BasicBlock, - [3, 4, 6, 3], - in_channels=in_channels, - num_classes=num_classes, - num_estimators=num_estimators, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - groups=groups, - style=style, - in_planes=64, - normalization_layer=normalization_layer, - ) - - -def batched_resnet50( - in_channels: int, - num_classes: int, - num_estimators: int, - conv_bias: bool = True, - dropout_rate: float = 0, - groups: int = 1, - width_multiplier: int = 1, - style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, -) -> _BatchedResNet: - """BatchEnsemble of ResNet-50. - - Args: - in_channels (int): Number of input channels. - num_estimators (int): Number of estimators in the ensemble. - conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. - dropout_rate (float): Dropout rate. Defaults to 0. - groups (int): Number of groups within each estimator. - num_classes (int): Number of classes to predict. - width_multiplier (int, optional): Expansion factor affecting the width - of the estimators. Defaults to ``1``. - style (bool, optional): Whether to use the ImageNet - structure. Defaults to ``True``. - normalization_layer (nn.Module, optional): Normalization layer. - - Returns: - _BatchedResNet: A BatchEnsemble-style ResNet-50. - """ - return _BatchedResNet( - _Bottleneck, - [3, 4, 6, 3], - in_channels=in_channels, - num_classes=num_classes, - num_estimators=num_estimators, - width_multiplier=width_multiplier, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - groups=groups, - style=style, - in_planes=64, - normalization_layer=normalization_layer, - ) - - -def batched_resnet101( +def batched_resnet( in_channels: int, num_classes: int, + arch: int, num_estimators: int, conv_bias: bool = True, dropout_rate: float = 0, groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, + normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> _BatchedResNet: - """BatchEnsemble of ResNet-101. + """BatchEnsemble of ResNet. Args: in_channels (int): Number of input channels. - num_estimators (int): Number of estimators in the ensemble. - conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. - dropout_rate (float): Dropout rate. Defaults to 0. - groups (int): Number of groups within each estimator. num_classes (int): Number of classes to predict. - style (bool, optional): Whether to use the ImageNet - structure. Defaults to ``True``. - normalization_layer (nn.Module, optional): Normalization layer. - - Returns: - _BatchedResNet: A BatchEnsemble-style ResNet-101. - """ - return _BatchedResNet( - _Bottleneck, - [3, 4, 23, 3], - in_channels=in_channels, - num_classes=num_classes, - num_estimators=num_estimators, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - groups=groups, - style=style, - in_planes=64, - normalization_layer=normalization_layer, - ) - - -def batched_resnet152( - in_channels: int, - num_classes: int, - num_estimators: int, - conv_bias: bool = True, - dropout_rate: float = 0, - groups: int = 1, - style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, -) -> _BatchedResNet: - """BatchEnsemble of ResNet-152. - - Args: - in_channels (int): Number of input channels. + arch (int): The architecture of the ResNet. num_estimators (int): Number of estimators in the ensemble. conv_bias (bool): Whether to use bias in convolutions. Defaults to ``True``. dropout_rate (float): Dropout rate. Defaults to 0. groups (int): Number of groups within each estimator. - num_classes (int): Number of classes to predict. style (bool, optional): Whether to use the ImageNet structure. Defaults to ``True``. normalization_layer (nn.Module, optional): Normalization layer. Returns: - _BatchedResNet: A BatchEnsemble-style ResNet-152. + _BatchedResNet: A BatchEnsemble-style ResNet. """ + block = _BasicBlock if arch in [18, 20, 34] else _Bottleneck return _BatchedResNet( - _Bottleneck, - [3, 8, 36, 3], + block=block, + num_blocks=get_resnet_num_blocks(arch), in_channels=in_channels, num_classes=num_classes, num_estimators=num_estimators, diff --git a/torch_uncertainty/models/resnet/lpbnn.py b/torch_uncertainty/models/resnet/lpbnn.py new file mode 100644 index 00000000..83f22f58 --- /dev/null +++ b/torch_uncertainty/models/resnet/lpbnn.py @@ -0,0 +1,339 @@ +from collections.abc import Callable +from typing import Literal + +from torch import Tensor, nn, relu + +from torch_uncertainty.layers.bayesian.lpbnn import LPBNNConv2d, LPBNNLinear + +from .utils import get_resnet_num_blocks + +__all__ = [ + "lpbnn_resnet", +] + + +class _BasicBlock(nn.Module): + expansion = 1 + + def __init__( + self, + in_planes: int, + planes: int, + stride: int, + dropout_rate: float, + num_estimators: int, + groups: int, + activation_fn: Callable, + normalization_layer: type[nn.Module], + conv_bias: bool, + ) -> None: + super().__init__() + self.activation_fn = activation_fn + + self.conv1 = LPBNNConv2d( + in_planes, + planes, + kernel_size=3, + num_estimators=num_estimators, + groups=groups, + stride=stride, + padding=1, + bias=conv_bias, + ) + self.bn1 = normalization_layer(planes) + self.dropout = nn.Dropout2d(p=dropout_rate) + self.conv2 = LPBNNConv2d( + planes, + planes, + kernel_size=3, + num_estimators=num_estimators, + groups=groups, + stride=1, + padding=1, + bias=conv_bias, + ) + self.bn2 = normalization_layer(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_planes, + self.expansion * planes, + groups=groups, + kernel_size=1, + stride=stride, + bias=conv_bias, + ), + normalization_layer(self.expansion * planes), + ) + + def forward(self, inputs: Tensor) -> Tensor: + out = self.activation_fn(self.dropout(self.bn1(self.conv1(inputs)))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(inputs) + return self.activation_fn(out) + + +class _Bottleneck(nn.Module): + expansion = 4 + + def __init__( + self, + in_planes: int, + planes: int, + stride: int, + num_estimators: int, + dropout_rate: float, + groups: int, + activation_fn: Callable, + normalization_layer: type[nn.Module], + conv_bias: bool, + ) -> None: + super().__init__() + self.activation_fn = activation_fn + + self.conv1 = LPBNNConv2d( + in_planes, + planes, + kernel_size=1, + num_estimators=num_estimators, + groups=groups, + bias=conv_bias, + ) + self.bn1 = normalization_layer(planes) + self.conv2 = LPBNNConv2d( + planes, + planes, + kernel_size=3, + num_estimators=num_estimators, + groups=groups, + stride=stride, + padding=1, + bias=conv_bias, + ) + self.bn2 = normalization_layer(planes) + self.dropout = nn.Dropout2d(p=dropout_rate) + self.conv3 = LPBNNConv2d( + planes, + self.expansion * planes, + num_estimators=num_estimators, + groups=groups, + kernel_size=1, + bias=conv_bias, + ) + self.bn3 = normalization_layer(self.expansion * planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + LPBNNConv2d( + in_planes, + self.expansion * planes, + kernel_size=1, + num_estimators=num_estimators, + groups=groups, + stride=stride, + bias=conv_bias, + ), + normalization_layer(self.expansion * planes), + ) + + def forward(self, x: Tensor) -> Tensor: + out = self.activation_fn(self.bn1(self.conv1(x))) + out = self.activation_fn(self.dropout(self.bn2(self.conv2(out)))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + return self.activation_fn(out) + + +class _LPBNNResNet(nn.Module): + def __init__( + self, + block: type[_BasicBlock | _Bottleneck], + num_blocks: list[int], + in_channels: int, + num_estimators: int, + num_classes: int, + conv_bias: bool, + dropout_rate: float, + groups: int, + style: Literal["imagenet", "cifar"] = "imagenet", + in_planes: int = 64, + activation_fn: Callable = relu, + normalization_layer: type[nn.Module] = nn.BatchNorm2d, + ): + super().__init__() + self.in_planes = in_planes + block_planes = in_planes + self.dropout_rate = dropout_rate + self.activation_fn = activation_fn + self.num_estimators = num_estimators + + if style == "imagenet": + self.conv1 = LPBNNConv2d( + in_channels, + block_planes, + kernel_size=7, + stride=2, + padding=3, + num_estimators=num_estimators, + groups=groups, + bias=conv_bias, + ) + elif style == "cifar": + self.conv1 = LPBNNConv2d( + in_channels, + block_planes, + kernel_size=3, + stride=1, + padding=1, + num_estimators=num_estimators, + groups=groups, + bias=conv_bias, + ) + else: + raise ValueError(f"Unknown style. Got {style}.") + + self.bn1 = normalization_layer(block_planes) + + if style == "imagenet": + self.optional_pool = nn.MaxPool2d( + kernel_size=3, stride=2, padding=1 + ) + else: + self.optional_pool = nn.Identity() + + self.layer1 = self._make_layer( + block, + block_planes, + num_blocks[0], + stride=1, + dropout_rate=dropout_rate, + groups=groups, + activation_fn=activation_fn, + normalization_layer=normalization_layer, + conv_bias=conv_bias, + num_estimators=num_estimators, + ) + self.layer2 = self._make_layer( + block, + block_planes * 2, + num_blocks[1], + stride=2, + dropout_rate=dropout_rate, + groups=groups, + activation_fn=activation_fn, + normalization_layer=normalization_layer, + conv_bias=conv_bias, + num_estimators=num_estimators, + ) + self.layer3 = self._make_layer( + block, + block_planes * 4, + num_blocks[2], + stride=2, + dropout_rate=dropout_rate, + groups=groups, + activation_fn=activation_fn, + normalization_layer=normalization_layer, + conv_bias=conv_bias, + num_estimators=num_estimators, + ) + if len(num_blocks) == 4: + self.layer4 = self._make_layer( + block, + block_planes * 8, + num_blocks[3], + stride=2, + dropout_rate=dropout_rate, + groups=groups, + activation_fn=activation_fn, + normalization_layer=normalization_layer, + conv_bias=conv_bias, + num_estimators=num_estimators, + ) + linear_multiplier = 8 + else: + self.layer4 = nn.Identity() + linear_multiplier = 4 + + self.dropout = nn.Dropout(p=dropout_rate) + self.pool = nn.AdaptiveAvgPool2d(output_size=1) + self.flatten = nn.Flatten(1) + + self.linear = LPBNNLinear( + block_planes * linear_multiplier * block.expansion, + num_classes, + num_estimators=num_estimators, + ) + + def _make_layer( + self, + block: type[_BasicBlock | _Bottleneck], + planes: int, + num_blocks: int, + stride: int, + num_estimators: int, + dropout_rate: float, + groups: int, + activation_fn: Callable, + normalization_layer: type[nn.Module], + conv_bias: bool, + ): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append( + block( + in_planes=self.in_planes, + planes=planes, + stride=stride, + dropout_rate=dropout_rate, + groups=groups, + activation_fn=activation_fn, + normalization_layer=normalization_layer, + conv_bias=conv_bias, + num_estimators=num_estimators, + ) + ) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def feats_forward(self, x: Tensor) -> Tensor: + out = x.repeat(self.num_estimators, 1, 1, 1) + out = self.activation_fn(self.bn1(self.conv1(out))) + out = self.optional_pool(out) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = self.pool(out) + return self.dropout(self.flatten(out)) + + def forward(self, x: Tensor) -> Tensor: + return self.linear(self.feats_forward(x)) + + +def lpbnn_resnet( + in_channels: int, + num_classes: int, + arch: int, + num_estimators: int, + dropout_rate: float = 0, + conv_bias: bool = True, + groups: int = 1, + style: Literal["imagenet", "cifar"] = "imagenet", +) -> _LPBNNResNet: + block = _BasicBlock if arch in [18, 20, 34] else _Bottleneck + return _LPBNNResNet( + block=block, + num_blocks=get_resnet_num_blocks(arch), + in_channels=in_channels, + num_estimators=num_estimators, + num_classes=num_classes, + dropout_rate=dropout_rate, + conv_bias=conv_bias, + groups=groups, + style=style, + ) diff --git a/torch_uncertainty/models/resnet/masked.py b/torch_uncertainty/models/resnet/masked.py index ea61606d..6af599df 100644 --- a/torch_uncertainty/models/resnet/masked.py +++ b/torch_uncertainty/models/resnet/masked.py @@ -5,13 +5,10 @@ from torch_uncertainty.layers import MaskedConv2d, MaskedLinear +from .utils import get_resnet_num_blocks + __all__ = [ - "masked_resnet18", - "masked_resnet20", - "masked_resnet34", - "masked_resnet50", - "masked_resnet101", - "masked_resnet152", + "masked_resnet", ] @@ -28,7 +25,7 @@ def __init__( conv_bias: bool, dropout_rate: float, groups: int, - normalization_layer: nn.Module, + normalization_layer: type[nn.Module], ) -> None: super().__init__() @@ -81,7 +78,7 @@ def forward(self, x: Tensor) -> Tensor: return F.relu(out) -class Bottleneck(nn.Module): +class _Bottleneck(nn.Module): expansion = 4 def __init__( @@ -94,7 +91,7 @@ def __init__( conv_bias: bool, dropout_rate: float, groups: int, - normalization_layer: nn.Module, + normalization_layer: type[nn.Module], ) -> None: super().__init__() @@ -159,7 +156,7 @@ def forward(self, x: Tensor) -> Tensor: class _MaskedResNet(nn.Module): def __init__( self, - block: type[_BasicBlock | Bottleneck], + block: type[_BasicBlock | _Bottleneck], num_blocks: list[int], in_channels: int, num_classes: int, @@ -170,7 +167,7 @@ def __init__( groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", in_planes: int = 64, - normalization_layer: nn.Module = nn.BatchNorm2d, + normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> None: super().__init__() self.in_channels = in_channels @@ -278,7 +275,7 @@ def __init__( def _make_layer( self, - block: type[_BasicBlock | Bottleneck], + block: type[_BasicBlock | _Bottleneck], planes: int, num_blocks: int, stride: int, @@ -287,7 +284,7 @@ def _make_layer( dropout_rate: float, scale: float, groups: int, - normalization_layer: nn.Module, + normalization_layer: type[nn.Module], ) -> nn.Module: strides = [stride] + [1] * (num_blocks - 1) layers = [] @@ -322,242 +319,24 @@ def forward(self, x: Tensor) -> Tensor: return self.linear(out) -def masked_resnet18( - in_channels: int, - num_classes: int, - num_estimators: int, - scale: float, - groups: int = 1, - conv_bias: bool = True, - dropout_rate: float = 0, - style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, -) -> _MaskedResNet: - """Masksembles of ResNet-18. - - Args: - in_channels (int): Number of input channels. - num_classes (int): Number of classes to predict. - num_estimators (int): Number of estimators in the ensemble. - scale (float): The scale of the mask. - groups (int): Number of groups within each estimator. Defaults to 1. - conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. - dropout_rate (float): Dropout rate. Defaults to 0. - style (str, optional): The style of the model. Defaults to "imagenet". - normalization_layer (nn.Module, optional): Normalization layer. - - Returns: - _MaskedResNet: A Masksembles-style ResNet-18. - """ - return _MaskedResNet( - num_classes=num_classes, - block=_BasicBlock, - num_blocks=[2, 2, 2, 2], - in_channels=in_channels, - num_estimators=num_estimators, - scale=scale, - groups=groups, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - style=style, - in_planes=64, - normalization_layer=normalization_layer, - ) - - -def masked_resnet20( +def masked_resnet( in_channels: int, num_classes: int, + arch: int, num_estimators: int, scale: float, groups: int = 1, conv_bias: bool = True, dropout_rate: float = 0, style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, + normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> _MaskedResNet: - """Masksembles of ResNet-20. - - Args: - in_channels (int): Number of input channels. - num_classes (int): Number of classes to predict. - num_estimators (int): Number of estimators in the ensemble. - scale (float): The scale of the mask. - groups (int): Number of groups within each estimator. Defaults to 1. - conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. - dropout_rate (float): Dropout rate. Defaults to 0. - style (str, optional): The style of the model. Defaults to "imagenet". - normalization_layer (nn.Module, optional): Normalization layer. - - Returns: - _MaskedResNet: A Masksembles-style ResNet-20. - """ - return _MaskedResNet( - num_classes=num_classes, - block=_BasicBlock, - num_blocks=[3, 3, 3], - in_channels=in_channels, - num_estimators=num_estimators, - scale=scale, - groups=groups, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - style=style, - in_planes=16, - normalization_layer=normalization_layer, - ) - - -def masked_resnet34( - in_channels: int, - num_classes: int, - num_estimators: int, - scale: float, - groups: int = 1, - conv_bias: bool = True, - dropout_rate: float = 0, - style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, -) -> _MaskedResNet: - """Masksembles of ResNet-34. - - Args: - in_channels (int): Number of input channels. - num_classes (int): Number of classes to predict. - num_estimators (int): Number of estimators in the ensemble. - scale (float): The scale of the mask. - groups (int): Number of groups within each estimator. Defaults to 1. - conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. - dropout_rate (float): Dropout rate. Defaults to 0. - style (str, optional): The style of the model. Defaults to "imagenet". - normalization_layer (nn.Module, optional): Normalization layer. - - Returns: - _MaskedResNet: A Masksembles-style ResNet-34. - """ - return _MaskedResNet( - num_classes=num_classes, - block=_BasicBlock, - num_blocks=[3, 4, 6, 3], - in_channels=in_channels, - num_estimators=num_estimators, - scale=scale, - groups=groups, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - style=style, - in_planes=64, - normalization_layer=normalization_layer, - ) - - -def masked_resnet50( - in_channels: int, - num_classes: int, - num_estimators: int, - scale: float, - groups: int = 1, - conv_bias: bool = True, - dropout_rate: float = 0, - style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, -) -> _MaskedResNet: - """Masksembles of ResNet-50. - - Args: - in_channels (int): Number of input channels. - num_classes (int): Number of classes to predict. - num_estimators (int): Number of estimators in the ensemble. - scale (float): The scale of the mask. - groups (int): Number of groups within each estimator. Defaults to 1. - conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. - dropout_rate (float): Dropout rate. Defaults to 0. - style (str, optional): The style of the model. Defaults to "imagenet". - normalization_layer (nn.Module, optional): Normalization layer. - - Returns: - _MaskedResNet: A Masksembles-style ResNet-50. - """ - return _MaskedResNet( - num_classes=num_classes, - block=Bottleneck, - num_blocks=[3, 4, 6, 3], - in_channels=in_channels, - num_estimators=num_estimators, - scale=scale, - groups=groups, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - style=style, - in_planes=64, - normalization_layer=normalization_layer, - ) - - -def masked_resnet101( - in_channels: int, - num_classes: int, - num_estimators: int, - scale: float, - groups: int = 1, - conv_bias: bool = True, - dropout_rate: float = 0, - style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, -) -> _MaskedResNet: - """Masksembles of ResNet-101. - - Args: - in_channels (int): Number of input channels. - num_classes (int): Number of classes to predict. - num_estimators (int): Number of estimators in the ensemble. - scale (float): The scale of the mask. - groups (int): Number of groups within each estimator. Defaults to 1. - conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. - dropout_rate (float): Dropout rate. Defaults to 0. - style (str, optional): The style of the model. Defaults to "imagenet". - normalization_layer (nn.Module, optional): Normalization layer. - - Returns: - _MaskedResNet: A Masksembles-style ResNet-101. - """ - return _MaskedResNet( - num_classes=num_classes, - block=Bottleneck, - num_blocks=[3, 4, 23, 3], - in_channels=in_channels, - num_estimators=num_estimators, - scale=scale, - groups=groups, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - style=style, - in_planes=64, - normalization_layer=normalization_layer, - ) - - -def masked_resnet152( - in_channels: int, - num_classes: int, - num_estimators: int, - scale: float, - groups: int = 1, - conv_bias: bool = True, - dropout_rate: float = 0, - style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, -) -> _MaskedResNet: # coverage: ignore - """Masksembles of ResNet-152. + """Masksembles of ResNet. Args: in_channels (int): Number of input channels. num_classes (int): Number of classes to predict. + arch (int): The architecture of the ResNet. num_estimators (int): Number of estimators in the ensemble. scale (float): The scale of the mask. groups (int): Number of groups within each estimator. Defaults to 1. @@ -568,12 +347,13 @@ def masked_resnet152( normalization_layer (nn.Module, optional): Normalization layer. Returns: - _MaskedResNet: A Masksembles-style ResNet-152. + _MaskedResNet: A Masksembles-style ResNet. """ + block = _BasicBlock if arch in [18, 20, 34] else _Bottleneck return _MaskedResNet( + block=block, + num_blocks=get_resnet_num_blocks(arch), num_classes=num_classes, - block=Bottleneck, - num_blocks=[3, 8, 36, 3], in_channels=in_channels, num_estimators=num_estimators, scale=scale, diff --git a/torch_uncertainty/models/resnet/mimo.py b/torch_uncertainty/models/resnet/mimo.py index b533a61c..05a25e14 100644 --- a/torch_uncertainty/models/resnet/mimo.py +++ b/torch_uncertainty/models/resnet/mimo.py @@ -5,14 +5,10 @@ from torch import nn from .std import _BasicBlock, _Bottleneck, _ResNet +from .utils import get_resnet_num_blocks __all__ = [ - "mimo_resnet18", - "mimo_resnet20", - "mimo_resnet34", - "mimo_resnet50", - "mimo_resnet101", - "mimo_resnet152", + "mimo_resnet", ] @@ -29,7 +25,7 @@ def __init__( groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", in_planes: int = 64, - normalization_layer: nn.Module = nn.BatchNorm2d, + normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> None: super().__init__( block=block, @@ -49,150 +45,26 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: if not self.training: x = x.repeat(self.num_estimators, 1, 1, 1) - out = rearrange(x, "(m b) c h w -> b (m c) h w", m=self.num_estimators) out = super().forward(out) return rearrange(out, "b (m d) -> (m b) d", m=self.num_estimators) -def mimo_resnet18( - in_channels: int, - num_classes: int, - num_estimators: int, - conv_bias: bool = True, - dropout_rate: float = 0.0, - groups: int = 1, - style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, -) -> _MIMOResNet: - return _MIMOResNet( - block=_BasicBlock, - num_blocks=[2, 2, 2, 2], - in_channels=in_channels, - num_classes=num_classes, - num_estimators=num_estimators, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - groups=groups, - style=style, - in_planes=64, - normalization_layer=normalization_layer, - ) - - -def mimo_resnet20( - in_channels: int, - num_classes: int, - num_estimators: int, - conv_bias: bool = True, - dropout_rate: float = 0.0, - groups: int = 1, - style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, -) -> _MIMOResNet: - return _MIMOResNet( - block=_BasicBlock, - num_blocks=[3, 3, 3], - in_channels=in_channels, - num_classes=num_classes, - num_estimators=num_estimators, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - groups=groups, - style=style, - in_planes=16, - normalization_layer=normalization_layer, - ) - - -def mimo_resnet34( - in_channels: int, - num_classes: int, - num_estimators: int, - conv_bias: bool = True, - dropout_rate: float = 0.0, - groups: int = 1, - style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, -) -> _MIMOResNet: - return _MIMOResNet( - block=_BasicBlock, - num_blocks=[3, 4, 6, 3], - in_channels=in_channels, - num_classes=num_classes, - num_estimators=num_estimators, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - groups=groups, - style=style, - in_planes=64, - normalization_layer=normalization_layer, - ) - - -def mimo_resnet50( - in_channels: int, - num_classes: int, - num_estimators: int, - conv_bias: bool = True, - dropout_rate: float = 0.0, - groups: int = 1, - style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, -) -> _MIMOResNet: - return _MIMOResNet( - block=_Bottleneck, - num_blocks=[3, 4, 6, 3], - in_channels=in_channels, - num_classes=num_classes, - num_estimators=num_estimators, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - groups=groups, - style=style, - in_planes=64, - normalization_layer=normalization_layer, - ) - - -def mimo_resnet101( - in_channels: int, - num_classes: int, - num_estimators: int, - conv_bias: bool = True, - dropout_rate: float = 0.0, - groups: int = 1, - style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, -) -> _MIMOResNet: - return _MIMOResNet( - block=_Bottleneck, - num_blocks=[3, 4, 23, 3], - in_channels=in_channels, - num_classes=num_classes, - num_estimators=num_estimators, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - groups=groups, - style=style, - in_planes=64, - normalization_layer=normalization_layer, - ) - - -def mimo_resnet152( +def mimo_resnet( in_channels: int, num_classes: int, + arch: int, num_estimators: int, conv_bias: bool = True, dropout_rate: float = 0.0, groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, + normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> _MIMOResNet: + block = _BasicBlock if arch in [18, 20, 34] else _Bottleneck return _MIMOResNet( - block=_Bottleneck, - num_blocks=[3, 8, 36, 3], + block=block, + num_blocks=get_resnet_num_blocks(arch), in_channels=in_channels, num_classes=num_classes, num_estimators=num_estimators, diff --git a/torch_uncertainty/models/resnet/packed.py b/torch_uncertainty/models/resnet/packed.py index 9aa5f658..fc9787d5 100644 --- a/torch_uncertainty/models/resnet/packed.py +++ b/torch_uncertainty/models/resnet/packed.py @@ -7,13 +7,10 @@ from torch_uncertainty.layers import PackedConv2d, PackedLinear from torch_uncertainty.utils import load_hf +from .utils import get_resnet_num_blocks + __all__ = [ - "packed_resnet18", - "packed_resnet20", - "packed_resnet34", - "packed_resnet50", - "packed_resnet101", - "packed_resnet152", + "packed_resnet", ] weight_ids = { @@ -55,7 +52,7 @@ def __init__( conv_bias: bool, dropout_rate: float, groups: int, - normalization_layer: nn.Module, + normalization_layer: type[nn.Module], ) -> None: super().__init__() @@ -125,7 +122,7 @@ def __init__( conv_bias: bool, dropout_rate: float, groups: int, - normalization_layer: nn.Module, + normalization_layer: type[nn.Module], ) -> None: super().__init__() @@ -207,7 +204,7 @@ def __init__( groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", in_planes: int = 64, - normalization_layer: nn.Module = nn.BatchNorm2d, + normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> None: super().__init__() @@ -342,7 +339,7 @@ def _make_layer( dropout_rate: float, gamma: int, groups: int, - normalization_layer: nn.Module, + normalization_layer: type[nn.Module], ) -> nn.Module: strides = [stride] + [1] * (num_blocks - 1) layers = [] @@ -390,9 +387,10 @@ def check_config(self, config: dict[str, Any]) -> bool: ) -def packed_resnet18( +def packed_resnet( in_channels: int, num_classes: int, + arch: int, num_estimators: int, alpha: int, gamma: int, @@ -400,14 +398,15 @@ def packed_resnet18( groups: int = 1, dropout_rate: float = 0, style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, + normalization_layer: type[nn.Module] = nn.BatchNorm2d, pretrained: bool = False, ) -> _PackedResNet: - """Packed-Ensembles of ResNet-18. + """Packed-Ensembles of ResNet. Args: in_channels (int): Number of input channels. num_classes (int): Number of classes to predict. + arch (int): The architecture of the ResNet. conv_bias (bool): Whether to use bias in convolutions. Defaults to ``True``. dropout_rate (float): Dropout rate. Defaults to 0. @@ -422,335 +421,26 @@ def packed_resnet18( Defaults to ``False``. Returns: - _PackedResNet: A Packed-Ensembles ResNet-18. - """ - net = _PackedResNet( - block=_BasicBlock, - num_blocks=[2, 2, 2, 2], - in_channels=in_channels, - num_estimators=num_estimators, - alpha=alpha, - gamma=gamma, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - groups=groups, - num_classes=num_classes, - style=style, - in_planes=64, - normalization_layer=normalization_layer, - ) - if pretrained: # coverage: ignore - weights = weight_ids[str(num_classes)]["18"] - if weights is None: - raise ValueError("No pretrained weights for this configuration") - state_dict, config = load_hf(weights) - if not net.check_config(config): - raise ValueError( - "Pretrained weights do not match current configuration." - ) - net.load_state_dict(state_dict) - return net - - -def packed_resnet20( - in_channels: int, - num_classes: int, - num_estimators: int, - alpha: int, - gamma: int, - groups: int = 1, - conv_bias: bool = True, - dropout_rate: float = 0, - style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, - pretrained: bool = False, -) -> _PackedResNet: - """Packed-Ensembles of ResNet-20. - - Args: - in_channels (int): Number of input channels. - num_classes (int): Number of classes to predict. - num_estimators (int): Number of estimators in the ensemble. - alpha (int): Expansion factor affecting the width of the estimators. - gamma (int): Number of groups within each estimator. - groups (int): Number of groups within each estimator. Defaults to 1. - conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. - dropout_rate (float): Dropout rate. Defaults to 0. - style (bool, optional): Whether to use the ImageNet - structure. Defaults to ``True``. - normalization_layer (nn.Module, optional): Normalization layer. - pretrained (bool, optional): Whether to load pretrained weights. - Defaults to ``False``. - - Returns: - _PackedResNet: A Packed-Ensembles ResNet-20. - """ - net = _PackedResNet( - block=_BasicBlock, - num_blocks=[3, 3, 3], - in_channels=in_channels, - num_estimators=num_estimators, - alpha=alpha, - gamma=gamma, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - groups=groups, - num_classes=num_classes, - style=style, - in_planes=16, - normalization_layer=normalization_layer, - ) - if pretrained: # coverage: ignore - weights = weight_ids[str(num_classes)]["18"] - if weights is None: - raise ValueError("No pretrained weights for this configuration") - state_dict, config = load_hf(weights) - if not net.check_config(config): - raise ValueError( - "Pretrained weights do not match current configuration." - ) - net.load_state_dict(state_dict) - return net - - -def packed_resnet34( - in_channels: int, - num_classes: int, - num_estimators: int, - alpha: int, - gamma: int, - groups: int = 1, - conv_bias: bool = True, - dropout_rate: float = 0, - style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, - pretrained: bool = False, -) -> _PackedResNet: - """Packed-Ensembles of ResNet-34. - - Args: - in_channels (int): Number of input channels. - num_classes (int): Number of classes to predict. - num_estimators (int): Number of estimators in the ensemble. - alpha (int): Expansion factor affecting the width of the estimators. - gamma (int): Number of groups within each estimator. - groups (int): Number of groups within each estimator. Defaults to 1. - conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. - dropout_rate (float): Dropout rate. Defaults to 0. - style (bool, optional): Whether to use the ImageNet - structure. Defaults to ``True``. - normalization_layer (nn.Module, optional): Normalization layer. - pretrained (bool, optional): Whether to load pretrained weights. - Defaults to ``False``. - - Returns: - _PackedResNet: A Packed-Ensembles ResNet-34. - """ - net = _PackedResNet( - block=_BasicBlock, - num_blocks=[3, 4, 6, 3], - in_channels=in_channels, - num_estimators=num_estimators, - alpha=alpha, - gamma=gamma, - groups=groups, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - num_classes=num_classes, - style=style, - in_planes=64, - normalization_layer=normalization_layer, - ) - if pretrained: # coverage: ignore - weights = weight_ids[str(num_classes)]["34"] - if weights is None: - raise ValueError("No pretrained weights for this configuration") - state_dict, config = load_hf(weights) - if not net.check_config(config): - raise ValueError( - "Pretrained weights do not match current configuration." - ) - net.load_state_dict(state_dict) - return net - - -def packed_resnet50( - in_channels: int, - num_classes: int, - num_estimators: int, - alpha: int, - gamma: int, - groups: int = 1, - conv_bias: bool = True, - dropout_rate: float = 0, - style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, - pretrained: bool = False, -) -> _PackedResNet: - """Packed-Ensembles of ResNet-50. - - Args: - in_channels (int): Number of input channels. - num_classes (int): Number of classes to predict. - num_estimators (int): Number of estimators in the ensemble. - alpha (int): Expansion factor affecting the width of the estimators. - gamma (int): Number of groups within each estimator. - groups (int): Number of groups within each estimator. Defaults to 1. - conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. - dropout_rate (float): Dropout rate. Defaults to 0. - style (bool, optional): Whether to use the ImageNet - structure. Defaults to ``True``. - normalization_layer (nn.Module, optional): Normalization layer. - pretrained (bool, optional): Whether to load pretrained weights. - Defaults to ``False``. - - Returns: - _PackedResNet: A Packed-Ensembles ResNet-50. + _PackedResNet: A Packed-Ensembles ResNet. """ + block = _BasicBlock if arch in [18, 20, 34] else _Bottleneck net = _PackedResNet( - block=_Bottleneck, - num_blocks=[3, 4, 6, 3], + block=block, + num_blocks=get_resnet_num_blocks(arch), in_channels=in_channels, num_estimators=num_estimators, alpha=alpha, gamma=gamma, - groups=groups, conv_bias=conv_bias, dropout_rate=dropout_rate, - num_classes=num_classes, - style=style, - in_planes=64, - normalization_layer=normalization_layer, - ) - if pretrained: # coverage: ignore - weights = weight_ids[str(num_classes)]["50"] - if weights is None: - raise ValueError("No pretrained weights for this configuration") - state_dict, config = load_hf(weights) - if not net.check_config(config): - raise ValueError( - "Pretrained weights do not match current configuration." - ) - net.load_state_dict(state_dict) - return net - - -def packed_resnet101( - in_channels: int, - num_classes: int, - num_estimators: int, - alpha: int, - gamma: int, - groups: int = 1, - conv_bias: bool = True, - dropout_rate: float = 0, - style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, - pretrained: bool = False, -) -> _PackedResNet: - """Packed-Ensembles of ResNet-101. - - Args: - in_channels (int): Number of input channels. - num_classes (int): Number of classes to predict. - num_estimators (int): Number of estimators in the ensemble. - alpha (int): Expansion factor affecting the width of the estimators. - gamma (int): Number of groups within each estimator. - groups (int): Number of groups within each estimator. Defaults to 1. - conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. - dropout_rate (float): Dropout rate. Defaults to 0. - style (bool, optional): Whether to use the ImageNet - structure. Defaults to ``True``. - normalization_layer (nn.Module, optional): Normalization layer. - pretrained (bool, optional): Whether to load pretrained weights. - Defaults to ``False``. - - Returns: - _PackedResNet: A Packed-Ensembles ResNet-101. - """ - net = _PackedResNet( - block=_Bottleneck, - num_blocks=[3, 4, 23, 3], - in_channels=in_channels, - num_estimators=num_estimators, - alpha=alpha, - gamma=gamma, groups=groups, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - num_classes=num_classes, - style=style, - in_planes=64, - normalization_layer=normalization_layer, - ) - if pretrained: # coverage: ignore - weights = weight_ids[str(num_classes)]["101"] - if weights is None: - raise ValueError("No pretrained weights for this configuration") - state_dict, config = load_hf(weights) - if not net.check_config(config): - raise ValueError( - "Pretrained weights do not match current configuration." - ) - net.load_state_dict(state_dict) - return net - - -def packed_resnet152( - in_channels: int, - num_classes: int, - num_estimators: int, - alpha: int, - gamma: int, - groups: int = 1, - conv_bias: bool = True, - dropout_rate: float = 0, - style: Literal["imagenet", "cifar"] = "imagenet", - normalization_layer: nn.Module = nn.BatchNorm2d, - pretrained: bool = False, -) -> _PackedResNet: - """Packed-Ensembles of ResNet-152. - - Args: - in_channels (int): Number of input channels. - num_classes (int): Number of classes to predict. - num_estimators (int): Number of estimators in the ensemble. - alpha (int): Expansion factor affecting the width of the estimators. - gamma (int): Number of groups within each estimator. - groups (int): Number of groups within each estimator. Defaults to 1. - conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. - dropout_rate (float): Dropout rate. Defaults to 0. - style (bool, optional): Whether to use the ImageNet - structure. Defaults to ``True``. - normalization_layer (nn.Module, optional): Normalization layer. - pretrained (bool, optional): Whether to load pretrained weights. - Defaults to ``False``. - - Returns: - _PackedResNet: A Packed-Ensembles ResNet-152. - """ - net = _PackedResNet( - block=_Bottleneck, - num_blocks=[3, 8, 36, 3], - in_channels=in_channels, - num_estimators=num_estimators, - alpha=alpha, - gamma=gamma, - groups=groups, - conv_bias=conv_bias, - dropout_rate=dropout_rate, num_classes=num_classes, style=style, in_planes=64, normalization_layer=normalization_layer, ) if pretrained: # coverage: ignore - weights = weight_ids[str(num_classes)]["152"] + weights = weight_ids[str(num_classes)][str(arch)] if weights is None: raise ValueError("No pretrained weights for this configuration") state_dict, config = load_hf(weights) diff --git a/torch_uncertainty/models/resnet/std.py b/torch_uncertainty/models/resnet/std.py index 0eeea7ba..0e643da7 100644 --- a/torch_uncertainty/models/resnet/std.py +++ b/torch_uncertainty/models/resnet/std.py @@ -4,14 +4,9 @@ from torch import Tensor, nn from torch.nn.functional import relu -__all__ = [ - "resnet18", - "resnet20", - "resnet34", - "resnet50", - "resnet101", - "resnet152", -] +from .utils import get_resnet_num_blocks + +__all__ = ["resnet"] class _BasicBlock(nn.Module): @@ -25,7 +20,7 @@ def __init__( dropout_rate: float, groups: int, activation_fn: Callable, - normalization_layer: nn.Module, + normalization_layer: type[nn.Module], conv_bias: bool, ) -> None: super().__init__() @@ -87,7 +82,7 @@ def __init__( dropout_rate: float, groups: int, activation_fn: Callable, - normalization_layer: nn.Module, + normalization_layer: type[nn.Module], conv_bias: bool, ) -> None: super().__init__() @@ -207,11 +202,10 @@ def __init__( style: Literal["imagenet", "cifar"] = "imagenet", in_planes: int = 64, activation_fn: Callable = relu, - normalization_layer: nn.Module = nn.BatchNorm2d, + normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> None: """ResNet from `Deep Residual Learning for Image Recognition`.""" super().__init__() - self.in_planes = in_planes block_planes = in_planes self.dropout_rate = dropout_rate @@ -317,7 +311,7 @@ def _make_layer( dropout_rate: float, groups: int, activation_fn: Callable, - normalization_layer: nn.Module, + normalization_layer: type[nn.Module], conv_bias: bool, ) -> nn.Module: strides = [stride] + [1] * (num_blocks - 1) @@ -352,21 +346,23 @@ def forward(self, x: Tensor) -> Tensor: return self.linear(self.feats_forward(x)) -def resnet18( +def resnet( in_channels: int, num_classes: int, + arch: int, conv_bias: bool = True, dropout_rate: float = 0.0, groups: int = 1, style: Literal["imagenet", "cifar"] = "imagenet", activation_fn: Callable = relu, - normalization_layer: nn.Module = nn.BatchNorm2d, + normalization_layer: type[nn.Module] = nn.BatchNorm2d, ) -> _ResNet: """ResNet-18 model. Args: in_channels (int): Number of input channels. num_classes (int): Number of classes to predict. + arch (int): The architecture of the ResNet. conv_bias (bool): Whether to use bias in convolutions. Defaults to ``True``. conv_bias (bool): Whether to use bias in convolutions. Defaults to @@ -379,222 +375,12 @@ def resnet18( normalization_layer (nn.Module, optional): Normalization layer. Returns: - _ResNet: A ResNet-18. - """ - return _ResNet( - block=_BasicBlock, - num_blocks=[2, 2, 2, 2], - in_channels=in_channels, - num_classes=num_classes, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - groups=groups, - style=style, - in_planes=64, - activation_fn=activation_fn, - normalization_layer=normalization_layer, - ) - - -def resnet20( - in_channels: int, - num_classes: int, - conv_bias: bool = True, - dropout_rate: float = 0.0, - groups: int = 1, - style: Literal["imagenet", "cifar"] = "imagenet", - activation_fn: Callable = relu, - normalization_layer: nn.Module = nn.BatchNorm2d, -) -> _ResNet: - """ResNet-18 model. - - Args: - in_channels (int): Number of input channels. - num_classes (int): Number of classes to predict. - conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. - dropout_rate (float): Dropout rate. Defaults to 0. - groups (int): Number of groups in convolutions. Defaults to 1. - style (bool, optional): Whether to use the ImageNet - structure. Defaults to ``True``. - activation_fn (Callable, optional): Activation function. - normalization_layer (nn.Module, optional): Normalization layer. - - Returns: - _ResNet: A ResNet-20. - """ - return _ResNet( - block=_BasicBlock, - num_blocks=[3, 3, 3], - in_channels=in_channels, - num_classes=num_classes, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - groups=groups, - style=style, - in_planes=16, - activation_fn=activation_fn, - normalization_layer=normalization_layer, - ) - - -def resnet34( - in_channels: int, - num_classes: int, - conv_bias: bool = True, - dropout_rate: float = 0, - groups: int = 1, - style: Literal["imagenet", "cifar"] = "imagenet", - activation_fn: Callable = relu, - normalization_layer: nn.Module = nn.BatchNorm2d, -) -> _ResNet: - """ResNet-34 model. - - Args: - in_channels (int): Number of input channels. - num_classes (int): Number of classes to predict. - conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. - dropout_rate (float): Dropout rate. Defaults to 0. - groups (int): Number of groups in convolutions. Defaults to 1. - style (bool, optional): Whether to use the ImageNet - structure. Defaults to ``True``. - activation_fn (Callable, optional): Activation function. - normalization_layer (nn.Module, optional): Normalization layer. - - Returns: - _ResNet: A ResNet-34. - """ - return _ResNet( - block=_BasicBlock, - num_blocks=[3, 4, 6, 3], - in_channels=in_channels, - num_classes=num_classes, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - groups=groups, - style=style, - in_planes=64, - activation_fn=activation_fn, - normalization_layer=normalization_layer, - ) - - -def resnet50( - in_channels: int, - num_classes: int, - conv_bias: bool = True, - dropout_rate: float = 0, - groups: int = 1, - style: Literal["imagenet", "cifar"] = "imagenet", - activation_fn: Callable = relu, - normalization_layer: nn.Module = nn.BatchNorm2d, -) -> _ResNet: - """ResNet-50 model. - - Args: - in_channels (int): Number of input channels. - num_classes (int): Number of classes to predict. - conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. - dropout_rate (float): Dropout rate. Defaults to 0. - groups (int): Number of groups in convolutions. Defaults to 1. - style (bool, optional): Whether to use the ImageNet - structure. Defaults to ``True``. - activation_fn (Callable, optional): Activation function. - normalization_layer (nn.Module, optional): Normalization layer. - - Returns: - _ResNet: A ResNet-50. - """ - return _ResNet( - block=_Bottleneck, - num_blocks=[3, 4, 6, 3], - in_channels=in_channels, - num_classes=num_classes, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - groups=groups, - style=style, - in_planes=64, - activation_fn=activation_fn, - normalization_layer=normalization_layer, - ) - - -def resnet101( - in_channels: int, - num_classes: int, - conv_bias: bool = True, - dropout_rate: float = 0, - groups: int = 1, - style: Literal["imagenet", "cifar"] = "imagenet", - activation_fn: Callable = relu, - normalization_layer: nn.Module = nn.BatchNorm2d, -) -> _ResNet: - """ResNet-101 model. - - Args: - in_channels (int): Number of input channels. - num_classes (int): Number of classes to predict. - conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. - dropout_rate (float): Dropout rate. Defaults to 0. - groups (int): Number of groups in convolutions. Defaults to 1. - style (bool, optional): Whether to use the ImageNet - structure. Defaults to ``True``. - activation_fn (Callable, optional): Activation function. - normalization_layer (nn.Module, optional): Normalization layer. - - Returns: - _ResNet: A ResNet-101. - """ - return _ResNet( - block=_Bottleneck, - num_blocks=[3, 4, 23, 3], - in_channels=in_channels, - num_classes=num_classes, - conv_bias=conv_bias, - dropout_rate=dropout_rate, - groups=groups, - style=style, - in_planes=64, - activation_fn=activation_fn, - normalization_layer=normalization_layer, - ) - - -def resnet152( - in_channels: int, - num_classes: int, - conv_bias: bool = True, - dropout_rate: float = 0, - groups: int = 1, - style: Literal["imagenet", "cifar"] = "imagenet", - activation_fn: Callable = relu, - normalization_layer: nn.Module = nn.BatchNorm2d, -) -> _ResNet: - """ResNet-152 model. - - Args: - in_channels (int): Number of input channels. - num_classes (int): Number of classes to predict. - conv_bias (bool): Whether to use bias in convolutions. Defaults to - ``True``. - dropout_rate (float): Dropout rate. Defaults to 0. - groups (int, optional): Number of groups in convolutions. Defaults to - ``1``. - style (bool, optional): Whether to use the ImageNet - structure. Defaults to ``True``. - activation_fn (Callable, optional): Activation function. - normalization_layer (nn.Module, optional): Normalization layer. - - Returns: - _ResNet: A ResNet-152. + _ResNet: The ResNet model. """ + block = _BasicBlock if arch in [18, 20, 34] else _Bottleneck return _ResNet( - block=_Bottleneck, - num_blocks=[3, 8, 36, 3], + block=block, + num_blocks=get_resnet_num_blocks(arch), in_channels=in_channels, num_classes=num_classes, conv_bias=conv_bias, diff --git a/torch_uncertainty/models/resnet/utils.py b/torch_uncertainty/models/resnet/utils.py new file mode 100644 index 00000000..0e082509 --- /dev/null +++ b/torch_uncertainty/models/resnet/utils.py @@ -0,0 +1,14 @@ +def get_resnet_num_blocks(arch: int) -> list[int]: + if arch == 18: + num_blocks = [2, 2, 2, 2] + elif arch == 20: + num_blocks = [3, 3, 3] + elif arch == 34 or arch == 50: + num_blocks = [3, 4, 6, 3] + elif arch == 101: + num_blocks = [3, 4, 23, 3] + elif arch == 152: + num_blocks = [3, 8, 36, 3] + else: + raise ValueError(f"Unknown ResNet architecture. Got {arch}.") + return num_blocks diff --git a/torch_uncertainty/models/segmentation/deeplab.py b/torch_uncertainty/models/segmentation/deeplab.py new file mode 100644 index 00000000..7029b4bf --- /dev/null +++ b/torch_uncertainty/models/segmentation/deeplab.py @@ -0,0 +1,419 @@ +from typing import Literal + +import torch +import torchvision.models as tv_models +from torch import Tensor, nn +from torch.nn import functional as F +from torch.nn.common_types import _size_2_t +from torchvision.models.resnet import ResNet50_Weights, ResNet101_Weights + +from torch_uncertainty.models.utils import Backbone, set_bn_momentum + + +class SeparableConv2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: _size_2_t = 0, + dilation: _size_2_t = 1, + bias=True, + ) -> None: + """Separable Convolution with dilation. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (_size_2_t): Kernel size. + stride (_size_2_t, optional): Stride. Defaults to 1. + padding (_size_2_t, optional): Padding. Defaults to 0. + dilation (_size_2_t, optional): Dilation. Defaults to 1. + bias (bool, optional): Use biases. Defaults to True. + """ + super().__init__() + self.separable = nn.Conv2d( + in_channels, + in_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + groups=in_channels, + ) + self.pointwise = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + def forward(self, x: Tensor) -> Tensor: + return self.pointwise(self.separable(x)) + + +class InnerConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dilation: _size_2_t, + separable: bool, + ) -> None: + """Inner convolution block. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + dilation (_size_2_t): Dilation. + separable (bool): Use separable convolutions to reduce the number + of parameters. + """ + super().__init__() + if not separable: + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + padding=dilation, + dilation=dilation, + bias=False, + ) + else: + self.conv = SeparableConv2d( + in_channels, + out_channels, + kernel_size=3, + padding=dilation, + dilation=dilation, + bias=False, + ) + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x: Tensor) -> Tensor: + return F.relu(self.bn(self.conv(x))) + + +class InnerPooling(nn.Module): + def __init__(self, in_channels: int, out_channels: int) -> None: + """Inner pooling block. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + """ + super().__init__() + self.pool = nn.AdaptiveAvgPool2d(1) + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size=1, bias=False + ) + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x: Tensor) -> Tensor: + size = x.shape[-2:] + x = F.relu(self.bn(self.conv(self.pool(x)))) + return F.interpolate(x, size=size, mode="bilinear", align_corners=False) + + +class ASPP(nn.Module): + def __init__( + self, + in_channels: int, + atrous_rates: list[int], + separable: bool, + dropout_rate: float, + ) -> None: + """Atrous Spatial Pyramid Pooling. + + Args: + in_channels (int): Number of input channels. + atrous_rates (list[int]): Atrous rates for the ASPP module. + separable (bool): Use separable convolutions to reduce the number + of parameters. + dropout_rate (float): Dropout rate of the ASPP. + """ + super().__init__() + out_channels = 256 + modules = [] + modules.append( + nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + ) + modules += [ + InnerConv(in_channels, out_channels, dilation, separable) + for dilation in atrous_rates + ] + modules.append(InnerPooling(in_channels, out_channels)) + self.convs = nn.ModuleList(modules) + + self.projection = nn.Sequential( + nn.Conv2d( + 5 * out_channels, out_channels, kernel_size=1, bias=False + ), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Dropout(dropout_rate), + ) + + def forward(self, x: Tensor) -> Tensor: + res = torch.cat([conv(x) for conv in self.convs], dim=1) + return self.projection(res) + + +class DeepLabV3Backbone(Backbone): + def __init__( + self, + backbone_name: Literal["resnet50", "resnet101"], + style: str, + pretrained: bool, + norm_momentum: float, + ) -> None: + """DeepLab V3(+) backbone. + + Args: + backbone_name (str): Backbone name. + style (str): Whether to use a DeepLab V3 or V3+ model. + pretrained (bool): Use pretrained backbone. + norm_momentum (float): BatchNorm momentum. + """ + # TODO: handle dilations + if backbone_name == "resnet50": + base_model = tv_models.resnet50( + weights=ResNet50_Weights.DEFAULT if pretrained else None + ) + elif backbone_name == "resnet101": + base_model = tv_models.resnet101( + weights=ResNet101_Weights.DEFAULT if pretrained else None + ) + else: + raise ValueError(f"Unknown backbone: {backbone_name}.") + base_model.avgpool = nn.Identity() + base_model.fc = nn.Identity() + set_bn_momentum(base_model, norm_momentum) + + feat_names = ["layer1", "layer4"] if style == "v3+" else ["layer4"] + super().__init__(base_model, feat_names) + + +class DeepLabV3Decoder(nn.Module): + """Decoder for the DeepLabV3 model. + + Args: + in_channels (int): Number of channels of the input latent space. + num_classes (int): Number of classes. + aspp_dilate (list[int], optional): Atrous rates for the ASPP module. + separable (bool, optional): Use separable convolutions to reduce the number + of parameters. Defaults to False. + dropout_rate (float, optional): Dropout rate of the ASPP. Defaults to 0.1. + """ + + def __init__( + self, + in_channels: int, + num_classes: int, + aspp_dilate: list[int], + separable: bool = False, + dropout_rate: float = 0.1, + ) -> None: + super().__init__() + self.aspp = ASPP(in_channels, aspp_dilate, separable, dropout_rate) + if not separable: + self.conv = nn.Conv2d(256, 256, 3, padding=1, bias=False) + else: + self.conv = SeparableConv2d(256, 256, 3, padding=1, bias=False) + self.bn = nn.BatchNorm2d(256) + self.classifier = nn.Conv2d(256, num_classes, kernel_size=1) + + def forward(self, features: list[Tensor]) -> Tensor: + out = F.relu(self.bn(self.conv(self.aspp(features[0])))) + return self.classifier(out) + + +class DeepLabV3PlusDecoder(nn.Module): + def __init__( + self, + in_channels: int, + low_level_channels: int, + num_classes: int, + aspp_dilate: list[int], + separable: bool, + dropout_rate: float = 0.1, + ) -> None: + """Decoder for the DeepLabV3+ model. + + Args: + in_channels (int): Number of channels of the input latent space. + low_level_channels (int): Number of low-level features channels. + num_classes (int): Number of classes. + aspp_dilate (list[int]): Atrous rates for the ASPP module. + separable (bool): Use separable convolutions to reduce the number + of parameters. + dropout_rate (float, optional): Dropout rate of the ASPP. Defaults + to 0.1. + """ + super().__init__() + self.project = nn.Sequential( + nn.Conv2d(low_level_channels, 48, kernel_size=1, bias=False), + nn.BatchNorm2d(48), + nn.ReLU(inplace=True), + ) + self.atrous_spatial_pyramid_pool = ASPP( + in_channels, aspp_dilate, separable, dropout_rate + ) + if separable: + self.conv = SeparableConv2d(304, 256, 3, padding=1, bias=False) + else: + self.conv = nn.Conv2d(304, 256, 3, padding=1, bias=False) + self.bn = nn.BatchNorm2d(256) + self.classifier = nn.Conv2d(256, num_classes, kernel_size=1) + + def forward(self, features: list[Tensor]) -> Tensor: + low_level_features = self.project(features[0]) + output_features = self.atrous_spatial_pyramid_pool(features[1]) + output_features = F.interpolate( + output_features, + size=low_level_features.shape[2:], + mode="bilinear", + align_corners=False, + ) + output_features = torch.cat( + [low_level_features, output_features], dim=1 + ) + out = F.relu(self.bn(self.conv(output_features))) + return self.classifier(out) + + +class _DeepLabV3(nn.Module): + def __init__( + self, + num_classes: int, + backbone_name: str, + style: Literal["v3", "v3+"], + output_stride: int = 16, + separable: bool = False, + pretrained_backbone: bool = True, + norm_momentum: float = 0.01, + ) -> None: + """DeepLab V3(+) model. + + Args: + num_classes (int): Number of classes. + backbone_name (str): Backbone name. + style (Literal["v3", "v3+"]): Whether to use a DeepLab V3 or + V3+ model. + output_stride (int, optional): Output stride. Defaults to 16. + separable (bool, optional): Use separable convolutions. Defaults + to False. + pretrained_backbone (bool, optional): Use pretrained backbone. + Defaults to True. + norm_momentum (float, optional): BatchNorm momentum. Defaults to + 0.01. + + References: + - Rethinking atrous convolution for semantic image segmentation. + Chen, L. C., Papandreou, G., Schroff, F., & Adam, H. (2018). + - Encoder-decoder with atrous separable convolution for semantic image segmentation. + Chen, L. C., Zhu, Y., Papandreou, G., Schroff, F., & Adam, H. In ECCV 2018. + """ + super().__init__() + if output_stride == 16: + dilations = [6, 12, 18] + elif output_stride == 8: + dilations = [12, 24, 36] + else: + raise ValueError( + f"output_stride: {output_stride} is not supported." + ) + + self.backbone = DeepLabV3Backbone( + backbone_name, style, pretrained_backbone, norm_momentum + ) + if style == "v3": + self.decoder = DeepLabV3Decoder( + in_channels=2048, + num_classes=num_classes, + aspp_dilate=dilations, + separable=separable, + dropout_rate=0.1, + ) + elif style == "v3+": + self.decoder = DeepLabV3PlusDecoder( + in_channels=2048, + low_level_channels=256, + num_classes=num_classes, + aspp_dilate=dilations, + separable=separable, + dropout_rate=0.1, + ) + else: + raise ValueError(f"Unknown style: {style}.") + + def forward(self, x: Tensor) -> Tensor: + input_shape = x.shape[-2:] + return F.interpolate( + self.decoder(self.backbone(x)), + size=input_shape, + mode="bilinear", + align_corners=False, + ) + + +def deep_lab_v3_resnet50( + num_classes: int, + style: Literal["v3", "v3+"], + output_stride: int = 16, + separable: bool = False, + pretrained_backbone: bool = True, +) -> _DeepLabV3: + """DeepLab V3(+) model with ResNet-50 backbone. + + Args: + num_classes (int): Number of classes. + style (Literal["v3", "v3+"]): Whether to use a DeepLab V3 or V3+ model. + output_stride (int, optional): Output stride. Defaults to 16. + separable (bool, optional): Use separable convolutions. Defaults to + False. + pretrained_backbone (bool, optional): Use pretrained backbone. Defaults + to True. + """ + return _DeepLabV3( + num_classes, + "resnet50", + style, + output_stride=output_stride, + separable=separable, + pretrained_backbone=pretrained_backbone, + ) + + +def deep_lab_v3_resnet101( + num_classes: int, + style: Literal["v3", "v3+"], + output_stride: int = 16, + separable: bool = False, + pretrained_backbone: bool = True, +) -> _DeepLabV3: + """DeepLab V3(+) model with ResNet-50 backbone. + + Args: + num_classes (int): Number of classes. + style (Literal["v3", "v3+"]): Whether to use a DeepLab V3 or V3+ model. + output_stride (int, optional): Output stride. Defaults to 16. + separable (bool, optional): Use separable convolutions. Defaults to False. + pretrained_backbone (bool, optional): Use pretrained backbone. Defaults to True. + """ + return _DeepLabV3( + num_classes, + "resnet101", + style, + output_stride=output_stride, + separable=separable, + pretrained_backbone=pretrained_backbone, + ) diff --git a/torch_uncertainty/models/segmentation/segformer/std.py b/torch_uncertainty/models/segmentation/segformer.py similarity index 66% rename from torch_uncertainty/models/segmentation/segformer/std.py rename to torch_uncertainty/models/segmentation/segformer.py index 3881e055..763aea71 100644 --- a/torch_uncertainty/models/segmentation/segformer/std.py +++ b/torch_uncertainty/models/segmentation/segformer.py @@ -1,39 +1,31 @@ -# --------------------------------------------------------------- -# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. -# -# This work is licensed under the NVIDIA Source Code License -# --------------------------------------------------------------- - import math -import warnings from functools import partial import torch -import torch.nn as nn import torch.nn.functional as F from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from torch import Tensor, nn class DWConv(nn.Module): - def __init__(self, dim=768): + def __init__(self, dim: int = 768) -> None: super().__init__() self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) - def forward(self, x, h, w): - b, _, c = x.shape - x = x.transpose(1, 2).view(b, c, h, w) - x = self.dwconv(x) - return x.flatten(2).transpose(1, 2) + def forward(self, inputs: Tensor, h: int, w: int) -> Tensor: + b, _, c = inputs.shape + inputs = self.dwconv(inputs.transpose(1, 2).view(b, c, h, w)) + return inputs.flatten(2).transpose(1, 2) -class Mlp(nn.Module): +class MLP(nn.Module): def __init__( self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - drop=0.0, + in_features: int, + hidden_features: int | None = None, + out_features: int | None = None, + act_layer: type[nn.Module] = nn.GELU, + dropout_rate: float = 0.0, ): super().__init__() out_features = out_features or in_features @@ -42,7 +34,7 @@ def __init__( self.dwconv = DWConv(hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) + self.dropout = nn.Dropout(dropout_rate) self.apply(self._init_weights) @@ -65,9 +57,9 @@ def forward(self, x, h, w): x = self.fc1(x) x = self.dwconv(x, h, w) x = self.act(x) - x = self.drop(x) + x = self.dropout(x) x = self.fc2(x) - return self.drop(x) + return self.dropout(x) class Attention(nn.Module): @@ -161,7 +153,7 @@ def __init__( mlp_ratio=4.0, qkv_bias=False, qk_scale=None, - drop=0.0, + dropout=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, @@ -176,7 +168,7 @@ def __init__( qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, - proj_drop=drop, + proj_drop=dropout, sr_ratio=sr_ratio, ) # NOTE: drop path for stochastic depth, we shall see if this is better @@ -186,11 +178,11 @@ def __init__( ) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp( + self.mlp = MLP( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, - drop=drop, + dropout_rate=dropout, ) self.apply(self._init_weights) @@ -270,32 +262,21 @@ def forward(self, x): class MixVisionTransformer(nn.Module): def __init__( self, - img_size=224, - patch_size=16, - in_chans=3, - num_classes=1000, - embed_dims=None, - num_heads=None, - mlp_ratios=None, - qkv_bias=False, - qk_scale=None, - drop_rate=0.0, - attn_drop_rate=0.0, - drop_path_rate=0.0, - norm_layer=nn.LayerNorm, - depths=None, - sr_ratios=None, + img_size, + in_channels, + num_classes, + embed_dims, + num_heads, + mlp_ratios, + qkv_bias, + qk_scale, + drop_rate, + attn_drop_rate, + drop_path_rate, + norm_layer, + depths, + sr_ratios, ): - if sr_ratios is None: - sr_ratios = [8, 4, 2, 1] - if depths is None: - depths = [3, 4, 6, 3] - if mlp_ratios is None: - mlp_ratios = [4, 4, 4, 4] - if num_heads is None: - num_heads = [1, 2, 4, 8] - if embed_dims is None: - embed_dims = [64, 128, 256, 512] super().__init__() self.num_classes = num_classes self.depths = depths @@ -305,7 +286,7 @@ def __init__( img_size=img_size, patch_size=7, stride=4, - in_chans=in_chans, + in_chans=in_channels, embed_dim=embed_dims[0], ) self.patch_embed2 = OverlapPatchEmbed( @@ -343,7 +324,7 @@ def __init__( mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, + dropout=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, @@ -363,7 +344,7 @@ def __init__( mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, + dropout=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, @@ -383,7 +364,7 @@ def __init__( mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, + dropout=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, @@ -403,7 +384,7 @@ def __init__( mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, + dropout=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, @@ -437,7 +418,7 @@ def forward_features(self, x): # stage 1 x, h, w = self.patch_embed1(x) - for _i, blk in enumerate(self.block1): + for blk in self.block1: x = blk(x, h, w) x = self.norm1(x) x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() @@ -445,7 +426,7 @@ def forward_features(self, x): # stage 2 x, h, w = self.patch_embed2(x) - for _i, blk in enumerate(self.block2): + for blk in self.block2: x = blk(x, h, w) x = self.norm2(x) x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() @@ -453,7 +434,7 @@ def forward_features(self, x): # stage 3 x, h, w = self.patch_embed3(x) - for _i, blk in enumerate(self.block3): + for blk in self.block3: x = blk(x, h, w) x = self.norm3(x) x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() @@ -461,7 +442,7 @@ def forward_features(self, x): # stage 4 x, h, w = self.patch_embed4(x) - for _i, blk in enumerate(self.block4): + for blk in self.block4: x = blk(x, h, w) x = self.norm4(x) x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous() @@ -473,123 +454,67 @@ def forward(self, x): return self.forward_features(x) -class MitB0(MixVisionTransformer): - def __init__(self): - super().__init__( - patch_size=4, - embed_dims=[32, 64, 160, 256], - num_heads=[1, 2, 5, 8], - mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - depths=[2, 2, 2, 2], - sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, - drop_path_rate=0.1, - ) - - -class MitB1(MixVisionTransformer): - def __init__(self): - super().__init__( - patch_size=4, - embed_dims=[64, 128, 320, 512], - num_heads=[1, 2, 5, 8], - mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - depths=[2, 2, 2, 2], - sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, - drop_path_rate=0.1, - ) - - -class MitB2(MixVisionTransformer): - def __init__(self): - super().__init__( - patch_size=4, - embed_dims=[64, 128, 320, 512], - num_heads=[1, 2, 5, 8], - mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - depths=[3, 4, 6, 3], - sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, - drop_path_rate=0.1, - ) +def _get_embed_dims(arch: int) -> list[int]: + if arch == 0: + return [32, 64, 160, 256] + return [64, 128, 320, 512] -class MitB3(MixVisionTransformer): - def __init__(self): - super().__init__( - patch_size=4, - embed_dims=[64, 128, 320, 512], - num_heads=[1, 2, 5, 8], - mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - depths=[3, 4, 18, 3], - sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, - drop_path_rate=0.1, - ) - - -class MitB4(MixVisionTransformer): - def __init__(self): - super().__init__( - patch_size=4, - embed_dims=[64, 128, 320, 512], - num_heads=[1, 2, 5, 8], - mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - depths=[3, 8, 27, 3], - sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, - drop_path_rate=0.1, - ) +def _get_depths(arch: int) -> list[int]: + if arch == 0 or arch == 1: + return [2, 2, 2, 2] + if arch == 2: + return [3, 4, 6, 3] + if arch == 3: + return [3, 4, 18, 3] + if arch == 4: + return [3, 8, 27, 3] + # arch == 5: + return [3, 6, 40, 3] -class MitB5(MixVisionTransformer): - def __init__(self): +class Mit(MixVisionTransformer): + def __init__(self, arch: int): + embed_dims = _get_embed_dims(arch) + depths = _get_depths(arch) super().__init__( - patch_size=4, - embed_dims=[64, 128, 320, 512], + img_size=224, + in_channels=3, + num_classes=1000, + qk_scale=None, + embed_dims=embed_dims, num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), - depths=[3, 6, 40, 3], + depths=depths, sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1, + attn_drop_rate=0.0, ) class MLPHead(nn.Module): - """Linear Embedding.""" + """Linear Embedding with transposition.""" - def __init__(self, input_dim=2048, embed_dim=768): + def __init__(self, input_dim: int = 2048, embed_dim: int = 768) -> None: super().__init__() self.proj = nn.Linear(input_dim, embed_dim) - def forward(self, x): - x = x.flatten(2).transpose(1, 2) - return self.proj(x) + def forward(self, inputs: Tensor) -> Tensor: + return self.proj(inputs.flatten(2).transpose(1, 2)) def resize( - inputs, - size=None, + inputs: Tensor, + size: tuple[int] | torch.Size | None = None, scale_factor=None, - mode="nearest", - align_corners=None, - warning=True, -): - if warning and size is not None and align_corners: + mode: str = "nearest", + align_corners: bool | None = None, + warning: bool = True, +) -> Tensor: + if warning and size is not None and align_corners: # coverage: ignore input_h, input_w = tuple(int(x) for x in inputs.shape[2:]) output_h, output_w = tuple(int(x) for x in size) if (output_h > input_h or output_w > output_h) and ( @@ -597,12 +522,11 @@ def resize( and (output_h - 1) % (input_h - 1) and (output_w - 1) % (input_w - 1) ): - warnings.warn( + print( f"When align_corners={align_corners}, " "the output would more aligned if " f"input size {(input_h, input_w)} is `x+1` and " f"out size {(output_h, output_w)} is `nx+1`", - stacklevel=2, ) if isinstance(size, torch.Size): size = tuple(int(x) for x in size) @@ -610,66 +534,40 @@ def resize( class SegFormerHead(nn.Module): - """SegFormer: Simple and Efficient Design for Semantic Segmentation with - Transformers. + """Head for SegFormer. + + Reference: + SegFormer: Simple and Efficient Design for Semantic Segmentation with + Transformers. """ def __init__( self, - in_channels, - feature_strides, - decoder_params, - num_classes, - dropout_ratio=0.1, + in_channels: list[int], + feature_strides: list[int], + embed_dim: int, + num_classes: int, + dropout_ratio: float = 0.1, ): super().__init__() self.in_channels = in_channels - assert len(feature_strides) == len(self.in_channels) + assert len(feature_strides) == len(in_channels) assert min(feature_strides) == feature_strides[0] - self.feature_strides = feature_strides - self.num_classes = num_classes - # --- self in_index [0, 1, 2, 3] - - ( - c1_in_channels, - c2_in_channels, - c3_in_channels, - c4_in_channels, - ) = self.in_channels - embedding_dim = decoder_params["embed_dim"] - - self.linear_c4 = MLPHead( - input_dim=c4_in_channels, embed_dim=embedding_dim - ) - self.linear_c3 = MLPHead( - input_dim=c3_in_channels, embed_dim=embedding_dim - ) - self.linear_c2 = MLPHead( - input_dim=c2_in_channels, embed_dim=embedding_dim - ) - self.linear_c1 = MLPHead( - input_dim=c1_in_channels, embed_dim=embedding_dim - ) + self.linear_c4 = MLPHead(input_dim=in_channels[3], embed_dim=embed_dim) + self.linear_c3 = MLPHead(input_dim=in_channels[2], embed_dim=embed_dim) + self.linear_c2 = MLPHead(input_dim=in_channels[1], embed_dim=embed_dim) + self.linear_c1 = MLPHead(input_dim=in_channels[0], embed_dim=embed_dim) self.fuse = nn.Sequential( - nn.Conv2d( - embedding_dim * 4, embedding_dim, kernel_size=1, bias=False - ), + nn.Conv2d(embed_dim * 4, embed_dim, kernel_size=1, bias=False), nn.ReLU(), - nn.BatchNorm2d(embedding_dim), + nn.BatchNorm2d(embed_dim), ) + self.classifier = nn.Conv2d(embed_dim, num_classes, kernel_size=1) + self.dropout = nn.Dropout2d(dropout_ratio) - self.linear_pred = nn.Conv2d( - embedding_dim, self.num_classes, kernel_size=1 - ) - - if dropout_ratio > 0: - self.dropout = nn.Dropout2d(dropout_ratio) - else: - self.dropout = None - - def forward(self, inputs): + def forward(self, inputs: Tensor) -> Tensor: # x [inputs[i] for i in self.in_index] # len=4, 1/4,1/8,1/16,1/32 c1, c2, c3, c4 = inputs[0], inputs[1], inputs[2], inputs[3] @@ -711,96 +609,42 @@ def forward(self, inputs): _c = self.fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) x = self.dropout(_c) - return self.linear_pred(x) + return self.classifier(x) class _SegFormer(nn.Module): def __init__( self, - in_channels, - feature_strides, - decoder_params, - num_classes, - dropout_ratio, + in_channels: list[int], + feature_strides: list[int], + embed_dim: int, + num_classes: int, + dropout_ratio: float, mit: nn.Module, ): super().__init__() - self.encoder = mit() + self.encoder = mit self.head = SegFormerHead( in_channels, feature_strides, - decoder_params, + embed_dim, num_classes, dropout_ratio, ) - def forward(self, x): - features = self.encoder(x) + def forward(self, inputs: Tensor) -> Tensor: + features = self.encoder(inputs) return self.head(features) -def seg_former_b0(num_classes: int): - return _SegFormer( - in_channels=[32, 64, 160, 256], - feature_strides=[4, 8, 16, 32], - decoder_params={"embed_dim": 256}, - num_classes=num_classes, - dropout_ratio=0.1, - mit=MitB0, - ) - - -def seg_former_b1(num_classes: int): - return _SegFormer( - in_channels=[64, 128, 320, 512], - feature_strides=[4, 8, 16, 32], - decoder_params={"embed_dim": 512}, - num_classes=num_classes, - dropout_ratio=0.1, - mit=MitB1, - ) - - -def seg_former_b2(num_classes: int): - return _SegFormer( - in_channels=[64, 128, 320, 512], - feature_strides=[4, 8, 16, 32], - decoder_params={"embed_dim": 512}, - num_classes=num_classes, - dropout_ratio=0.1, - mit=MitB2, - ) - - -def seg_former_b3(num_classes: int): - return _SegFormer( - in_channels=[64, 128, 320, 512], - feature_strides=[4, 8, 16, 32], - decoder_params={"embed_dim": 512}, - num_classes=num_classes, - dropout_ratio=0.1, - mit=MitB3, - ) - - -def seg_former_b4(num_classes: int): - return _SegFormer( - in_channels=[64, 128, 320, 512], - feature_strides=[4, 8, 16, 32], - decoder_params={"embed_dim": 512}, - num_classes=num_classes, - dropout_ratio=0.1, - mit=MitB4, - ) - - -def seg_former_b5(num_classes: int): +def seg_former(num_classes: int, arch: int) -> _SegFormer: + in_channels = _get_embed_dims(arch) return _SegFormer( - in_channels=[64, 128, 320, 512], + in_channels=in_channels, feature_strides=[4, 8, 16, 32], - decoder_params={"embed_dim": 512}, + embed_dim=256 if arch == 0 else 512, num_classes=num_classes, dropout_ratio=0.1, - mit=MitB5, + mit=Mit(arch), ) diff --git a/torch_uncertainty/models/segmentation/segformer/__init__.py b/torch_uncertainty/models/segmentation/segformer/__init__.py deleted file mode 100644 index dc3fb2ee..00000000 --- a/torch_uncertainty/models/segmentation/segformer/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# ruff: noqa: F401, F403 -from .std import * diff --git a/torch_uncertainty/models/utils.py b/torch_uncertainty/models/utils.py index 58851399..cb0bcc02 100644 --- a/torch_uncertainty/models/utils.py +++ b/torch_uncertainty/models/utils.py @@ -1,4 +1,4 @@ -from torch import nn +from torch import Tensor, nn from torch_uncertainty.layers.bayesian import bayesian_modules @@ -56,3 +56,42 @@ def unfreeze(self) -> None: model.unfreeze = unfreeze return model + + +class Backbone(nn.Module): + def __init__(self, model: nn.Module, feat_names: list[str]) -> None: + """Encoder backbone. + + Return the skip features of the :attr:`model` corresponding to the + :attr:`feat_names`. + + Args: + model (nn.Module): Base model. + feat_names (list[str]): List of the feature names. + """ + super().__init__() + self.model = model + self.feat_names = feat_names + + def forward(self, x: Tensor) -> list[Tensor]: + """Encoder forward pass. + + Args: + x (Tensor): Input tensor. + + Returns: + list[Tensor]: List of the features. + """ + feature = x + features = [] + for k, v in self.model._modules.items(): + feature = v(feature) + if k in self.feat_names: + features.append(feature) + return features + + +def set_bn_momentum(model: nn.Module, momentum: float) -> None: + for m in model.modules(): + if isinstance(m, nn.BatchNorm2d): + m.momentum = momentum diff --git a/torch_uncertainty/models/vgg/base.py b/torch_uncertainty/models/vgg/base.py index 2bb45e5c..7a633b5b 100644 --- a/torch_uncertainty/models/vgg/base.py +++ b/torch_uncertainty/models/vgg/base.py @@ -130,7 +130,7 @@ def forward(self, x: Tensor) -> Tensor: def _vgg( - vgg_cfg: dict[str, list[str | int]], + vgg_cfg: list[str | int], in_channels: int, num_classes: int, linear_layer: type[nn.Module] = nn.Linear, diff --git a/torch_uncertainty/models/vgg/packed.py b/torch_uncertainty/models/vgg/packed.py index 2e76d0d0..f4466e31 100644 --- a/torch_uncertainty/models/vgg/packed.py +++ b/torch_uncertainty/models/vgg/packed.py @@ -6,16 +6,14 @@ from .configs import cfgs __all__ = [ - "packed_vgg11", - "packed_vgg13", - "packed_vgg16", - "packed_vgg19", + "packed_vgg", ] -def packed_vgg11( +def packed_vgg( in_channels: int, num_classes: int, + arch: int, alpha: int, num_estimators: int, gamma: int, @@ -24,89 +22,18 @@ def packed_vgg11( dropout_rate: float = 0.5, style: str = "imagenet", ) -> VGG: + if arch == 11: + config = cfgs["A"] + elif arch == 13: # coverage: ignore + config = cfgs["B"] + elif arch == 16: # coverage: ignore + config = cfgs["D"] + elif arch == 19: # coverage: ignore + config = cfgs["E"] + else: + raise ValueError(f"Unknown VGG arch {arch}.") return _vgg( - cfgs["A"], - in_channels=in_channels, - num_classes=num_classes, - linear_layer=PackedLinear, - conv2d_layer=PackedConv2d, - norm=norm, - groups=groups, - dropout_rate=dropout_rate, - style=style, - alpha=alpha, - num_estimators=num_estimators, - gamma=gamma, - ) - - -def packed_vgg13( - in_channels: int, - num_classes: int, - alpha: int, - num_estimators: int, - gamma: int, - norm: type[nn.Module] = nn.Identity, - groups: int = 1, - dropout_rate: float = 0.5, - style: str = "imagenet", -) -> VGG: # coverage: ignore - return _vgg( - cfgs["B"], - in_channels=in_channels, - num_classes=num_classes, - linear_layer=PackedLinear, - conv2d_layer=PackedConv2d, - norm=norm, - groups=groups, - dropout_rate=dropout_rate, - style=style, - alpha=alpha, - num_estimators=num_estimators, - gamma=gamma, - ) - - -def packed_vgg16( - in_channels: int, - num_classes: int, - alpha: int, - num_estimators: int, - gamma: int, - norm: type[nn.Module] = nn.Identity, - groups: int = 1, - dropout_rate: float = 0.5, - style: str = "imagenet", -) -> VGG: # coverage: ignore - return _vgg( - cfgs["D"], - in_channels=in_channels, - num_classes=num_classes, - linear_layer=PackedLinear, - conv2d_layer=PackedConv2d, - norm=norm, - groups=groups, - dropout_rate=dropout_rate, - style=style, - alpha=alpha, - num_estimators=num_estimators, - gamma=gamma, - ) - - -def packed_vgg19( - in_channels: int, - num_classes: int, - alpha: int, - num_estimators: int, - gamma: int, - norm: type[nn.Module] = nn.Identity, - groups: int = 1, - dropout_rate: float = 0.5, - style: str = "imagenet", -) -> VGG: # coverage: ignore - return _vgg( - cfgs["E"], + vgg_cfg=config, in_channels=in_channels, num_classes=num_classes, linear_layer=PackedLinear, diff --git a/torch_uncertainty/models/vgg/std.py b/torch_uncertainty/models/vgg/std.py index 635e29b7..c41a3b90 100644 --- a/torch_uncertainty/models/vgg/std.py +++ b/torch_uncertainty/models/vgg/std.py @@ -3,83 +3,31 @@ from .base import VGG, _vgg from .configs import cfgs -__all__ = ["vgg11", "vgg13", "vgg16", "vgg19"] +__all__ = ["vgg"] -def vgg11( +def vgg( in_channels: int, num_classes: int, + arch: int, norm: type[nn.Module] = nn.Identity, groups: int = 1, dropout_rate: float = 0.5, style: str = "imagenet", num_estimators: int | None = None, ) -> VGG: + if arch == 11: + config = cfgs["A"] + elif arch == 13: # coverage: ignore + config = cfgs["B"] + elif arch == 16: # coverage: ignore + config = cfgs["D"] + elif arch == 19: # coverage: ignore + config = cfgs["E"] + else: + raise ValueError(f"Unknown VGG arch {arch}.") return _vgg( - cfgs["A"], - in_channels=in_channels, - num_classes=num_classes, - norm=norm, - groups=groups, - dropout_rate=dropout_rate, - style=style, - num_estimators=num_estimators, - ) - - -def vgg13( - in_channels: int, - num_classes: int, - norm: type[nn.Module] = nn.Identity, - groups: int = 1, - dropout_rate: float = 0.5, - style: str = "imagenet", - num_estimators: int | None = None, -) -> VGG: # coverage: ignore - return _vgg( - cfgs["B"], - in_channels=in_channels, - num_classes=num_classes, - norm=norm, - groups=groups, - dropout_rate=dropout_rate, - style=style, - num_estimators=num_estimators, - ) - - -def vgg16( - in_channels: int, - num_classes: int, - norm: type[nn.Module] = nn.Identity, - groups: int = 1, - dropout_rate: float = 0.5, - style: str = "imagenet", - num_estimators: int | None = None, -) -> VGG: # coverage: ignore - return _vgg( - cfgs["D"], - in_channels=in_channels, - num_classes=num_classes, - norm=norm, - groups=groups, - dropout_rate=dropout_rate, - style=style, - num_estimators=num_estimators, - ) - - -def vgg19( - in_channels: int, - num_classes: int, - norm: type[nn.Module] = nn.Identity, - groups: int = 1, - dropout_rate: float = 0.5, - style: str = "imagenet", - num_estimators: int | None = None, -) -> VGG: # coverage: ignore - return _vgg( - cfgs["E"], + vgg_cfg=config, in_channels=in_channels, num_classes=num_classes, norm=norm, diff --git a/torch_uncertainty/models/wideresnet/std.py b/torch_uncertainty/models/wideresnet/std.py index 3e14b2c8..bd3d6a76 100644 --- a/torch_uncertainty/models/wideresnet/std.py +++ b/torch_uncertainty/models/wideresnet/std.py @@ -84,7 +84,7 @@ def __init__( self.dropout_rate = dropout_rate if (depth - 4) % 6 != 0: - raise ValueError("Wide-resnet depth should be 6n+4.") + raise ValueError(f"Wide-resnet depth should be 6n+4. Got {depth}.") num_blocks = int((depth - 4) / 6) k = widen_factor @@ -189,7 +189,7 @@ def feats_forward(self, x: Tensor) -> Tensor: out = self.layer2(out) out = self.layer3(out) out = self.pool(out) - return self.dropout(self.flatten(out)) + return self.flatten(out) def forward(self, x: Tensor) -> Tensor: return self.linear(self.feats_forward(x)) diff --git a/torch_uncertainty/post_processing/calibration/scaler.py b/torch_uncertainty/post_processing/calibration/scaler.py index f4141214..d87730b9 100644 --- a/torch_uncertainty/post_processing/calibration/scaler.py +++ b/torch_uncertainty/post_processing/calibration/scaler.py @@ -68,8 +68,7 @@ def fit( ) with torch.no_grad(): for inputs, labels in tqdm(calibration_dl, disable=not progress): - inputs = inputs.to(self.device) - logits = self.model(inputs) + logits = self.model(inputs.to(self.device)) logits_list.append(logits) labels_list.append(labels) all_logits = torch.cat(logits_list).detach().to(self.device) diff --git a/torch_uncertainty/routines/__init__.py b/torch_uncertainty/routines/__init__.py index 41b7ea80..4c44c49a 100644 --- a/torch_uncertainty/routines/__init__.py +++ b/torch_uncertainty/routines/__init__.py @@ -1,4 +1,5 @@ # ruff: noqa: F401 from .classification import ClassificationRoutine +from .pixel_regression import PixelRegressionRoutine from .regression import RegressionRoutine from .segmentation import SegmentationRoutine diff --git a/torch_uncertainty/routines/classification.py b/torch_uncertainty/routines/classification.py index 29019873..0c07b8ab 100644 --- a/torch_uncertainty/routines/classification.py +++ b/torch_uncertainty/routines/classification.py @@ -20,14 +20,17 @@ from torch_uncertainty.layers import Identity from torch_uncertainty.losses import DECLoss, ELBOLoss from torch_uncertainty.metrics import ( - CE, + AURC, FPR95, BrierScore, + CalibrationError, CategoricalNLL, + CovAt5Risk, Disagreement, Entropy, GroupingLoss, MutualInformation, + RiskAt80Cov, VariationRatio, ) from torch_uncertainty.post_processing import TemperatureScaler @@ -59,9 +62,9 @@ def __init__( log_plots: bool = False, save_in_csv: bool = False, calibration_set: Literal["val", "test"] | None = None, + num_calibration_bins: int = 15, ) -> None: - r"""Routine for efficient training and testing on **classification tasks** - using LightningModule. + r"""Routine for training & testing on **classification tasks**. Args: model (torch.nn.Module): Model to train. @@ -100,10 +103,12 @@ def __init__( metrics. Defaults to ``False``. save_in_csv(bool, optional): Save the results in csv. Defaults to ``False``. - calibration_set (str, optional): The calibration dataset to use for - scaling. If not ``None``, it uses either the validation set when - set to ``"val"`` or the test set when set to ``"test"``. - Defaults to ``None``. + calibration_set (str, optional): The post-hoc calibration dataset to + use for scaling. If not ``None``, it uses either the validation + set when set to ``"val"`` or the test set when set to ``"test"``. + Defaults to ``None``. Else, no post-hoc calibration. + num_calibration_bins (int, optional): Number of bins to compute calibration + metrics. Defaults to ``15``. Warning: You must define :attr:`optim_recipe` if you do not use the CLI. @@ -120,6 +125,7 @@ def __init__( num_estimators=num_estimators, ood_criterion=ood_criterion, eval_grouping_loss=eval_grouping_loss, + num_calibration_bins=num_calibration_bins, ) if format_batch_fn is None: @@ -141,35 +147,44 @@ def __init__( self.optim_recipe = optim_recipe # metrics - if self.binary_cls: - cls_metrics = MetricCollection( - { - "Acc": Accuracy(task="binary"), - "ECE": CE(task="binary"), - "Brier": BrierScore(num_classes=1), - }, - compute_groups=False, - ) - else: - cls_metrics = MetricCollection( - { - "NLL": CategoricalNLL(), - "Acc": Accuracy( - task="multiclass", num_classes=self.num_classes - ), - "ECE": CE(task="multiclass", num_classes=self.num_classes), - "Brier": BrierScore(num_classes=self.num_classes), - }, - compute_groups=False, - ) + task = "binary" if self.binary_cls else "multiclass" + + cls_metrics = MetricCollection( + { + "cls/Acc": Accuracy(task=task, num_classes=num_classes), + "cls/Brier": BrierScore(num_classes=num_classes), + "cls/NLL": CategoricalNLL(), + "cal/ECE": CalibrationError( + task=task, + num_bins=num_calibration_bins, + num_classes=num_classes, + ), + "cal/aECE": CalibrationError( + task=task, + adaptive=True, + num_bins=num_calibration_bins, + num_classes=num_classes, + ), + "sc/AURC": AURC(), + "sc/CovAt5Risk": CovAt5Risk(), + "sc/RiskAt80Cov": RiskAt80Cov(), + }, + compute_groups=[ + ["cls/Acc"], + ["cls/Brier"], + ["cls/NLL"], + ["cal/ECE", "cal/aECE"], + ["sc/AURC", "sc/CovAt5Risk", "sc/RiskAt80Cov"], + ], + ) - self.val_cls_metrics = cls_metrics.clone(prefix="cls_val/") - self.test_cls_metrics = cls_metrics.clone(prefix="cls_test/") + self.val_cls_metrics = cls_metrics.clone(prefix="val/") + self.test_cls_metrics = cls_metrics.clone(prefix="test/") if self.calibration_set is not None: - self.ts_cls_metrics = cls_metrics.clone(prefix="cls_test/ts_") + self.ts_cls_metrics = cls_metrics.clone(prefix="test/ts_") - self.test_entropy_id = Entropy() + self.test_id_entropy = Entropy() if self.eval_ood: ood_metrics = MetricCollection( @@ -181,8 +196,24 @@ def __init__( compute_groups=[["AUROC", "AUPR"], ["FPR95"]], ) self.test_ood_metrics = ood_metrics.clone(prefix="ood/") - self.test_entropy_ood = Entropy() + self.test_ood_entropy = Entropy() + # metrics for ensembles only + if self.num_estimators > 1: + ens_metrics = MetricCollection( + { + "Disagreement": Disagreement(), + "MI": MutualInformation(), + "Entropy": Entropy(), + } + ) + + self.test_id_ens_metrics = ens_metrics.clone(prefix="test/ens_") + + if self.eval_ood: + self.test_ood_ens_metrics = ens_metrics.clone(prefix="ood/ens_") + + # Mixup self.mixtype = mixtype self.mixmode = mixmode self.dist_sim = dist_sim @@ -199,33 +230,16 @@ def __init__( if self.eval_grouping_loss: grouping_loss = MetricCollection( - {"grouping_loss": GroupingLoss()} - ) - self.val_grouping_loss = grouping_loss.clone(prefix="gpl/val_") - self.test_grouping_loss = grouping_loss.clone( - prefix="gpl/test_" + {"cls/grouping_loss": GroupingLoss()} ) + self.val_grouping_loss = grouping_loss.clone(prefix="val/") + self.test_grouping_loss = grouping_loss.clone(prefix="test/") self.is_elbo = isinstance(self.loss, ELBOLoss) if self.is_elbo: self.loss.set_model(self.model) self.is_dec = isinstance(self.loss, DECLoss) - # metrics for ensembles only - if self.num_estimators > 1: - ens_metrics = MetricCollection( - { - "Disagreement": Disagreement(), - "MI": MutualInformation(), - "Entropy": Entropy(), - } - ) - - self.test_id_ens_metrics = ens_metrics.clone(prefix="cls_test/ens_") - - if self.eval_ood: - self.test_ood_ens_metrics = ens_metrics.clone(prefix="ood/ens_") - self.id_logit_storage = None self.ood_logit_storage = None @@ -276,13 +290,9 @@ def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe def on_train_start(self) -> None: - init_metrics = dict.fromkeys(self.val_cls_metrics, 0) - init_metrics.update(dict.fromkeys(self.test_cls_metrics, 0)) - if self.logger is not None: # coverage: ignore self.logger.log_hyperparams( self.hparams, - init_metrics, ) def on_test_start(self) -> None: @@ -290,15 +300,15 @@ def on_test_start(self) -> None: "val", "test", ]: - dataset = ( + calibration_dataset = ( self.trainer.datamodule.val_dataloader().dataset if self.calibration_set == "val" - else self.trainer.datamodule.test_dataloader().dataset + else self.trainer.datamodule.test_dataloader()[0].dataset ) with torch.inference_mode(False): self.cal_model = TemperatureScaler( model=self.model, device=self.device - ).fit(calibration_set=dataset) + ).fit(calibration_dataset) else: self.cal_model = None @@ -344,21 +354,21 @@ def training_step( else: batch = self.mixup(*batch) - inputs, targets = self.format_batch_fn(batch) + inputs, target = self.format_batch_fn(batch) if self.is_elbo: - loss = self.loss(inputs, targets) + loss = self.loss(inputs, target) else: logits = self.forward(inputs) - # BCEWithLogitsLoss expects float targets + # BCEWithLogitsLoss expects float target if self.binary_cls and isinstance(self.loss, nn.BCEWithLogitsLoss): logits = logits.squeeze(-1) - targets = targets.float() + target = target.float() if not self.is_dec: - loss = self.loss(logits, targets) + loss = self.loss(logits, target) else: - loss = self.loss(logits, targets, self.current_epoch) + loss = self.loss(logits, target, self.current_epoch) self.log("train_loss", loss) return loss @@ -366,7 +376,7 @@ def training_step( def validation_step( self, batch: tuple[Tensor, Tensor], batch_idx: int ) -> None: - inputs, targets = batch + inputs, target = batch logits = self.forward( inputs, save_feats=self.eval_grouping_loss ) # (m*b, c) @@ -378,10 +388,10 @@ def validation_step( probs_per_est = F.softmax(logits, dim=-1) probs = probs_per_est.mean(dim=1) - self.val_cls_metrics.update(probs, targets) + self.val_cls_metrics.update(probs, target) if self.eval_grouping_loss: - self.val_grouping_loss.update(probs, targets, self.features) + self.val_grouping_loss.update(probs, target, self.features) def test_step( self, @@ -389,7 +399,7 @@ def test_step( batch_idx: int, dataloader_idx: int = 0, ) -> None: - inputs, targets = batch + inputs, target = batch logits = self.forward( inputs, save_feats=self.eval_grouping_loss ) # (m*b, c) @@ -429,31 +439,27 @@ def test_step( ood_scores = -confs # Scaling for single models - if ( - self.num_estimators == 1 - and self.calibration_set is not None - and self.cal_model is not None - ): + if self.num_estimators == 1 and self.cal_model is not None: cal_logits = self.cal_model(inputs) cal_probs = F.softmax(cal_logits, dim=-1) - self.ts_cls_metrics.update(cal_probs, targets) + self.ts_cls_metrics.update(cal_probs, target) if dataloader_idx == 0: # squeeze if binary classification only for binary metrics self.test_cls_metrics.update( probs.squeeze(-1) if self.binary_cls else probs, - targets, + target, ) if self.eval_grouping_loss: - self.test_grouping_loss.update(probs, targets, self.features) + self.test_grouping_loss.update(probs, target, self.features) self.log_dict( self.test_cls_metrics, on_epoch=True, add_dataloader_idx=False ) - self.test_entropy_id(probs) + self.test_id_entropy(probs) self.log( - "cls_test/entropy", - self.test_entropy_id, + "test/cls/entropy", + self.test_id_entropy, on_epoch=True, add_dataloader_idx=False, ) @@ -463,18 +469,18 @@ def test_step( if self.eval_ood: self.test_ood_metrics.update( - ood_scores, torch.zeros_like(targets) + ood_scores, torch.zeros_like(target) ) if self.id_logit_storage is not None: self.id_logit_storage.append(logits.detach().cpu()) elif self.eval_ood and dataloader_idx == 1: - self.test_ood_metrics.update(ood_scores, torch.ones_like(targets)) - self.test_entropy_ood(probs) + self.test_ood_metrics.update(ood_scores, torch.ones_like(target)) + self.test_ood_entropy(probs) self.log( - "ood/entropy", - self.test_entropy_ood, + "ood/Entropy", + self.test_ood_entropy, on_epoch=True, add_dataloader_idx=False, ) @@ -485,11 +491,11 @@ def test_step( self.ood_logit_storage.append(logits.detach().cpu()) def on_validation_epoch_end(self) -> None: - self.log_dict(self.val_cls_metrics.compute()) + self.log_dict(self.val_cls_metrics.compute(), sync_dist=True) self.val_cls_metrics.reset() if self.eval_grouping_loss: - self.log_dict(self.val_grouping_loss.compute()) + self.log_dict(self.val_grouping_loss.compute(), sync_dist=True) self.val_grouping_loss.reset() def on_test_epoch_end(self) -> None: @@ -497,7 +503,9 @@ def on_test_epoch_end(self) -> None: result_dict = self.test_cls_metrics.compute() # already logged - result_dict.update({"cls_test/entropy": self.test_entropy_id.compute()}) + result_dict.update( + {"test/Entropy": self.test_id_entropy.compute()}, sync_dist=True + ) if ( self.num_estimators == 1 @@ -505,41 +513,48 @@ def on_test_epoch_end(self) -> None: and self.cal_model is not None ): tmp_metrics = self.ts_cls_metrics.compute() - self.log_dict(tmp_metrics) + self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) - self.ts_cls_metrics.reset() if self.eval_grouping_loss: self.log_dict( self.test_grouping_loss.compute(), + sync_dist=True, ) if self.num_estimators > 1: tmp_metrics = self.test_id_ens_metrics.compute() - self.log_dict(tmp_metrics) + self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) - self.test_id_ens_metrics.reset() if self.eval_ood: tmp_metrics = self.test_ood_metrics.compute() - self.log_dict(tmp_metrics) + self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) - self.test_ood_metrics.reset() # already logged - result_dict.update({"ood/entropy": self.test_entropy_ood.compute()}) + result_dict.update({"ood/Entropy": self.test_ood_entropy.compute()}) if self.num_estimators > 1: tmp_metrics = self.test_ood_ens_metrics.compute() - self.log_dict(tmp_metrics) + self.log_dict(tmp_metrics, sync_dist=True) result_dict.update(tmp_metrics) - self.test_ood_ens_metrics.reset() if isinstance(self.logger, Logger) and self.log_plots: self.logger.experiment.add_figure( - "Calibration Plot", self.test_cls_metrics["ECE"].plot()[0] + "Reliabity diagram", self.test_cls_metrics["cal/ECE"].plot()[0] + ) + self.logger.experiment.add_figure( + "Risk-Coverage curve", + self.test_cls_metrics["sc/AURC"].plot()[0], ) + if self.cal_model is not None: + self.logger.experiment.add_figure( + "Reliabity diagram after calibration", + self.ts_cls_metrics["cal/ECE"].plot()[0], + ) + # plot histograms of logits and likelihoods if self.eval_ood: id_logits = torch.cat(self.id_logit_storage, dim=0) @@ -586,6 +601,7 @@ def _classification_routine_checks( num_estimators: int, ood_criterion: str, eval_grouping_loss: bool, + num_calibration_bins: int, ) -> None: if not isinstance(num_estimators, int) or num_estimators < 1: raise ValueError( @@ -636,3 +652,8 @@ def _classification_routine_checks( "Your model must have a `classification_head` or `linear` " "attribute to compute the grouping loss." ) + + if num_calibration_bins < 2: + raise ValueError( + f"num_calibration_bins must be at least 2, got {num_calibration_bins}." + ) diff --git a/torch_uncertainty/routines/pixel_regression.py b/torch_uncertainty/routines/pixel_regression.py new file mode 100644 index 00000000..d729e11b --- /dev/null +++ b/torch_uncertainty/routines/pixel_regression.py @@ -0,0 +1,321 @@ +from typing import Literal + +import matplotlib.cm as cm +import torch +from einops import rearrange +from lightning.pytorch import LightningModule +from lightning.pytorch.loggers import TensorBoardLogger +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import Tensor, nn +from torch.distributions import ( + Categorical, + Distribution, + Independent, + MixtureSameFamily, +) +from torch.optim import Optimizer +from torchmetrics import MeanSquaredError, MetricCollection +from torchvision.transforms.v2 import functional as F +from torchvision.utils import make_grid + +from torch_uncertainty.metrics import ( + DistributionNLL, + Log10, + MeanAbsoluteErrorInverse, + MeanGTRelativeAbsoluteError, + MeanGTRelativeSquaredError, + MeanSquaredErrorInverse, + MeanSquaredLogError, + SILog, + ThresholdAccuracy, +) +from torch_uncertainty.utils.distributions import dist_rearrange, squeeze_dist + + +class PixelRegressionRoutine(LightningModule): + inv_norm_params = { + "mean": [-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.255], + "std": [1 / 0.229, 1 / 0.224, 1 / 0.255], + } + + def __init__( + self, + model: nn.Module, + output_dim: int, + probabilistic: bool, + loss: nn.Module, + num_estimators: int = 1, + optim_recipe: dict | Optimizer | None = None, + format_batch_fn: nn.Module | None = None, + num_image_plot: int = 4, + ) -> None: + super().__init__() + _depth_routine_checks(num_estimators, output_dim) + + self.model = model + self.output_dim = output_dim + self.one_dim_depth = output_dim == 1 + self.probabilistic = probabilistic + self.loss = loss + self.num_estimators = num_estimators + self.num_image_plot = num_image_plot + + if format_batch_fn is None: + format_batch_fn = nn.Identity() + + self.optim_recipe = optim_recipe + self.format_batch_fn = format_batch_fn + + depth_metrics = MetricCollection( + { + "SILog": SILog(), + "log10": Log10(), + "ARE": MeanGTRelativeAbsoluteError(), + "RSRE": MeanGTRelativeSquaredError(squared=False), + "RMSE": MeanSquaredError(squared=False), + "RMSELog": MeanSquaredLogError(squared=False), + "iMAE": MeanAbsoluteErrorInverse(), + "iRMSE": MeanSquaredErrorInverse(squared=False), + "d1": ThresholdAccuracy(power=1), + "d2": ThresholdAccuracy(power=2), + "d3": ThresholdAccuracy(power=3), + }, + compute_groups=False, + ) + + self.val_metrics = depth_metrics.clone(prefix="val/") + self.test_metrics = depth_metrics.clone(prefix="test/") + + if self.probabilistic: + depth_prob_metrics = MetricCollection( + {"NLL": DistributionNLL(reduction="mean")} + ) + self.val_prob_metrics = depth_prob_metrics.clone(prefix="val/") + self.test_prob_metrics = depth_prob_metrics.clone(prefix="test/") + + def configure_optimizers(self) -> Optimizer | dict: + return self.optim_recipe + + def on_train_start(self) -> None: + if self.logger is not None: # coverage: ignore + self.logger.log_hyperparams( + self.hparams, + ) + + def forward(self, inputs: Tensor) -> Tensor | Distribution: + """Forward pass of the routine. + + The forward pass automatically squeezes the output if the regression + is one-dimensional and if the routine contains a single model. + + Args: + inputs (Tensor): The input tensor. + + Returns: + Tensor: The output tensor. + """ + pred = self.model(inputs) + if self.probabilistic: + if self.num_estimators == 1: + pred = squeeze_dist(pred, -1) + else: + if self.num_estimators == 1: + pred = pred.squeeze(-1) + return pred + + def training_step( + self, batch: tuple[Tensor, Tensor], batch_idx: int + ) -> STEP_OUTPUT: + inputs, target = self.format_batch_fn(batch) + if self.one_dim_depth: + target = target.unsqueeze(1) + + dists = self.model(inputs) + target = F.resize( + target, dists.shape[-2:], interpolation=F.InterpolationMode.NEAREST + ) + valid_mask = ~torch.isnan(target) + loss = self.loss(dists[valid_mask], target[valid_mask]) + self.log("train_loss", loss) + return loss + + def validation_step( + self, batch: tuple[Tensor, Tensor], batch_idx: int + ) -> None: + inputs, target = batch + if self.one_dim_depth: + target = target.unsqueeze(1) + preds = self.model(inputs) + + if self.probabilistic: + ens_dist = Independent( + dist_rearrange( + preds, "(m b) c h w -> b m c h w", m=self.num_estimators + ), + 1, + ) + mix = Categorical( + torch.ones(self.num_estimators, device=self.device) + ) + mixture = MixtureSameFamily(mix, ens_dist) + preds = mixture.mean + else: + preds = rearrange( + preds, "(m b) c h w -> b m c h w", m=self.num_estimators + ) + preds = preds.mean(dim=1) + + if batch_idx == 0: + self._plot_depth( + inputs[: self.num_image_plot, ...], + preds[: self.num_image_plot, ...], + target[: self.num_image_plot, ...], + stage="val", + ) + + valid_mask = ~torch.isnan(target) + self.val_metrics.update(preds[valid_mask], target[valid_mask]) + if self.probabilistic: + self.val_prob_metrics.update( + mixture[valid_mask], target[valid_mask] + ) + + def test_step( + self, + batch: tuple[Tensor, Tensor], + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + if dataloader_idx != 0: + raise NotImplementedError( + "Depth OOD detection not implemented yet. Raise an issue " + "if needed." + ) + + inputs, target = batch + if self.one_dim_depth: + target = target.unsqueeze(1) + preds = self.model(inputs) + + if self.probabilistic: + ens_dist = dist_rearrange( + preds, "(m b) c h w -> b m c h w", m=self.num_estimators + ) + mix = Categorical( + torch.ones(self.num_estimators, device=self.device) + ) + mixture = MixtureSameFamily(mix, ens_dist) + self.test_metrics.nll.update(mixture, target) + preds = mixture.mean + else: + preds = rearrange( + preds, "(m b) c h w -> b m c h w", m=self.num_estimators + ) + preds = preds.mean(dim=1) + + if batch_idx == 0: + num_images = ( + self.num_image_plot + if self.num_image_plot < inputs.size(0) + else inputs.size(0) + ) + self._plot_depth( + inputs[:num_images, ...], + preds[:num_images, ...], + target[:num_images, ...], + stage="test", + ) + + valid_mask = ~torch.isnan(target) + self.test_metrics.update(preds[valid_mask], target[valid_mask]) + if self.probabilistic: + self.test_prob_metrics.update( + mixture[valid_mask], target[valid_mask] + ) + + def on_validation_epoch_end(self) -> None: + self.log_dict(self.val_metrics.compute(), sync_dist=True) + self.val_metrics.reset() + if self.probabilistic: + self.log_dict( + self.val_prob_metrics.compute(), + sync_dist=True, + ) + self.val_prob_metrics.reset() + + def on_test_epoch_end(self) -> None: + self.log_dict( + self.test_metrics.compute(), + sync_dist=True, + ) + self.test_metrics.reset() + if self.probabilistic: + self.log_dict( + self.test_prob_metrics.compute(), + sync_dist=True, + ) + self.test_prob_metrics.reset() + + def _plot_depth( + self, + inputs: Tensor, + preds: Tensor, + target: Tensor, + stage: Literal["val", "test"], + ) -> None: + if ( + self.logger is not None + and isinstance(self.logger, TensorBoardLogger) + and self.one_dim_depth + ): + all_imgs = [] + for i in range(inputs.size(0)): + img = F.normalize(inputs[i, ...].cpu(), **self.inv_norm_params) + pred = colorize( + preds[i, 0, ...].cpu(), vmin=0, vmax=self.model.max_depth + ) + tgt = colorize( + target[i, 0, ...].cpu(), vmin=0, vmax=self.model.max_depth + ) + all_imgs.extend([img, pred, tgt]) + + self.logger.experiment.add_image( + f"{stage}/samples", + make_grid(torch.stack(all_imgs, dim=0), nrow=3), + self.current_epoch, + ) + + +def colorize( + value: Tensor, + vmin: float | None = None, + vmax: float | None = None, + cmap: str = "magma", +): + """Colorize a tensor of depth values. + + Args: + value (Tensor): The tensor of depth values. + vmin (float, optional): The minimum depth value. Defaults to None. + vmax (float, optional): The maximum depth value. Defaults to None. + cmap (str, optional): The colormap to use. Defaults to 'magma'. + """ + vmin = value.min().item() if vmin is None else vmin + vmax = value.max().item() if vmax is None else vmax + if vmin == vmax: + return torch.zeros_like(value) + value = (value - vmin) / (vmax - vmin) + cmapper = cm.get_cmap(cmap) + value = cmapper(value.numpy(), bytes=True) + img = value[:, :, :3] + return torch.as_tensor(img).permute(2, 0, 1).float() / 255.0 + + +def _depth_routine_checks(num_estimators: int, output_dim: int) -> None: + if num_estimators < 1: + raise ValueError( + f"num_estimators must be positive, got {num_estimators}." + ) + + if output_dim < 1: + raise ValueError(f"output_dim must be positive, got {output_dim}.") diff --git a/torch_uncertainty/routines/regression.py b/torch_uncertainty/routines/regression.py index 3124856d..55998518 100644 --- a/torch_uncertainty/routines/regression.py +++ b/torch_uncertainty/routines/regression.py @@ -5,13 +5,16 @@ from torch import Tensor, nn from torch.distributions import ( Categorical, + Distribution, Independent, MixtureSameFamily, ) from torch.optim import Optimizer from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection -from torch_uncertainty.metrics.regression.nll import DistributionNLL +from torch_uncertainty.metrics import ( + DistributionNLL, +) from torch_uncertainty.utils.distributions import dist_rearrange, squeeze_dist @@ -26,8 +29,7 @@ def __init__( optim_recipe: dict | Optimizer | None = None, format_batch_fn: nn.Module | None = None, ) -> None: - r"""Routine for efficient training and testing on **regression tasks** - using LightningModule. + r"""Routine for training & testing on **regression tasks**. Args: model (torch.nn.Module): Model to train. @@ -79,15 +81,15 @@ def __init__( compute_groups=True, ) - self.val_metrics = reg_metrics.clone(prefix="reg_val/") - self.test_metrics = reg_metrics.clone(prefix="reg_test/") + self.val_metrics = reg_metrics.clone(prefix="val/") + self.test_metrics = reg_metrics.clone(prefix="test/") if self.probabilistic: reg_prob_metrics = MetricCollection( {"NLL": DistributionNLL(reduction="mean")} ) - self.val_prob_metrics = reg_prob_metrics.clone(prefix="reg_val/") - self.test_prob_metrics = reg_prob_metrics.clone(prefix="reg_test/") + self.val_prob_metrics = reg_prob_metrics.clone(prefix="val/") + self.test_prob_metrics = reg_prob_metrics.clone(prefix="test/") self.one_dim_regression = output_dim == 1 @@ -95,19 +97,12 @@ def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe def on_train_start(self) -> None: - init_metrics = dict.fromkeys(self.val_metrics, 0) - init_metrics.update(dict.fromkeys(self.test_metrics, 0)) - if self.probabilistic: - init_metrics.update(dict.fromkeys(self.val_prob_metrics, 0)) - init_metrics.update(dict.fromkeys(self.test_prob_metrics, 0)) - if self.logger is not None: # coverage: ignore self.logger.log_hyperparams( self.hparams, - init_metrics, ) - def forward(self, inputs: Tensor) -> Tensor: + def forward(self, inputs: Tensor) -> Tensor | Distribution: """Forward pass of the routine. The forward pass automatically squeezes the output if the regression @@ -173,15 +168,6 @@ def validation_step( if self.probabilistic: self.val_prob_metrics.update(mixture, targets) - def on_validation_epoch_end(self) -> None: - self.log_dict(self.val_metrics.compute()) - self.val_metrics.reset() - if self.probabilistic: - self.log_dict( - self.val_prob_metrics.compute(), - ) - self.val_prob_metrics.reset() - def test_step( self, batch: tuple[Tensor, Tensor], @@ -219,6 +205,13 @@ def test_step( if self.probabilistic: self.test_prob_metrics.update(mixture, targets) + def on_validation_epoch_end(self) -> None: + self.log_dict(self.val_metrics.compute(), sync_dist=True) + self.val_metrics.reset() + if self.probabilistic: + self.log_dict(self.val_prob_metrics.compute(), sync_dist=True) + self.val_prob_metrics.reset() + def on_test_epoch_end(self) -> None: self.log_dict( self.test_metrics.compute(), diff --git a/torch_uncertainty/routines/segmentation.py b/torch_uncertainty/routines/segmentation.py index a3227dcf..c4fc02f2 100644 --- a/torch_uncertainty/routines/segmentation.py +++ b/torch_uncertainty/routines/segmentation.py @@ -1,5 +1,7 @@ +import torch from einops import rearrange from lightning.pytorch import LightningModule +from lightning.pytorch.loggers import Logger from lightning.pytorch.utilities.types import STEP_OUTPUT from torch import Tensor, nn from torch.optim import Optimizer @@ -7,8 +9,9 @@ from torchvision.transforms.v2 import functional as F from torch_uncertainty.metrics import ( - CE, + AURC, BrierScore, + CalibrationError, CategoricalNLL, MeanIntersectionOverUnion, ) @@ -23,9 +26,11 @@ def __init__( num_estimators: int = 1, optim_recipe: dict | Optimizer | None = None, format_batch_fn: nn.Module | None = None, + metric_subsampling_rate: float = 1e-2, + log_plots: bool = False, + num_calibration_bins: int = 15, ) -> None: - """Routine for efficient training and testing on **segmentation tasks** - using LightningModule. + """Routine for training & testing on **segmentation tasks**. Args: model (torch.nn.Module): Model to train. @@ -37,10 +42,15 @@ def __init__( optionally the scheduler to use. Defaults to ``None``. format_batch_fn (torch.nn.Module, optional): The function to format the batch. Defaults to ``None``. + metric_subsampling_rate (float, optional): The rate of subsampling for the + memory consuming metrics. Defaults to ``1e-2``. + log_plots (bool, optional): Indicates whether to log plots from + metrics. Defaults to ``False` + num_calibration_bins (int, optional): Number of bins to compute calibration + metrics. Defaults to ``15``. Warning: - You must define :attr:`optim_recipe` if you do not use - the CLI. + You must define :attr:`optim_recipe` if you do not use the CLI. Note: :attr:`optim_recipe` can be anything that can be returned by @@ -48,7 +58,12 @@ def __init__( `here `_. """ super().__init__() - _segmentation_routine_checks(num_estimators, num_classes) + _segmentation_routine_checks( + num_estimators, + num_classes, + metric_subsampling_rate, + num_calibration_bins, + ) self.model = model self.num_classes = num_classes @@ -60,33 +75,61 @@ def __init__( self.optim_recipe = optim_recipe self.format_batch_fn = format_batch_fn + self.metric_subsampling_rate = metric_subsampling_rate + self.log_plots = log_plots # metrics seg_metrics = MetricCollection( { - "Acc": Accuracy(task="multiclass", num_classes=num_classes), - "ECE": CE(task="multiclass", num_classes=num_classes), - "mIoU": MeanIntersectionOverUnion(num_classes=num_classes), - "Brier": BrierScore(num_classes=num_classes), - "NLL": CategoricalNLL(), + "seg/mIoU": MeanIntersectionOverUnion(num_classes=num_classes), + }, + compute_groups=False, + ) + sbsmpl_seg_metrics = MetricCollection( + { + "seg/mAcc": Accuracy( + task="multiclass", average="macro", num_classes=num_classes + ), + "seg/Brier": BrierScore(num_classes=num_classes), + "seg/NLL": CategoricalNLL(), + "seg/pixAcc": Accuracy( + task="multiclass", num_classes=num_classes + ), + "cal/ECE": CalibrationError( + task="multiclass", + num_classes=num_classes, + num_bins=num_calibration_bins, + ), + "cal/aECE": CalibrationError( + task="multiclass", + adaptive=True, + num_bins=num_calibration_bins, + num_classes=num_classes, + ), + "sc/AURC": AURC(), }, - compute_groups=[["Acc", "mIoU"], ["ECE"], ["Brier"], ["NLL"]], + compute_groups=False, ) - self.val_seg_metrics = seg_metrics.clone(prefix="seg_val/") - self.test_seg_metrics = seg_metrics.clone(prefix="seg_test/") + self.val_seg_metrics = seg_metrics.clone(prefix="val/") + self.val_sbsmpl_seg_metrics = sbsmpl_seg_metrics.clone(prefix="val/") + self.test_seg_metrics = seg_metrics.clone(prefix="test/") + self.test_sbsmpl_seg_metrics = sbsmpl_seg_metrics.clone(prefix="test/") def configure_optimizers(self) -> Optimizer | dict: return self.optim_recipe - def forward(self, img: Tensor) -> Tensor: - return self.model(img) + def forward(self, inputs: Tensor) -> Tensor: + """Forward pass of the model. - def on_train_start(self) -> None: - init_metrics = dict.fromkeys(self.val_seg_metrics, 0) - init_metrics.update(dict.fromkeys(self.test_seg_metrics, 0)) + Args: + inputs (torch.Tensor): Input tensor. + """ + return self.model(inputs) - self.logger.log_hyperparams(self.hparams, init_metrics) + def on_train_start(self) -> None: + if self.logger is not None: # coverage: ignore + self.logger.log_hyperparams(self.hparams) def training_step( self, batch: tuple[Tensor, Tensor], batch_idx: int @@ -119,7 +162,9 @@ def validation_step( probs = probs_per_est.mean(dim=1) target = target.flatten() valid_mask = target != 255 - self.val_seg_metrics.update(probs[valid_mask], target[valid_mask]) + probs, target = probs[valid_mask], target[valid_mask] + self.val_seg_metrics.update(probs, target) + self.val_sbsmpl_seg_metrics.update(*self.subsample(probs, target)) def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: img, target = batch @@ -134,18 +179,42 @@ def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> None: probs = probs_per_est.mean(dim=1) target = target.flatten() valid_mask = target != 255 - self.test_seg_metrics.update(probs[valid_mask], target[valid_mask]) + probs, target = probs[valid_mask], target[valid_mask] + self.test_seg_metrics.update(probs, target) + self.test_sbsmpl_seg_metrics.update(*self.subsample(probs, target)) def on_validation_epoch_end(self) -> None: - self.log_dict(self.val_seg_metrics.compute()) + self.log_dict(self.val_seg_metrics.compute(), sync_dist=True) + self.log_dict(self.val_sbsmpl_seg_metrics.compute(), sync_dist=True) self.val_seg_metrics.reset() + self.val_sbsmpl_seg_metrics.reset() def on_test_epoch_end(self) -> None: - self.log_dict(self.test_seg_metrics.compute()) - self.test_seg_metrics.reset() - - -def _segmentation_routine_checks(num_estimators: int, num_classes: int) -> None: + self.log_dict(self.test_seg_metrics.compute(), sync_dist=True) + self.log_dict(self.test_sbsmpl_seg_metrics.compute(), sync_dist=True) + if isinstance(self.logger, Logger) and self.log_plots: + self.logger.experiment.add_figure( + "Reliabity diagram", + self.test_sbsmpl_seg_metrics["cal/ECE"].plot()[0], + ) + self.logger.experiment.add_figure( + "Risk-Coverage curve", + self.test_sbsmpl_seg_metrics["sc/AURC"].plot()[0], + ) + + def subsample(self, pred: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: + total_size = target.size(0) + num_samples = max(1, int(total_size * self.metric_subsampling_rate)) + indices = torch.randperm(total_size, device=pred.device)[:num_samples] + return pred[indices], target[indices] + + +def _segmentation_routine_checks( + num_estimators: int, + num_classes: int, + metric_subsampling_rate: float, + num_calibration_bins: int, +) -> None: if num_estimators < 1: raise ValueError( f"num_estimators must be positive, got {num_estimators}." @@ -153,3 +222,13 @@ def _segmentation_routine_checks(num_estimators: int, num_classes: int) -> None: if num_classes < 2: raise ValueError(f"num_classes must be at least 2, got {num_classes}.") + + if not 0 < metric_subsampling_rate <= 1: + raise ValueError( + f"metric_subsampling_rate must be in the range (0, 1], got {metric_subsampling_rate}." + ) + + if num_calibration_bins < 2: + raise ValueError( + f"num_calibration_bins must be at least 2, got {num_calibration_bins}." + ) diff --git a/torch_uncertainty/transforms/batch.py b/torch_uncertainty/transforms/batch.py index 600cea3d..dd96bba7 100644 --- a/torch_uncertainty/transforms/batch.py +++ b/torch_uncertainty/transforms/batch.py @@ -25,7 +25,9 @@ def __init__(self, num_repeats: int) -> None: def forward(self, batch: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]: inputs, targets = batch - return inputs, targets.repeat_interleave(self.num_repeats, dim=0) + return inputs, targets.repeat( + self.num_repeats, *[1] * (targets.ndim - 1) + ) class MIMOBatchFormat(nn.Module): diff --git a/torch_uncertainty/utils/__init__.py b/torch_uncertainty/utils/__init__.py index de0547c7..885f2dd0 100644 --- a/torch_uncertainty/utils/__init__.py +++ b/torch_uncertainty/utils/__init__.py @@ -4,3 +4,4 @@ from .hub import load_hf from .misc import create_train_val_split, csv_writer, plot_hist from .trainer import TUTrainer +from .transforms import interpolation_modes_from_str diff --git a/torch_uncertainty/utils/distributions.py b/torch_uncertainty/utils/distributions.py index ab2336bc..58f8532b 100644 --- a/torch_uncertainty/utils/distributions.py +++ b/torch_uncertainty/utils/distributions.py @@ -62,11 +62,11 @@ def squeeze_dist(distribution: Distribution, dim: int) -> Distribution: Distribution: The squeezed distribution. """ dist_type = type(distribution) - if isinstance(distribution, Normal | Laplace): + if dist_type in (Normal, Laplace): loc = distribution.loc.squeeze(dim) scale = distribution.scale.squeeze(dim) return dist_type(loc=loc, scale=scale) - if isinstance(distribution, NormalInverseGamma): + if dist_type == NormalInverseGamma: loc = distribution.loc.squeeze(dim) lmbda = distribution.lmbda.squeeze(dim) alpha = distribution.alpha.squeeze(dim) @@ -82,11 +82,11 @@ def dist_rearrange( distribution: Distribution, pattern: str, **axes_lengths: int ) -> Distribution: dist_type = type(distribution) - if isinstance(distribution, Normal | Laplace): + if dist_type in (Normal, Laplace): loc = rearrange(distribution.loc, pattern=pattern, **axes_lengths) scale = rearrange(distribution.scale, pattern=pattern, **axes_lengths) return dist_type(loc=loc, scale=scale) - if isinstance(distribution, NormalInverseGamma): + if dist_type == NormalInverseGamma: loc = rearrange(distribution.loc, pattern=pattern, **axes_lengths) lmbda = rearrange(distribution.lmbda, pattern=pattern, **axes_lengths) alpha = rearrange(distribution.alpha, pattern=pattern, **axes_lengths) @@ -141,17 +141,17 @@ def mean(self) -> Tensor: def mode(self) -> None: raise NotImplementedError( - "Mode is not meaningful for the NormalInverseGamma distribution" + "NormalInverseGamma distribution has no mode." ) def stddev(self) -> None: raise NotImplementedError( - "Standard deviation is not meaningful for the NormalInverseGamma distribution" + "NormalInverseGamma distribution has no stddev." ) def variance(self) -> None: raise NotImplementedError( - "Variance is not meaningful for the NormalInverseGamma distribution" + "NormalInverseGamma distribution has no variance." ) @property diff --git a/torch_uncertainty/utils/hub.py b/torch_uncertainty/utils/hub.py index 83ef289e..b48bc324 100644 --- a/torch_uncertainty/utils/hub.py +++ b/torch_uncertainty/utils/hub.py @@ -51,5 +51,4 @@ def load_hf(weight_id: str, version: int = 0) -> tuple[torch.Tensor, dict]: # Load the config config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml") config = yaml.safe_load(Path(config_path).read_text()) - return weight, config diff --git a/torch_uncertainty/utils/learning_rate.py b/torch_uncertainty/utils/learning_rate.py new file mode 100644 index 00000000..1e6aef85 --- /dev/null +++ b/torch_uncertainty/utils/learning_rate.py @@ -0,0 +1,30 @@ +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler + + +class PolyLR(LRScheduler): + def __init__( + self, + optimizer: Optimizer, + total_iters: int, + power: float = 0.9, + last_epoch: int = -1, + min_lr: float = 1e-6, + ) -> None: + self.power = power + self.total_iters = total_iters + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> list[float]: + return self._get_closed_form_lr() + + def _get_closed_form_lr(self) -> list[float]: + return [ + max( + base_lr + * (1 - self.last_epoch / self.total_iters) ** self.power, + self.min_lr, + ) + for base_lr in self.base_lrs + ] diff --git a/torch_uncertainty/utils/transforms.py b/torch_uncertainty/utils/transforms.py new file mode 100644 index 00000000..6c755aff --- /dev/null +++ b/torch_uncertainty/utils/transforms.py @@ -0,0 +1,14 @@ +from torchvision.transforms import InterpolationMode + + +def interpolation_modes_from_str(val: str) -> InterpolationMode: + val = val.lower() + inverse_modes_mapping = { + "nearest": InterpolationMode.NEAREST, + "bilinear": InterpolationMode.BILINEAR, + "bicubic": InterpolationMode.BICUBIC, + "box": InterpolationMode.BOX, + "hamming": InterpolationMode.HAMMING, + "lanczos": InterpolationMode.LANCZOS, + } + return inverse_modes_mapping[val]