Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pydantic BaseModel classes in state #983

Merged
merged 5 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mesop/dataclass_utils/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
load("//build_defs:defaults.bzl", "THIRD_PARTY_PY_DEEPDIFF", "THIRD_PARTY_PY_PANDAS", "THIRD_PARTY_PY_PYTEST", "py_library", "py_test")
load("//build_defs:defaults.bzl", "THIRD_PARTY_PY_DEEPDIFF", "THIRD_PARTY_PY_PANDAS", "THIRD_PARTY_PY_PYDANTIC", "THIRD_PARTY_PY_PYTEST", "py_library", "py_test")

package(
default_visibility = ["//build_defs:mesop_internal"],
Expand All @@ -13,7 +13,7 @@ py_library(
deps = [
"//mesop/components/uploader:uploaded_file",
"//mesop/exceptions",
] + THIRD_PARTY_PY_DEEPDIFF,
] + THIRD_PARTY_PY_DEEPDIFF + THIRD_PARTY_PY_PYDANTIC,
)

py_test(
Expand Down
58 changes: 44 additions & 14 deletions mesop/dataclass_utils/dataclass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@
from deepdiff import DeepDiff, Delta
from deepdiff.operator import BaseOperator
from deepdiff.path import parse_path
from pydantic import BaseModel

from mesop.components.uploader.uploaded_file import UploadedFile
from mesop.exceptions import MesopDeveloperException, MesopException

_PANDAS_OBJECT_KEY = "__pandas.DataFrame__"
_PYDANTIC_OBJECT_KEY = "__pydantic.BaseModel__"
_DATETIME_OBJECT_KEY = "__datetime.datetime__"
_BYTES_OBJECT_KEY = "__python.bytes__"
_SET_OBJECT_KEY = "__python.set__"
_UPLOADED_FILE_OBJECT_KEY = "__mesop.UploadedFile__"
_DIFF_ACTION_DATA_FRAME_CHANGED = "data_frame_changed"
_DIFF_ACTION_UPLOADED_FILE_CHANGED = "mesop_uploaded_file_changed"
_DIFF_ACTION_EQUALITY_CHANGED = "mesop_equality_changed"

C = TypeVar("C")

Expand All @@ -36,6 +38,8 @@ def _check_has_pandas():

_has_pandas = _check_has_pandas()

pydantic_model_cache = {}


def dataclass_with_defaults(cls: Type[C]) -> Type[C]:
"""
Expand Down Expand Up @@ -64,6 +68,14 @@ def dataclass_with_defaults(cls: Type[C]) -> Type[C]:

annotations = get_type_hints(cls)
for name, type_hint in annotations.items():
if (
isinstance(type_hint, type)
and has_parent(type_hint)
and issubclass(type_hint, BaseModel)
):
pydantic_model_cache[(type_hint.__module__, type_hint.__qualname__)] = (
type_hint
)
if name not in cls.__dict__: # Skip if default already set
if type_hint == int:
setattr(cls, name, field(default=0))
Expand Down Expand Up @@ -187,6 +199,15 @@ def default(self, obj):
}
}

if isinstance(obj, BaseModel):
return {
_PYDANTIC_OBJECT_KEY: {
"json": obj.model_dump_json(),
"module": obj.__class__.__module__,
"qualname": obj.__class__.__qualname__,
}
}

if isinstance(obj, datetime):
return {_DATETIME_OBJECT_KEY: obj.isoformat()}

Expand Down Expand Up @@ -221,6 +242,18 @@ def decode_mesop_json_state_hook(dct):
if _PANDAS_OBJECT_KEY in dct:
return pd.read_json(StringIO(dct[_PANDAS_OBJECT_KEY]), orient="table")

if _PYDANTIC_OBJECT_KEY in dct:
cache_key = (
dct[_PYDANTIC_OBJECT_KEY]["module"],
dct[_PYDANTIC_OBJECT_KEY]["qualname"],
)
if cache_key not in pydantic_model_cache:
raise MesopException(
f"Tried to deserialize Pydantic model, but it's not in the cache: {cache_key}"
)
model_class = pydantic_model_cache[cache_key]
return model_class.model_validate_json(dct[_PYDANTIC_OBJECT_KEY]["json"])
wwwillchen marked this conversation as resolved.
Show resolved Hide resolved

if _DATETIME_OBJECT_KEY in dct:
return datetime.fromisoformat(dct[_DATETIME_OBJECT_KEY])

Expand Down Expand Up @@ -269,25 +302,22 @@ def give_up_diffing(self, level, diff_instance) -> bool:
return True


class UploadedFileOperator(BaseOperator):
"""Custom operator to detect changes in UploadedFile class.
class EqualityOperator(BaseOperator):
"""Custom operator to detect changes with direct equality.

DeepDiff does not diff the UploadedFile class correctly, so we will just use a normal
equality check, rather than diffing further into the io.BytesIO parent class.

This class could probably be made more generic to handle other classes where we want
to diff using equality checks.
"""

def match(self, level) -> bool:
return isinstance(level.t1, UploadedFile) and isinstance(
level.t2, UploadedFile
return isinstance(level.t1, (UploadedFile, BaseModel)) and isinstance(
level.t2, (UploadedFile, BaseModel)
)

def give_up_diffing(self, level, diff_instance) -> bool:
if level.t1 != level.t2:
diff_instance.custom_report_result(
_DIFF_ACTION_UPLOADED_FILE_CHANGED, level, {"value": level.t2}
_DIFF_ACTION_EQUALITY_CHANGED, level, {"value": level.t2}
)
return True

Expand All @@ -306,7 +336,7 @@ def diff_state(state1: Any, state2: Any) -> str:
raise MesopException("Tried to diff state which was not a dataclass")

custom_actions = []
custom_operators = [UploadedFileOperator()]
custom_operators = [EqualityOperator()]
# Only use the `DataFrameOperator` if pandas exists.
if _has_pandas:
differences = DeepDiff(
Expand All @@ -328,15 +358,15 @@ def diff_state(state1: Any, state2: Any) -> str:
else:
differences = DeepDiff(state1, state2, custom_operators=custom_operators)

# Manually format UploadedFile diffs to flat dict format.
if _DIFF_ACTION_UPLOADED_FILE_CHANGED in differences:
# Manually format diffs to flat dict format.
if _DIFF_ACTION_EQUALITY_CHANGED in differences:
custom_actions = [
{
"path": parse_path(path),
"action": _DIFF_ACTION_UPLOADED_FILE_CHANGED,
"action": _DIFF_ACTION_EQUALITY_CHANGED,
**diff,
}
for path, diff in differences[_DIFF_ACTION_UPLOADED_FILE_CHANGED].items()
for path, diff in differences[_DIFF_ACTION_EQUALITY_CHANGED].items()
]

# Handle the set case which will have a modified path after being JSON encoded.
Expand Down
82 changes: 82 additions & 0 deletions mesop/dataclass_utils/dataclass_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pandas as pd
import pytest
from pydantic import BaseModel

import mesop.protos.ui_pb2 as pb
from mesop.components.uploader.uploaded_file import UploadedFile
Expand Down Expand Up @@ -49,6 +50,35 @@ class WithUploadedFile:
data: UploadedFile = field(default_factory=UploadedFile)


class NestedPydanticModel(BaseModel):
default_value: str = "default"
no_default_value: str


class PydanticModel(BaseModel):
wwwillchen marked this conversation as resolved.
Show resolved Hide resolved
name: str = "World"
counter: int = 0
list_models: list[NestedPydanticModel] = field(default_factory=lambda: [])
nested: NestedPydanticModel = field(
default_factory=lambda: NestedPydanticModel(
no_default_value="<no_default_factory>"
)
)
optional_value: str | None = None
union_value: str | int = 0
tuple_value: tuple[str, int] = ("a", 1)


@dataclass_with_defaults
class WithPydanticModel:
data: PydanticModel


@dataclass_with_defaults
class WithPydanticModelDefaultFactory:
default_factory: PydanticModel = field(default_factory=PydanticModel)


JSON_STR = """{"b": {"c": {"val": "<init>"}},
"list_b": [
{"c": {"val": "1"}},
Expand Down Expand Up @@ -180,6 +210,58 @@ def test_serialize_uploaded_file():
)


def test_serialize_deserialize_pydantic_model():
state = WithPydanticModel()
state.data.name = "Hello"
state.data.counter = 1
state.data.nested = NestedPydanticModel(no_default_value="no_default")
state.data.list_models.append(
NestedPydanticModel(no_default_value="no_default_list_model_val_1")
)
state.data.list_models.append(
NestedPydanticModel(no_default_value="no_default_list_model_val_2")
)
new_state = WithPydanticModel()
update_dataclass_from_json(new_state, serialize_dataclass(state))
assert new_state == state


def test_serialize_deserialize_pydantic_model_set_optional_value():
state = WithPydanticModel()
state.data.optional_value = "optional"
new_state = WithPydanticModel()
update_dataclass_from_json(new_state, serialize_dataclass(state))
assert new_state == state


def test_serialize_deserialize_pydantic_model_set_union_value():
state = WithPydanticModel()
state.data.union_value = "union_value"
new_state = WithPydanticModel()
update_dataclass_from_json(new_state, serialize_dataclass(state))
assert new_state == state


def test_serialize_deserialize_pydantic_model_set_tuple_value():
state = WithPydanticModel()
state.data.tuple_value = ("tuple_value", 1)
new_state = WithPydanticModel()
update_dataclass_from_json(new_state, serialize_dataclass(state))
assert new_state == state


def test_serialize_deserialize_pydantic_model_default_factory():
state = WithPydanticModelDefaultFactory()
state.default_factory.name = "Hello"
state.default_factory.counter = 1
state.default_factory.nested = NestedPydanticModel(
no_default_value="no_default"
)
new_state = WithPydanticModelDefaultFactory()
update_dataclass_from_json(new_state, serialize_dataclass(state))
assert new_state == state


@pytest.mark.parametrize(
"input_bytes, expected_json",
[
Expand Down
30 changes: 29 additions & 1 deletion mesop/dataclass_utils/diff_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pandas as pd
import pytest
from pydantic import BaseModel

from mesop.components.uploader.uploaded_file import UploadedFile
from mesop.dataclass_utils.dataclass_utils import diff_state
Expand Down Expand Up @@ -409,7 +410,7 @@ class C:
assert json.loads(diff_state(s1, s2)) == [
{
"path": ["data"],
"action": "mesop_uploaded_file_changed",
"action": "mesop_equality_changed",
"value": {
"__mesop.UploadedFile__": {
"contents": "ZGF0YQ==",
Expand All @@ -422,6 +423,33 @@ class C:
]


def test_diff_pydantic_model():
class PydanticModel(BaseModel):
name: str = "World"
counter: int = 0

@dataclass
class C:
data: PydanticModel

s1 = C(data=PydanticModel())
s2 = C(data=PydanticModel(name="Hello", counter=1))

assert json.loads(diff_state(s1, s2)) == [
{
"path": ["data"],
"action": "mesop_equality_changed",
"value": {
"__pydantic.BaseModel__": {
"json": '{"name":"Hello","counter":1}',
"module": "dataclass_utils.diff_state_test",
"qualname": "test_diff_pydantic_model.<locals>.PydanticModel",
},
},
}
]


def test_diff_uploaded_file_same_no_diff():
@dataclass
class C:
Expand Down
1 change: 1 addition & 0 deletions mesop/examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from mesop.examples import on_load_generator as on_load_generator
from mesop.examples import playground as playground
from mesop.examples import playground_critic as playground_critic
from mesop.examples import pydantic_state as pydantic_state
from mesop.examples import query_params as query_params
from mesop.examples import readme_app as readme_app
from mesop.examples import responsive_layout as responsive_layout
Expand Down
27 changes: 27 additions & 0 deletions mesop/examples/pydantic_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from pydantic import BaseModel

import mesop as me


class PydanticModel(BaseModel):
name: str = "World"
counter: int = 0


@me.stateclass
class State:
model: PydanticModel


@me.page(path="/pydantic_state")
def main():
state = me.state(State)
me.text(f"Name: {state.model.name}")
me.text(f"Counter: {state.model.counter}")

me.button("Increment Counter", on_click=on_click)


def on_click(e: me.ClickEvent):
state = me.state(State)
state.model.counter += 1
14 changes: 14 additions & 0 deletions mesop/tests/e2e/pydantic_state_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import {test, expect} from '@playwright/test';

test('pydantic state is serialized and deserialized properly', async ({
page,
}) => {
await page.goto('/pydantic_state');

await expect(page.getByText('Name: world')).toBeVisible();
await expect(page.getByText('Counter: 0')).toBeVisible();
await page.getByRole('button', {name: 'Increment Counter'}).click();
await expect(page.getByText('Counter: 1')).toBeVisible();
await page.getByRole('button', {name: 'Increment Counter'}).click();
await expect(page.getByText('Counter: 2')).toBeVisible();
});
4 changes: 2 additions & 2 deletions mesop/web/src/utils/diff.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ export function applyComponentDiff(component: Component, diff: ComponentDiff) {
const STATE_DIFF_VALUES_CHANGED = 'values_changed';
const STATE_DIFF_TYPE_CHANGES = 'type_changes';
const STATE_DIFF_DATA_FRAME_CHANGED = 'data_frame_changed';
const STATE_DIFF_UPLOADED_FILE_CHANGED = 'mesop_uploaded_file_changed';
const STATE_DIFF_EQUALITY_CHANGED = 'mesop_equality_changed';
const STATE_DIFF_ITERABLE_ITEM_REMOVED = 'iterable_item_removed';
const STATE_DIFF_ITERABLE_ITEM_ADDED = 'iterable_item_added';
const STATE_DIFF_SET_ITEM_REMOVED = 'set_item_removed';
Expand Down Expand Up @@ -118,7 +118,7 @@ export function applyStateDiff(stateJson: string, diffJson: string): string {
row.action === STATE_DIFF_VALUES_CHANGED ||
row.action === STATE_DIFF_TYPE_CHANGES ||
row.action === STATE_DIFF_DATA_FRAME_CHANGED ||
row.action === STATE_DIFF_UPLOADED_FILE_CHANGED
row.action === STATE_DIFF_EQUALITY_CHANGED
) {
updateValue(root, row.path, row.value);
} else if (row.action === STATE_DIFF_DICT_ITEM_ADDED) {
Expand Down
2 changes: 1 addition & 1 deletion mesop/web/src/utils/diff_state_spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ describe('applyStateDiff functionality', () => {
const diff = JSON.stringify([
{
path: ['data'],
action: 'mesop_uploaded_file_changed',
action: 'mesop_equality_changed',
value: {
'__mesop.UploadedFile__': {
'contents': 'data',
Expand Down
Loading