Skip to content

Commit

Permalink
refactor the search for direct-origin dependencies (#5904)
Browse files Browse the repository at this point in the history
* refactor the search for direct-origin dependencies

* clarify where the interface is

* fix treatment of dependency via cache
  • Loading branch information
dimbleby authored Jun 26, 2022
1 parent 8b64088 commit 63c86bf
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 99 deletions.
15 changes: 1 addition & 14 deletions src/poetry/console/commands/show.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import cast

from cleo.helpers import argument
from cleo.helpers import option
from poetry.core.packages.directory_dependency import DirectoryDependency
from poetry.core.packages.file_dependency import FileDependency
from poetry.core.packages.vcs_dependency import VCSDependency

from poetry.console.commands.group_command import GroupCommand
from poetry.utils.helpers import canonicalize_name
Expand Down Expand Up @@ -499,16 +495,7 @@ def find_latest_package(
for dep in requires:
if dep.name == package.name:
provider = Provider(root, self.poetry.pool, NullIO())

if dep.is_vcs():
dep = cast(VCSDependency, dep)
return provider.search_for_vcs(dep)[0]
if dep.is_file():
dep = cast(FileDependency, dep)
return provider.search_for_file(dep)[0]
if dep.is_directory():
dep = cast(DirectoryDependency, dep)
return provider.search_for_directory(dep)[0]
return provider.search_for_direct_origin_dependency(dep)

name = package.name
selector = VersionSelector(self.poetry.pool)
Expand Down
120 changes: 53 additions & 67 deletions src/poetry/puzzle/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,44 @@ def search_for_installed_packages(
)
return packages

def search_for_direct_origin_dependency(self, dependency: Dependency) -> Package:
package = self._deferred_cache.get(dependency)
if package is not None:
pass

elif dependency.is_vcs():
dependency = cast(VCSDependency, dependency)
package = self._search_for_vcs(dependency)

elif dependency.is_file():
dependency = cast(FileDependency, dependency)
package = self._search_for_file(dependency)

elif dependency.is_directory():
dependency = cast(DirectoryDependency, dependency)
package = self._search_for_directory(dependency)

elif dependency.is_url():
dependency = cast(URLDependency, dependency)
package = self._search_for_url(dependency)

else:
raise RuntimeError(
f"Unknown direct dependency type {dependency.source_type}"
)

if dependency.is_vcs():
dependency._source_reference = package.source_reference
dependency._source_resolved_reference = package.source_resolved_reference
dependency._source_subdirectory = package.source_subdirectory

dependency._constraint = package.version
dependency._pretty_constraint = package.version.text

self._deferred_cache[dependency] = package

return package

def search_for(self, dependency: Dependency) -> list[DependencyPackage]:
"""
Search for the specifications that match the given dependency.
Expand All @@ -231,18 +269,9 @@ def search_for(self, dependency: Dependency) -> list[DependencyPackage]:
if dependency.is_root:
return PackageCollection(dependency, [self._package])

if dependency.is_vcs():
dependency = cast(VCSDependency, dependency)
packages = self.search_for_vcs(dependency)
elif dependency.is_file():
dependency = cast(FileDependency, dependency)
packages = self.search_for_file(dependency)
elif dependency.is_directory():
dependency = cast(DirectoryDependency, dependency)
packages = self.search_for_directory(dependency)
elif dependency.is_url():
dependency = cast(URLDependency, dependency)
packages = self.search_for_url(dependency)
if dependency.is_direct_origin():
packages = [self.search_for_direct_origin_dependency(dependency)]

else:
packages = self._pool.find_packages(dependency)

Expand All @@ -259,7 +288,7 @@ def search_for(self, dependency: Dependency) -> list[DependencyPackage]:

return PackageCollection(dependency, packages)

def search_for_vcs(self, dependency: VCSDependency) -> list[Package]:
def _search_for_vcs(self, dependency: VCSDependency) -> Package:
"""
Search for the specifications that match the given VCS dependency.
Expand All @@ -281,16 +310,7 @@ def search_for_vcs(self, dependency: VCSDependency) -> list[Package]:

package.develop = dependency.develop

dependency._constraint = package.version
dependency._pretty_constraint = package.version.text

dependency._source_reference = package.source_reference
dependency._source_resolved_reference = package.source_resolved_reference
dependency._source_subdirectory = package.source_subdirectory

self._deferred_cache[dependency] = package

return [package]
return package

@staticmethod
def get_package_from_vcs(
Expand All @@ -314,18 +334,8 @@ def get_package_from_vcs(
source_root=source_root,
)

def search_for_file(self, dependency: FileDependency) -> list[Package]:
if dependency in self._deferred_cache:
_package = self._deferred_cache[dependency]

package = _package.clone()
else:
package = self.get_package_from_file(dependency.full_path)

dependency._constraint = package.version
dependency._pretty_constraint = package.version.text

self._deferred_cache[dependency] = package
def _search_for_file(self, dependency: FileDependency) -> Package:
package = self.get_package_from_file(dependency.full_path)

self.validate_package_for_dependency(dependency=dependency, package=package)

Expand All @@ -336,7 +346,7 @@ def search_for_file(self, dependency: FileDependency) -> list[Package]:
{"file": dependency.path.name, "hash": "sha256:" + dependency.hash()}
]

return [package]
return package

@classmethod
def get_package_from_file(cls, file_path: Path) -> Package:
Expand All @@ -351,18 +361,8 @@ def get_package_from_file(cls, file_path: Path) -> Package:

return package

def search_for_directory(self, dependency: DirectoryDependency) -> list[Package]:
if dependency in self._deferred_cache:
_package = self._deferred_cache[dependency]

package = _package.clone()
else:
package = self.get_package_from_directory(dependency.full_path)

dependency._constraint = package.version
dependency._pretty_constraint = package.version.text

self._deferred_cache[dependency] = package
def _search_for_directory(self, dependency: DirectoryDependency) -> Package:
package = self.get_package_from_directory(dependency.full_path)

self.validate_package_for_dependency(dependency=dependency, package=package)

Expand All @@ -371,16 +371,13 @@ def search_for_directory(self, dependency: DirectoryDependency) -> list[Package]
if dependency.base is not None:
package.root_dir = dependency.base

return [package]
return package

@classmethod
def get_package_from_directory(cls, directory: Path) -> Package:
return PackageInfo.from_directory(path=directory).to_package(root_dir=directory)

def search_for_url(self, dependency: URLDependency) -> list[Package]:
if dependency in self._deferred_cache:
return [self._deferred_cache[dependency]]

def _search_for_url(self, dependency: URLDependency) -> Package:
package = self.get_package_from_url(dependency.url)

self.validate_package_for_dependency(dependency=dependency, package=package)
Expand All @@ -393,12 +390,7 @@ def search_for_url(self, dependency: URLDependency) -> list[Package]:
for extra_dep in package.extras[extra]:
package.add_dependency(extra_dep)

dependency._constraint = package.version
dependency._pretty_constraint = package.version.text

self._deferred_cache[dependency] = package

return [package]
return package

@classmethod
def get_package_from_url(cls, url: str) -> Package:
Expand Down Expand Up @@ -538,14 +530,8 @@ def complete_package(self, package: DependencyPackage) -> DependencyPackage:
if self._load_deferred:
# Retrieving constraints for deferred dependencies
for r in requires:
if r.is_directory():
self.search_for_directory(r)
elif r.is_file():
self.search_for_file(r)
elif r.is_vcs():
self.search_for_vcs(r)
elif r.is_url():
self.search_for_url(r)
if r.is_direct_origin():
self.search_for_direct_origin_dependency(r)

optional_dependencies = []
_dependencies = []
Expand Down
36 changes: 18 additions & 18 deletions tests/puzzle/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ def test_search_for_vcs_retains_develop_flag(provider: Provider, value: bool):
dependency = VCSDependency(
"demo", "git", "https://github.com/demo/demo.git", develop=value
)
package = provider.search_for_vcs(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)
assert package.develop == value


def test_search_for_vcs_setup_egg_info(provider: Provider):
dependency = VCSDependency("demo", "git", "https://github.com/demo/demo.git")

package = provider.search_for_vcs(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.2"
Expand All @@ -87,7 +87,7 @@ def test_search_for_vcs_setup_egg_info_with_extras(provider: Provider):
"demo", "git", "https://github.com/demo/demo.git", extras=["foo"]
)

package = provider.search_for_vcs(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.2"
Expand All @@ -107,7 +107,7 @@ def test_search_for_vcs_read_setup(provider: Provider, mocker: MockerFixture):

dependency = VCSDependency("demo", "git", "https://github.com/demo/demo.git")

package = provider.search_for_vcs(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.2"
Expand All @@ -131,7 +131,7 @@ def test_search_for_vcs_read_setup_with_extras(
"demo", "git", "https://github.com/demo/demo.git", extras=["foo"]
)

package = provider.search_for_vcs(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.2"
Expand All @@ -153,7 +153,7 @@ def test_search_for_vcs_read_setup_raises_error_if_no_version(
dependency = VCSDependency("demo", "git", "https://github.com/demo/no-version.git")

with pytest.raises(RuntimeError):
provider.search_for_vcs(dependency)
provider.search_for_direct_origin_dependency(dependency)


@pytest.mark.parametrize("directory", ["demo", "non-canonical-name"])
Expand All @@ -168,7 +168,7 @@ def test_search_for_directory_setup_egg_info(provider: Provider, directory: str)
/ directory,
)

package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.2"
Expand All @@ -195,7 +195,7 @@ def test_search_for_directory_setup_egg_info_with_extras(provider: Provider):
extras=["foo"],
)

package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.2"
Expand Down Expand Up @@ -228,7 +228,7 @@ def test_search_for_directory_setup_with_base(provider: Provider, directory: str
/ directory,
)

package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.2"
Expand Down Expand Up @@ -266,7 +266,7 @@ def test_search_for_directory_setup_read_setup(
/ "demo",
)

package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.2"
Expand Down Expand Up @@ -297,7 +297,7 @@ def test_search_for_directory_setup_read_setup_with_extras(
extras=["foo"],
)

package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.2"
Expand All @@ -323,7 +323,7 @@ def test_search_for_directory_setup_read_setup_with_no_dependencies(provider: Pr
/ "no-dependencies",
)

package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.2"
Expand All @@ -337,7 +337,7 @@ def test_search_for_directory_poetry(provider: Provider):
Path(__file__).parent.parent / "fixtures" / "project_with_extras",
)

package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "project-with-extras"
assert package.version.text == "1.2.3"
Expand Down Expand Up @@ -366,7 +366,7 @@ def test_search_for_directory_poetry_with_extras(provider: Provider):
extras=["extras_a"],
)

package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "project-with-extras"
assert package.version.text == "1.2.3"
Expand Down Expand Up @@ -397,7 +397,7 @@ def test_search_for_file_sdist(provider: Provider):
/ "demo-0.1.0.tar.gz",
)

package = provider.search_for_file(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.0"
Expand Down Expand Up @@ -429,7 +429,7 @@ def test_search_for_file_sdist_with_extras(provider: Provider):
extras=["foo"],
)

package = provider.search_for_file(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.0"
Expand Down Expand Up @@ -460,7 +460,7 @@ def test_search_for_file_wheel(provider: Provider):
/ "demo-0.1.0-py2.py3-none-any.whl",
)

package = provider.search_for_file(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.0"
Expand Down Expand Up @@ -492,7 +492,7 @@ def test_search_for_file_wheel_with_extras(provider: Provider):
extras=["foo"],
)

package = provider.search_for_file(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.0"
Expand Down

0 comments on commit 63c86bf

Please sign in to comment.