Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor the search for direct-origin dependencies #5904

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions src/poetry/console/commands/show.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import cast

from cleo.helpers import argument
from cleo.helpers import option
from poetry.core.packages.directory_dependency import DirectoryDependency
from poetry.core.packages.file_dependency import FileDependency
from poetry.core.packages.vcs_dependency import VCSDependency

from poetry.console.commands.group_command import GroupCommand
from poetry.utils.helpers import canonicalize_name
Expand Down Expand Up @@ -499,16 +495,7 @@ def find_latest_package(
for dep in requires:
if dep.name == package.name:
provider = Provider(root, self.poetry.pool, NullIO())

if dep.is_vcs():
dep = cast(VCSDependency, dep)
return provider.search_for_vcs(dep)[0]
if dep.is_file():
dep = cast(FileDependency, dep)
return provider.search_for_file(dep)[0]
if dep.is_directory():
dep = cast(DirectoryDependency, dep)
return provider.search_for_directory(dep)[0]
return provider.search_for_direct_origin_dependency(dep)

name = package.name
selector = VersionSelector(self.poetry.pool)
Expand Down
120 changes: 53 additions & 67 deletions src/poetry/puzzle/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,44 @@ def search_for_installed_packages(
)
return packages

def search_for_direct_origin_dependency(self, dependency: Dependency) -> Package:
package = self._deferred_cache.get(dependency)
if package is not None:
pass

elif dependency.is_vcs():
dependency = cast(VCSDependency, dependency)
package = self._search_for_vcs(dependency)

elif dependency.is_file():
dependency = cast(FileDependency, dependency)
package = self._search_for_file(dependency)

elif dependency.is_directory():
dependency = cast(DirectoryDependency, dependency)
package = self._search_for_directory(dependency)

elif dependency.is_url():
dependency = cast(URLDependency, dependency)
package = self._search_for_url(dependency)

else:
raise RuntimeError(
f"Unknown direct dependency type {dependency.source_type}"
)

if dependency.is_vcs():
dependency._source_reference = package.source_reference
dependency._source_resolved_reference = package.source_resolved_reference
dependency._source_subdirectory = package.source_subdirectory

dependency._constraint = package.version
dependency._pretty_constraint = package.version.text

self._deferred_cache[dependency] = package

return package

def search_for(self, dependency: Dependency) -> list[DependencyPackage]:
"""
Search for the specifications that match the given dependency.
Expand All @@ -231,18 +269,9 @@ def search_for(self, dependency: Dependency) -> list[DependencyPackage]:
if dependency.is_root:
return PackageCollection(dependency, [self._package])

if dependency.is_vcs():
dependency = cast(VCSDependency, dependency)
packages = self.search_for_vcs(dependency)
elif dependency.is_file():
dependency = cast(FileDependency, dependency)
packages = self.search_for_file(dependency)
elif dependency.is_directory():
dependency = cast(DirectoryDependency, dependency)
packages = self.search_for_directory(dependency)
elif dependency.is_url():
dependency = cast(URLDependency, dependency)
packages = self.search_for_url(dependency)
if dependency.is_direct_origin():
packages = [self.search_for_direct_origin_dependency(dependency)]

else:
packages = self._pool.find_packages(dependency)

Expand All @@ -259,7 +288,7 @@ def search_for(self, dependency: Dependency) -> list[DependencyPackage]:

return PackageCollection(dependency, packages)

def search_for_vcs(self, dependency: VCSDependency) -> list[Package]:
def _search_for_vcs(self, dependency: VCSDependency) -> Package:
"""
Search for the specifications that match the given VCS dependency.

Expand All @@ -281,16 +310,7 @@ def search_for_vcs(self, dependency: VCSDependency) -> list[Package]:

package.develop = dependency.develop

dependency._constraint = package.version
dependency._pretty_constraint = package.version.text

dependency._source_reference = package.source_reference
dependency._source_resolved_reference = package.source_resolved_reference
dependency._source_subdirectory = package.source_subdirectory

self._deferred_cache[dependency] = package

return [package]
return package

@staticmethod
def get_package_from_vcs(
Expand All @@ -314,18 +334,8 @@ def get_package_from_vcs(
source_root=source_root,
)

def search_for_file(self, dependency: FileDependency) -> list[Package]:
if dependency in self._deferred_cache:
_package = self._deferred_cache[dependency]

package = _package.clone()
else:
package = self.get_package_from_file(dependency.full_path)

dependency._constraint = package.version
dependency._pretty_constraint = package.version.text

self._deferred_cache[dependency] = package
def _search_for_file(self, dependency: FileDependency) -> Package:
package = self.get_package_from_file(dependency.full_path)

self.validate_package_for_dependency(dependency=dependency, package=package)

Expand All @@ -336,7 +346,7 @@ def search_for_file(self, dependency: FileDependency) -> list[Package]:
{"file": dependency.path.name, "hash": "sha256:" + dependency.hash()}
]

return [package]
return package

@classmethod
def get_package_from_file(cls, file_path: Path) -> Package:
Expand All @@ -351,18 +361,8 @@ def get_package_from_file(cls, file_path: Path) -> Package:

return package

def search_for_directory(self, dependency: DirectoryDependency) -> list[Package]:
if dependency in self._deferred_cache:
_package = self._deferred_cache[dependency]

package = _package.clone()
else:
package = self.get_package_from_directory(dependency.full_path)

dependency._constraint = package.version
dependency._pretty_constraint = package.version.text

self._deferred_cache[dependency] = package
def _search_for_directory(self, dependency: DirectoryDependency) -> Package:
package = self.get_package_from_directory(dependency.full_path)

self.validate_package_for_dependency(dependency=dependency, package=package)

Expand All @@ -371,16 +371,13 @@ def search_for_directory(self, dependency: DirectoryDependency) -> list[Package]
if dependency.base is not None:
package.root_dir = dependency.base

return [package]
return package

@classmethod
def get_package_from_directory(cls, directory: Path) -> Package:
return PackageInfo.from_directory(path=directory).to_package(root_dir=directory)

def search_for_url(self, dependency: URLDependency) -> list[Package]:
if dependency in self._deferred_cache:
return [self._deferred_cache[dependency]]

def _search_for_url(self, dependency: URLDependency) -> Package:
package = self.get_package_from_url(dependency.url)

self.validate_package_for_dependency(dependency=dependency, package=package)
Expand All @@ -393,12 +390,7 @@ def search_for_url(self, dependency: URLDependency) -> list[Package]:
for extra_dep in package.extras[extra]:
package.add_dependency(extra_dep)

dependency._constraint = package.version
dependency._pretty_constraint = package.version.text

self._deferred_cache[dependency] = package

return [package]
return package

@classmethod
def get_package_from_url(cls, url: str) -> Package:
Expand Down Expand Up @@ -538,14 +530,8 @@ def complete_package(self, package: DependencyPackage) -> DependencyPackage:
if self._load_deferred:
# Retrieving constraints for deferred dependencies
for r in requires:
if r.is_directory():
self.search_for_directory(r)
elif r.is_file():
self.search_for_file(r)
elif r.is_vcs():
self.search_for_vcs(r)
elif r.is_url():
self.search_for_url(r)
if r.is_direct_origin():
self.search_for_direct_origin_dependency(r)

optional_dependencies = []
_dependencies = []
Expand Down
36 changes: 18 additions & 18 deletions tests/puzzle/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ def test_search_for_vcs_retains_develop_flag(provider: Provider, value: bool):
dependency = VCSDependency(
"demo", "git", "https://github.com/demo/demo.git", develop=value
)
package = provider.search_for_vcs(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)
assert package.develop == value


def test_search_for_vcs_setup_egg_info(provider: Provider):
dependency = VCSDependency("demo", "git", "https://github.com/demo/demo.git")

package = provider.search_for_vcs(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.2"
Expand All @@ -87,7 +87,7 @@ def test_search_for_vcs_setup_egg_info_with_extras(provider: Provider):
"demo", "git", "https://github.com/demo/demo.git", extras=["foo"]
)

package = provider.search_for_vcs(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.2"
Expand All @@ -107,7 +107,7 @@ def test_search_for_vcs_read_setup(provider: Provider, mocker: MockerFixture):

dependency = VCSDependency("demo", "git", "https://github.com/demo/demo.git")

package = provider.search_for_vcs(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.2"
Expand All @@ -131,7 +131,7 @@ def test_search_for_vcs_read_setup_with_extras(
"demo", "git", "https://github.com/demo/demo.git", extras=["foo"]
)

package = provider.search_for_vcs(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.2"
Expand All @@ -153,7 +153,7 @@ def test_search_for_vcs_read_setup_raises_error_if_no_version(
dependency = VCSDependency("demo", "git", "https://github.com/demo/no-version.git")

with pytest.raises(RuntimeError):
provider.search_for_vcs(dependency)
provider.search_for_direct_origin_dependency(dependency)


@pytest.mark.parametrize("directory", ["demo", "non-canonical-name"])
Expand All @@ -168,7 +168,7 @@ def test_search_for_directory_setup_egg_info(provider: Provider, directory: str)
/ directory,
)

package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.2"
Expand All @@ -195,7 +195,7 @@ def test_search_for_directory_setup_egg_info_with_extras(provider: Provider):
extras=["foo"],
)

package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.2"
Expand Down Expand Up @@ -228,7 +228,7 @@ def test_search_for_directory_setup_with_base(provider: Provider, directory: str
/ directory,
)

package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.2"
Expand Down Expand Up @@ -266,7 +266,7 @@ def test_search_for_directory_setup_read_setup(
/ "demo",
)

package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.2"
Expand Down Expand Up @@ -297,7 +297,7 @@ def test_search_for_directory_setup_read_setup_with_extras(
extras=["foo"],
)

package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.2"
Expand All @@ -323,7 +323,7 @@ def test_search_for_directory_setup_read_setup_with_no_dependencies(provider: Pr
/ "no-dependencies",
)

package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.2"
Expand All @@ -337,7 +337,7 @@ def test_search_for_directory_poetry(provider: Provider):
Path(__file__).parent.parent / "fixtures" / "project_with_extras",
)

package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "project-with-extras"
assert package.version.text == "1.2.3"
Expand Down Expand Up @@ -366,7 +366,7 @@ def test_search_for_directory_poetry_with_extras(provider: Provider):
extras=["extras_a"],
)

package = provider.search_for_directory(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "project-with-extras"
assert package.version.text == "1.2.3"
Expand Down Expand Up @@ -397,7 +397,7 @@ def test_search_for_file_sdist(provider: Provider):
/ "demo-0.1.0.tar.gz",
)

package = provider.search_for_file(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.0"
Expand Down Expand Up @@ -429,7 +429,7 @@ def test_search_for_file_sdist_with_extras(provider: Provider):
extras=["foo"],
)

package = provider.search_for_file(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.0"
Expand Down Expand Up @@ -460,7 +460,7 @@ def test_search_for_file_wheel(provider: Provider):
/ "demo-0.1.0-py2.py3-none-any.whl",
)

package = provider.search_for_file(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.0"
Expand Down Expand Up @@ -492,7 +492,7 @@ def test_search_for_file_wheel_with_extras(provider: Provider):
extras=["foo"],
)

package = provider.search_for_file(dependency)[0]
package = provider.search_for_direct_origin_dependency(dependency)

assert package.name == "demo"
assert package.version.text == "0.1.0"
Expand Down