Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REF-3225] implement __format__ for immutable vars #3617

17 changes: 16 additions & 1 deletion reflex/experimental/vars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import sys
from typing import Any, Optional, Type

from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
from reflex.utils import serializers, types
from reflex.utils.exceptions import VarTypeError
from reflex.vars import Var, VarData, _extract_var_data
from reflex.vars import Var, VarData, _extract_var_data, _global_vars


@dataclasses.dataclass(
Expand Down Expand Up @@ -156,3 +157,17 @@ def create(
_var_type=type_,
_var_data=_var_data,
)

def __format__(self, format_spec: str) -> str:
"""Format the var into a Javascript equivalent to an f-string.

Args:
format_spec: The format specifier (Ignored for now).

Returns:
The formatted var.
"""
_global_vars[hash(self)] = self

# Encode the _var_data into the formatted output for tracking purposes.
return f"{REFLEX_VAR_OPENING_TAG}{hash(self)}{REFLEX_VAR_CLOSING_TAG}{self._var_name}"
36 changes: 27 additions & 9 deletions reflex/vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ def _encode_var(value: Var) -> str:
)
_decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL)

# Defined global immutable vars.
_global_vars: Dict[int, Var] = {}


def _decode_var(value: str) -> tuple[VarData | None, str]:
"""Decode the state name from a formatted var.
Expand Down Expand Up @@ -294,17 +297,32 @@ def json_loads(s):
start, end = m.span()
value = value[:start] + value[end:]

# Read the JSON, pull out the string length, parse the rest as VarData.
data = json_loads(m.group(1))
string_length = data.pop("string_length", None)
var_data = VarData.parse_obj(data)
serialized_data = m.group(1)

if serialized_data[1:].isnumeric():
# This is a global immutable var.
var = _global_vars[int(serialized_data)]
var_data = var._var_data

if var_data is not None:
realstart = start + offset
var_data.interpolations = [
(realstart, realstart + len(var._var_name))
]

var_datas.append(var_data)
else:
# Read the JSON, pull out the string length, parse the rest as VarData.
data = json_loads(serialized_data)
string_length = data.pop("string_length", None)
var_data = VarData.parse_obj(data)

# Use string length to compute positions of interpolations.
if string_length is not None:
realstart = start + offset
var_data.interpolations = [(realstart, realstart + string_length)]
# Use string length to compute positions of interpolations.
if string_length is not None:
realstart = start + offset
var_data.interpolations = [(realstart, realstart + string_length)]

var_datas.append(var_data)
var_datas.append(var_data)
offset += end - start

return VarData.merge(*var_datas) if var_datas else None, value
Expand Down
3 changes: 3 additions & 0 deletions reflex/vars.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ USED_VARIABLES: Incomplete

def get_unique_variable_name() -> str: ...
def _encode_var(value: Var) -> str: ...

_global_vars: Dict[int, Var] = {}
adhami3310 marked this conversation as resolved.
Show resolved Hide resolved

def _decode_var(value: str) -> tuple[VarData, str]: ...
def _extract_var_data(value: Iterable) -> list[VarData | None]: ...

Expand Down
28 changes: 28 additions & 0 deletions tests/test_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@
from pandas import DataFrame

from reflex.base import Base
from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
from reflex.experimental.vars.base import ImmutableVar
from reflex.state import BaseState
from reflex.utils.imports import ImportVar
from reflex.vars import (
BaseVar,
ComputedVar,
Var,
VarData,
computed_var,
)

Expand Down Expand Up @@ -835,6 +839,30 @@ def test_state_with_initial_computed_var(
assert runtime_dict[var_name] == expected_runtime


def test_retrival():
var_without_data = ImmutableVar.create("test")
assert var_without_data is not None

original_var_data = VarData(
state="Test",
imports={"react": [ImportVar(tag="useRef")]},
hooks={"const state = useContext(StateContexts.state)": None},
)

var_with_data = var_without_data._replace(merge_var_data=original_var_data)

f_string = f"foo{var_with_data}bar"

assert REFLEX_VAR_OPENING_TAG in f_string
assert REFLEX_VAR_CLOSING_TAG in f_string

result_var_data = Var.create_safe(f"foo{var_with_data}bar")._var_data
adhami3310 marked this conversation as resolved.
Show resolved Hide resolved
assert result_var_data is not None
assert result_var_data.state == original_var_data.state
assert result_var_data.imports == original_var_data.imports
assert result_var_data.hooks == original_var_data.hooks


@pytest.mark.parametrize(
"out, expected",
[
Expand Down
Loading