Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect type hints for log_dict(MetricCollection())? #18641

Closed
adamjstewart opened this issue Sep 26, 2023 · 3 comments · Fixed by #18646
Closed

Incorrect type hints for log_dict(MetricCollection())? #18641

adamjstewart opened this issue Sep 26, 2023 · 3 comments · Fixed by #18646
Labels
bug Something isn't working code quality good first issue Good for newcomers help wanted Open to be worked on ver: 2.0.x

Comments

@adamjstewart
Copy link
Contributor

adamjstewart commented Sep 26, 2023

Bug description

I'm using TorchMetrics to compute various metrics in my LightningModule. However, it doesn't seem like the type hints allow passing a MetricCollection directly to log_dict.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

from lightning.pytorch import LightningModule
from torch import Tensor
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex

class MyModel(LightningModule):
    def __init__(self) -> None:
        self.metrics = MetricCollection([MulticlassAccuracy(1), MulticlassJaccardIndex(1)])

    def training_step(self, batch: Tensor, batch_idx: int) -> Tensor:
        x, y = batch
        preds: Tensor = self(x)
        self.metrics(preds, y)
        self.log_dict(self.metrics)
        return preds

Error messages and logs

> mypy --strict test.py
test.py:14: error: Argument 1 to "log_dict" of "LightningModule" has incompatible type "MetricCollection"; expected "Mapping[str, Union[Metric, Tensor, Union[int, float]]]"  [arg-type]
Found 1 error in 1 file (checked 1 source file)

Environment

Current environment
  • CUDA:
    • GPU: None
    • available: False
    • version: None
  • Lightning:
    • efficientnet-pytorch: 0.7.1
    • lightning: 2.0.9
    • lightning-cloud: 0.5.38
    • lightning-utilities: 0.8.0
    • pytorch-lightning: 2.0.0
    • pytorch-sphinx-theme: 0.0.24
    • segmentation-models-pytorch: 0.3.3
    • torch: 2.0.1
    • torchmetrics: 1.1.1
    • torchvision: 0.15.2
  • Packages:
    • absl-py: 1.4.0
    • aenum: 3.1.12
    • affine: 2.1.0
    • aiohttp: 3.8.4
    • aiosignal: 1.2.0
    • alabaster: 0.7.13
    • antlr4-python3-runtime: 4.9.3
    • anyio: 3.6.2
    • appdirs: 1.4.4
    • appnope: 0.1.3
    • argon2-cffi: 21.3.0
    • argon2-cffi-bindings: 21.2.0
    • arrow: 1.2.3
    • asttokens: 2.2.1
    • astunparse: 1.6.3
    • async-lru: 1.0.3
    • async-timeout: 4.0.2
    • attrs: 23.1.0
    • babel: 2.12.1
    • backcall: 0.2.0
    • backoff: 2.2.1
    • beautifulsoup4: 4.12.2
    • black: 23.9.1
    • bleach: 6.0.0
    • blessed: 1.19.0
    • bottleneck: 1.3.7
    • build: 1.0.3
    • cachetools: 5.2.0
    • cartopy: 0.22.0
    • certifi: 2023.5.7
    • cffi: 1.15.1
    • cftime: 1.0.3.4
    • charset-normalizer: 3.1.0
    • click: 8.1.3
    • click-plugins: 1.1.1
    • cligj: 0.7.2
    • cmocean: 2.0
    • colorama: 0.4.6
    • comm: 0.1.3
    • contourpy: 1.0.7
    • coverage: 7.2.6
    • croniter: 1.3.8
    • cycler: 0.11.0
    • cython: 0.29.36
    • dateutils: 0.6.12
    • debugpy: 1.6.7
    • decorator: 5.1.1
    • deepdiff: 6.3.0
    • defusedxml: 0.7.1
    • docstring-parser: 0.15
    • docutils: 0.18.1
    • editables: 0.3
    • efficientnet-pytorch: 0.7.1
    • einops: 0.6.1
    • et-xmlfile: 1.0.1
    • executing: 1.2.0
    • fastapi: 0.98.0
    • fastjsonschema: 2.16.3
    • filelock: 3.12.0
    • fiona: 1.9.4
    • flake8: 6.1.0
    • flit-core: 3.9.0
    • fonttools: 4.39.4
    • fqdn: 1.5.1
    • frozenlist: 1.3.1
    • fsspec: 2023.1.0
    • gdal: 3.7.2
    • geocube: 0.3.2
    • geopandas: 0.11.1
    • gevent: 23.7.0
    • google-auth: 2.20.0
    • google-auth-oauthlib: 0.5.2
    • greenlet: 2.0.2
    • grpcio: 1.52.0
    • h11: 0.13.0
    • h5py: 3.8.0
    • hatch-jupyter-builder: 0.8.3
    • hatchling: 1.17.0
    • huggingface-hub: 0.14.1
    • hydra-core: 1.3.1
    • idna: 3.4
    • imagesize: 1.4.1
    • importlib-metadata: 6.6.0
    • importlib-resources: 5.12.0
    • iniconfig: 2.0.0
    • inquirer: 3.1.3
    • ipykernel: 6.23.1
    • ipython: 8.14.0
    • ipywidgets: 8.0.2
    • isoduration: 20.11.0
    • isort: 5.12.0
    • itsdangerous: 2.1.2
    • jaraco.classes: 3.2.3
    • jedi: 0.18.2
    • jinja2: 3.0.3
    • joblib: 1.2.0
    • json5: 0.9.14
    • jsonargparse: 4.25.0
    • jsonpointer: 2.0
    • jsonschema: 4.17.3
    • jupyter-client: 8.2.0
    • jupyter-core: 5.3.0
    • jupyter-events: 0.6.3
    • jupyter-lsp: 2.2.0
    • jupyter-server: 2.6.0
    • jupyter-server-terminals: 0.4.4
    • jupyterlab: 4.0.1
    • jupyterlab-pygments: 0.2.2
    • jupyterlab-server: 2.22.1
    • jupyterlab-widgets: 3.0.3
    • keyring: 23.13.1
    • kiwisolver: 1.4.4
    • kornia: 0.7.0
    • laspy: 2.2.0
    • lightly: 1.4.18
    • lightly-utils: 0.0.2
    • lightning: 2.0.9
    • lightning-cloud: 0.5.38
    • lightning-utilities: 0.8.0
    • markdown: 3.4.1
    • markdown-it-py: 3.0.0
    • markupsafe: 2.1.3
    • matplotlib: 3.8.0
    • matplotlib-inline: 0.1.6
    • mccabe: 0.7.0
    • mdurl: 0.1.2
    • mistune: 2.0.5
    • more-itertools: 9.1.0
    • mpmath: 1.2.1
    • multidict: 6.0.4
    • munch: 2.5.0
    • mypy: 1.3.0
    • mypy-extensions: 1.0.0
    • nbclient: 0.6.7
    • nbconvert: 7.4.0
    • nbformat: 5.8.0
    • nbmake: 1.4.3
    • nbsphinx: 0.8.8
    • nest-asyncio: 1.5.6
    • netcdf4: 1.6.2
    • networkx: 3.1
    • notebook-shim: 0.2.3
    • numexpr: 2.8.4
    • numpy: 1.25.2
    • oauthlib: 3.2.1
    • odc-geo: 0.1.2
    • omegaconf: 2.3.0
    • openpyxl: 3.1.2
    • ordered-set: 4.0.2
    • overrides: 7.3.1
    • packaging: 23.1
    • pandas: 2.0.2
    • pandocfilters: 1.5.0
    • parso: 0.8.3
    • pathspec: 0.11.1
    • pexpect: 4.8.0
    • pickleshare: 0.7.5
    • pillow: 9.5.0
    • pip: 23.0
    • pkginfo: 1.9.6
    • planetary-computer: 0.4.9
    • platformdirs: 3.5.3
    • pluggy: 1.0.0
    • poetry-core: 1.6.1
    • pretrainedmodels: 0.7.4
    • prometheus-client: 0.17.0
    • prompt-toolkit: 3.0.38
    • protobuf: 3.20.3
    • psutil: 5.9.5
    • ptyprocess: 0.7.0
    • pure-eval: 0.2.2
    • pyasn1: 0.4.8
    • pyasn1-modules: 0.2.8
    • pybind11: 2.10.1
    • pycocotools: 2.0.6
    • pycodestyle: 2.11.0
    • pycparser: 2.21
    • pydantic: 1.10.9
    • pydocstyle: 6.2.1
    • pyflakes: 3.1.0
    • pygeos: 0.10
    • pygments: 2.15.1
    • pyjwt: 2.4.0
    • pyparsing: 3.0.9
    • pyproj: 3.6.0
    • pyproject-hooks: 1.0.0
    • pyrsistent: 0.19.3
    • pyshp: 2.1.0
    • pystac: 1.4.0
    • pystac-client: 0.5.1
    • pytest: 7.3.2
    • pytest-cov: 4.0.0
    • python-dateutil: 2.8.2
    • python-dotenv: 0.19.2
    • python-editor: 1.0.4
    • python-json-logger: 2.0.7
    • python-multipart: 0.0.5
    • pytorch-lightning: 2.0.0
    • pytorch-sphinx-theme: 0.0.24
    • pytz: 2023.3
    • pyupgrade: 3.3.1
    • pyyaml: 6.0
    • pyzmq: 25.0.2
    • radiant-mlhub: 0.5.1
    • rarfile: 4.1
    • rasterio: 1.3.8
    • readchar: 4.0.5
    • readme-renderer: 37.3
    • requests: 2.31.0
    • requests-oauthlib: 1.3.1
    • requests-toolbelt: 1.0.0
    • rfc3339-validator: 0.1.4
    • rfc3986: 2.0.0
    • rfc3986-validator: 0.1.1
    • rich: 13.4.2
    • rioxarray: 0.4.1.post0
    • rsa: 4.9
    • rtree: 1.0.1
    • safetensors: 0.3.1
    • scikit-learn: 1.3.1
    • scipy: 1.10.1
    • segmentation-models-pytorch: 0.3.3
    • send2trash: 1.8.0
    • setuptools: 63.4.3
    • setuptools-scm: 7.1.0
    • shapely: 1.8.4
    • six: 1.16.0
    • sniffio: 1.3.0
    • snowballstemmer: 2.2.0
    • snuggs: 1.4.1
    • soupsieve: 2.4.1
    • sphinx: 5.3.0
    • sphinx-design: 0.4.1
    • sphinx-rtd-theme: 1.2.2
    • sphinxcontrib-applehelp: 1.0.2
    • sphinxcontrib-devhelp: 1.0.2
    • sphinxcontrib-htmlhelp: 2.0.0
    • sphinxcontrib-jquery: 4.1
    • sphinxcontrib-jsmath: 1.0.1
    • sphinxcontrib-programoutput: 0.15
    • sphinxcontrib-qthelp: 1.0.3
    • sphinxcontrib-serializinghtml: 1.1.9
    • stack-data: 0.6.2
    • starlette: 0.27.0
    • starsessions: 1.3.0
    • sympy: 1.11.1
    • tensorboard: 2.13.0
    • tensorboard-data-server: 0.7.0
    • tensorboard-plugin-wit: 1.8.1
    • terminado: 0.17.1
    • threadpoolctl: 3.1.0
    • timm: 0.9.2
    • tinycss2: 1.1.1
    • tokenize-rt: 4.2.1
    • tomli: 2.0.1
    • torch: 2.0.1
    • torchmetrics: 1.1.1
    • torchvision: 0.15.2
    • tornado: 6.2
    • tqdm: 4.65.0
    • traitlets: 5.9.0
    • trove-classifiers: 2023.3.9
    • twine: 4.0.2
    • typeshed-client: 2.1.0
    • typing-extensions: 4.6.3
    • tzdata: 2023.3
    • uri-template: 1.2.0
    • urllib3: 1.26.12
    • uvicorn: 0.20.0
    • vermin: 1.5.2
    • wcwidth: 0.2.5
    • webcolors: 1.11.1
    • webencodings: 0.5.1
    • websocket-client: 1.5.1
    • websockets: 10.4
    • werkzeug: 2.3.4
    • wheel: 0.41.2
    • widgetsnbextension: 4.0.3
    • xarray: 2023.7.0
    • yarl: 1.8.1
    • zipfile-deflate64: 0.2.0
    • zipp: 3.8.1
    • zope.event: 4.6
    • zope.interface: 5.4.0
  • System:
    • OS: Darwin
    • architecture:
      • 64bit
    • processor: arm
    • python: 3.11.4
    • release: 22.6.0
    • version: Darwin Kernel Version 22.6.0: Wed Jul 5 22:21:53 PDT 2023; root:xnu-8796.141.3~6/RELEASE_ARM64_T6020

More info

No response

cc @Borda

@adamjstewart adamjstewart added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Sep 26, 2023
@awaelchli
Copy link
Contributor

@adamjstewart I think adding MetricCollection to the Union type of log_dict's input argument should fix the issue.

@awaelchli awaelchli added help wanted Open to be worked on good first issue Good for newcomers code quality and removed needs triage Waiting to be triaged by maintainers labels Sep 26, 2023
@adamjstewart
Copy link
Contributor Author

Alright, I can submit a PR to do that. I was originally going to add it to _METRIC, but that would break other things like log. I agree it makes more sense to do Union[_METRIC, MetricCollection].

@adamjstewart
Copy link
Contributor Author

Oh wait no, it would be Union[Mapping[...], MetricCollection].

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working code quality good first issue Good for newcomers help wanted Open to be worked on ver: 2.0.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants