Skip to content

Commit

Permalink
Add a connector for SQLAlchemy with Psycopg2
Browse files Browse the repository at this point in the history
  • Loading branch information
Éric Lemoine committed Sep 24, 2021
1 parent 8f15b9b commit 0bb0f7b
Show file tree
Hide file tree
Showing 9 changed files with 505 additions and 6 deletions.
24 changes: 23 additions & 1 deletion docs/howto/sync_defer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,31 @@ situation::


How does it work?
-----------------
~~~~~~~~~~~~~~~~~

The synchronous connector will use a ``psycopg2.pool.ThreadedConnectionPool`` (see
psycopg2 documentation__), which should fit most workflows.

.. __: https://www.psycopg.org/docs/pool.html#psycopg2.pool.ThreadedConnectionPool


``SQLAlchemyPsycopg2Connector``
-------------------------------

If you use SQLAlchemy in your synchronous application, you may want to use an
`SQLAlchemyPsycopg2Connector` instead. The advantage over using a `Psycopg2Connector` is
that Procrastinate can use the same SQLAchemy engine (and connection pool) as the rest
of your application, thereby minimizing the number of database connections.

::

from sqlalchemy import create_engine

import procrastinate

engine = create_engine("postgresql+psycopg2://", echo=True)

app = procrastinate.App(
connector=procrastinate.SQLAlchemyPsycopg2Connector()
)
app.open(engine)
2 changes: 2 additions & 0 deletions procrastinate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from procrastinate.job_context import JobContext
from procrastinate.psycopg2_connector import Psycopg2Connector
from procrastinate.retry import BaseRetryStrategy, RetryStrategy
from procrastinate.sqlalchemy_connector import SQLAlchemyPsycopg2Connector

__all__ = [
"App",
Expand All @@ -16,6 +17,7 @@
"AiopgConnector",
"Psycopg2Connector",
"RetryStrategy",
"SQLAlchemyPsycopg2Connector",
]


Expand Down
21 changes: 18 additions & 3 deletions procrastinate/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
import functools
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
Union,
)

from procrastinate import connector as connector_module
from procrastinate import exceptions, jobs, manager, protocols
Expand Down Expand Up @@ -326,8 +336,13 @@ async def check_connection_async(self):
def schema_manager(self) -> schema.SchemaManager:
return schema.SchemaManager(connector=self.connector)

def open(self, pool: Optional[connector_module.Pool] = None) -> "App":
self.connector.open(pool)
def open(
self,
pool_or_engine: Optional[
Union[connector_module.Pool, connector_module.Engine]
] = None,
) -> "App":
self.connector.open(pool_or_engine)
return self

def close(self) -> None:
Expand Down
3 changes: 2 additions & 1 deletion procrastinate/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
QUEUEING_LOCK_CONSTRAINT = "procrastinate_jobs_queueing_lock_idx"


Pool = Any # The connection pool can be any pool object compatible with the database.
Pool = Any
Engine = Any


class BaseConnector:
Expand Down
176 changes: 176 additions & 0 deletions procrastinate/sqlalchemy_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import functools
import re
from typing import Any, Callable, Dict, Optional

import psycopg2.errors
import sqlalchemy
from psycopg2.extras import Json

from procrastinate import connector, exceptions


def wrap_exceptions(func: Callable) -> Callable:
"""
Wrap SQLAlchemy errors as connector exceptions.
"""

@functools.wraps(func)
def wrapped(*args, **kwargs):
try:
return func(*args, **kwargs)
except sqlalchemy.exc.IntegrityError as exc:
if isinstance(exc.orig, psycopg2.errors.UniqueViolation):
raise exceptions.UniqueViolation(
constraint_name=exc.orig.diag.constraint_name
)
except sqlalchemy.exc.SQLAlchemyError as exc:
raise exceptions.ConnectorException from exc

# Attaching a custom attribute to ease testability and make the
# decorator more introspectable
wrapped._exceptions_wrapped = True # type: ignore
return wrapped


def wrap_query_exceptions(func: Callable) -> Callable:
"""
Detect "admin shutdown" errors and retry once.
This is to handle the case where the database connection (obtained from the pool)
was actually closed by the server. In this case, SQLAlchemy raises a ``DBAPIError``
with ``connection_invalidated`` set to ``True``, and also invalidates the rest of
the connection pool. So we just retry once, to get a fresh connection.
"""

@functools.wraps(func)
def wrapped(*args, **kwargs):
try:
return func(*args, **kwargs)
except sqlalchemy.exc.DBAPIError as exc:
if exc.connection_invalidated:
return func(*args, **kwargs)
raise exc

return wrapped


PERCENT_PATTERN = re.compile(r"%(?![\(s])")


class SQLAlchemyPsycopg2Connector(connector.BaseConnector):
def __init__(
self,
*,
dsn: Optional[str] = None,
json_dumps: Optional[Callable] = None,
json_loads: Optional[Callable] = None,
**kwargs: Any,
):
"""
Synchronous connector based on SQLAlchemy with Psycopg2.
This is used if you want your ``.defer()`` calls to be purely synchronous, not
asynchronous with a sync wrapper. You may need this if your program is
multi-threaded and doen't handle async loops well
(see `discussion-sync-defer`).
All other arguments than ``dsn``, ``json_dumps``, and ``json_loads`` are passed
to `py:func:`create_engine` (see SQLAlchemy documentation__).
.. __: https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine
Parameters
----------
dsn : The dsn string or URL object passed to SQLAlchemy's ``create_engine``
function. Ignored if the engine is externally created and set into the
connector through the ``App.open`` method.
json_dumps :
The JSON dumps function to use for serializing job arguments. Defaults to
the function used by psycopg2. See the `psycopg2 doc`_.
json_loads :
The JSON loads function to use for deserializing job arguments. Defaults
Python's ``json.loads`` function.
"""
self.json_dumps = json_dumps
self.json_loads = json_loads
self._engine: Optional[sqlalchemy.engine.Engine] = None
self._engine_dsn = dsn
self._engine_args = kwargs
self._engine_externally_set = False

@wrap_exceptions
def open(self, engine: Optional[sqlalchemy.engine.Engine] = None) -> None:
"""
Create an SQLAlchemy engine for the connector.
Parameters
----------
engine :
Optional engine. Procrastinate can use an existing engine. If set the
engine dsn and arguments passed in the constructor will be ignored.
"""
if engine:
self._engine_externally_set = True
self._engine = engine
else:
self._engine = self._create_engine(self._engine_dsn, self._engine_args)

@staticmethod
def _create_engine(
dsn: str, engine_args: Dict[str, Any]
) -> sqlalchemy.engine.Engine:
"""
Create an SQLAlchemy engine.
"""
return sqlalchemy.create_engine(dsn, **engine_args)

@wrap_exceptions
def close(self) -> None:
"""
Dispose of the connection pool used by the SQLAlchemy engine.
"""
if not self._engine_externally_set:
self.engine.dispose()
self._engine = None

@property
def engine(self) -> sqlalchemy.engine.Engine:
if self._engine is None: # Set by open
raise exceptions.AppNotOpen
return self._engine

def _wrap_json(self, arguments: Dict[str, Any]):
return {
key: Json(value, dumps=self.json_dumps)
if isinstance(value, dict)
else value
for key, value in arguments.items()
}

@wrap_exceptions
@wrap_query_exceptions
def execute_query(self, query: str, **arguments: Any) -> None:
with self.engine.begin() as connection:
connection.exec_driver_sql(
PERCENT_PATTERN.sub("%%", query), self._wrap_json(arguments)
)

@wrap_exceptions
@wrap_query_exceptions
def execute_query_one(self, query: str, **arguments: Any) -> Dict[str, Any]:
with self.engine.begin() as connection:
cursor_result = connection.exec_driver_sql(
PERCENT_PATTERN.sub("%%", query), self._wrap_json(arguments)
)
cursor_result = cursor_result.mappings()
return cursor_result.fetchone()

@wrap_exceptions
@wrap_query_exceptions
def execute_query_all(self, query: str, **arguments: Any) -> Dict[str, Any]:
with self.engine.begin() as connection:
cursor_result = connection.exec_driver_sql(
PERCENT_PATTERN.sub("%%", query), self._wrap_json(arguments)
)
cursor_result = cursor_result.mappings()
return cursor_result.all()
25 changes: 24 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from procrastinate import app as app_module
from procrastinate import jobs
from procrastinate import psycopg2_connector as psycopg2_connector_module
from procrastinate import schema, testing
from procrastinate import schema
from procrastinate import sqlalchemy_connector as sqlalchemy_connector_module
from procrastinate import testing

# Just ensuring the tests are not polluted by environment
for key in os.environ:
Expand Down Expand Up @@ -100,6 +102,13 @@ def connection_params(setup_db, db_factory):
yield {"dsn": "", "dbname": "procrastinate_test"}


@pytest.fixture
def sqlalchemy_engine_dsn(setup_db, db_factory):
db_factory(dbname="procrastinate_test", template=setup_db)

yield "postgresql+psycopg2:///procrastinate_test"


@pytest.fixture
async def connection(connection_params):
async with aiopg.connect(**connection_params) as connection:
Expand All @@ -116,6 +125,13 @@ def not_opened_psycopg2_connector(connection_params):
yield psycopg2_connector_module.Psycopg2Connector(**connection_params)


@pytest.fixture
def not_opened_sqlalchemy_psycopg2_connector(sqlalchemy_engine_dsn):
yield sqlalchemy_connector_module.SQLAlchemyPsycopg2Connector(
dsn=sqlalchemy_engine_dsn, echo=True
)


@pytest.fixture
async def aiopg_connector(not_opened_aiopg_connector):
await not_opened_aiopg_connector.open_async()
Expand All @@ -130,6 +146,13 @@ def psycopg2_connector(not_opened_psycopg2_connector):
not_opened_psycopg2_connector.close()


@pytest.fixture
def sqlalchemy_psycopg2_connector(not_opened_sqlalchemy_psycopg2_connector):
not_opened_sqlalchemy_psycopg2_connector.open()
yield not_opened_sqlalchemy_psycopg2_connector
not_opened_sqlalchemy_psycopg2_connector.close()


@pytest.fixture
def kill_own_pid():
def f(signal=stdlib_signal.SIGTERM):
Expand Down
Loading

0 comments on commit 0bb0f7b

Please sign in to comment.