From b96eb5c172932528d1361ec821e29741aecc66bc Mon Sep 17 00:00:00 2001 From: Will Chen Date: Wed, 25 Sep 2024 01:25:02 -0700 Subject: [PATCH 1/5] Support pydantic BaseModel classes in state --- mesop/dataclass_utils/BUILD | 4 +- mesop/dataclass_utils/dataclass_utils.py | 49 +++++++++++++++---- mesop/dataclass_utils/dataclass_utils_test.py | 22 +++++++++ mesop/dataclass_utils/diff_state_test.py | 30 +++++++++++- mesop/examples/__init__.py | 1 + mesop/examples/pydantic_state.py | 27 ++++++++++ mesop/tests/e2e/pydantic_state_test.ts | 14 ++++++ mesop/web/src/utils/diff.ts | 4 +- 8 files changed, 136 insertions(+), 15 deletions(-) create mode 100644 mesop/examples/pydantic_state.py create mode 100644 mesop/tests/e2e/pydantic_state_test.ts diff --git a/mesop/dataclass_utils/BUILD b/mesop/dataclass_utils/BUILD index f86151989..b4d09e62e 100644 --- a/mesop/dataclass_utils/BUILD +++ b/mesop/dataclass_utils/BUILD @@ -1,4 +1,4 @@ -load("//build_defs:defaults.bzl", "THIRD_PARTY_PY_DEEPDIFF", "THIRD_PARTY_PY_PANDAS", "THIRD_PARTY_PY_PYTEST", "py_library", "py_test") +load("//build_defs:defaults.bzl", "THIRD_PARTY_PY_DEEPDIFF", "THIRD_PARTY_PY_PANDAS", "THIRD_PARTY_PY_PYDANTIC", "THIRD_PARTY_PY_PYTEST", "py_library", "py_test") package( default_visibility = ["//build_defs:mesop_internal"], @@ -13,7 +13,7 @@ py_library( deps = [ "//mesop/components/uploader:uploaded_file", "//mesop/exceptions", - ] + THIRD_PARTY_PY_DEEPDIFF, + ] + THIRD_PARTY_PY_DEEPDIFF + THIRD_PARTY_PY_PYDANTIC, ) py_test( diff --git a/mesop/dataclass_utils/dataclass_utils.py b/mesop/dataclass_utils/dataclass_utils.py index fc6108ee3..1a01b363a 100644 --- a/mesop/dataclass_utils/dataclass_utils.py +++ b/mesop/dataclass_utils/dataclass_utils.py @@ -9,17 +9,19 @@ from deepdiff import DeepDiff, Delta from deepdiff.operator import BaseOperator from deepdiff.path import parse_path +from pydantic import BaseModel from mesop.components.uploader.uploaded_file import UploadedFile from mesop.exceptions import MesopDeveloperException, MesopException _PANDAS_OBJECT_KEY = "__pandas.DataFrame__" +_PYDANTIC_OBJECT_KEY = "__pydantic.BaseModel__" _DATETIME_OBJECT_KEY = "__datetime.datetime__" _BYTES_OBJECT_KEY = "__python.bytes__" _SET_OBJECT_KEY = "__python.set__" _UPLOADED_FILE_OBJECT_KEY = "__mesop.UploadedFile__" _DIFF_ACTION_DATA_FRAME_CHANGED = "data_frame_changed" -_DIFF_ACTION_UPLOADED_FILE_CHANGED = "mesop_uploaded_file_changed" +_DIFF_ACTION_EQUALITY_CHANGED = "mesop_equality_changed" C = TypeVar("C") @@ -36,6 +38,8 @@ def _check_has_pandas(): _has_pandas = _check_has_pandas() +pydantic_model_cache = {} + def dataclass_with_defaults(cls: Type[C]) -> Type[C]: """ @@ -79,6 +83,10 @@ def dataclass_with_defaults(cls: Type[C]) -> Type[C]: setattr(cls, name, field(default_factory=dict)) elif isinstance(type_hint, type): if has_parent(type_hint): + if issubclass(type_hint, BaseModel): + pydantic_model_cache[ + (type_hint.__module__, type_hint.__qualname__) + ] = type_hint # If this isn't a simple class (i.e. it inherits from another class) # then we will preserve its semantics (not try to set default values # because it's not a dataclass) and instantiate it with each new instance @@ -187,6 +195,15 @@ def default(self, obj): } } + if isinstance(obj, BaseModel): + return { + _PYDANTIC_OBJECT_KEY: { + "json": obj.model_dump_json(), + "module": obj.__class__.__module__, + "qualname": obj.__class__.__qualname__, + } + } + if isinstance(obj, datetime): return {_DATETIME_OBJECT_KEY: obj.isoformat()} @@ -221,6 +238,18 @@ def decode_mesop_json_state_hook(dct): if _PANDAS_OBJECT_KEY in dct: return pd.read_json(StringIO(dct[_PANDAS_OBJECT_KEY]), orient="table") + if _PYDANTIC_OBJECT_KEY in dct: + cache_key = ( + dct[_PYDANTIC_OBJECT_KEY]["module"], + dct[_PYDANTIC_OBJECT_KEY]["qualname"], + ) + if cache_key not in pydantic_model_cache: + raise MesopException( + f"Tried to deserialize Pydantic model, but it's not in the cache: {cache_key}" + ) + model_class = pydantic_model_cache[cache_key] + return model_class.model_validate_json(dct[_PYDANTIC_OBJECT_KEY]["json"]) + if _DATETIME_OBJECT_KEY in dct: return datetime.fromisoformat(dct[_DATETIME_OBJECT_KEY]) @@ -269,7 +298,7 @@ def give_up_diffing(self, level, diff_instance) -> bool: return True -class UploadedFileOperator(BaseOperator): +class EqualityOperator(BaseOperator): """Custom operator to detect changes in UploadedFile class. DeepDiff does not diff the UploadedFile class correctly, so we will just use a normal @@ -280,14 +309,14 @@ class UploadedFileOperator(BaseOperator): """ def match(self, level) -> bool: - return isinstance(level.t1, UploadedFile) and isinstance( - level.t2, UploadedFile + return isinstance(level.t1, (UploadedFile, BaseModel)) and isinstance( + level.t2, (UploadedFile, BaseModel) ) def give_up_diffing(self, level, diff_instance) -> bool: if level.t1 != level.t2: diff_instance.custom_report_result( - _DIFF_ACTION_UPLOADED_FILE_CHANGED, level, {"value": level.t2} + _DIFF_ACTION_EQUALITY_CHANGED, level, {"value": level.t2} ) return True @@ -306,7 +335,7 @@ def diff_state(state1: Any, state2: Any) -> str: raise MesopException("Tried to diff state which was not a dataclass") custom_actions = [] - custom_operators = [UploadedFileOperator()] + custom_operators = [EqualityOperator()] # Only use the `DataFrameOperator` if pandas exists. if _has_pandas: differences = DeepDiff( @@ -328,15 +357,15 @@ def diff_state(state1: Any, state2: Any) -> str: else: differences = DeepDiff(state1, state2, custom_operators=custom_operators) - # Manually format UploadedFile diffs to flat dict format. - if _DIFF_ACTION_UPLOADED_FILE_CHANGED in differences: + # Manually format diffs to flat dict format. + if _DIFF_ACTION_EQUALITY_CHANGED in differences: custom_actions = [ { "path": parse_path(path), - "action": _DIFF_ACTION_UPLOADED_FILE_CHANGED, + "action": _DIFF_ACTION_EQUALITY_CHANGED, **diff, } - for path, diff in differences[_DIFF_ACTION_UPLOADED_FILE_CHANGED].items() + for path, diff in differences[_DIFF_ACTION_EQUALITY_CHANGED].items() ] # Handle the set case which will have a modified path after being JSON encoded. diff --git a/mesop/dataclass_utils/dataclass_utils_test.py b/mesop/dataclass_utils/dataclass_utils_test.py index 7fadc3301..88bc838ee 100644 --- a/mesop/dataclass_utils/dataclass_utils_test.py +++ b/mesop/dataclass_utils/dataclass_utils_test.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd import pytest +from pydantic import BaseModel import mesop.protos.ui_pb2 as pb from mesop.components.uploader.uploaded_file import UploadedFile @@ -49,6 +50,16 @@ class WithUploadedFile: data: UploadedFile = field(default_factory=UploadedFile) +class PydanticModel(BaseModel): + name: str = "World" + counter: int = 0 + + +@dataclass +class WithPydanticModel: + data: PydanticModel = field(default_factory=PydanticModel) + + JSON_STR = """{"b": {"c": {"val": ""}}, "list_b": [ {"c": {"val": "1"}}, @@ -180,6 +191,17 @@ def test_serialize_uploaded_file(): ) +def test_serialize_pydantic_model(): + serialized_dataclass = serialize_dataclass( + WithPydanticModel(data=PydanticModel(name="Hello", counter=1)) + ) + print("SERIALIZED", serialized_dataclass, type(serialized_dataclass)) + assert ( + serialized_dataclass + == """{"data": {"__pydantic.BaseModel__": {"json": "{\\"name\\":\\"Hello\\",\\"counter\\":1}", "module": "dataclass_utils.dataclass_utils_test", "qualname": "PydanticModel"}}}""" + ) + + @pytest.mark.parametrize( "input_bytes, expected_json", [ diff --git a/mesop/dataclass_utils/diff_state_test.py b/mesop/dataclass_utils/diff_state_test.py index fa52a0fc4..1dd8f06e9 100644 --- a/mesop/dataclass_utils/diff_state_test.py +++ b/mesop/dataclass_utils/diff_state_test.py @@ -5,6 +5,7 @@ import pandas as pd import pytest +from pydantic import BaseModel from mesop.components.uploader.uploaded_file import UploadedFile from mesop.dataclass_utils.dataclass_utils import diff_state @@ -409,7 +410,7 @@ class C: assert json.loads(diff_state(s1, s2)) == [ { "path": ["data"], - "action": "mesop_uploaded_file_changed", + "action": "mesop_equality_changed", "value": { "__mesop.UploadedFile__": { "contents": "ZGF0YQ==", @@ -422,6 +423,33 @@ class C: ] +def test_diff_pydantic_model(): + class PydanticModel(BaseModel): + name: str = "World" + counter: int = 0 + + @dataclass + class C: + data: PydanticModel + + s1 = C(data=PydanticModel()) + s2 = C(data=PydanticModel(name="Hello", counter=1)) + + assert json.loads(diff_state(s1, s2)) == [ + { + "path": ["data"], + "action": "mesop_equality_changed", + "value": { + "__pydantic.BaseModel__": { + "json": '{"name":"Hello","counter":1}', + "module": "dataclass_utils.diff_state_test", + "qualname": "test_diff_pydantic_model..PydanticModel", + }, + }, + } + ] + + def test_diff_uploaded_file_same_no_diff(): @dataclass class C: diff --git a/mesop/examples/__init__.py b/mesop/examples/__init__.py index 21109712f..bd0bb87c7 100644 --- a/mesop/examples/__init__.py +++ b/mesop/examples/__init__.py @@ -35,6 +35,7 @@ from mesop.examples import on_load_generator as on_load_generator from mesop.examples import playground as playground from mesop.examples import playground_critic as playground_critic +from mesop.examples import pydantic_state as pydantic_state from mesop.examples import query_params as query_params from mesop.examples import readme_app as readme_app from mesop.examples import responsive_layout as responsive_layout diff --git a/mesop/examples/pydantic_state.py b/mesop/examples/pydantic_state.py new file mode 100644 index 000000000..53c3943b6 --- /dev/null +++ b/mesop/examples/pydantic_state.py @@ -0,0 +1,27 @@ +from pydantic import BaseModel + +import mesop as me + + +class PydanticModel(BaseModel): + name: str = "World" + counter: int = 0 + + +@me.stateclass +class State: + model: PydanticModel + + +@me.page(path="/pydantic_state") +def main(): + state = me.state(State) + me.text(f"Name: {state.model.name}") + me.text(f"Counter: {state.model.counter}") + + me.button("Increment Counter", on_click=on_click) + + +def on_click(e: me.ClickEvent): + state = me.state(State) + state.model.counter += 1 diff --git a/mesop/tests/e2e/pydantic_state_test.ts b/mesop/tests/e2e/pydantic_state_test.ts new file mode 100644 index 000000000..900ab3159 --- /dev/null +++ b/mesop/tests/e2e/pydantic_state_test.ts @@ -0,0 +1,14 @@ +import {test, expect} from '@playwright/test'; + +test('pydantic state is serialized and deserialized properly', async ({ + page, +}) => { + await page.goto('/pydantic_state'); + + await expect(page.getByText('Name: world')).toBeVisible(); + await expect(page.getByText('Counter: 0')).toBeVisible(); + await page.getByRole('button', {name: 'Increment Counter'}).click(); + await expect(page.getByText('Counter: 1')).toBeVisible(); + // await page.getByRole('button', {name: 'Increment Counter'}).click(); + // await expect(page.getByText('Counter: 2')).toBeVisible(); +}); diff --git a/mesop/web/src/utils/diff.ts b/mesop/web/src/utils/diff.ts index c479eca66..6f77e904c 100644 --- a/mesop/web/src/utils/diff.ts +++ b/mesop/web/src/utils/diff.ts @@ -78,7 +78,7 @@ export function applyComponentDiff(component: Component, diff: ComponentDiff) { const STATE_DIFF_VALUES_CHANGED = 'values_changed'; const STATE_DIFF_TYPE_CHANGES = 'type_changes'; const STATE_DIFF_DATA_FRAME_CHANGED = 'data_frame_changed'; -const STATE_DIFF_UPLOADED_FILE_CHANGED = 'mesop_uploaded_file_changed'; +const STATE_DIFF_EQUALITY_CHANGED = 'mesop_equality_changed'; const STATE_DIFF_ITERABLE_ITEM_REMOVED = 'iterable_item_removed'; const STATE_DIFF_ITERABLE_ITEM_ADDED = 'iterable_item_added'; const STATE_DIFF_SET_ITEM_REMOVED = 'set_item_removed'; @@ -118,7 +118,7 @@ export function applyStateDiff(stateJson: string, diffJson: string): string { row.action === STATE_DIFF_VALUES_CHANGED || row.action === STATE_DIFF_TYPE_CHANGES || row.action === STATE_DIFF_DATA_FRAME_CHANGED || - row.action === STATE_DIFF_UPLOADED_FILE_CHANGED + row.action === STATE_DIFF_EQUALITY_CHANGED ) { updateValue(root, row.path, row.value); } else if (row.action === STATE_DIFF_DICT_ITEM_ADDED) { From 4575212d02d6ee28f0ab473ac90378ed34f5bec0 Mon Sep 17 00:00:00 2001 From: Will Chen Date: Wed, 25 Sep 2024 01:27:05 -0700 Subject: [PATCH 2/5] fix --- mesop/dataclass_utils/dataclass_utils.py | 5 +---- mesop/web/src/utils/diff_state_spec.ts | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/mesop/dataclass_utils/dataclass_utils.py b/mesop/dataclass_utils/dataclass_utils.py index 1a01b363a..afb23df37 100644 --- a/mesop/dataclass_utils/dataclass_utils.py +++ b/mesop/dataclass_utils/dataclass_utils.py @@ -299,13 +299,10 @@ def give_up_diffing(self, level, diff_instance) -> bool: class EqualityOperator(BaseOperator): - """Custom operator to detect changes in UploadedFile class. + """Custom operator to detect changes with direct equality. DeepDiff does not diff the UploadedFile class correctly, so we will just use a normal equality check, rather than diffing further into the io.BytesIO parent class. - - This class could probably be made more generic to handle other classes where we want - to diff using equality checks. """ def match(self, level) -> bool: diff --git a/mesop/web/src/utils/diff_state_spec.ts b/mesop/web/src/utils/diff_state_spec.ts index 608ea3ad3..7aac12d40 100644 --- a/mesop/web/src/utils/diff_state_spec.ts +++ b/mesop/web/src/utils/diff_state_spec.ts @@ -388,7 +388,7 @@ describe('applyStateDiff functionality', () => { const diff = JSON.stringify([ { path: ['data'], - action: 'mesop_uploaded_file_changed', + action: 'mesop_equality_changed', value: { '__mesop.UploadedFile__': { 'contents': 'data', From 4f8b1131141b190273452e31d78fd6bc1cab8827 Mon Sep 17 00:00:00 2001 From: Will Chen Date: Wed, 25 Sep 2024 14:51:11 -0700 Subject: [PATCH 3/5] fix --- mesop/dataclass_utils/dataclass_utils_test.py | 1 - mesop/tests/e2e/pydantic_state_test.ts | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/mesop/dataclass_utils/dataclass_utils_test.py b/mesop/dataclass_utils/dataclass_utils_test.py index 88bc838ee..bec86ded9 100644 --- a/mesop/dataclass_utils/dataclass_utils_test.py +++ b/mesop/dataclass_utils/dataclass_utils_test.py @@ -195,7 +195,6 @@ def test_serialize_pydantic_model(): serialized_dataclass = serialize_dataclass( WithPydanticModel(data=PydanticModel(name="Hello", counter=1)) ) - print("SERIALIZED", serialized_dataclass, type(serialized_dataclass)) assert ( serialized_dataclass == """{"data": {"__pydantic.BaseModel__": {"json": "{\\"name\\":\\"Hello\\",\\"counter\\":1}", "module": "dataclass_utils.dataclass_utils_test", "qualname": "PydanticModel"}}}""" diff --git a/mesop/tests/e2e/pydantic_state_test.ts b/mesop/tests/e2e/pydantic_state_test.ts index 900ab3159..dd5a7b589 100644 --- a/mesop/tests/e2e/pydantic_state_test.ts +++ b/mesop/tests/e2e/pydantic_state_test.ts @@ -9,6 +9,6 @@ test('pydantic state is serialized and deserialized properly', async ({ await expect(page.getByText('Counter: 0')).toBeVisible(); await page.getByRole('button', {name: 'Increment Counter'}).click(); await expect(page.getByText('Counter: 1')).toBeVisible(); - // await page.getByRole('button', {name: 'Increment Counter'}).click(); - // await expect(page.getByText('Counter: 2')).toBeVisible(); + await page.getByRole('button', {name: 'Increment Counter'}).click(); + await expect(page.getByText('Counter: 2')).toBeVisible(); }); From 1e4453f1eed04b63d6c2dc70b9b1acac2902565f Mon Sep 17 00:00:00 2001 From: Will Chen Date: Wed, 25 Sep 2024 21:22:46 -0700 Subject: [PATCH 4/5] pr --- mesop/dataclass_utils/dataclass_utils.py | 12 ++-- mesop/dataclass_utils/dataclass_utils_test.py | 68 ++++++++++++++++--- 2 files changed, 68 insertions(+), 12 deletions(-) diff --git a/mesop/dataclass_utils/dataclass_utils.py b/mesop/dataclass_utils/dataclass_utils.py index afb23df37..fd01ebc8b 100644 --- a/mesop/dataclass_utils/dataclass_utils.py +++ b/mesop/dataclass_utils/dataclass_utils.py @@ -68,6 +68,14 @@ def dataclass_with_defaults(cls: Type[C]) -> Type[C]: annotations = get_type_hints(cls) for name, type_hint in annotations.items(): + if ( + isinstance(type_hint, type) + and has_parent(type_hint) + and issubclass(type_hint, BaseModel) + ): + pydantic_model_cache[(type_hint.__module__, type_hint.__qualname__)] = ( + type_hint + ) if name not in cls.__dict__: # Skip if default already set if type_hint == int: setattr(cls, name, field(default=0)) @@ -83,10 +91,6 @@ def dataclass_with_defaults(cls: Type[C]) -> Type[C]: setattr(cls, name, field(default_factory=dict)) elif isinstance(type_hint, type): if has_parent(type_hint): - if issubclass(type_hint, BaseModel): - pydantic_model_cache[ - (type_hint.__module__, type_hint.__qualname__) - ] = type_hint # If this isn't a simple class (i.e. it inherits from another class) # then we will preserve its semantics (not try to set default values # because it's not a dataclass) and instantiate it with each new instance diff --git a/mesop/dataclass_utils/dataclass_utils_test.py b/mesop/dataclass_utils/dataclass_utils_test.py index bec86ded9..936dadcd5 100644 --- a/mesop/dataclass_utils/dataclass_utils_test.py +++ b/mesop/dataclass_utils/dataclass_utils_test.py @@ -50,14 +50,32 @@ class WithUploadedFile: data: UploadedFile = field(default_factory=UploadedFile) +class NestedPydanticModel(BaseModel): + default_value: str = "default" + no_default_value: str + + class PydanticModel(BaseModel): name: str = "World" counter: int = 0 + list_models: list[NestedPydanticModel] = field(default_factory=lambda: []) + nested: NestedPydanticModel = field( + default_factory=lambda: NestedPydanticModel( + no_default_value="" + ) + ) + optional_value: str | None = None + union_value: str | int = 0 -@dataclass +@dataclass_with_defaults class WithPydanticModel: - data: PydanticModel = field(default_factory=PydanticModel) + data: PydanticModel + + +@dataclass_with_defaults +class WithPydanticModelDefaultFactory: + default_factory: PydanticModel = field(default_factory=PydanticModel) JSON_STR = """{"b": {"c": {"val": ""}}, @@ -191,14 +209,48 @@ def test_serialize_uploaded_file(): ) -def test_serialize_pydantic_model(): - serialized_dataclass = serialize_dataclass( - WithPydanticModel(data=PydanticModel(name="Hello", counter=1)) +def test_serialize_deserialize_pydantic_model(): + state = WithPydanticModel() + state.data.name = "Hello" + state.data.counter = 1 + state.data.nested = NestedPydanticModel(no_default_value="no_default") + state.data.list_models.append( + NestedPydanticModel(no_default_value="no_default_list_model_val_1") ) - assert ( - serialized_dataclass - == """{"data": {"__pydantic.BaseModel__": {"json": "{\\"name\\":\\"Hello\\",\\"counter\\":1}", "module": "dataclass_utils.dataclass_utils_test", "qualname": "PydanticModel"}}}""" + state.data.list_models.append( + NestedPydanticModel(no_default_value="no_default_list_model_val_2") ) + new_state = WithPydanticModel() + update_dataclass_from_json(new_state, serialize_dataclass(state)) + assert new_state == state + + +def test_serialize_deserialize_pydantic_model_set_optional_value(): + state = WithPydanticModel() + state.data.optional_value = "optional" + new_state = WithPydanticModel() + update_dataclass_from_json(new_state, serialize_dataclass(state)) + assert new_state == state + + +def test_serialize_deserialize_pydantic_model_set_union_value(): + state = WithPydanticModel() + state.data.union_value = "union_value" + new_state = WithPydanticModel() + update_dataclass_from_json(new_state, serialize_dataclass(state)) + assert new_state == state + + +def test_serialize_deserialize_pydantic_model_default_factory(): + state = WithPydanticModelDefaultFactory() + state.default_factory.name = "Hello" + state.default_factory.counter = 1 + state.default_factory.nested = NestedPydanticModel( + no_default_value="no_default" + ) + new_state = WithPydanticModelDefaultFactory() + update_dataclass_from_json(new_state, serialize_dataclass(state)) + assert new_state == state @pytest.mark.parametrize( From fc4f3cf5380a0aed78f331ca83700d9336db9dc7 Mon Sep 17 00:00:00 2001 From: Will Chen Date: Wed, 25 Sep 2024 21:23:52 -0700 Subject: [PATCH 5/5] done --- mesop/dataclass_utils/dataclass_utils_test.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mesop/dataclass_utils/dataclass_utils_test.py b/mesop/dataclass_utils/dataclass_utils_test.py index 936dadcd5..2cd5bd2b9 100644 --- a/mesop/dataclass_utils/dataclass_utils_test.py +++ b/mesop/dataclass_utils/dataclass_utils_test.py @@ -66,6 +66,7 @@ class PydanticModel(BaseModel): ) optional_value: str | None = None union_value: str | int = 0 + tuple_value: tuple[str, int] = ("a", 1) @dataclass_with_defaults @@ -241,6 +242,14 @@ def test_serialize_deserialize_pydantic_model_set_union_value(): assert new_state == state +def test_serialize_deserialize_pydantic_model_set_tuple_value(): + state = WithPydanticModel() + state.data.tuple_value = ("tuple_value", 1) + new_state = WithPydanticModel() + update_dataclass_from_json(new_state, serialize_dataclass(state)) + assert new_state == state + + def test_serialize_deserialize_pydantic_model_default_factory(): state = WithPydanticModelDefaultFactory() state.default_factory.name = "Hello"