Skip to content

Commit

Permalink
Merge branch 'master' into psnrb
Browse files Browse the repository at this point in the history
  • Loading branch information
soma2000-lang authored Jan 4, 2023
2 parents f487bad + 107dbfd commit cdd5b91
Show file tree
Hide file tree
Showing 40 changed files with 222 additions and 194 deletions.
16 changes: 16 additions & 0 deletions .azure/gpu-pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ jobs:

variables:
DEVICES: $( python -c 'name = "$(Agent.Name)" ; gpus = name.split("_")[-1] if "_" in name else "0,1"; print(gpus)' )
HF_CACHE_DIR: "$(Pipeline.Workspace)/ci-cache_huggingface"

container:
image: "$(docker-image)"
Expand Down Expand Up @@ -95,6 +96,19 @@ jobs:
python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu >= 2, f'GPU: {mgpu}'"
displayName: 'Sanity check'
- task: Cache@2
inputs:
key: transformers | "$(Agent.OS)"
restoreKeys: transformers
path: $(HF_CACHE_DIR)
cacheHitVar: HF_CACHE_RESTORED
- bash: |
printf "cache location: $(HF_CACHE_DIR)\n"
printf "hit the HF cache: $(variables.HF_CACHE_RESTORED)\n"
mkdir -p $(HF_CACHE_DIR) # in case cache was void
ls -lh $(HF_CACHE_DIR) # show what was restored...
displayName: 'Show HF cache'
- bash: python -m pytest torchmetrics --cov=torchmetrics --timeout=120 --durations=50
workingDirectory: src
displayName: 'DocTesting'
Expand All @@ -120,6 +134,8 @@ jobs:
python -m codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) --flags=gpu,unittest --name="GPU-coverage" --env=linux,azure
ls -l
workingDirectory: tests
env:
TRANSFORMERS_CACHE: $(HF_CACHE_DIR)
displayName: 'Statistics'
- task: PublishTestResults@2
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/publish-pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ jobs:
# We do this, since failures on test.pypi aren't that bad
- name: Publish to Test PyPI
if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
uses: pypa/gh-action-pypi-publish@v1.5.2
uses: pypa/gh-action-pypi-publish@v1.6.4
with:
user: __token__
password: ${{ secrets.test_pypi_password }}
Expand All @@ -82,7 +82,7 @@ jobs:

- name: Publish distribution 📦 to PyPI
if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
uses: pypa/gh-action-pypi-publish@v1.5.2
uses: pypa/gh-action-pypi-publish@v1.6.4
with:
user: __token__
password: ${{ secrets.pypi_password }}
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ repos:
name: Upgrade code

- repo: https://github.com/PyCQA/docformatter
rev: v1.5.0
rev: v1.5.1
hooks:
- id: docformatter
args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120]

- repo: https://github.com/PyCQA/isort
rev: 5.11.2
rev: 5.11.4
hooks:
- id: isort
name: imports
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

############################################
Error Relative Global Dim. Synthesis (ERGAS)
############################################
Expand All @@ -12,6 +14,8 @@ ________________

.. autoclass:: torchmetrics.image.ergas.ErrorRelativeGlobalDimensionlessSynthesis
:noindex:
:exclude-members: update, compute


Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/image/frechet_inception_distance.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ ________________

.. autoclass:: torchmetrics.image.fid.FrechetInceptionDistance
:noindex:
:exclude-members: update, compute
3 changes: 3 additions & 0 deletions docs/source/image/inception_score.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

###############
Inception Score
###############
Expand All @@ -12,3 +14,4 @@ ________________

.. autoclass:: torchmetrics.image.inception.InceptionScore
:noindex:
:exclude-members: update, compute
1 change: 1 addition & 0 deletions docs/source/image/kernel_inception_distance.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ ________________

.. autoclass:: torchmetrics.image.kid.KernelInceptionDistance
:noindex:
:exclude-members: update, compute
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ ________________

.. autoclass:: torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity
:noindex:
:exclude-members: update, compute
1 change: 1 addition & 0 deletions docs/source/image/multi_scale_structural_similarity.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.MultiScaleStructuralSimilarityIndexMeasure
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/image/peak_signal_noise_ratio.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.PeakSignalNoiseRatio
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/image/spectral_angle_mapper.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ ________________

.. autoclass:: torchmetrics.SpectralAngleMapper
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/image/spectral_distortion_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.SpectralDistortionIndex
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/image/structural_similarity.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.StructuralSimilarityIndexMeasure
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/image/total_variation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.TotalVariation
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
1 change: 1 addition & 0 deletions docs/source/image/universal_image_quality_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ________________

.. autoclass:: torchmetrics.UniversalImageQualityIndex
:noindex:
:exclude-members: update, compute

Functional Interface
____________________
Expand Down
10 changes: 10 additions & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,15 @@
.. _Scale-invariant signal-to-noise ratio: https://arxiv.org/abs/1711.00541
.. _Signal-to-noise ratio: https://arxiv.org/abs/1811.02508
.. _Permutation invariant training: https://arxiv.org/abs/1607.00325
.. _ranking ref1: https://link.springer.com/chapter/10.1007/978-0-387-09823-4_34
.. _Spectral Distortion Index: https://www.ingentaconnect.com/content/asprs/pers/2008/00000074/00000002/art00003;jsessionid=nzjnb3v9xxr1.x-ic-live-03
.. _Relative dimensionless global error synthesis: https://ieeexplore.ieee.org/document/4317530
.. _fid ref1: https://arxiv.org/abs/1512.00567
.. _fid ref2: https://arxiv.org/abs/1706.08500
.. _inception ref1: https://arxiv.org/abs/1606.03498
.. _inception ref2: https://arxiv.org/abs/1706.08500
.. _kid ref1: https://arxiv.org/abs/1801.01401
.. _kid ref2: https://arxiv.org/abs/1706.08500
.. _Spectral Angle Mapper: https://ntrs.nasa.gov/citations/19940012238
.. _Multilabel coverage error: https://link.springer.com/chapter/10.1007/978-0-387-09823-4_34
.. _Peak Signal to Noise Ratio with Blocked Effect:https://ieeexplore.ieee.org/abstract/document/5535179
22 changes: 11 additions & 11 deletions docs/source/pages/lightning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ The example below shows how to use a metric in your `LightningModule <https://py

class MyModel(LightningModule):

def __init__(self):
def __init__(self, num_classes):
...
self.accuracy = torchmetrics.Accuracy(task='multiclass')
self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)

def training_step(self, batch, batch_idx):
x, y = batch
Expand Down Expand Up @@ -78,10 +78,10 @@ value by calling ``.compute()``.

class MyModule(LightningModule):

def __init__(self):
def __init__(self, num_classes):
...
self.train_acc = torchmetrics.Accuracy(task='multiclass')
self.valid_acc = torchmetrics.Accuracy(task='multiclass')
self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
self.valid_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)

def training_step(self, batch, batch_idx):
x, y = batch
Expand All @@ -105,8 +105,8 @@ of the metrics.

def __init__(self):
...
self.train_acc = torchmetrics.Accuracy(task='multiclass')
self.valid_acc = torchmetrics.Accuracy(task='multiclass')
self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
self.valid_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)

def training_step(self, batch, batch_idx):
x, y = batch
Expand Down Expand Up @@ -141,9 +141,9 @@ mixed as it can lead to wrong results.

class MyModule(LightningModule):

def __init__(self):
def __init__(self, num_classes):
...
self.valid_acc = torchmetrics.Accuracy(task='multiclass')
self.valid_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)

def validation_step(self, batch, batch_idx):
logits = self(x)
Expand Down Expand Up @@ -185,9 +185,9 @@ The following contains a list of pitfalls to be aware of:

class MyModule(LightningModule):

def __init__(self):
def __init__(self, num_classes):
...
self.val_acc = nn.ModuleList([torchmetrics.Accuracy(task='multiclass') for _ in range(2)])
self.val_acc = nn.ModuleList([torchmetrics.Accuracy(task="multiclass", num_classes=num_classes) for _ in range(2)])

def val_dataloader(self):
return [DataLoader(...), DataLoader(...)]
Expand Down
4 changes: 2 additions & 2 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ inside your LightningModule. In most cases we just have to replace ``self.log``
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall

class MyModule(LightningModule):
def __init__(self):
def __init__(self, num_classes):
metrics = MetricCollection([
MulticlassAccuracy(), MulticlassPrecision(), MulticlassRecall()
MulticlassAccuracy(num_classes), MulticlassPrecision(num_classes), MulticlassRecall(num_classes)
])
self.train_metrics = metrics.clone(prefix='train_')
self.valid_metrics = metrics.clone(prefix='val_')
Expand Down
4 changes: 2 additions & 2 deletions docs/source/pages/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ The code-snippet below shows a simple example for calculating the accuracy using
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))

acc = torchmetrics.functional.accuracy(preds, target, task='multiclass', num_classes=5)
acc = torchmetrics.functional.accuracy(preds, target, task="multiclass", num_classes=5)

Module metrics
~~~~~~~~~~~~~~
Expand All @@ -86,7 +86,7 @@ The code below shows how to use the class-based interface:
import torchmetrics

# initialize metric
metric = torchmetrics.Accuracy(task='multiclass', num_classes=5)
metric = torchmetrics.Accuracy(task="multiclass", num_classes=5)

n_batches = 10
for i in range(n_batches):
Expand Down
1 change: 0 additions & 1 deletion src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,6 @@ def _move_list_states_to_cpu(self) -> None:
setattr(self, key, current_to_cpu)

def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]:

if self.iou_type == "bbox":
boxes = _fix_empty_tensors(item["boxes"])
if boxes.numel() > 0:
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/image/d_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def spectral_distortion_index(
p: int = 1,
reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
) -> Tensor:
"""Spectral Distortion Index (SpectralDistortionIndex_) also now as D_lambda is used to compare the spectral
distortion between two images.
"""Calculates `Spectral Distortion Index`_ (SpectralDistortionIndex_) also known as D_lambda that is used to
compare the spectral distortion between two images.
Args:
preds: Low resolution multispectral image
Expand Down
24 changes: 11 additions & 13 deletions src/torchmetrics/image/d_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ class SpectralDistortionIndex(Metric):
"""Computes Spectral Distortion Index (SpectralDistortionIndex_) also now as D_lambda is used to compare the
spectral distortion between two images.
As input to ``forward`` and ``update`` the metric accepts the following input
- ``preds`` (:class:`~torch.Tensor`): Low resolution multispectral image of shape ``(N,C,H,W)``
- ``target``(:class:`~torch.Tensor`): High resolution fused image of shape ``(N,C,H,W)``
As output of `forward` and `compute` the metric returns the following output
- ``sdi`` (:class:`~torch.Tensor`): if ``reduction!='none'`` returns float scalar tensor with average SDI value
over sample else returns tensor of shape ``(N,)`` with SDI values per sample
Args:
p: Large spectral differences
reduction: a method to reduce metric score over labels.
Expand All @@ -36,7 +46,6 @@ class SpectralDistortionIndex(Metric):
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example:
>>> import torch
>>> _ = torch.manual_seed(42)
Expand All @@ -46,12 +55,6 @@ class SpectralDistortionIndex(Metric):
>>> sdi = SpectralDistortionIndex()
>>> sdi(preds, target)
tensor(0.0234)
References:
[1] Alparone, Luciano & Aiazzi, Bruno & Baronti, Stefano & Garzelli, Andrea & Nencini,
Filippo & Selva, Massimo. (2008). Multispectral and Panchromatic Data Fusion
Assessment Without Reference. ASPRS Journal of Photogrammetric Engineering
and Remote Sensing. 74. 193-200. 10.14358/PERS.74.2.193.
"""

higher_is_better: bool = True
Expand Down Expand Up @@ -82,12 +85,7 @@ def __init__(
self.add_state("target", default=[], dist_reduce_fx="cat")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with preds and target.
Args:
preds: Low resolution multispectral image
target: High resolution fused image
"""
"""Update state with preds and target."""
preds, target = _spectral_distortion_index_update(preds, target)
self.preds.append(preds)
self.target.append(target)
Expand Down
Loading

0 comments on commit cdd5b91

Please sign in to comment.