From 63c86bf952172cc479c3ce1b2e0d202e2d1f4b57 Mon Sep 17 00:00:00 2001 From: David Hotham Date: Sun, 26 Jun 2022 17:40:13 +0100 Subject: [PATCH] refactor the search for direct-origin dependencies (#5904) * refactor the search for direct-origin dependencies * clarify where the interface is * fix treatment of dependency via cache --- src/poetry/console/commands/show.py | 15 +--- src/poetry/puzzle/provider.py | 120 ++++++++++++---------------- tests/puzzle/test_provider.py | 36 ++++----- 3 files changed, 72 insertions(+), 99 deletions(-) diff --git a/src/poetry/console/commands/show.py b/src/poetry/console/commands/show.py index 1218af67255..14dbe6bda9f 100644 --- a/src/poetry/console/commands/show.py +++ b/src/poetry/console/commands/show.py @@ -1,13 +1,9 @@ from __future__ import annotations from typing import TYPE_CHECKING -from typing import cast from cleo.helpers import argument from cleo.helpers import option -from poetry.core.packages.directory_dependency import DirectoryDependency -from poetry.core.packages.file_dependency import FileDependency -from poetry.core.packages.vcs_dependency import VCSDependency from poetry.console.commands.group_command import GroupCommand from poetry.utils.helpers import canonicalize_name @@ -499,16 +495,7 @@ def find_latest_package( for dep in requires: if dep.name == package.name: provider = Provider(root, self.poetry.pool, NullIO()) - - if dep.is_vcs(): - dep = cast(VCSDependency, dep) - return provider.search_for_vcs(dep)[0] - if dep.is_file(): - dep = cast(FileDependency, dep) - return provider.search_for_file(dep)[0] - if dep.is_directory(): - dep = cast(DirectoryDependency, dep) - return provider.search_for_directory(dep)[0] + return provider.search_for_direct_origin_dependency(dep) name = package.name selector = VersionSelector(self.poetry.pool) diff --git a/src/poetry/puzzle/provider.py b/src/poetry/puzzle/provider.py index 7b75211cd84..bdaeb1fb607 100644 --- a/src/poetry/puzzle/provider.py +++ b/src/poetry/puzzle/provider.py @@ -221,6 +221,44 @@ def search_for_installed_packages( ) return packages + def search_for_direct_origin_dependency(self, dependency: Dependency) -> Package: + package = self._deferred_cache.get(dependency) + if package is not None: + pass + + elif dependency.is_vcs(): + dependency = cast(VCSDependency, dependency) + package = self._search_for_vcs(dependency) + + elif dependency.is_file(): + dependency = cast(FileDependency, dependency) + package = self._search_for_file(dependency) + + elif dependency.is_directory(): + dependency = cast(DirectoryDependency, dependency) + package = self._search_for_directory(dependency) + + elif dependency.is_url(): + dependency = cast(URLDependency, dependency) + package = self._search_for_url(dependency) + + else: + raise RuntimeError( + f"Unknown direct dependency type {dependency.source_type}" + ) + + if dependency.is_vcs(): + dependency._source_reference = package.source_reference + dependency._source_resolved_reference = package.source_resolved_reference + dependency._source_subdirectory = package.source_subdirectory + + dependency._constraint = package.version + dependency._pretty_constraint = package.version.text + + self._deferred_cache[dependency] = package + + return package + def search_for(self, dependency: Dependency) -> list[DependencyPackage]: """ Search for the specifications that match the given dependency. @@ -231,18 +269,9 @@ def search_for(self, dependency: Dependency) -> list[DependencyPackage]: if dependency.is_root: return PackageCollection(dependency, [self._package]) - if dependency.is_vcs(): - dependency = cast(VCSDependency, dependency) - packages = self.search_for_vcs(dependency) - elif dependency.is_file(): - dependency = cast(FileDependency, dependency) - packages = self.search_for_file(dependency) - elif dependency.is_directory(): - dependency = cast(DirectoryDependency, dependency) - packages = self.search_for_directory(dependency) - elif dependency.is_url(): - dependency = cast(URLDependency, dependency) - packages = self.search_for_url(dependency) + if dependency.is_direct_origin(): + packages = [self.search_for_direct_origin_dependency(dependency)] + else: packages = self._pool.find_packages(dependency) @@ -259,7 +288,7 @@ def search_for(self, dependency: Dependency) -> list[DependencyPackage]: return PackageCollection(dependency, packages) - def search_for_vcs(self, dependency: VCSDependency) -> list[Package]: + def _search_for_vcs(self, dependency: VCSDependency) -> Package: """ Search for the specifications that match the given VCS dependency. @@ -281,16 +310,7 @@ def search_for_vcs(self, dependency: VCSDependency) -> list[Package]: package.develop = dependency.develop - dependency._constraint = package.version - dependency._pretty_constraint = package.version.text - - dependency._source_reference = package.source_reference - dependency._source_resolved_reference = package.source_resolved_reference - dependency._source_subdirectory = package.source_subdirectory - - self._deferred_cache[dependency] = package - - return [package] + return package @staticmethod def get_package_from_vcs( @@ -314,18 +334,8 @@ def get_package_from_vcs( source_root=source_root, ) - def search_for_file(self, dependency: FileDependency) -> list[Package]: - if dependency in self._deferred_cache: - _package = self._deferred_cache[dependency] - - package = _package.clone() - else: - package = self.get_package_from_file(dependency.full_path) - - dependency._constraint = package.version - dependency._pretty_constraint = package.version.text - - self._deferred_cache[dependency] = package + def _search_for_file(self, dependency: FileDependency) -> Package: + package = self.get_package_from_file(dependency.full_path) self.validate_package_for_dependency(dependency=dependency, package=package) @@ -336,7 +346,7 @@ def search_for_file(self, dependency: FileDependency) -> list[Package]: {"file": dependency.path.name, "hash": "sha256:" + dependency.hash()} ] - return [package] + return package @classmethod def get_package_from_file(cls, file_path: Path) -> Package: @@ -351,18 +361,8 @@ def get_package_from_file(cls, file_path: Path) -> Package: return package - def search_for_directory(self, dependency: DirectoryDependency) -> list[Package]: - if dependency in self._deferred_cache: - _package = self._deferred_cache[dependency] - - package = _package.clone() - else: - package = self.get_package_from_directory(dependency.full_path) - - dependency._constraint = package.version - dependency._pretty_constraint = package.version.text - - self._deferred_cache[dependency] = package + def _search_for_directory(self, dependency: DirectoryDependency) -> Package: + package = self.get_package_from_directory(dependency.full_path) self.validate_package_for_dependency(dependency=dependency, package=package) @@ -371,16 +371,13 @@ def search_for_directory(self, dependency: DirectoryDependency) -> list[Package] if dependency.base is not None: package.root_dir = dependency.base - return [package] + return package @classmethod def get_package_from_directory(cls, directory: Path) -> Package: return PackageInfo.from_directory(path=directory).to_package(root_dir=directory) - def search_for_url(self, dependency: URLDependency) -> list[Package]: - if dependency in self._deferred_cache: - return [self._deferred_cache[dependency]] - + def _search_for_url(self, dependency: URLDependency) -> Package: package = self.get_package_from_url(dependency.url) self.validate_package_for_dependency(dependency=dependency, package=package) @@ -393,12 +390,7 @@ def search_for_url(self, dependency: URLDependency) -> list[Package]: for extra_dep in package.extras[extra]: package.add_dependency(extra_dep) - dependency._constraint = package.version - dependency._pretty_constraint = package.version.text - - self._deferred_cache[dependency] = package - - return [package] + return package @classmethod def get_package_from_url(cls, url: str) -> Package: @@ -538,14 +530,8 @@ def complete_package(self, package: DependencyPackage) -> DependencyPackage: if self._load_deferred: # Retrieving constraints for deferred dependencies for r in requires: - if r.is_directory(): - self.search_for_directory(r) - elif r.is_file(): - self.search_for_file(r) - elif r.is_vcs(): - self.search_for_vcs(r) - elif r.is_url(): - self.search_for_url(r) + if r.is_direct_origin(): + self.search_for_direct_origin_dependency(r) optional_dependencies = [] _dependencies = [] diff --git a/tests/puzzle/test_provider.py b/tests/puzzle/test_provider.py index a31c6d82dad..a6d9d784baf 100644 --- a/tests/puzzle/test_provider.py +++ b/tests/puzzle/test_provider.py @@ -60,14 +60,14 @@ def test_search_for_vcs_retains_develop_flag(provider: Provider, value: bool): dependency = VCSDependency( "demo", "git", "https://github.com/demo/demo.git", develop=value ) - package = provider.search_for_vcs(dependency)[0] + package = provider.search_for_direct_origin_dependency(dependency) assert package.develop == value def test_search_for_vcs_setup_egg_info(provider: Provider): dependency = VCSDependency("demo", "git", "https://github.com/demo/demo.git") - package = provider.search_for_vcs(dependency)[0] + package = provider.search_for_direct_origin_dependency(dependency) assert package.name == "demo" assert package.version.text == "0.1.2" @@ -87,7 +87,7 @@ def test_search_for_vcs_setup_egg_info_with_extras(provider: Provider): "demo", "git", "https://github.com/demo/demo.git", extras=["foo"] ) - package = provider.search_for_vcs(dependency)[0] + package = provider.search_for_direct_origin_dependency(dependency) assert package.name == "demo" assert package.version.text == "0.1.2" @@ -107,7 +107,7 @@ def test_search_for_vcs_read_setup(provider: Provider, mocker: MockerFixture): dependency = VCSDependency("demo", "git", "https://github.com/demo/demo.git") - package = provider.search_for_vcs(dependency)[0] + package = provider.search_for_direct_origin_dependency(dependency) assert package.name == "demo" assert package.version.text == "0.1.2" @@ -131,7 +131,7 @@ def test_search_for_vcs_read_setup_with_extras( "demo", "git", "https://github.com/demo/demo.git", extras=["foo"] ) - package = provider.search_for_vcs(dependency)[0] + package = provider.search_for_direct_origin_dependency(dependency) assert package.name == "demo" assert package.version.text == "0.1.2" @@ -153,7 +153,7 @@ def test_search_for_vcs_read_setup_raises_error_if_no_version( dependency = VCSDependency("demo", "git", "https://github.com/demo/no-version.git") with pytest.raises(RuntimeError): - provider.search_for_vcs(dependency) + provider.search_for_direct_origin_dependency(dependency) @pytest.mark.parametrize("directory", ["demo", "non-canonical-name"]) @@ -168,7 +168,7 @@ def test_search_for_directory_setup_egg_info(provider: Provider, directory: str) / directory, ) - package = provider.search_for_directory(dependency)[0] + package = provider.search_for_direct_origin_dependency(dependency) assert package.name == "demo" assert package.version.text == "0.1.2" @@ -195,7 +195,7 @@ def test_search_for_directory_setup_egg_info_with_extras(provider: Provider): extras=["foo"], ) - package = provider.search_for_directory(dependency)[0] + package = provider.search_for_direct_origin_dependency(dependency) assert package.name == "demo" assert package.version.text == "0.1.2" @@ -228,7 +228,7 @@ def test_search_for_directory_setup_with_base(provider: Provider, directory: str / directory, ) - package = provider.search_for_directory(dependency)[0] + package = provider.search_for_direct_origin_dependency(dependency) assert package.name == "demo" assert package.version.text == "0.1.2" @@ -266,7 +266,7 @@ def test_search_for_directory_setup_read_setup( / "demo", ) - package = provider.search_for_directory(dependency)[0] + package = provider.search_for_direct_origin_dependency(dependency) assert package.name == "demo" assert package.version.text == "0.1.2" @@ -297,7 +297,7 @@ def test_search_for_directory_setup_read_setup_with_extras( extras=["foo"], ) - package = provider.search_for_directory(dependency)[0] + package = provider.search_for_direct_origin_dependency(dependency) assert package.name == "demo" assert package.version.text == "0.1.2" @@ -323,7 +323,7 @@ def test_search_for_directory_setup_read_setup_with_no_dependencies(provider: Pr / "no-dependencies", ) - package = provider.search_for_directory(dependency)[0] + package = provider.search_for_direct_origin_dependency(dependency) assert package.name == "demo" assert package.version.text == "0.1.2" @@ -337,7 +337,7 @@ def test_search_for_directory_poetry(provider: Provider): Path(__file__).parent.parent / "fixtures" / "project_with_extras", ) - package = provider.search_for_directory(dependency)[0] + package = provider.search_for_direct_origin_dependency(dependency) assert package.name == "project-with-extras" assert package.version.text == "1.2.3" @@ -366,7 +366,7 @@ def test_search_for_directory_poetry_with_extras(provider: Provider): extras=["extras_a"], ) - package = provider.search_for_directory(dependency)[0] + package = provider.search_for_direct_origin_dependency(dependency) assert package.name == "project-with-extras" assert package.version.text == "1.2.3" @@ -397,7 +397,7 @@ def test_search_for_file_sdist(provider: Provider): / "demo-0.1.0.tar.gz", ) - package = provider.search_for_file(dependency)[0] + package = provider.search_for_direct_origin_dependency(dependency) assert package.name == "demo" assert package.version.text == "0.1.0" @@ -429,7 +429,7 @@ def test_search_for_file_sdist_with_extras(provider: Provider): extras=["foo"], ) - package = provider.search_for_file(dependency)[0] + package = provider.search_for_direct_origin_dependency(dependency) assert package.name == "demo" assert package.version.text == "0.1.0" @@ -460,7 +460,7 @@ def test_search_for_file_wheel(provider: Provider): / "demo-0.1.0-py2.py3-none-any.whl", ) - package = provider.search_for_file(dependency)[0] + package = provider.search_for_direct_origin_dependency(dependency) assert package.name == "demo" assert package.version.text == "0.1.0" @@ -492,7 +492,7 @@ def test_search_for_file_wheel_with_extras(provider: Provider): extras=["foo"], ) - package = provider.search_for_file(dependency)[0] + package = provider.search_for_direct_origin_dependency(dependency) assert package.name == "demo" assert package.version.text == "0.1.0"