Skip to content

Commit

Permalink
Ensure that event data is attached to a database session when passed …
Browse files Browse the repository at this point in the history
…to analytics collector.
  • Loading branch information
dbernstein committed Dec 11, 2024
1 parent 55b14f4 commit 5c942dd
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 48 deletions.
89 changes: 46 additions & 43 deletions src/palace/manager/celery/tasks/opds_odl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
from dataclasses import dataclass
from typing import Any

from celery import shared_task
Expand All @@ -8,42 +9,34 @@

from palace.manager.api.odl.api import OPDS2WithODLApi
from palace.manager.celery.task import Task
from palace.manager.service.analytics.analytics import Analytics
from palace.manager.service.celery.celery import QueueNames
from palace.manager.service.redis.models.lock import RedisLock
from palace.manager.service.redis.redis import Redis
from palace.manager.sqlalchemy.model.circulationevent import CirculationEvent
from palace.manager.sqlalchemy.model.collection import Collection
from palace.manager.sqlalchemy.model.library import Library
from palace.manager.sqlalchemy.model.licensing import License, LicensePool
from palace.manager.sqlalchemy.model.patron import Hold
from palace.manager.sqlalchemy.model.patron import Hold, Patron
from palace.manager.util.datetime_helpers import utc_now


@dataclass
class CirculationEventData:
library: Library
license_pool: LicensePool
event_type: str
patron: Patron


def remove_expired_holds_for_collection(
db: Session,
collection_id: int,
) -> tuple[int, list[dict[str, Any]]]:
) -> list[CirculationEventData]:
"""
Remove expired holds from the database for this collection.
"""

# generate expiration events for expired holds before deleting them
# lock rows
lock_query = (
select(Hold.id)
.where(
Hold.position == 0,
Hold.end < utc_now(),
Hold.license_pool_id == LicensePool.id,
LicensePool.collection_id == collection_id,
)
.with_for_update()
)

db.execute(lock_query).all()

# a separate query is required to get around the
# "FOR UPDATE cannot be applied to the nullable side of an outer join" issue when trying to use with_for_update
# on the Hold object.
select_query = select(Hold).where(
Hold.position == 0,
Hold.end < utc_now(),
Expand All @@ -55,7 +48,7 @@ def remove_expired_holds_for_collection(
expired_hold_events: list[dict[str, Any]] = []
for hold in expired_holds:
expired_hold_events.append(
dict(
CirculationEventData(
library=hold.library,
license_pool=hold.license_pool,
event_type=CirculationEvent.CM_HOLD_EXPIRED,
Expand All @@ -66,21 +59,13 @@ def remove_expired_holds_for_collection(
# delete the holds
query = (
delete(Hold)
.where(
Hold.position == 0,
Hold.end < utc_now(),
Hold.license_pool_id == LicensePool.id,
LicensePool.collection_id == collection_id,
)
.where(Hold.id.in_(h.id for h in expired_holds))
.execution_options(synchronize_session="fetch")
)
result = db.execute(query)

# We need the type ignores here because result doesn't always have
# a rowcount, but the sqlalchemy docs swear it will in the case of
# a delete statement.
# https://docs.sqlalchemy.org/en/20/tutorial/data_update.html#getting-affected-row-count-from-update-delete
return result.rowcount, expired_hold_events # type: ignore
db.execute(query)

return expired_hold_events


def licensepool_ids_with_holds(
Expand Down Expand Up @@ -119,7 +104,7 @@ def lock_licenses(license_pool: LicensePool) -> None:
def recalculate_holds_for_licensepool(
license_pool: LicensePool,
reservation_period: datetime.timedelta,
) -> tuple[int, list[dict[str, Any]]]:
) -> tuple[int, list[CirculationEventData]]:
# We take out row level locks on all the licenses and holds for this license pool, so that
# everything is in a consistent state while we update the hold queue. This means we should be
# quickly committing the transaction, to avoid contention or deadlocks.
Expand All @@ -144,7 +129,7 @@ def recalculate_holds_for_licensepool(
hold.end = utc_now() + reservation_period
updated += 1
events.append(
dict(
CirculationEventData(
library=hold.library,
license_pool=hold.license_pool,
event_type=CirculationEvent.CM_HOLD_READY_FOR_CHECKOUT,
Expand All @@ -169,21 +154,20 @@ def remove_expired_holds_for_collection_task(task: Task, collection_id: int) ->
A shared task for removing expired holds from the database for a collection
"""
analytics = task.services.analytics.analytics()

with task.transaction() as session:
collection = Collection.by_id(session, collection_id)
removed, events = remove_expired_holds_for_collection(
events = remove_expired_holds_for_collection(
session,
collection_id,
)

collection_name = None if not collection else collection.name
task.log.info(
f"Removed {removed} expired holds for collection {collection_name} ({collection_id})."
f"Removed {len(events)} expired holds for collection {collection_name} ({collection_id})."
)

# publish events only after successful commit
for event in events:
analytics.collect_event(**event)
collect_events(task, events, analytics)


@shared_task(queue=QueueNames.default, bind=True)
Expand Down Expand Up @@ -225,6 +209,27 @@ def _redis_lock_recalculate_holds(client: Redis, collection_id: int) -> RedisLoc
)


def collect_events(
task: Task, events: list[CirculationEventData], analytics: Analytics
) -> None:
"""
Collect events after successful database is commit and any row locks are removed.
We perform this operation outside after completed the transaction to ensure that any row locks
are held for the shortest possible duration in case writing to the s3 analytics provider is slow.
"""
with task.session() as session:
for e in events:
session.refresh(e.library)
session.refresh(e.license_pool)
session.refresh(e.patron)
analytics.collect_event(
event_type=e.event_type,
library=e.library,
license_pool=e.license_pool,
patron=e.patron,
)


@shared_task(queue=QueueNames.default, bind=True)
def recalculate_hold_queue_collection(
task: Task, collection_id: int, batch_size: int = 100, after_id: int | None = None
Expand Down Expand Up @@ -288,9 +293,7 @@ def recalculate_hold_queue_collection(
f"{updated} holds out of date."
)

# fire events after successful database update
for event in events:
analytics.collect_event(**event)
collect_events(task, events, analytics)

if len(license_pool_ids) == batch_size:
# We are done this batch, but there is probably more work to do, we queue up the next batch.
Expand Down
9 changes: 4 additions & 5 deletions tests/manager/celery/tasks/test_opds_odl.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_remove_expired_holds_for_collection(

# Remove the expired holds
assert collection.id is not None
removed, events = remove_expired_holds_for_collection(
events = remove_expired_holds_for_collection(
db.session,
collection.id,
)
Expand All @@ -165,8 +165,6 @@ def test_remove_expired_holds_for_collection(
assert decoy_non_expired_holds.issubset(current_holds)
assert decoy_expired_holds.issubset(current_holds)

assert removed == 10

pools_after = db.session.scalars(
select(func.count()).select_from(LicensePool)
).one()
Expand All @@ -177,7 +175,8 @@ def test_remove_expired_holds_for_collection(
# verify that the correct analytics calls were made
assert len(events) == 10
for event in events:
assert event["event_type"] == CirculationEvent.CM_HOLD_EXPIRED
assert event.event_type == CirculationEvent.CM_HOLD_EXPIRED
assert event.library == db.default_library()


def test_licensepools_with_holds(
Expand Down Expand Up @@ -270,7 +269,7 @@ def test_recalculate_holds_for_licensepool(
# verify that the correct analytics events were returned
assert len(events) == 3
for event in events:
assert event["event_type"] == CirculationEvent.CM_HOLD_READY_FOR_CHECKOUT
assert event.event_type == CirculationEvent.CM_HOLD_READY_FOR_CHECKOUT


def test_remove_expired_holds_for_collection_task(
Expand Down

0 comments on commit 5c942dd

Please sign in to comment.