diff --git a/warehouse/db.py b/warehouse/db.py index 13966f4a54a3..91489fa216ef 100644 --- a/warehouse/db.py +++ b/warehouse/db.py @@ -23,7 +23,7 @@ from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.ext.declarative import declarative_base # type: ignore -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session as BaseSession, raiseload, sessionmaker from warehouse.metrics import IMetricsService from warehouse.utils.attrs import make_repr @@ -92,10 +92,18 @@ class Model(ModelBase): ) +# Custom Session to prevent lazy-loading +class StrictSession(BaseSession): + def query(self, *entities, **kwargs): + query = super().query(*entities) + query = query.options(raiseload("*")) + return query + + # Create our session class here, this will stay stateless as we'll bind the # engine to each new state we create instead of binding it to the session # class. -Session = sessionmaker() +Session = sessionmaker(class_=StrictSession) def listens_for(target, identifier, *args, **kwargs):