diff --git a/reflex/state.py b/reflex/state.py index 88eb7ae7364..a3961d6042c 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -28,6 +28,7 @@ ) import dill +from sqlalchemy.orm import DeclarativeBase try: import pydantic.v1 as pydantic @@ -2919,7 +2920,7 @@ class MutableProxy(wrapt.ObjectProxy): pydantic.BaseModel.__dict__ ) - __mutable_types__ = (list, dict, set, Base) + __mutable_types__ = (list, dict, set, Base, DeclarativeBase) def __init__(self, wrapped: Any, state: BaseState, field_name: str): """Create a proxy for a mutable object that tracks changes. diff --git a/tests/conftest.py b/tests/conftest.py index 71815ca9ad4..589d35cd71b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -231,7 +231,7 @@ def tmp_working_dir(tmp_path): @pytest.fixture -def mutable_state(): +def mutable_state() -> MutableTestState: """Create a Test state containing mutable types. Returns: diff --git a/tests/states/mutation.py b/tests/states/mutation.py index 5825b6d12bd..b05f558a1ac 100644 --- a/tests/states/mutation.py +++ b/tests/states/mutation.py @@ -2,8 +2,12 @@ from typing import Dict, List, Set, Union +from sqlalchemy import ARRAY, JSON, String +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + import reflex as rx from reflex.state import BaseState +from reflex.utils.serializers import serializer class DictMutationTestState(BaseState): @@ -145,15 +149,47 @@ class CustomVar(rx.Base): custom: OtherBase = OtherBase() +class MutableSQLABase(DeclarativeBase): + """SQLAlchemy base model for mutable vars.""" + + pass + + +class MutableSQLAModel(MutableSQLABase): + """SQLAlchemy model for mutable vars.""" + + __tablename__: str = "mutable_test_state" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + strlist: Mapped[List[str]] = mapped_column(ARRAY(String)) + hashmap: Mapped[Dict[str, str]] = mapped_column(JSON) + test_set: Mapped[Set[str]] = mapped_column(ARRAY(String)) + + +@serializer +def serialize_mutable_sqla_model( + model: MutableSQLAModel, +) -> Dict[str, Union[List[str], Dict[str, str]]]: + """Serialize the MutableSQLAModel. + + Args: + model: The MutableSQLAModel instance to serialize. + + Returns: + The serialized model. + """ + return {"strlist": model.strlist, "hashmap": model.hashmap} + + class MutableTestState(BaseState): """A test state.""" - array: List[Union[str, List, Dict[str, str]]] = [ + array: List[Union[str, int, List, Dict[str, str]]] = [ "value", [1, 2, 3], {"key": "value"}, ] - hashmap: Dict[str, Union[List, str, Dict[str, str]]] = { + hashmap: Dict[str, Union[List, str, Dict[str, Union[str, Dict]]]] = { "key": ["list", "of", "values"], "another_key": "another_value", "third_key": {"key": "value"}, @@ -161,6 +197,11 @@ class MutableTestState(BaseState): test_set: Set[Union[str, int]] = {1, 2, 3, 4, "five"} custom: CustomVar = CustomVar() _be_custom: CustomVar = CustomVar() + sqla_model: MutableSQLAModel = MutableSQLAModel( + strlist=["a", "b", "c"], + hashmap={"key": "value"}, + test_set={"one", "two", "three"}, + ) def reassign_mutables(self): """Assign mutable fields to different values.""" @@ -171,3 +212,8 @@ def reassign_mutables(self): "mod_third_key": {"key": "value"}, } self.test_set = {1, 2, 3, 4, "five"} + self.sqla_model = MutableSQLAModel( + strlist=["d", "e", "f"], + hashmap={"key": "value"}, + test_set={"one", "two", "three"}, + ) diff --git a/tests/test_state.py b/tests/test_state.py index 211ca297023..d3cd700c47c 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -7,7 +7,7 @@ import os import sys from textwrap import dedent -from typing import Any, Dict, Generator, List, Optional, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Union from unittest.mock import AsyncMock, Mock import dill @@ -40,6 +40,7 @@ from reflex.utils import format, prerequisites, types from reflex.utils.format import json_dumps from reflex.vars import BaseVar, ComputedVar +from tests.states.mutation import MutableSQLAModel, MutableTestState from .states import GenState @@ -1389,7 +1390,7 @@ def handler(self): assert bms._be_method() -def test_setattr_of_mutable_types(mutable_state): +def test_setattr_of_mutable_types(mutable_state: MutableTestState): """Test that mutable types are converted to corresponding Reflex wrappers. Args: @@ -1398,6 +1399,7 @@ def test_setattr_of_mutable_types(mutable_state): array = mutable_state.array hashmap = mutable_state.hashmap test_set = mutable_state.test_set + sqla_model = mutable_state.sqla_model assert isinstance(array, MutableProxy) assert isinstance(array, list) @@ -1425,11 +1427,21 @@ def test_setattr_of_mutable_types(mutable_state): assert isinstance(mutable_state.custom.test_set, set) assert isinstance(mutable_state.custom.custom, MutableProxy) + assert isinstance(sqla_model, MutableProxy) + assert isinstance(sqla_model, MutableSQLAModel) + assert isinstance(sqla_model.strlist, MutableProxy) + assert isinstance(sqla_model.strlist, list) + assert isinstance(sqla_model.hashmap, MutableProxy) + assert isinstance(sqla_model.hashmap, dict) + assert isinstance(sqla_model.test_set, MutableProxy) + assert isinstance(sqla_model.test_set, set) + mutable_state.reassign_mutables() array = mutable_state.array hashmap = mutable_state.hashmap test_set = mutable_state.test_set + sqla_model = mutable_state.sqla_model assert isinstance(array, MutableProxy) assert isinstance(array, list) @@ -1448,6 +1460,15 @@ def test_setattr_of_mutable_types(mutable_state): assert isinstance(test_set, MutableProxy) assert isinstance(test_set, set) + assert isinstance(sqla_model, MutableProxy) + assert isinstance(sqla_model, MutableSQLAModel) + assert isinstance(sqla_model.strlist, MutableProxy) + assert isinstance(sqla_model.strlist, list) + assert isinstance(sqla_model.hashmap, MutableProxy) + assert isinstance(sqla_model.hashmap, dict) + assert isinstance(sqla_model.test_set, MutableProxy) + assert isinstance(sqla_model.test_set, set) + def test_error_on_state_method_shadow(): """Test that an error is thrown when an event handler shadows a state method.""" @@ -2089,7 +2110,7 @@ async def test_background_task_no_chain(): await bts.bad_chain2() -def test_mutable_list(mutable_state): +def test_mutable_list(mutable_state: MutableTestState): """Test that mutable lists are tracked correctly. Args: @@ -2119,7 +2140,7 @@ def assert_array_dirty(): assert_array_dirty() mutable_state.array.reverse() assert_array_dirty() - mutable_state.array.sort() + mutable_state.array.sort() # type: ignore[reportCallIssue,reportUnknownMemberType] assert_array_dirty() mutable_state.array[0] = 666 assert_array_dirty() @@ -2143,7 +2164,7 @@ def assert_array_dirty(): assert_array_dirty() -def test_mutable_dict(mutable_state): +def test_mutable_dict(mutable_state: MutableTestState): """Test that mutable dicts are tracked correctly. Args: @@ -2157,40 +2178,40 @@ def assert_hashmap_dirty(): assert not mutable_state.dirty_vars # Test all dict operations - mutable_state.hashmap.update({"new_key": 43}) + mutable_state.hashmap.update({"new_key": "43"}) assert_hashmap_dirty() - assert mutable_state.hashmap.setdefault("another_key", 66) == "another_value" + assert mutable_state.hashmap.setdefault("another_key", "66") == "another_value" assert_hashmap_dirty() - assert mutable_state.hashmap.setdefault("setdefault_key", 67) == 67 + assert mutable_state.hashmap.setdefault("setdefault_key", "67") == "67" assert_hashmap_dirty() - assert mutable_state.hashmap.setdefault("setdefault_key", 68) == 67 + assert mutable_state.hashmap.setdefault("setdefault_key", "68") == "67" assert_hashmap_dirty() - assert mutable_state.hashmap.pop("new_key") == 43 + assert mutable_state.hashmap.pop("new_key") == "43" assert_hashmap_dirty() mutable_state.hashmap.popitem() assert_hashmap_dirty() mutable_state.hashmap.clear() assert_hashmap_dirty() - mutable_state.hashmap["new_key"] = 42 + mutable_state.hashmap["new_key"] = "42" assert_hashmap_dirty() del mutable_state.hashmap["new_key"] assert_hashmap_dirty() if sys.version_info >= (3, 9): - mutable_state.hashmap |= {"new_key": 44} + mutable_state.hashmap |= {"new_key": "44"} assert_hashmap_dirty() # Test nested dict operations mutable_state.hashmap["array"] = [] assert_hashmap_dirty() - mutable_state.hashmap["array"].append(1) + mutable_state.hashmap["array"].append("1") assert_hashmap_dirty() mutable_state.hashmap["dict"] = {} assert_hashmap_dirty() - mutable_state.hashmap["dict"]["key"] = 42 + mutable_state.hashmap["dict"]["key"] = "42" assert_hashmap_dirty() mutable_state.hashmap["dict"]["dict"] = {} assert_hashmap_dirty() - mutable_state.hashmap["dict"]["dict"]["key"] = 43 + mutable_state.hashmap["dict"]["dict"]["key"] = "43" assert_hashmap_dirty() # Test proxy returned from `setdefault` and `get` @@ -2212,14 +2233,14 @@ def assert_hashmap_dirty(): mutable_value_third_ref = mutable_state.hashmap.pop("setdefault_mutable_key") assert not isinstance(mutable_value_third_ref, MutableProxy) assert_hashmap_dirty() - mutable_value_third_ref.append("baz") + mutable_value_third_ref.append("baz") # type: ignore[reportUnknownMemberType,reportAttributeAccessIssue,reportUnusedCallResult] assert not mutable_state.dirty_vars # Unfortunately previous refs still will mark the state dirty... nothing doing about that assert mutable_value.pop() assert_hashmap_dirty() -def test_mutable_set(mutable_state): +def test_mutable_set(mutable_state: MutableTestState): """Test that mutable sets are tracked correctly. Args: @@ -2261,7 +2282,7 @@ def assert_set_dirty(): assert_set_dirty() -def test_mutable_custom(mutable_state): +def test_mutable_custom(mutable_state: MutableTestState): """Test that mutable custom types derived from Base are tracked correctly. Args: @@ -2276,17 +2297,38 @@ def assert_custom_dirty(): mutable_state.custom.foo = "bar" assert_custom_dirty() - mutable_state.custom.array.append(42) + mutable_state.custom.array.append("42") assert_custom_dirty() - mutable_state.custom.hashmap["key"] = 68 + mutable_state.custom.hashmap["key"] = "value" assert_custom_dirty() - mutable_state.custom.test_set.add(42) + mutable_state.custom.test_set.add("foo") assert_custom_dirty() mutable_state.custom.custom.bar = "baz" assert_custom_dirty() -def test_mutable_backend(mutable_state): +def test_mutable_sqla_model(mutable_state: MutableTestState): + """Test that mutable SQLA models are tracked correctly. + + Args: + mutable_state: A test state. + """ + assert not mutable_state.dirty_vars + + def assert_sqla_model_dirty(): + assert mutable_state.dirty_vars == {"sqla_model"} + mutable_state._clean() + assert not mutable_state.dirty_vars + + mutable_state.sqla_model.strlist.append("foo") + assert_sqla_model_dirty() + mutable_state.sqla_model.hashmap["key"] = "value" + assert_sqla_model_dirty() + mutable_state.sqla_model.test_set.add("bar") + assert_sqla_model_dirty() + + +def test_mutable_backend(mutable_state: MutableTestState): """Test that mutable backend vars are tracked correctly. Args: @@ -2301,11 +2343,11 @@ def assert_custom_dirty(): mutable_state._be_custom.foo = "bar" assert_custom_dirty() - mutable_state._be_custom.array.append(42) + mutable_state._be_custom.array.append("baz") assert_custom_dirty() - mutable_state._be_custom.hashmap["key"] = 68 + mutable_state._be_custom.hashmap["key"] = "value" assert_custom_dirty() - mutable_state._be_custom.test_set.add(42) + mutable_state._be_custom.test_set.add("foo") assert_custom_dirty() mutable_state._be_custom.custom.bar = "baz" assert_custom_dirty() @@ -2318,7 +2360,7 @@ def assert_custom_dirty(): (copy.deepcopy,), ], ) -def test_mutable_copy(mutable_state, copy_func): +def test_mutable_copy(mutable_state: MutableTestState, copy_func: Callable): """Test that mutable types are copied correctly. Args: @@ -2345,7 +2387,7 @@ def test_mutable_copy(mutable_state, copy_func): (copy.deepcopy,), ], ) -def test_mutable_copy_vars(mutable_state, copy_func): +def test_mutable_copy_vars(mutable_state: MutableTestState, copy_func: Callable): """Test that mutable types are copied correctly. Args: