Skip to content

Commit

Permalink
Merge branch 'add_micro_average_IOU' of https://github.com/razmikmeli…
Browse files Browse the repository at this point in the history
…kbekyan/metrics into add_micro_average_IOU
  • Loading branch information
SkafteNicki committed May 25, 2022
2 parents 344c10c + 1008e88 commit 5435cec
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/docs-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ jobs:

- name: Install dependencies
run: |
pip install -r requirements.txt -U -f https://download.pytorch.org/whl/cpu/torch_stable.html -q --use-feature=2020-resolver
pip install -r requirements/docs.txt --use-feature=2020-resolver
pip install -r requirements.txt -U -f https://download.pytorch.org/whl/cpu/torch_stable.html -q
pip install -r requirements/docs.txt
python --version ; pip --version ; pip list
shell: bash

Expand Down Expand Up @@ -62,8 +62,8 @@ jobs:

- name: Install dependencies
run: |
pip install -r requirements.txt -U -f https://download.pytorch.org/whl/torch_stable.html -q --use-feature=2020-resolver
pip install -r requirements/docs.txt --use-feature=2020-resolver
pip install -r requirements.txt -U -f https://download.pytorch.org/whl/torch_stable.html -q
pip install -r requirements/docs.txt
# install Texlive, see https://linuxconfig.org/how-to-install-latex-on-ubuntu-20-04-focal-fossa-linux
sudo apt-get update
sudo apt-get install -y texlive-latex-extra dvipng texlive-pictures
Expand Down
18 changes: 10 additions & 8 deletions torchmetrics/functional/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,23 @@ def _jaccard_from_confmat(
]
)
return scores
elif average == "macro":

if average == "macro":
scores = _jaccard_from_confmat(
confmat, num_classes, average="none", ignore_index=ignore_index, absent_score=absent_score
)
return torch.mean(scores)
elif average == "micro":

if average == "micro":
intersection = torch.sum(torch.diag(confmat))
union = torch.sum(torch.sum(confmat, dim=1) + torch.sum(confmat, dim=0) - torch.diag(confmat))
return intersection.float() / union.float()
else:
weights = torch.sum(confmat, dim=1).float() / torch.sum(confmat).float()
scores = _jaccard_from_confmat(
confmat, num_classes, average="none", ignore_index=ignore_index, absent_score=absent_score
)
return torch.sum(weights * scores)

weights = torch.sum(confmat, dim=1).float() / torch.sum(confmat).float()
scores = _jaccard_from_confmat(
confmat, num_classes, average="none", ignore_index=ignore_index, absent_score=absent_score
)
return torch.sum(weights * scores)


def jaccard_index(
Expand Down

0 comments on commit 5435cec

Please sign in to comment.