Skip to content

Commit

Permalink
add Bare SQLAlchemy mutation tracking, improve typing
Browse files Browse the repository at this point in the history
  • Loading branch information
benedikt-bartscher committed Jul 6, 2024
1 parent 109adc2 commit 2999e7f
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 31 deletions.
3 changes: 2 additions & 1 deletion reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)

import dill
from sqlalchemy.orm import DeclarativeBase

try:
import pydantic.v1 as pydantic
Expand Down Expand Up @@ -2919,7 +2920,7 @@ class MutableProxy(wrapt.ObjectProxy):
pydantic.BaseModel.__dict__
)

__mutable_types__ = (list, dict, set, Base)
__mutable_types__ = (list, dict, set, Base, DeclarativeBase)

def __init__(self, wrapped: Any, state: BaseState, field_name: str):
"""Create a proxy for a mutable object that tracks changes.
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def tmp_working_dir(tmp_path):


@pytest.fixture
def mutable_state():
def mutable_state() -> MutableTestState:
"""Create a Test state containing mutable types.
Returns:
Expand Down
50 changes: 48 additions & 2 deletions tests/states/mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@

from typing import Dict, List, Set, Union

from sqlalchemy import ARRAY, JSON, String
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column

import reflex as rx
from reflex.state import BaseState
from reflex.utils.serializers import serializer


class DictMutationTestState(BaseState):
Expand Down Expand Up @@ -145,22 +149,59 @@ class CustomVar(rx.Base):
custom: OtherBase = OtherBase()


class MutableSQLABase(DeclarativeBase):
"""SQLAlchemy base model for mutable vars."""

pass


class MutableSQLAModel(MutableSQLABase):
"""SQLAlchemy model for mutable vars."""

__tablename__: str = "mutable_test_state"

id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
strlist: Mapped[List[str]] = mapped_column(ARRAY(String))
hashmap: Mapped[Dict[str, str]] = mapped_column(JSON)
test_set: Mapped[Set[str]] = mapped_column(ARRAY(String))


@serializer
def serialize_mutable_sqla_model(
model: MutableSQLAModel,
) -> Dict[str, Union[List[str], Dict[str, str]]]:
"""Serialize the MutableSQLAModel.
Args:
model: The MutableSQLAModel instance to serialize.
Returns:
The serialized model.
"""
return {"strlist": model.strlist, "hashmap": model.hashmap}


class MutableTestState(BaseState):
"""A test state."""

array: List[Union[str, List, Dict[str, str]]] = [
array: List[Union[str, int, List, Dict[str, str]]] = [
"value",
[1, 2, 3],
{"key": "value"},
]
hashmap: Dict[str, Union[List, str, Dict[str, str]]] = {
hashmap: Dict[str, Union[List, str, Dict[str, Union[str, Dict]]]] = {
"key": ["list", "of", "values"],
"another_key": "another_value",
"third_key": {"key": "value"},
}
test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"}
custom: CustomVar = CustomVar()
_be_custom: CustomVar = CustomVar()
sqla_model: MutableSQLAModel = MutableSQLAModel(
strlist=["a", "b", "c"],
hashmap={"key": "value"},
test_set={"one", "two", "three"},
)

def reassign_mutables(self):
"""Assign mutable fields to different values."""
Expand All @@ -171,3 +212,8 @@ def reassign_mutables(self):
"mod_third_key": {"key": "value"},
}
self.test_set = {1, 2, 3, 4, "five"}
self.sqla_model = MutableSQLAModel(
strlist=["d", "e", "f"],
hashmap={"key": "value"},
test_set={"one", "two", "three"},
)
96 changes: 69 additions & 27 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import sys
from textwrap import dedent
from typing import Any, Dict, Generator, List, Optional, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Union
from unittest.mock import AsyncMock, Mock

import dill
Expand Down Expand Up @@ -40,6 +40,7 @@
from reflex.utils import format, prerequisites, types
from reflex.utils.format import json_dumps
from reflex.vars import BaseVar, ComputedVar
from tests.states.mutation import MutableSQLAModel, MutableTestState

from .states import GenState

Expand Down Expand Up @@ -1389,7 +1390,7 @@ def handler(self):
assert bms._be_method()


def test_setattr_of_mutable_types(mutable_state):
def test_setattr_of_mutable_types(mutable_state: MutableTestState):
"""Test that mutable types are converted to corresponding Reflex wrappers.
Args:
Expand All @@ -1398,6 +1399,7 @@ def test_setattr_of_mutable_types(mutable_state):
array = mutable_state.array
hashmap = mutable_state.hashmap
test_set = mutable_state.test_set
sqla_model = mutable_state.sqla_model

assert isinstance(array, MutableProxy)
assert isinstance(array, list)
Expand Down Expand Up @@ -1425,11 +1427,21 @@ def test_setattr_of_mutable_types(mutable_state):
assert isinstance(mutable_state.custom.test_set, set)
assert isinstance(mutable_state.custom.custom, MutableProxy)

assert isinstance(sqla_model, MutableProxy)
assert isinstance(sqla_model, MutableSQLAModel)
assert isinstance(sqla_model.strlist, MutableProxy)
assert isinstance(sqla_model.strlist, list)
assert isinstance(sqla_model.hashmap, MutableProxy)
assert isinstance(sqla_model.hashmap, dict)
assert isinstance(sqla_model.test_set, MutableProxy)
assert isinstance(sqla_model.test_set, set)

mutable_state.reassign_mutables()

array = mutable_state.array
hashmap = mutable_state.hashmap
test_set = mutable_state.test_set
sqla_model = mutable_state.sqla_model

assert isinstance(array, MutableProxy)
assert isinstance(array, list)
Expand All @@ -1448,6 +1460,15 @@ def test_setattr_of_mutable_types(mutable_state):
assert isinstance(test_set, MutableProxy)
assert isinstance(test_set, set)

assert isinstance(sqla_model, MutableProxy)
assert isinstance(sqla_model, MutableSQLAModel)
assert isinstance(sqla_model.strlist, MutableProxy)
assert isinstance(sqla_model.strlist, list)
assert isinstance(sqla_model.hashmap, MutableProxy)
assert isinstance(sqla_model.hashmap, dict)
assert isinstance(sqla_model.test_set, MutableProxy)
assert isinstance(sqla_model.test_set, set)


def test_error_on_state_method_shadow():
"""Test that an error is thrown when an event handler shadows a state method."""
Expand Down Expand Up @@ -2089,7 +2110,7 @@ async def test_background_task_no_chain():
await bts.bad_chain2()


def test_mutable_list(mutable_state):
def test_mutable_list(mutable_state: MutableTestState):
"""Test that mutable lists are tracked correctly.
Args:
Expand Down Expand Up @@ -2119,7 +2140,7 @@ def assert_array_dirty():
assert_array_dirty()
mutable_state.array.reverse()
assert_array_dirty()
mutable_state.array.sort()
mutable_state.array.sort() # type: ignore[reportCallIssue,reportUnknownMemberType]
assert_array_dirty()
mutable_state.array[0] = 666
assert_array_dirty()
Expand All @@ -2143,7 +2164,7 @@ def assert_array_dirty():
assert_array_dirty()


def test_mutable_dict(mutable_state):
def test_mutable_dict(mutable_state: MutableTestState):
"""Test that mutable dicts are tracked correctly.
Args:
Expand All @@ -2157,40 +2178,40 @@ def assert_hashmap_dirty():
assert not mutable_state.dirty_vars

# Test all dict operations
mutable_state.hashmap.update({"new_key": 43})
mutable_state.hashmap.update({"new_key": "43"})
assert_hashmap_dirty()
assert mutable_state.hashmap.setdefault("another_key", 66) == "another_value"
assert mutable_state.hashmap.setdefault("another_key", "66") == "another_value"
assert_hashmap_dirty()
assert mutable_state.hashmap.setdefault("setdefault_key", 67) == 67
assert mutable_state.hashmap.setdefault("setdefault_key", "67") == "67"
assert_hashmap_dirty()
assert mutable_state.hashmap.setdefault("setdefault_key", 68) == 67
assert mutable_state.hashmap.setdefault("setdefault_key", "68") == "67"
assert_hashmap_dirty()
assert mutable_state.hashmap.pop("new_key") == 43
assert mutable_state.hashmap.pop("new_key") == "43"
assert_hashmap_dirty()
mutable_state.hashmap.popitem()
assert_hashmap_dirty()
mutable_state.hashmap.clear()
assert_hashmap_dirty()
mutable_state.hashmap["new_key"] = 42
mutable_state.hashmap["new_key"] = "42"
assert_hashmap_dirty()
del mutable_state.hashmap["new_key"]
assert_hashmap_dirty()
if sys.version_info >= (3, 9):
mutable_state.hashmap |= {"new_key": 44}
mutable_state.hashmap |= {"new_key": "44"}
assert_hashmap_dirty()

# Test nested dict operations
mutable_state.hashmap["array"] = []
assert_hashmap_dirty()
mutable_state.hashmap["array"].append(1)
mutable_state.hashmap["array"].append("1")
assert_hashmap_dirty()
mutable_state.hashmap["dict"] = {}
assert_hashmap_dirty()
mutable_state.hashmap["dict"]["key"] = 42
mutable_state.hashmap["dict"]["key"] = "42"
assert_hashmap_dirty()
mutable_state.hashmap["dict"]["dict"] = {}
assert_hashmap_dirty()
mutable_state.hashmap["dict"]["dict"]["key"] = 43
mutable_state.hashmap["dict"]["dict"]["key"] = "43"
assert_hashmap_dirty()

# Test proxy returned from `setdefault` and `get`
Expand All @@ -2212,14 +2233,14 @@ def assert_hashmap_dirty():
mutable_value_third_ref = mutable_state.hashmap.pop("setdefault_mutable_key")
assert not isinstance(mutable_value_third_ref, MutableProxy)
assert_hashmap_dirty()
mutable_value_third_ref.append("baz")
mutable_value_third_ref.append("baz") # type: ignore[reportUnknownMemberType,reportAttributeAccessIssue,reportUnusedCallResult]
assert not mutable_state.dirty_vars
# Unfortunately previous refs still will mark the state dirty... nothing doing about that
assert mutable_value.pop()
assert_hashmap_dirty()


def test_mutable_set(mutable_state):
def test_mutable_set(mutable_state: MutableTestState):
"""Test that mutable sets are tracked correctly.
Args:
Expand Down Expand Up @@ -2261,7 +2282,7 @@ def assert_set_dirty():
assert_set_dirty()


def test_mutable_custom(mutable_state):
def test_mutable_custom(mutable_state: MutableTestState):
"""Test that mutable custom types derived from Base are tracked correctly.
Args:
Expand All @@ -2276,17 +2297,38 @@ def assert_custom_dirty():

mutable_state.custom.foo = "bar"
assert_custom_dirty()
mutable_state.custom.array.append(42)
mutable_state.custom.array.append("42")
assert_custom_dirty()
mutable_state.custom.hashmap["key"] = 68
mutable_state.custom.hashmap["key"] = "value"
assert_custom_dirty()
mutable_state.custom.test_set.add(42)
mutable_state.custom.test_set.add("foo")
assert_custom_dirty()
mutable_state.custom.custom.bar = "baz"
assert_custom_dirty()


def test_mutable_backend(mutable_state):
def test_mutable_sqla_model(mutable_state: MutableTestState):
"""Test that mutable SQLA models are tracked correctly.
Args:
mutable_state: A test state.
"""
assert not mutable_state.dirty_vars

def assert_sqla_model_dirty():
assert mutable_state.dirty_vars == {"sqla_model"}
mutable_state._clean()
assert not mutable_state.dirty_vars

mutable_state.sqla_model.strlist.append("foo")
assert_sqla_model_dirty()
mutable_state.sqla_model.hashmap["key"] = "value"
assert_sqla_model_dirty()
mutable_state.sqla_model.test_set.add("bar")
assert_sqla_model_dirty()


def test_mutable_backend(mutable_state: MutableTestState):
"""Test that mutable backend vars are tracked correctly.
Args:
Expand All @@ -2301,11 +2343,11 @@ def assert_custom_dirty():

mutable_state._be_custom.foo = "bar"
assert_custom_dirty()
mutable_state._be_custom.array.append(42)
mutable_state._be_custom.array.append("baz")
assert_custom_dirty()
mutable_state._be_custom.hashmap["key"] = 68
mutable_state._be_custom.hashmap["key"] = "value"
assert_custom_dirty()
mutable_state._be_custom.test_set.add(42)
mutable_state._be_custom.test_set.add("foo")
assert_custom_dirty()
mutable_state._be_custom.custom.bar = "baz"
assert_custom_dirty()
Expand All @@ -2318,7 +2360,7 @@ def assert_custom_dirty():
(copy.deepcopy,),
],
)
def test_mutable_copy(mutable_state, copy_func):
def test_mutable_copy(mutable_state: MutableTestState, copy_func: Callable):
"""Test that mutable types are copied correctly.
Args:
Expand All @@ -2345,7 +2387,7 @@ def test_mutable_copy(mutable_state, copy_func):
(copy.deepcopy,),
],
)
def test_mutable_copy_vars(mutable_state, copy_func):
def test_mutable_copy_vars(mutable_state: MutableTestState, copy_func: Callable):
"""Test that mutable types are copied correctly.
Args:
Expand Down

0 comments on commit 2999e7f

Please sign in to comment.