Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for case insensitive regexp + add support for REGEXP SQLite module #1737

Merged
merged 17 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ test: deps
test_sqlite:
$(py_warn) TORTOISE_TEST_DB=sqlite://:memory: pytest --cov-report= $(pytest_opts)

test_sqlite_regexp:
$(py_warn) TORTOISE_TEST_DB=sqlite://:memory:?install_regexp_functions=True pytest --cov-report= $(pytest_opts)

test_postgres_asyncpg:
python -V | grep PyPy || $(py_warn) TORTOISE_TEST_DB="asyncpg://postgres:$(TORTOISE_POSTGRES_PASS)@127.0.0.1:5432/test_\{\}" pytest $(pytest_opts) --cov-append --cov-report=

Expand Down
6 changes: 4 additions & 2 deletions docs/query.rst
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,9 @@ The ``filter`` option allows you to filter the JSON object by its keys and value
obj5 = await JSONModel.filter(data__filter={"owner__name__isnull": True}).first()
obj6 = await JSONModel.filter(data__filter={"owner__last__not_isnull": False}).first()

In PostgreSQL and MySQL, you can use ``postgres_posix_regex`` to make comparisons using POSIX regular expressions:
In PostgreSQL, this is done with the ``~`` operator, while in MySQL the ``REGEXP`` operator is used.
In PostgreSQL and MySQL and SQLite, you can use ``posix_regex`` to make comparisons using POSIX regular expressions:
On PostgreSQL, this uses the ``~`` operator, on MySQL and SQLite it uses the ``REGEXP`` operator.
PostgreSQL and SQLite also support ``iposix_regex``, which makes case insensive comparisons.


.. code-block:: python3
Expand All @@ -281,6 +282,7 @@ In PostgreSQL, this is done with the ``~`` operator, while in MySQL the ``REGEXP

await DemoModel.create(demo_text="Hello World")
obj = await DemoModel.filter(demo_text__posix_regex="^Hello World$").first()
obj = await DemoModel.filter(demo_text__iposix_regex="^hello world$").first()


In PostgreSQL, ``filter`` supports additional lookup types:
Expand Down
15 changes: 11 additions & 4 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,21 @@ def test_clear_storage(self, mocked_get_storage: Mock):
@patch("tortoise.connection.importlib.import_module")
def test_discover_client_class_proper_impl(self, mocked_import_module: Mock):
mocked_import_module.return_value = Mock(client_class="some_class")
client_class = self.conn_handler._discover_client_class("blah")
del mocked_import_module.return_value.get_client_class
client_class = self.conn_handler._discover_client_class({"engine": "blah"})

mocked_import_module.assert_called_once_with("blah")
self.assertEqual(client_class, "some_class")

@patch("tortoise.connection.importlib.import_module")
def test_discover_client_class_improper_impl(self, mocked_import_module: Mock):
del mocked_import_module.return_value.client_class
del mocked_import_module.return_value.get_client_class
engine = "some_engine"
with self.assertRaises(
ConfigurationError, msg=f'Backend for engine "{engine}" does not implement db client'
):
_ = self.conn_handler._discover_client_class(engine)
_ = self.conn_handler._discover_client_class({"engine": engine})

@patch("tortoise.connection.ConnectionHandler.db_config", new_callable=PropertyMock)
def test_get_db_info_present(self, mocked_db_config: Mock):
Expand Down Expand Up @@ -156,7 +159,9 @@ def test_create_connection_db_info_str(

mocked_get_db_info.assert_called_once_with(alias)
mocked_expand_db_url.assert_called_once_with("some_db_url")
mocked_discover_client_class.assert_called_once_with("some_engine")
mocked_discover_client_class.assert_called_once_with(
{"engine": "some_engine", "credentials": {"cred_key": "some_val"}}
)
expected_client_class.assert_called_once_with(**expected_db_params)
self.assertEqual(ret_val, "some_connection")

Expand All @@ -182,7 +187,9 @@ def test_create_connection_db_info_not_str(

mocked_get_db_info.assert_called_once_with(alias)
mocked_expand_db_url.assert_not_called()
mocked_discover_client_class.assert_called_once_with("some_engine")
mocked_discover_client_class.assert_called_once_with(
{"engine": "some_engine", "credentials": {"cred_key": "some_val"}}
)
expected_client_class.assert_called_once_with(**expected_db_params)
self.assertEqual(ret_val, "some_connection")

Expand Down
60 changes: 58 additions & 2 deletions tests/test_posix_regex_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
from tortoise.contrib import test


class RegexTestCase(test.TestCase):
async def asyncSetUp(self) -> None:
await super().asyncSetUp()


class TestPosixRegexFilter(test.TestCase):

@test.requireCapability(dialect="mysql")
@test.requireCapability(dialect="postgres")
@test.requireCapability(support_for_posix_regex_queries=True)
async def test_regex_filter(self):
author = await testmodels.Author.create(name="Johann Wolfgang von Goethe")
self.assertEqual(
Expand All @@ -16,3 +20,55 @@ async def test_regex_filter(self):
),
{author.name},
)

@test.requireCapability(dialect="postgres", support_for_posix_regex_queries=True)
async def test_regex_filter_works_with_null_field_postgres(self):
await testmodels.Tournament.create(name="Test")
print(testmodels.Tournament.filter(desc__posix_regex="^test$").sql())
self.assertEqual(
set(
await testmodels.Tournament.filter(desc__posix_regex="^test$").values_list(
"name", flat=True
)
),
set(),
)

@test.requireCapability(dialect="sqlite", support_for_posix_regex_queries=True)
async def test_regex_filter_works_with_null_field_sqlite(self):
await testmodels.Tournament.create(name="Test")
print(testmodels.Tournament.filter(desc__posix_regex="^test$").sql())
self.assertEqual(
set(
await testmodels.Tournament.filter(desc__posix_regex="^test$").values_list(
"name", flat=True
)
),
set(),
)


class TestCaseInsensitivePosixRegexFilter(test.TestCase):
@test.requireCapability(dialect="postgres", support_for_posix_regex_queries=True)
async def test_case_insensitive_regex_filter_postgres(self):
author = await testmodels.Author.create(name="Johann Wolfgang von Goethe")
self.assertEqual(
set(
await testmodels.Author.filter(
name__iposix_regex="^johann [a-zA-Z]+ Von goethe$"
).values_list("name", flat=True)
),
{author.name},
)

@test.requireCapability(dialect="sqlite", support_for_posix_regex_queries=True)
async def test_case_insensitive_regex_filter_sqlite(self):
author = await testmodels.Author.create(name="Johann Wolfgang von Goethe")
self.assertEqual(
set(
await testmodels.Author.filter(
name__iposix_regex="^johann [a-zA-Z]+ Von goethe$"
).values_list("name", flat=True)
),
{author.name},
)
3 changes: 3 additions & 0 deletions tortoise/backends/base/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Capabilities:
:param support_for_update: Indicates that this DB supports SELECT ... FOR UPDATE SQL statement.
:param support_index_hint: Support force index or use index.
:param support_update_limit_order_by: support update/delete with limit and order by.
:param: support_for_posix_regex_queries: indicated if the db supports posix regex queries
"""

def __init__(
Expand All @@ -63,6 +64,7 @@ def __init__(
support_index_hint: bool = False,
# support update/delete with limit and order by
support_update_limit_order_by: bool = True,
support_for_posix_regex_queries: bool = False,
) -> None:
super().__setattr__("_mutable", True)

Expand All @@ -74,6 +76,7 @@ def __init__(
self.support_for_update = support_for_update
self.support_index_hint = support_index_hint
self.support_update_limit_order_by = support_update_limit_order_by
self.support_for_posix_regex_queries = support_for_posix_regex_queries
super().__setattr__("_mutable", False)

def __setattr__(self, attr: str, value: Any) -> None:
Expand Down
5 changes: 4 additions & 1 deletion tortoise/backends/base/config_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@
"skip_first_char": False,
"vmap": {"path": "file_path"},
"defaults": {"journal_mode": "WAL", "journal_size_limit": 16384},
"cast": {"journal_size_limit": int},
"cast": {
"journal_size_limit": int,
"install_regexp_functions": bool,
},
},
"mysql": {
"engine": "tortoise.backends.mysql",
Expand Down
4 changes: 3 additions & 1 deletion tortoise/backends/base_postgres/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ class BasePostgresClient(BaseDBAsyncClient, abc.ABC):
query_class: Type[PostgreSQLQuery] = PostgreSQLQuery
executor_class: Type[BasePostgresExecutor] = BasePostgresExecutor
schema_generator: Type[BasePostgresSchemaGenerator] = BasePostgresSchemaGenerator
capabilities = Capabilities("postgres", support_update_limit_order_by=False)
capabilities = Capabilities(
"postgres", support_update_limit_order_by=False, support_for_posix_regex_queries=True
)
connection_class: "Optional[Union[AsyncConnection, Connection]]" = None
loop: Optional[AbstractEventLoop] = None
_pool: Optional[Any] = None
Expand Down
7 changes: 6 additions & 1 deletion tortoise/backends/base_postgres/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@
postgres_json_contains,
postgres_json_filter,
)
from tortoise.contrib.postgres.regex import postgres_posix_regex
from tortoise.contrib.postgres.regex import (
postgres_insensitive_posix_regex,
postgres_posix_regex,
)
from tortoise.contrib.postgres.search import SearchCriterion
from tortoise.filters import (
insensitive_posix_regex,
json_contained_by,
json_contains,
json_filter,
Expand All @@ -35,6 +39,7 @@ class BasePostgresExecutor(BaseExecutor):
json_contained_by: postgres_json_contained_by,
json_filter: postgres_json_filter,
posix_regex: postgres_posix_regex,
insensitive_posix_regex: postgres_insensitive_posix_regex,
}

def _prepare_insert_statement(
Expand Down
6 changes: 5 additions & 1 deletion tortoise/backends/mysql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ class MySQLClient(BaseDBAsyncClient):
executor_class = MySQLExecutor
schema_generator = MySQLSchemaGenerator
capabilities = Capabilities(
"mysql", requires_limit=True, inline_comment=True, support_index_hint=True
"mysql",
requires_limit=True,
inline_comment=True,
support_index_hint=True,
support_for_posix_regex_queries=True,
)

def __init__(
Expand Down
11 changes: 10 additions & 1 deletion tortoise/backends/mysql/executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import enum

from pypika_tortoise import functions
from pypika_tortoise.enums import SqlTypes
from pypika_tortoise.functions import Cast, Coalesce
from pypika_tortoise.terms import BasicCriterion, Criterion
from pypika_tortoise.utils import format_quotes

Expand Down Expand Up @@ -31,6 +34,10 @@
)


class MySQLRegexpComparators(enum.Enum):
REGEXP = " REGEXP "


class StrWrapper(ValueWrapper):
"""
Naive str wrapper that doesn't use the monkey-patched pypika ValueWrapper for MySQL
Expand Down Expand Up @@ -97,7 +104,9 @@ def mysql_search(field: Term, value: str) -> SearchCriterion:


def mysql_posix_regex(field: Term, value: str) -> BasicCriterion:
return BasicCriterion(" REGEXP ", field, StrWrapper(value)) # type:ignore[arg-type]
return BasicCriterion(
MySQLRegexpComparators.REGEXP, Coalesce(Cast(field, SqlTypes.CHAR)), StrWrapper(value)
)


class MySQLExecutor(BaseExecutor):
Expand Down
9 changes: 8 additions & 1 deletion tortoise/backends/sqlite/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
from .client import SqliteClient
from .client import SqliteClient, SqliteClientWithRegexpSupport

client_class = SqliteClient


def get_client_class(db_info: dict):
if db_info.get("credentials", {}).get("install_regexp_functions"):
return SqliteClientWithRegexpSupport
else:
return SqliteClient
20 changes: 20 additions & 0 deletions tortoise/backends/sqlite/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from tortoise.backends.sqlite.executor import SqliteExecutor
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
from tortoise.connection import connections
from tortoise.contrib.sqlite.regex import (
install_regexp_functions as install_regexp_functions_to_db,
)
from tortoise.exceptions import (
IntegrityError,
OperationalError,
Expand Down Expand Up @@ -212,6 +215,7 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:

class SqliteTransactionWrapper(SqliteClient, TransactionalDBClient):
def __init__(self, connection: SqliteClient) -> None:
self.capabilities = connection.capabilities
self.connection_name = connection.connection_name
self._connection: aiosqlite.Connection = cast(aiosqlite.Connection, connection._connection)
self._lock = asyncio.Lock()
Expand Down Expand Up @@ -272,3 +276,19 @@ async def release_savepoint(self) -> None:

def _gen_savepoint_name(_c=count()) -> str:
return f"tortoise_savepoint_{next(_c)}"


class SqliteClientWithRegexpSupport(SqliteClient):
capabilities = Capabilities(
"sqlite",
daemon=False,
requires_limit=True,
inline_comment=True,
support_for_update=False,
support_for_posix_regex_queries=True,
)

async def create_connection(self, with_db: bool) -> None:
await super().create_connection(with_db)
if self._connection:
await install_regexp_functions_to_db(self._connection)
9 changes: 9 additions & 0 deletions tortoise/backends/sqlite/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@

from tortoise import Model, fields, timezone
from tortoise.backends.base.executor import BaseExecutor
from tortoise.contrib.sqlite.regex import (
insensitive_posix_sqlite_regexp,
posix_sqlite_regexp,
)
from tortoise.fields import (
BigIntField,
BooleanField,
Expand All @@ -16,6 +20,7 @@
SmallIntField,
TimeField,
)
from tortoise.filters import insensitive_posix_regex, posix_regex


def to_db_bool(
Expand Down Expand Up @@ -91,6 +96,10 @@ class SqliteExecutor(BaseExecutor):
}
EXPLAIN_PREFIX = "EXPLAIN QUERY PLAN"
DB_NATIVE = {bytes, str, int, float}
FILTER_FUNC_OVERRIDE = {
posix_regex: posix_sqlite_regexp,
insensitive_posix_regex: insensitive_posix_sqlite_regexp,
}

async def _process_insert_result(self, instance: Model, results: int) -> None:
pk_field_object = self.model._meta.pk
Expand Down
16 changes: 11 additions & 5 deletions tortoise/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,19 @@ def _copy_storage(self) -> Dict[str, "BaseDBAsyncClient"]:
def _clear_storage(self) -> None:
self._get_storage().clear()

def _discover_client_class(self, engine: str) -> Type["BaseDBAsyncClient"]:
def _discover_client_class(self, db_info: dict) -> Type["BaseDBAsyncClient"]:
# Let exception bubble up for transparency
engine_module = importlib.import_module(engine)
engine_str = db_info.get("engine", "")
engine_module = importlib.import_module(engine_str)
try:
client_class = engine_module.client_class
if hasattr(engine_module, "get_client_class"):
client_class = engine_module.get_client_class(db_info)
else:
client_class = engine_module.client_class
except AttributeError:
raise ConfigurationError(f'Backend for engine "{engine}" does not implement db client')
raise ConfigurationError(
f'Backend for engine "{engine_str}" does not implement db client'
)
return client_class

def _get_db_info(self, conn_alias: str) -> Union[str, Dict]:
Expand All @@ -93,7 +99,7 @@ def _create_connection(self, conn_alias: str) -> "BaseDBAsyncClient":
db_info = self._get_db_info(conn_alias)
if isinstance(db_info, str):
db_info = expand_db_url(db_info)
client_class = self._discover_client_class(db_info.get("engine", ""))
client_class = self._discover_client_class(db_info)
db_params = db_info["credentials"].copy()
db_params.update({"connection_name": conn_alias})
connection: "BaseDBAsyncClient" = client_class(**db_params)
Expand Down
Loading