From aa61407c89f60b4c3b73eaf557f61756ddd38fde Mon Sep 17 00:00:00 2001 From: Mike Fiedler Date: Tue, 5 Nov 2024 16:23:35 -0500 Subject: [PATCH 1/4] lint/doc: annotate and document tasks.py Signed-off-by: Mike Fiedler --- warehouse/tasks.py | 96 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 90 insertions(+), 6 deletions(-) diff --git a/warehouse/tasks.py b/warehouse/tasks.py index 830906cb88a3..11ce07815809 100644 --- a/warehouse/tasks.py +++ b/warehouse/tasks.py @@ -10,11 +10,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import functools import hashlib import logging import os import time +import typing import urllib.parse import celery @@ -32,6 +35,9 @@ from warehouse.config import Environment from warehouse.metrics import IMetricsService +if typing.TYPE_CHECKING: + from pyramid.request import Request + # We need to trick Celery into supporting rediss:// URLs which is how redis-py # signals that you should use Redis with TLS. celery.app.backends.BACKEND_ALIASES["rediss"] = ( @@ -53,7 +59,19 @@ def _params_from_url(self, url, defaults): class WarehouseTask(celery.Task): - def __new__(cls, *args, **kwargs): + """ + A custom Celery Task that integrates with Pyramid's transaction manager and + metrics service. + """ + + __header__: typing.Callable + _wh_original_run: typing.Callable + + def __new__(cls, *args, **kwargs) -> WarehouseTask: + """ + Override to wrap the `run` method of the task with a new method that + will handle exceptions from the task and retry them if they're retryable. + """ obj = super().__new__(cls, *args, **kwargs) if getattr(obj, "__header__", None) is not None: obj.__header__ = functools.partial(obj.__header__, object()) @@ -82,16 +100,34 @@ def run(*args, **kwargs): metrics.increment("warehouse.task.failed", tags=metric_tags) raise - obj._wh_original_run, obj.run = obj.run, run + # Reassign the `run` method to the new one we've created. + obj._wh_original_run, obj.run = obj.run, run # type: ignore[method-assign] return obj def __call__(self, *args, **kwargs): + """ + Override to inject a faux request object into the task when it's called. + There's no WSGI request object available when a task is called, so we + create a fake one here. This is necessary as a lot of our code assumes + that there's a Pyramid request object available. + """ return super().__call__(*(self.get_request(),) + args, **kwargs) - def get_request(self): + def get_request(self) -> Request: + """ + Get a request object to use for this task. + + This will either return the request object that was injected into the + task when it was called, or it will create a new request object to use + for the task. + + Note: The `type: ignore` comments are necessary because the `pyramid_env` + attribute is not defined on the request object, but we're adding it + dynamically. + """ if not hasattr(self.request, "pyramid_env"): - registry = self.app.pyramid_config.registry + registry = self.app.pyramid_config.registry # type: ignore[attr-defined] env = pyramid.scripting.prepare(registry=registry) env["request"].tm = transaction.TransactionManager(explicit=True) env["request"].timings = {"new_request_start": time.time() * 1000} @@ -101,15 +137,29 @@ def get_request(self): ).hexdigest() self.request.update(pyramid_env=env) - return self.request.pyramid_env["request"] + return self.request.pyramid_env["request"] # type: ignore[attr-defined] def after_return(self, status, retval, task_id, args, kwargs, einfo): + """ + Called after the task has returned. This is where we'll clean up the + request object that we injected into the task. + """ if hasattr(self.request, "pyramid_env"): pyramid_env = self.request.pyramid_env pyramid_env["request"]._process_finished_callbacks() pyramid_env["closer"]() def apply_async(self, *args, **kwargs): + """ + Override the apply_async method to add an after commit hook to the + transaction manager to send the task after the transaction has been + committed. + + This is necessary because we want to ensure that the task is only sent + after the transaction has been committed. This is important because we + want to ensure that the task is only sent if the transaction was + successful. + """ # The API design of Celery makes this threadlocal pretty impossible to # avoid :( request = get_current_request() @@ -137,17 +187,51 @@ def apply_async(self, *args, **kwargs): ) def retry(self, *args, **kwargs): + """ + Override the retry method to increment a metric when a task is retried. + + This is necessary because the `retry` method is called when a task is + retried, and we want to track how many times a task has been retried. + """ request = get_current_request() metrics = request.find_service(IMetricsService, context=None) metrics.increment("warehouse.task.retried", tags=[f"task:{self.name}"]) return super().retry(*args, **kwargs) def _after_commit_hook(self, success, *args, **kwargs): + """ + This is the hook that gets called after the transaction has been + committed. We'll only send the task if the transaction was successful. + """ if success: super().apply_async(*args, **kwargs) def task(**kwargs): + """ + A decorator that can be used to define a Celery task. + + A thin wrapper around Celery's `task` decorator that allows us to attach + the task to the Celery app when the configuration is scanned during the + application startup. + + This decorator also sets the `shared` option to `False` by default. This + means that the task will be created anew for each worker process that is + started. This is important because the `WarehouseTask` class that we use + for our tasks is not thread-safe, so we need to ensure that each worker + process has its own instance of the task. + + This decorator also adds the task to the `warehouse` category in the + configuration scanner. This is important because we use this category to + find all the tasks that have been defined in the configuration. + + Example usage: + ``` + @tasks.task(...) + def my_task(self, *args, **kwargs): + pass + ``` + """ kwargs.setdefault("shared", False) def deco(wrapped): @@ -193,7 +277,7 @@ def add_task(): def includeme(config): s = config.registry.settings - broker_transport_options = {} + broker_transport_options: dict[str, str | dict] = {} broker_url = s.get("celery.broker_url") if broker_url is None: From d235a8f955fa6d5bc7efd6d3030dbe5fc36bbdd9 Mon Sep 17 00:00:00 2001 From: Mike Fiedler Date: Fri, 20 Dec 2024 10:59:37 -0500 Subject: [PATCH 2/4] refactor: update queries for clarity Using the same variable name and reassigning it multiple times is not necessary. Signed-off-by: Mike Fiedler --- warehouse/packaging/models.py | 40 +++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/warehouse/packaging/models.py b/warehouse/packaging/models.py index e0c27e9bf0b2..4a409c3465fb 100644 --- a/warehouse/packaging/models.py +++ b/warehouse/packaging/models.py @@ -333,19 +333,26 @@ def __acl__(self): acls.append((Allow, f"oidc:{publisher.id}", [Permissions.ProjectsUpload])) # Get all of the users for this project. - query = session.query(Role).filter(Role.project == self) - query = query.options(orm.lazyload(Role.project)) - query = query.options(orm.lazyload(Role.user)) + user_query = ( + session.query(Role) + .filter(Role.project == self) + .options(orm.lazyload(Role.project), orm.lazyload(Role.user)) + ) permissions = { (role.user_id, "Administer" if role.role_name == "Owner" else "Upload") - for role in query.all() + for role in user_query.all() } # Add all of the team members for this project. - query = session.query(TeamProjectRole).filter(TeamProjectRole.project == self) - query = query.options(orm.lazyload(TeamProjectRole.project)) - query = query.options(orm.lazyload(TeamProjectRole.team)) - for role in query.all(): + team_query = ( + session.query(TeamProjectRole) + .filter(TeamProjectRole.project == self) + .options( + orm.lazyload(TeamProjectRole.project), + orm.lazyload(TeamProjectRole.team), + ) + ) + for role in team_query.all(): permissions |= { (user.id, "Administer" if role.role_name.value == "Owner" else "Upload") for user in role.team.members @@ -353,13 +360,18 @@ def __acl__(self): # Add all organization owners for this project. if self.organization: - query = session.query(OrganizationRole).filter( - OrganizationRole.organization == self.organization, - OrganizationRole.role_name == OrganizationRoleType.Owner, + org_query = ( + session.query(OrganizationRole) + .filter( + OrganizationRole.organization == self.organization, + OrganizationRole.role_name == OrganizationRoleType.Owner, + ) + .options( + orm.lazyload(OrganizationRole.organization), + orm.lazyload(OrganizationRole.user), + ) ) - query = query.options(orm.lazyload(OrganizationRole.organization)) - query = query.options(orm.lazyload(OrganizationRole.user)) - permissions |= {(role.user_id, "Administer") for role in query.all()} + permissions |= {(role.user_id, "Administer") for role in org_query.all()} for user_id, permission_name in sorted(permissions, key=lambda x: (x[1], x[0])): # Disallow Write permissions for Projects in quarantine, allow Upload From 168e18a2d4cddf0fbd02d29f474871094ab61126 Mon Sep 17 00:00:00 2001 From: Mike Fiedler Date: Tue, 19 Nov 2024 10:54:10 -0500 Subject: [PATCH 3/4] lint: assign explicit types Signed-off-by: Mike Fiedler --- warehouse/sitemap/views.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/warehouse/sitemap/views.py b/warehouse/sitemap/views.py index 2573574a8564..92cc7764744e 100644 --- a/warehouse/sitemap/views.py +++ b/warehouse/sitemap/views.py @@ -88,15 +88,15 @@ def sitemap_index(request): .group_by(User.sitemap_bucket) .all() ) - buckets = {} + buckets: dict[str, datetime.datetime] = {} for b in itertools.chain(projects, users): current = buckets.setdefault(b.sitemap_bucket, b.modified) if current is None or (b.modified is not None and b.modified > current): buckets[b.sitemap_bucket] = b.modified - buckets = [Bucket(name=k, modified=v) for k, v in buckets.items()] - buckets.sort(key=lambda x: x.name) + bucket_list = [Bucket(name=k, modified=v) for k, v in buckets.items()] + bucket_list.sort(key=lambda x: x.name) - return {"buckets": buckets} + return {"buckets": bucket_list} @view_config( From cb9cdf1bbb748661d076a42227548ed01754d109 Mon Sep 17 00:00:00 2001 From: Mike Fiedler Date: Fri, 15 Nov 2024 12:01:16 -0500 Subject: [PATCH 4/4] lint: correct types - The type returned from `.values()` is an iterable object, not an iterator. - The input `reserved_names` is either a set or list, not a sequence. - Adding `.any()` produces a boolean to match the filter condition's expectations, currently silently coerced. Signed-off-by: Mike Fiedler --- warehouse/cli/db/dbml.py | 4 ++-- warehouse/packaging/models.py | 2 +- warehouse/search/tasks.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/warehouse/cli/db/dbml.py b/warehouse/cli/db/dbml.py index 5b622dfed4e1..3312a74f2360 100644 --- a/warehouse/cli/db/dbml.py +++ b/warehouse/cli/db/dbml.py @@ -12,7 +12,7 @@ import json -from collections.abc import Iterator +from collections.abc import Iterable from typing import Literal, NotRequired, TypedDict import click @@ -95,7 +95,7 @@ class TableInfo(TypedDict): comment: NotRequired[str] -def generate_dbml_file(tables: Iterator[Table], _output: str | None) -> None: +def generate_dbml_file(tables: Iterable[Table], _output: str | None) -> None: file = click.open_file(_output, "w") if _output else click.open_file("-", "w") tables_info = {} diff --git a/warehouse/packaging/models.py b/warehouse/packaging/models.py index 4a409c3465fb..a023eafb738a 100644 --- a/warehouse/packaging/models.py +++ b/warehouse/packaging/models.py @@ -771,7 +771,7 @@ def urls_by_verify_status(self, *, verified: bool): return _urls def verified_user_name_and_repo_name( - self, domains: set[str], reserved_names: typing.Sequence[str] | None = None + self, domains: set[str], reserved_names: typing.Collection[str] | None = None ): for _, url in self.urls_by_verify_status(verified=True).items(): try: diff --git a/warehouse/search/tasks.py b/warehouse/search/tasks.py index 3b3d05999cfd..a675ee929393 100644 --- a/warehouse/search/tasks.py +++ b/warehouse/search/tasks.py @@ -42,7 +42,7 @@ def _project_docs(db, project_name=None): releases_list = ( select(Release.id) - .filter(Release.yanked.is_(False), Release.files) + .filter(Release.yanked.is_(False), Release.files.any()) .order_by( Release.project_id, Release.is_prerelease.nullslast(),