Skip to content

Commit

Permalink
Merge pull request #8 from tanbro/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
tanbro authored Nov 30, 2024
2 parents 755655c + 03d00f9 commit 38a930b
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 115 deletions.
16 changes: 8 additions & 8 deletions src/sqlalchemy_dlock/asyncio/factory.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from importlib import import_module
from typing import Type, Union
from string import Template
from typing import Any, Mapping, Type, Union

from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import AsyncConnection

from ..utils import pascal_case, safe_name
from .lock.base import BaseAsyncSadLock
from .types import TAsyncConnectionOrSession

Expand All @@ -22,10 +22,10 @@ def create_async_sadlock(
engine = bind.engine
else:
engine = bind
engine_name = safe_name(engine.name)
try:
mod = import_module(f"..lock.{engine_name}", __name__)
except ImportError as exception: # pragma: no cover
raise NotImplementedError(f"{engine_name}: {exception}")
clz: Type[BaseAsyncSadLock] = getattr(mod, f"{pascal_case(engine_name)}AsyncSadLock")
conf: Mapping[str, Any] = getattr(import_module(".registry", __package__), "REGISTRY")[engine.name]
package: Union[str, None] = conf.get("package")
if package:
package = Template(package).safe_substitute(package=__package__)
mod = import_module(conf["module"], package)
clz: Type[BaseAsyncSadLock] = getattr(mod, conf["class"])
return clz(connection_or_session, key, contextual_timeout=contextual_timeout, **kwargs)
8 changes: 4 additions & 4 deletions src/sqlalchemy_dlock/asyncio/lock/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

from ..types import TAsyncConnectionOrSession

TKey = TypeVar("TKey")
KT = TypeVar("KT")


class BaseAsyncSadLock(Generic[TKey]):
class BaseAsyncSadLock(Generic[KT]):
def __init__(
self,
connection_or_session: TAsyncConnectionOrSession,
key: TKey,
key: KT,
/,
contextual_timeout: Union[float, int, None] = None,
**kwargs,
Expand Down Expand Up @@ -49,7 +49,7 @@ def connection_or_session(self) -> TAsyncConnectionOrSession:
return self._connection_or_session

@property
def key(self) -> TKey:
def key(self) -> KT:
return self._key

@property
Expand Down
26 changes: 14 additions & 12 deletions src/sqlalchemy_dlock/asyncio/lock/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,6 @@ def __init__(self, connection_or_session: TAsyncConnectionOrSession, key, **kwar
PostgresqlSadLockMixin.__init__(self, key=key, **kwargs)
BaseAsyncSadLock.__init__(self, connection_or_session, self._actual_key, **kwargs)

@override
async def __aexit__(self, exc_type, exc_value, exc_tb):
if sys.version_info < (3, 11):
with catch_warnings():
return await super().__aexit__(exc_type, exc_value, exc_tb)
else:
with catch_warnings(category=RuntimeWarning):
return await super().__aexit__(exc_type, exc_value, exc_tb)

@override
async def acquire(
self,
Expand Down Expand Up @@ -73,17 +64,28 @@ async def acquire(

@override
async def release(self):
if not self._acquired:
raise ValueError("invoked on an unlocked lock")
if self._stmt_unlock is None:
warn(
"PostgreSQL transaction level advisory locks are held until the current transaction ends; there is no provision for manual release.",
"PostgreSQL transaction level advisory locks are held until the current transaction ends; "
"there is no provision for manual release.",
RuntimeWarning,
)
return
if not self._acquired:
raise ValueError("invoked on an unlocked lock")
ret_val = (await self.connection_or_session.execute(self._stmt_unlock)).scalar_one()
if ret_val:
self._acquired = False
else: # pragma: no cover
self._acquired = False
raise SqlAlchemyDLockDatabaseError(f"The advisory lock {self.key!r} was not held.")

@override
async def close(self):
if self._acquired:
if sys.version_info < (3, 11):
with catch_warnings():
return await self.release()
else:
with catch_warnings(category=RuntimeWarning):
return await self.release()
12 changes: 12 additions & 0 deletions src/sqlalchemy_dlock/asyncio/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
REGISTRY = {
"mysql": {
"module": ".lock.mysql",
"package": "${package}", # module name relative to the package
"class": "MysqlAsyncSadLock",
},
"postgresql": {
"module": ".lock.postgresql",
"package": "${package}", # module name relative to the package
"class": "PostgresqlAsyncSadLock",
},
}
17 changes: 9 additions & 8 deletions src/sqlalchemy_dlock/factory.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from importlib import import_module
from typing import Type, Union
from string import Template
from typing import Any, Mapping, Type, Union

from sqlalchemy.engine import Connection

from .lock.base import BaseSadLock
from .types import TConnectionOrSession
from .utils import pascal_case, safe_name

__all__ = ["create_sadlock"]

Expand Down Expand Up @@ -47,10 +47,11 @@ def create_sadlock(
engine = bind.engine
else:
engine = bind
engine_name = safe_name(engine.name)
try:
mod = import_module(f"..lock.{engine_name}", __name__)
except ImportError as exception: # pragma: no cover
raise NotImplementedError(f"{engine_name}: {exception}")
clz: Type[BaseSadLock] = getattr(mod, f"{pascal_case(engine_name)}SadLock")

conf: Mapping[str, Any] = getattr(import_module(".registry", __package__), "REGISTRY")[engine.name]
package: Union[str, None] = conf.get("package")
if package:
package = Template(package).safe_substitute(package=__package__)
mod = import_module(conf["module"], package)
clz: Type[BaseSadLock] = getattr(mod, conf["class"])
return clz(connection_or_session, key, contextual_timeout=contextual_timeout, **kwargs)
8 changes: 4 additions & 4 deletions src/sqlalchemy_dlock/lock/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

from ..types import TConnectionOrSession

TKey = TypeVar("TKey")
KT = TypeVar("KT")


class BaseSadLock(Generic[TKey], local):
class BaseSadLock(Generic[KT], local):
"""Base class of database lock implementation
Note:
Expand Down Expand Up @@ -43,7 +43,7 @@ class BaseSadLock(Generic[TKey], local):
def __init__(
self,
connection_or_session: TConnectionOrSession,
key: TKey,
key: KT,
/,
contextual_timeout: Union[float, int, None] = None,
**kwargs,
Expand Down Expand Up @@ -108,7 +108,7 @@ def connection_or_session(self) -> TConnectionOrSession:
return self._connection_or_session

@property
def key(self) -> TKey:
def key(self) -> KT:
"""ID or name of the SQL locking function
It returns ``key`` parameter of the class's constructor"""
Expand Down
5 changes: 5 additions & 0 deletions src/sqlalchemy_dlock/lock/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def convert(value) -> str:
if len(self._actual_key) > MYSQL_LOCK_NAME_MAX_LENGTH:
raise ValueError(f"MySQL enforces a maximum length on lock names of {MYSQL_LOCK_NAME_MAX_LENGTH} characters.")

@property
def actual_key(self) -> str:
"""The actual key used in MySQL named lock"""
return self._actual_key


class MysqlSadLock(MysqlSadLockMixin, BaseSadLock[str]):
"""A distributed lock implemented by MySQL named-lock
Expand Down
33 changes: 20 additions & 13 deletions src/sqlalchemy_dlock/lock/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ def xact(self) -> bool:
"""Is the advisory lock transaction level or session level"""
return self._xact

@property
def actual_key(self) -> int:
"""The actual key used in PostgreSQL advisory lock"""
return self._actual_key


class PostgresqlSadLock(PostgresqlSadLockMixin, BaseSadLock[int]):
"""A distributed lock implemented by PostgreSQL advisory lock
Expand All @@ -114,16 +119,7 @@ def __init__(self, connection_or_session: TConnectionOrSession, key, **kwargs):
**kwargs: other named parameters pass to :class:`.BaseSadLock` and :class:`.PostgresqlSadLockMixin`
""" # noqa: E501
PostgresqlSadLockMixin.__init__(self, key=key, **kwargs)
BaseSadLock.__init__(self, connection_or_session, self._actual_key, **kwargs)

@override
def __exit__(self, exc_type, exc_value, exc_tb):
if sys.version_info < (3, 11):
with catch_warnings():
return super().__exit__(exc_type, exc_value, exc_tb)
else:
with catch_warnings(category=RuntimeWarning):
return super().__exit__(exc_type, exc_value, exc_tb)
BaseSadLock.__init__(self, connection_or_session, self.actual_key, **kwargs)

@override
def acquire(
Expand Down Expand Up @@ -181,17 +177,28 @@ def acquire(

@override
def release(self):
if not self._acquired:
raise ValueError("invoked on an unlocked lock")
if self._stmt_unlock is None:
warn(
"PostgreSQL transaction level advisory locks are held until the current transaction ends; there is no provision for manual release.",
"PostgreSQL transaction level advisory locks are held until the current transaction ends; "
"there is no provision for manual release.",
RuntimeWarning,
)
return
if not self._acquired:
raise ValueError("invoked on an unlocked lock")
ret_val = self.connection_or_session.execute(self._stmt_unlock).scalar_one()
if ret_val:
self._acquired = False
else: # pragma: no cover
self._acquired = False
raise SqlAlchemyDLockDatabaseError(f"The advisory lock {self.key!r} was not held.")

@override
def close(self):
if self._acquired:
if sys.version_info < (3, 11):
with catch_warnings():
return self.release()
else:
with catch_warnings(category=RuntimeWarning):
return self.release()
12 changes: 12 additions & 0 deletions src/sqlalchemy_dlock/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
REGISTRY = {
"mysql": {
"module": ".lock.mysql",
"package": "${package}", # module name relative to the package
"class": "MysqlSadLock",
},
"postgresql": {
"module": ".lock.postgresql",
"package": "${package}", # module name relative to the package
"class": "PostgresqlSadLock",
},
}
66 changes: 0 additions & 66 deletions src/sqlalchemy_dlock/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,69 +53,3 @@ def ensure_int64(i: int) -> int:
if i < -0x8000_0000_0000_0000:
raise OverflowError("int too small")
return i


def camel_case(s: str) -> str:
"""Convert string into camel case.
Args:
s: String to convert
Returns:
Camel case string.
"""
s = re.sub(r"\w[\s\W]+\w", "", s)
if not s:
return s
return lower_case(s[0]) + re.sub(r"[\-_\.\s]([a-z])", lambda x: upper_case(str(x.group(1))), s[1:])


def lower_case(s: str) -> str:
"""Convert string into lower case.
Args:
s: String to convert
Returns:
Lowercase case string.
"""
return s.lower()


def upper_case(s: str) -> str:
"""Convert string into upper case.
Args:
s: String to convert
Returns:
Uppercase case string.
"""
return s.upper()


def capital_case(s: str) -> str:
"""Convert string into capital case.
First letters will be uppercase.
Args:
s: String to convert
Returns:
Capital case string.
"""
if not s:
return s
return upper_case(s[0]) + s[1:]


def pascal_case(s: str) -> str:
"""Convert string into pascal case.
Args:
s: String to convert
Returns:
Pascal case string.
"""
return capital_case(camel_case(s))

0 comments on commit 38a930b

Please sign in to comment.