Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jonyscathe/execute mutate #342

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"python.testing.pytestArgs": ["tests"],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
295 changes: 147 additions & 148 deletions poetry.lock

Large diffs are not rendered by default.

18 changes: 15 additions & 3 deletions src/mock_alchemy/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@
from typing import Optional
from unittest import mock

import sqlalchemy
from packaging import version
from sqlalchemy import __version__ as sqlalchemy_version
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import insert
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.sql.expression import column
from sqlalchemy.sql.expression import or_
from sqlalchemy.sql.expression import table

from .utils import match_type

Expand All @@ -31,9 +35,17 @@
ALCHEMY_FUNC_TYPE,
ALCHEMY_LABEL_TYPE,
)
if version.parse(sqlalchemy.__version__) >= version.parse("1.4.0"):
if version.parse(sqlalchemy_version) >= version.parse("1.4.0"):
ALCHEMY_SELECT_TYPE = type(select(column("")))
ALCHEMY_TYPES += (ALCHEMY_SELECT_TYPE,)
ALCHEMY_UPDATE_TYPE = type(update(table("")))
ALCHEMY_DELETE_TYPE = type(delete(table("")))
ALCHEMY_INSERT_TYPE = type(insert(table("")))
ALCHEMY_TYPES += (
ALCHEMY_SELECT_TYPE,
ALCHEMY_UPDATE_TYPE,
ALCHEMY_DELETE_TYPE,
ALCHEMY_INSERT_TYPE,
)


class PrettyExpression(object):
Expand Down
166 changes: 165 additions & 1 deletion src/mock_alchemy/mocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@
from typing import overload
from unittest import mock

from packaging import version
from sqlalchemy import __version__ as sqlalchemy_version
from sqlalchemy import select
from sqlalchemy.orm.exc import MultipleResultsFound
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.sql.dml import Delete
from sqlalchemy.sql.dml import Insert
from sqlalchemy.sql.dml import Update

from .comparison import ExpressionMatcher
from .utils import build_identity_map
Expand Down Expand Up @@ -454,7 +460,6 @@ class UnifiedAlchemyMagicMock(AlchemyMagicMock):
unify: Dict[str, Optional[UnorderedCall]] = {
"add_columns": None,
"distinct": None,
"execute": None,
"filter": UnorderedCall,
"filter_by": UnorderedCall,
"group_by": None,
Expand All @@ -470,6 +475,8 @@ class UnifiedAlchemyMagicMock(AlchemyMagicMock):

mutate: Set[str] = {"add", "add_all", "delete"}

execute_statement: Set[str] = {"execute"}

@overload
def __init__(
self,
Expand Down Expand Up @@ -518,6 +525,16 @@ def __init__(self, *args, **kwargs) -> None:
}
)

kwargs.update(
{
k: AlchemyMagicMock(
return_value=self,
side_effect=partial(self._execute_statement, _mock_name=k),
)
for k in self.execute_statement
}
)

super(UnifiedAlchemyMagicMock, self).__init__(*args, **kwargs)

def _get_previous_calls(self, calls: Sequence[Call]) -> Iterator:
Expand Down Expand Up @@ -644,6 +661,22 @@ def _mutate_data(self, *args: Any, **kwargs: Any) -> Optional[int]:
else:
_mock_data.append(([query_call], [to_add]))

if version.parse(sqlalchemy_version) >= version.parse("1.4.0"):
execute_call = mock.call.execute(select(type(to_add)))

execute_mocked_data = next(
iter(
filter(
lambda i: i[0] == [ExpressionMatcher(execute_call)],
_mock_data,
)
),
None,
)
if execute_mocked_data:
execute_mocked_data[1].append(to_add)
else:
_mock_data.append(([execute_call], [to_add]))
elif _mock_name == "add_all":
to_add = args[0]
_kwargs = kwargs.copy()
Expand Down Expand Up @@ -685,3 +718,134 @@ def _mutate_data(self, *args: Any, **kwargs: Any) -> Optional[int]:
temp_mock_data.append((calls, result))
self._mock_data = temp_mock_data
return num_deleted

def _execute_insert(
self, execute_statement: Insert, *args: Any, **kwargs: Any
) -> Any:
"""Insert data from execute statement."""
_kwargs = kwargs.copy()
execute_statement = args[0]
_kwargs["_mock_name"] = "add"
table_type = execute_statement.entity_description["type"]
# Values should either be a list of dictionaries as arg[1] or a list of
# dictionaries as values.
if len(args) > 1:
for i in args[1]:
self._mutate_data(table_type(**i), **_kwargs)
else:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an opportunity here to break to down this new logic into a new function?

# Values will be stored within _multi_values list
values = execute_statement._multi_values[0]
for i in values:
self._mutate_data(
table_type(**{k.name: v for k, v in i.items()}), **_kwargs
)
# insert a boundary so that this is no longer part of a unified call.
self.all()
# Start a new unify if the insert statement is returning
if execute_statement._returning:
return self.execute(select(execute_statement._returning[0]))
return None

def _execute_delete(self, execute_statement: Delete, *args: Any) -> mock.Mock:
"""Delete data according to execute statement."""
execute_statement = args[0]
# Create equivalent select statement as an Expression Matcher
select_statement = (
[
ExpressionMatcher(
mock.call.execute(
select(execute_statement.table).where(
execute_statement.whereclause
)
)
)
]
if execute_statement.whereclause is not None
else [ExpressionMatcher(mock.call.execute(select(execute_statement.table)))]
)
_mock_data = self._mock_data = self._mock_data or []
sorted_mock_data = sorted(_mock_data, key=lambda x: len(x[0]), reverse=True)
temp_mock_data = list()
found_query = False
num_deleted = 0
for calls, result in sorted_mock_data:
calls = [
sqlalchemy_call(
i,
with_name=True,
base_call=self.unify.get(i[0]) or Call,
)
for i in calls
]
if all(c in select_statement for c in calls) and not found_query:
num_deleted = len(result)
temp_mock_data.append((calls, []))
found_query = True
else:
temp_mock_data.append((calls, result))
self._mock_data = temp_mock_data
delete_result = mock.Mock()
delete_result.rowcount = num_deleted
# insert a boundary so that this is no longer part of a unified call.
self.all()
return delete_result

def _execute_update(self, execute_statement: Update) -> mock.Mock:
"""Update data according to execute statement."""
# Create equivalent select statement as an Expression Matcher
select_statement = (
[
ExpressionMatcher(
mock.call.execute(
select(execute_statement.table).where(
execute_statement.whereclause
)
)
)
]
if execute_statement.whereclause is not None
else [ExpressionMatcher(mock.call.execute(select(execute_statement.table)))]
)
_mock_data = self._mock_data = self._mock_data or []
sorted_mock_data = sorted(_mock_data, key=lambda x: len(x[0]), reverse=True)
temp_mock_data = list()
found_query = False
num_updated = 0
for calls, result in sorted_mock_data:
calls = [
sqlalchemy_call(
i,
with_name=True,
base_call=self.unify.get(i[0]) or Call,
)
for i in calls
]
if all(c in select_statement for c in calls) and not found_query:
num_updated = len(result)
for r in result:
for k, v in execute_statement._values.items():
setattr(r, k.name, v.value)
temp_mock_data.append((calls, result))
found_query = True
else:
temp_mock_data.append((calls, result))
self._mock_data = temp_mock_data
update_result = mock.Mock()
update_result.rowcount = num_updated
# insert a boundary so that this is no longer part of a unified call.
self.all()
return update_result

def _execute_statement(self, *args: Any, **kwargs: Any) -> Any:
"""Depending on statement being executed, update data and/or unify statement."""
# Need to check if the execute was an insert, update or delete.
execute_statement = args[0]
if isinstance(execute_statement, Insert):
return self._execute_insert(execute_statement, *args, **kwargs)
elif isinstance(execute_statement, Delete):
return self._execute_delete(execute_statement, *args)
elif isinstance(execute_statement, Update):
return self._execute_update(execute_statement)
else:
# assume any other execute types need to unify
return self._unify(self, *args, **kwargs)
4 changes: 2 additions & 2 deletions src/mock_alchemy/sql_alchemy_imports.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""A module for importing SQLAlchemy sessions and calls."""
import sqlalchemy
from packaging import version
from sqlalchemy import __version__ as sqlalchemy_version

if version.parse(sqlalchemy.__version__) >= version.parse("1.4.0"):
if version.parse(sqlalchemy_version) >= version.parse("1.4.0"):
from sqlalchemy.orm import declarative_base
else:
from sqlalchemy.ext.declarative import declarative_base
4 changes: 2 additions & 2 deletions tests/test_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from unittest import mock

import pytest
import sqlalchemy
from packaging import version
from sqlalchemy import __version__ as sqlalchemy_version
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.sql.expression import column
Expand Down Expand Up @@ -55,7 +55,7 @@ def test_expression_matcher() -> None:


@pytest.mark.skipif(
version.parse(sqlalchemy.__version__) < version.parse("1.4.0"),
version.parse(sqlalchemy_version) < version.parse("1.4.0"),
reason="requires sqlalchemy 1.4.0 or higher to run",
)
def test_expression_matcher_select() -> None:
Expand Down
Loading