Skip to content

Commit

Permalink
update to sqlalchemy 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel-Faber committed Dec 8, 2023
1 parent 734df09 commit 7ccf874
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 40 deletions.
21 changes: 10 additions & 11 deletions src/allocation/adapters/orm.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
import logging
from sqlalchemy import (
Table,
MetaData,
Column,
Integer,
String,
Date,
ForeignKey,
event,
)
from sqlalchemy.orm import mapper, relationship
from sqlalchemy.orm import registry, relationship

from allocation.domain import model

logger = logging.getLogger(__name__)

metadata = MetaData()
mapper_registry = registry()

order_lines = Table(
"order_lines",
metadata,
mapper_registry.metadata,
Column("id", Integer, primary_key=True, autoincrement=True),
Column("sku", String(255)),
Column("qty", Integer, nullable=False),
Expand All @@ -28,14 +27,14 @@

products = Table(
"products",
metadata,
mapper_registry.metadata,
Column("sku", String(255), primary_key=True),
Column("version_number", Integer, nullable=False, server_default="0"),
)

batches = Table(
"batches",
metadata,
mapper_registry.metadata,
Column("id", Integer, primary_key=True, autoincrement=True),
Column("reference", String(255)),
Column("sku", ForeignKey("products.sku")),
Expand All @@ -45,15 +44,15 @@

allocations = Table(
"allocations",
metadata,
mapper_registry.metadata,
Column("id", Integer, primary_key=True, autoincrement=True),
Column("orderline_id", ForeignKey("order_lines.id")),
Column("batch_id", ForeignKey("batches.id")),
)

allocations_view = Table(
"allocations_view",
metadata,
mapper_registry.metadata,
Column("orderid", String(255)),
Column("sku", String(255)),
Column("batchref", String(255)),
Expand All @@ -62,8 +61,8 @@

def start_mappers():
logger.info("Starting mappers")
lines_mapper = mapper(model.OrderLine, order_lines)
batches_mapper = mapper(
lines_mapper = mapper_registry.map_imperatively(model.OrderLine, order_lines)
batches_mapper = mapper_registry.map_imperatively(
model.Batch,
batches,
properties={
Expand All @@ -74,7 +73,7 @@ def start_mappers():
)
},
)
mapper(
mapper_registry.map_imperatively(
model.Product,
products,
properties={"batches": relationship(batches_mapper)},
Expand Down
17 changes: 9 additions & 8 deletions src/allocation/service_layer/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List, Dict, Callable, Type, TYPE_CHECKING
from allocation.domain import commands, events, model
from allocation.domain.model import OrderLine
from sqlalchemy.sql import text

if TYPE_CHECKING:
from allocation.adapters import notifications
Expand Down Expand Up @@ -83,10 +84,10 @@ def add_allocation_to_read_model(
):
with uow:
uow.session.execute(
"""
INSERT INTO allocations_view (orderid, sku, batchref)
VALUES (:orderid, :sku, :batchref)
""",
text(
"INSERT INTO allocations_view (orderid, sku, batchref)"
" VALUES (:orderid, :sku, :batchref)"
),
dict(orderid=event.orderid, sku=event.sku, batchref=event.batchref),
)
uow.commit()
Expand All @@ -98,10 +99,10 @@ def remove_allocation_from_read_model(
):
with uow:
uow.session.execute(
"""
DELETE FROM allocations_view
WHERE orderid = :orderid AND sku = :sku
""",
text(
"DELETE FROM allocations_view"
" WHERE orderid = :orderid AND sku = :sku"
),
dict(orderid=event.orderid, sku=event.sku),
)
uow.commit()
Expand Down
9 changes: 5 additions & 4 deletions src/allocation/views.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from allocation.service_layer import unit_of_work
from sqlalchemy.sql import text


def allocations(orderid: str, uow: unit_of_work.SqlAlchemyUnitOfWork):
with uow:
results = uow.session.execute(
"""
SELECT sku, batchref FROM allocations_view WHERE orderid = :orderid
""",
text(
"SELECT sku, batchref FROM allocations_view WHERE orderid = :orderid"
),
dict(orderid=orderid),
)
return [dict(r) for r in results]
return [r._asdict() for r in results]
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from sqlalchemy.orm import sessionmaker, clear_mappers
from tenacity import retry, stop_after_delay

from allocation.adapters.orm import metadata, start_mappers
from allocation.adapters.orm import mapper_registry, start_mappers
from allocation import config

pytest.register_assert_rewrite("tests.e2e.api_client")
Expand All @@ -20,7 +20,7 @@
@pytest.fixture
def in_memory_sqlite_db():
engine = create_engine("sqlite:///:memory:")
metadata.create_all(engine)
mapper_registry.metadata.create_all(engine)
return engine


Expand Down Expand Up @@ -56,7 +56,7 @@ def wait_for_redis_to_come_up():
def postgres_db():
engine = create_engine(config.get_postgres_uri(), isolation_level="SERIALIZABLE")
wait_for_postgres_to_come_up(engine)
metadata.create_all(engine)
mapper_registry.metadata.create_all(engine)
return engine


Expand Down
35 changes: 21 additions & 14 deletions tests/integration/test_uow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,35 @@
from allocation.domain import model
from allocation.service_layer import unit_of_work
from ..random_refs import random_sku, random_batchref, random_orderid
from sqlalchemy.sql import text

pytestmark = pytest.mark.usefixtures("mappers")


def insert_batch(session, ref, sku, qty, eta, product_version=1):
session.execute(
"INSERT INTO products (sku, version_number) VALUES (:sku, :version)",
text("INSERT INTO products (sku, version_number) VALUES (:sku, :version)"),
dict(sku=sku, version=product_version),
)
session.execute(
"INSERT INTO batches (reference, sku, _purchased_quantity, eta)"
" VALUES (:ref, :sku, :qty, :eta)",
text(
"INSERT INTO batches (reference, sku, _purchased_quantity, eta)"
" VALUES (:ref, :sku, :qty, :eta)"
),
dict(ref=ref, sku=sku, qty=qty, eta=eta),
)


def get_allocated_batch_ref(session, orderid, sku):
[[orderlineid]] = session.execute(
"SELECT id FROM order_lines WHERE orderid=:orderid AND sku=:sku",
text("SELECT id FROM order_lines WHERE orderid=:orderid AND sku=:sku"),
dict(orderid=orderid, sku=sku),
)
[[batchref]] = session.execute(
"SELECT b.reference FROM allocations JOIN batches AS b ON batch_id = b.id"
" WHERE orderline_id=:orderlineid",
text(
"SELECT b.reference FROM allocations JOIN batches AS b ON batch_id = b.id"
" WHERE orderline_id=:orderlineid"
),
dict(orderlineid=orderlineid),
)
return batchref
Expand Down Expand Up @@ -59,7 +64,7 @@ def test_rolls_back_uncommitted_work_by_default(sqlite_session_factory):
insert_batch(uow.session, "batch1", "MEDIUM-PLINTH", 100, None)

new_session = sqlite_session_factory()
rows = list(new_session.execute('SELECT * FROM "batches"'))
rows = list(new_session.execute(text('SELECT * FROM "batches"')))
assert rows == []


Expand All @@ -74,7 +79,7 @@ class MyException(Exception):
raise MyException()

new_session = sqlite_session_factory()
rows = list(new_session.execute('SELECT * FROM "batches"'))
rows = list(new_session.execute(text('SELECT * FROM "batches"')))
assert rows == []


Expand Down Expand Up @@ -113,20 +118,22 @@ def test_concurrent_updates_to_version_are_not_allowed(postgres_session_factory)
thread2.join()

[[version]] = session.execute(
"SELECT version_number FROM products WHERE sku=:sku",
text("SELECT version_number FROM products WHERE sku=:sku"),
dict(sku=sku),
)
assert version == 2
[exception] = exceptions
assert "could not serialize access due to concurrent update" in str(exception)

orders = session.execute(
"SELECT orderid FROM allocations"
" JOIN batches ON allocations.batch_id = batches.id"
" JOIN order_lines ON allocations.orderline_id = order_lines.id"
" WHERE order_lines.sku=:sku",
text(
"SELECT orderid FROM allocations"
" JOIN batches ON allocations.batch_id = batches.id"
" JOIN order_lines ON allocations.orderline_id = order_lines.id"
" WHERE order_lines.sku=:sku"
),
dict(sku=sku),
)
assert orders.rowcount == 1
with unit_of_work.SqlAlchemyUnitOfWork(postgres_session_factory) as uow:
uow.session.execute("select 1")
uow.session.execute(text("select 1"))

0 comments on commit 7ccf874

Please sign in to comment.