diff --git a/src/lightning_utilities/__about__.py b/src/lightning_utilities/__about__.py index 781b80dd..f49ae695 100644 --- a/src/lightning_utilities/__about__.py +++ b/src/lightning_utilities/__about__.py @@ -1,6 +1,6 @@ import time -__version__ = "0.11.4" +__version__ = "0.11.5" __author__ = "Lightning AI et al." __author_email__ = "pytorch@lightning.ai" __license__ = "Apache-2.0" diff --git a/src/lightning_utilities/core/imports.py b/src/lightning_utilities/core/imports.py index 5362845b..d8435966 100644 --- a/src/lightning_utilities/core/imports.py +++ b/src/lightning_utilities/core/imports.py @@ -7,7 +7,7 @@ import os import warnings from functools import lru_cache -from importlib.metadata import PackageNotFoundError +from importlib.metadata import PackageNotFoundError, distribution from importlib.metadata import version as _version from importlib.util import find_spec from types import ModuleType @@ -128,7 +128,9 @@ def _check_requirement(self) -> None: try: req = Requirement(self.requirement) pkg_version = Version(_version(req.name)) - self.available = req.specifier.contains(pkg_version) + self.available = req.specifier.contains(pkg_version) and ( + not req.extras or self._check_extras_available(req) + ) except (PackageNotFoundError, InvalidVersion) as ex: self.available = False self.message = f"{ex.__class__.__name__}: {ex}. HINT: Try running `pip install -U {self.requirement!r}`" @@ -143,6 +145,9 @@ def _check_requirement(self) -> None: self.available = module_available(module) if self.available: self.message = f"Module {module!r} available" + self.message = ( + f"Requirement {self.requirement!r} not met. HINT: Try running `pip install -U {self.requirement!r}`" + ) def _check_module(self) -> None: assert self.module # noqa: S101; needed for typing @@ -160,6 +165,34 @@ def _check_available(self) -> None: if getattr(self, "available", True) and self.module: self._check_module() + def _check_extras_available(self, requirement: Requirement) -> bool: + if not requirement.extras: + return True + + extra_requirements = self._get_extra_requirements(requirement) + + if not extra_requirements: + # The specified extra is not found in the package metadata + return False + + # Verify each extra requirement is installed + for extra_req in extra_requirements: + try: + extra_dist = distribution(extra_req.name) + extra_installed_version = Version(extra_dist.version) + if extra_req.specifier and not extra_req.specifier.contains(extra_installed_version): + return False + except importlib.metadata.PackageNotFoundError: + return False + + return True + + def _get_extra_requirements(self, requirement: Requirement) -> List[Requirement]: + dist = distribution(requirement.name) + # Get the required dependencies for the specified extras + extra_requirements = dist.metadata.get_all("Requires-Dist") or [] + return [Requirement(r) for r in extra_requirements if any(extra in r for extra in requirement.extras)] + def __bool__(self) -> bool: """Format as bool.""" self._check_available() diff --git a/tests/unittests/core/test_imports.py b/tests/unittests/core/test_imports.py index 7e3998f9..2c406feb 100644 --- a/tests/unittests/core/test_imports.py +++ b/tests/unittests/core/test_imports.py @@ -1,5 +1,7 @@ import operator import re +from unittest import mock +from unittest.mock import Mock import pytest from lightning_utilities.core.imports import ( @@ -61,6 +63,41 @@ def test_requirement_cache(): assert not cache assert "pip install -U 'this_module_is_not_installed" in str(cache) + cache = RequirementCache("pytest[not-valid-extra]") + assert not cache + assert "pip install -U 'pytest[not-valid-extra]" in str(cache) + + +@mock.patch("lightning_utilities.core.imports.Requirement") +@mock.patch("lightning_utilities.core.imports._version") +@mock.patch("lightning_utilities.core.imports.distribution") +def test_requirement_cache_with_extras(distribution_mock, version_mock, requirement_mock): + requirement_mock().specifier.contains.return_value = True + requirement_mock().name = "jsonargparse" + requirement_mock().extras = [] + version_mock.return_value = "1.0.0" + assert RequirementCache("jsonargparse>=1.0.0") + + with mock.patch("lightning_utilities.core.imports.RequirementCache._get_extra_requirements") as get_extra_req_mock: + get_extra_req_mock.return_value = [ + # Extra packages, all versions satisfied + Mock(name="extra_package1", specifier=Mock(contains=Mock(return_value=True))), + Mock(name="extra_package2", specifier=Mock(contains=Mock(return_value=True))), + ] + distribution_mock.return_value = Mock(version="0.10.0") + requirement_mock().extras = ["signatures"] + assert RequirementCache("jsonargparse[signatures]>=1.0.0") + + with mock.patch("lightning_utilities.core.imports.RequirementCache._get_extra_requirements") as get_extra_req_mock: + get_extra_req_mock.return_value = [ + # Extra packages, but not all versions are satisfied + Mock(name="extra_package1", specifier=Mock(contains=Mock(return_value=True))), + Mock(name="extra_package2", specifier=Mock(contains=Mock(return_value=False))), + ] + distribution_mock.return_value = Mock(version="0.10.0") + requirement_mock().extras = ["signatures"] + assert not RequirementCache("jsonargparse[signatures]>=1.0.0") + def test_module_available_cache(): assert RequirementCache(module="pytest")