Skip to content

Commit

Permalink
Merge branch 'master' into feature/hinge_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Mar 25, 2021
2 parents ccb8839 + 26eae39 commit bdae664
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 7 deletions.
76 changes: 76 additions & 0 deletions .github/workflows/ci_test-conda.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
name: PyTorch & Conda

# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows
on: # Trigger the workflow on push or pull request, but only for the master branch
push:
branches: [master, "release/*"]
pull_request:
branches: [master, "release/*"]

jobs:
conda:
runs-on: ubuntu-20.04
strategy:
fail-fast: false
matrix:
python-version: [3.7]
pytorch-version: [1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9]

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 35
steps:
- uses: actions/checkout@v2

- name: Cache conda
uses: actions/cache@v2
with:
path: ~/conda_pkgs_dir
key: conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}-${{ hashFiles('environment.yml') }}
restore-keys: conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}-

# Add another cache for Pip as not all packages lives in Conda env
- name: Cache pip
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: pip-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}-${{ hashFiles('requirements/base.txt') }}
restore-keys: pip-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}-

# https://docs.conda.io/projects/conda/en/4.6.0/_downloads/52a95608c49671267e40c689e0bc00ca/conda-cheatsheet.pdf
# https://gist.github.com/mwouts/9842452d020c08faf9e84a3bba38a66f
- name: Setup Miniconda
uses: conda-incubator/setup-miniconda@v2
with:
miniconda-version: "4.7.12"
python-version: ${{ matrix.python-version }}
channels: conda-forge,pytorch,pytorch-test,pytorch-nightly
channel-priority: true
auto-activate-base: true
# environment-file: ./environment.yml
use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly!

- name: Update Environment
run: |
conda info
conda install pytorch=${{ matrix.pytorch-version }} cpuonly
conda list
pip --version
pip install --requirement requirements.txt --upgrade-strategy only-if-needed --quiet
pip install --requirement requirements/test.txt --upgrade-strategy only-if-needed --quiet
pip list
python -c "import torch; assert torch.__version__[:3] == '${{ matrix.pytorch-version }}', torch.__version__"
shell: bash -l {0}

- name: Testing
run: |
# NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003
python -m pytest torchmetrics tests -v --durations=35 --junitxml=junit/test-conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}.xml
shell: bash -l {0}

- name: Upload pytest test results
uses: actions/upload-artifact@master
with:
name: test-conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}.xml
path: junit/test-conda-py${{ matrix.python-version }}-pt${{ matrix.pytorch-version }}.xml
# Use always() to always run this step to publish test results when there are test failures
if: failure()
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ prune notebook*
prune temp*
prune test*
prune benchmark*
prune integration*
2 changes: 1 addition & 1 deletion azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,5 @@ jobs:
condition: succeededOrFailed()

- bash: |
python -m pytest integrations --durations=25
python -m pytest integrations -v --durations=25
displayName: 'Integrations'
3 changes: 3 additions & 0 deletions integrations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from torchmetrics.utilities.imports import _module_available

_PL_AVAILABLE = _module_available('pytorch_lightning')
1 change: 1 addition & 0 deletions integrations/lightning_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from pytorch_lightning import LightningModule
from torch.utils.data import Dataset
Expand Down
5 changes: 4 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ exclude_lines =

[flake8]
max-line-length = 120
exclude = .tox,*.egg,build,temp
exclude =
*.egg
build
temp
select = E,W,F
doctests = True
verbose = 2
Expand Down
10 changes: 5 additions & 5 deletions tests/classification/test_stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch import Tensor, tensor

from tests.classification.inputs import _input_binary, _input_binary_prob, _input_multiclass
from tests.classification.inputs import _input_multiclass_prob as _input_mccls_prob
from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.classification.inputs import _input_multilabel as _input_mcls
Expand Down Expand Up @@ -104,8 +104,8 @@ def _sk_stat_scores_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, i
["macro", None, None, _input_binary, None],
["micro", None, None, _input_mdmc_prob, None],
["micro", None, None, _input_binary_prob, 0],
["micro", None, None, _input_mccls_prob, NUM_CLASSES],
["micro", None, NUM_CLASSES, _input_mccls_prob, NUM_CLASSES],
["micro", None, None, _input_mcls_prob, NUM_CLASSES],
["micro", None, NUM_CLASSES, _input_mcls_prob, NUM_CLASSES],
],
)
def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index):
Expand Down Expand Up @@ -141,8 +141,8 @@ def test_wrong_threshold():
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
(_input_mcls.preds, _input_mcls.target, _sk_stat_scores, None, NUM_CLASSES, False, None),
(_input_mccls_prob.preds, _input_mccls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mccls_prob.preds, _input_mccls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2),
(_input_multiclass.preds, _input_multiclass.target, _sk_stat_scores, None, NUM_CLASSES, None, None),
(_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None, None),
(
Expand Down
58 changes: 58 additions & 0 deletions torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,64 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distutils.version import LooseVersion
from importlib import import_module
from importlib.util import find_spec

import torch
from pkg_resources import DistributionNotFound


def _module_available(module_path: str) -> bool:
"""
Check if a path is available in your environment
>>> _module_available('os')
True
>>> _module_available('bla.bla')
False
"""
try:
return find_spec(module_path) is not None
except AttributeError:
# Python 3.6
return False
except ModuleNotFoundError:
# Python 3.7+
return False


def _compare_version(package: str, op, version) -> bool:
"""
Compare package version with some requirements
>>> import operator
>>> _compare_version("torch", operator.ge, "0.1")
True
"""
try:
pkg = import_module(package)
except (ModuleNotFoundError, DistributionNotFound):
return False
try:
pkg_version = LooseVersion(pkg.__version__)
except AttributeError:
return False
if not (hasattr(pkg_version, "vstring") and hasattr(pkg_version, "version")):
# this is mock by sphinx, so it shall return True ro generate all summaries
return True
return op(pkg_version, LooseVersion(version))


_TORCH_LOWER_1_4 = LooseVersion(torch.__version__) < LooseVersion("1.4.0")
_TORCH_LOWER_1_5 = LooseVersion(torch.__version__) < LooseVersion("1.5.0")
Expand Down

0 comments on commit bdae664

Please sign in to comment.