Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lint/doc updates #17310

Merged
merged 6 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions warehouse/cli/db/dbml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import json

from collections.abc import Iterator
from collections.abc import Iterable
from typing import Literal, NotRequired, TypedDict

import click
Expand Down Expand Up @@ -95,7 +95,7 @@ class TableInfo(TypedDict):
comment: NotRequired[str]


def generate_dbml_file(tables: Iterator[Table], _output: str | None) -> None:
def generate_dbml_file(tables: Iterable[Table], _output: str | None) -> None:
file = click.open_file(_output, "w") if _output else click.open_file("-", "w")

tables_info = {}
Expand Down
42 changes: 27 additions & 15 deletions warehouse/packaging/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,33 +333,45 @@ def __acl__(self):
acls.append((Allow, f"oidc:{publisher.id}", [Permissions.ProjectsUpload]))

# Get all of the users for this project.
query = session.query(Role).filter(Role.project == self)
query = query.options(orm.lazyload(Role.project))
query = query.options(orm.lazyload(Role.user))
user_query = (
session.query(Role)
.filter(Role.project == self)
.options(orm.lazyload(Role.project), orm.lazyload(Role.user))
)
permissions = {
(role.user_id, "Administer" if role.role_name == "Owner" else "Upload")
for role in query.all()
for role in user_query.all()
}

# Add all of the team members for this project.
query = session.query(TeamProjectRole).filter(TeamProjectRole.project == self)
query = query.options(orm.lazyload(TeamProjectRole.project))
query = query.options(orm.lazyload(TeamProjectRole.team))
for role in query.all():
team_query = (
session.query(TeamProjectRole)
.filter(TeamProjectRole.project == self)
.options(
orm.lazyload(TeamProjectRole.project),
orm.lazyload(TeamProjectRole.team),
)
)
for role in team_query.all():
permissions |= {
(user.id, "Administer" if role.role_name.value == "Owner" else "Upload")
for user in role.team.members
}

# Add all organization owners for this project.
if self.organization:
query = session.query(OrganizationRole).filter(
OrganizationRole.organization == self.organization,
OrganizationRole.role_name == OrganizationRoleType.Owner,
org_query = (
session.query(OrganizationRole)
.filter(
OrganizationRole.organization == self.organization,
OrganizationRole.role_name == OrganizationRoleType.Owner,
)
.options(
orm.lazyload(OrganizationRole.organization),
orm.lazyload(OrganizationRole.user),
)
)
query = query.options(orm.lazyload(OrganizationRole.organization))
query = query.options(orm.lazyload(OrganizationRole.user))
permissions |= {(role.user_id, "Administer") for role in query.all()}
permissions |= {(role.user_id, "Administer") for role in org_query.all()}

for user_id, permission_name in sorted(permissions, key=lambda x: (x[1], x[0])):
# Disallow Write permissions for Projects in quarantine, allow Upload
Expand Down Expand Up @@ -759,7 +771,7 @@ def urls_by_verify_status(self, *, verified: bool):
return _urls

def verified_user_name_and_repo_name(
self, domains: set[str], reserved_names: typing.Sequence[str] | None = None
self, domains: set[str], reserved_names: typing.Collection[str] | None = None
):
for _, url in self.urls_by_verify_status(verified=True).items():
try:
Expand Down
2 changes: 1 addition & 1 deletion warehouse/search/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
def _project_docs(db, project_name=None):
releases_list = (
select(Release.id)
.filter(Release.yanked.is_(False), Release.files)
.filter(Release.yanked.is_(False), Release.files.any())
.order_by(
Release.project_id,
Release.is_prerelease.nullslast(),
Expand Down
8 changes: 4 additions & 4 deletions warehouse/sitemap/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,15 @@ def sitemap_index(request):
.group_by(User.sitemap_bucket)
.all()
)
buckets = {}
buckets: dict[str, datetime.datetime] = {}
for b in itertools.chain(projects, users):
current = buckets.setdefault(b.sitemap_bucket, b.modified)
if current is None or (b.modified is not None and b.modified > current):
buckets[b.sitemap_bucket] = b.modified
buckets = [Bucket(name=k, modified=v) for k, v in buckets.items()]
buckets.sort(key=lambda x: x.name)
bucket_list = [Bucket(name=k, modified=v) for k, v in buckets.items()]
bucket_list.sort(key=lambda x: x.name)

return {"buckets": buckets}
return {"buckets": bucket_list}


@view_config(
Expand Down
96 changes: 90 additions & 6 deletions warehouse/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import functools
import hashlib
import logging
import os
import time
import typing
import urllib.parse

import celery
Expand All @@ -32,6 +35,9 @@
from warehouse.config import Environment
from warehouse.metrics import IMetricsService

if typing.TYPE_CHECKING:
from pyramid.request import Request

# We need to trick Celery into supporting rediss:// URLs which is how redis-py
# signals that you should use Redis with TLS.
celery.app.backends.BACKEND_ALIASES["rediss"] = (
Expand All @@ -53,7 +59,19 @@ def _params_from_url(self, url, defaults):


class WarehouseTask(celery.Task):
def __new__(cls, *args, **kwargs):
"""
A custom Celery Task that integrates with Pyramid's transaction manager and
metrics service.
"""

__header__: typing.Callable
_wh_original_run: typing.Callable

def __new__(cls, *args, **kwargs) -> WarehouseTask:
"""
Override to wrap the `run` method of the task with a new method that
will handle exceptions from the task and retry them if they're retryable.
"""
obj = super().__new__(cls, *args, **kwargs)
if getattr(obj, "__header__", None) is not None:
obj.__header__ = functools.partial(obj.__header__, object())
Expand Down Expand Up @@ -82,16 +100,34 @@ def run(*args, **kwargs):
metrics.increment("warehouse.task.failed", tags=metric_tags)
raise

obj._wh_original_run, obj.run = obj.run, run
# Reassign the `run` method to the new one we've created.
obj._wh_original_run, obj.run = obj.run, run # type: ignore[method-assign]

return obj

def __call__(self, *args, **kwargs):
"""
Override to inject a faux request object into the task when it's called.
There's no WSGI request object available when a task is called, so we
create a fake one here. This is necessary as a lot of our code assumes
that there's a Pyramid request object available.
"""
return super().__call__(*(self.get_request(),) + args, **kwargs)

def get_request(self):
def get_request(self) -> Request:
"""
Get a request object to use for this task.

This will either return the request object that was injected into the
task when it was called, or it will create a new request object to use
for the task.

Note: The `type: ignore` comments are necessary because the `pyramid_env`
attribute is not defined on the request object, but we're adding it
dynamically.
"""
if not hasattr(self.request, "pyramid_env"):
registry = self.app.pyramid_config.registry
registry = self.app.pyramid_config.registry # type: ignore[attr-defined]
env = pyramid.scripting.prepare(registry=registry)
env["request"].tm = transaction.TransactionManager(explicit=True)
env["request"].timings = {"new_request_start": time.time() * 1000}
Expand All @@ -101,15 +137,29 @@ def get_request(self):
).hexdigest()
self.request.update(pyramid_env=env)

return self.request.pyramid_env["request"]
return self.request.pyramid_env["request"] # type: ignore[attr-defined]

def after_return(self, status, retval, task_id, args, kwargs, einfo):
"""
Called after the task has returned. This is where we'll clean up the
request object that we injected into the task.
"""
if hasattr(self.request, "pyramid_env"):
pyramid_env = self.request.pyramid_env
pyramid_env["request"]._process_finished_callbacks()
pyramid_env["closer"]()

def apply_async(self, *args, **kwargs):
"""
Override the apply_async method to add an after commit hook to the
transaction manager to send the task after the transaction has been
committed.

This is necessary because we want to ensure that the task is only sent
after the transaction has been committed. This is important because we
want to ensure that the task is only sent if the transaction was
successful.
"""
# The API design of Celery makes this threadlocal pretty impossible to
# avoid :(
request = get_current_request()
Expand Down Expand Up @@ -137,17 +187,51 @@ def apply_async(self, *args, **kwargs):
)

def retry(self, *args, **kwargs):
"""
Override the retry method to increment a metric when a task is retried.

This is necessary because the `retry` method is called when a task is
retried, and we want to track how many times a task has been retried.
"""
request = get_current_request()
metrics = request.find_service(IMetricsService, context=None)
metrics.increment("warehouse.task.retried", tags=[f"task:{self.name}"])
return super().retry(*args, **kwargs)

def _after_commit_hook(self, success, *args, **kwargs):
"""
This is the hook that gets called after the transaction has been
committed. We'll only send the task if the transaction was successful.
"""
if success:
super().apply_async(*args, **kwargs)


def task(**kwargs):
"""
A decorator that can be used to define a Celery task.

A thin wrapper around Celery's `task` decorator that allows us to attach
the task to the Celery app when the configuration is scanned during the
application startup.

This decorator also sets the `shared` option to `False` by default. This
means that the task will be created anew for each worker process that is
started. This is important because the `WarehouseTask` class that we use
for our tasks is not thread-safe, so we need to ensure that each worker
process has its own instance of the task.

This decorator also adds the task to the `warehouse` category in the
configuration scanner. This is important because we use this category to
find all the tasks that have been defined in the configuration.

Example usage:
```
@tasks.task(...)
def my_task(self, *args, **kwargs):
pass
```
"""
kwargs.setdefault("shared", False)

def deco(wrapped):
Expand Down Expand Up @@ -193,7 +277,7 @@ def add_task():
def includeme(config):
s = config.registry.settings

broker_transport_options = {}
broker_transport_options: dict[str, str | dict] = {}

broker_url = s.get("celery.broker_url")
if broker_url is None:
Expand Down