diff --git a/sqlmodel/main.py b/sqlmodel/main.py index d343c698e9..d456b88015 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -34,6 +34,7 @@ from sqlalchemy import Boolean, Column, Date, DateTime from sqlalchemy import Enum as sa_Enum from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect +from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship from sqlalchemy.orm.attributes import set_attribute from sqlalchemy.orm.decl_api import DeclarativeMeta @@ -207,6 +208,7 @@ def Relationship( @__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): __sqlmodel_relationships__: Dict[str, RelationshipInfo] + __sqlalchemy_constructs__: Dict[str, Any] __config__: Type[BaseConfig] __fields__: Dict[str, ModelField] @@ -232,6 +234,7 @@ def __new__( **kwargs: Any, ) -> Any: relationships: Dict[str, RelationshipInfo] = {} + sqlalchemy_constructs = {} dict_for_pydantic = {} original_annotations = resolve_annotations( class_dict.get("__annotations__", {}), class_dict.get("__module__", None) @@ -241,6 +244,8 @@ def __new__( for k, v in class_dict.items(): if isinstance(v, RelationshipInfo): relationships[k] = v + elif isinstance(v, hybrid_property): + sqlalchemy_constructs[k] = v else: dict_for_pydantic[k] = v for k, v in original_annotations.items(): @@ -253,6 +258,7 @@ def __new__( "__weakref__": None, "__sqlmodel_relationships__": relationships, "__annotations__": pydantic_annotations, + "__sqlalchemy_constructs__": sqlalchemy_constructs, } # Duplicate logic from Pydantic to filter config kwargs because if they are # passed directly including the registry Pydantic will pass them over to the @@ -276,6 +282,9 @@ def __new__( **new_cls.__annotations__, } + for k, v in sqlalchemy_constructs.items(): + setattr(new_cls, k, v) + def get_config(name: str) -> Any: config_class_value = getattr(new_cls.__config__, name, Undefined) if config_class_value is not Undefined: @@ -290,8 +299,9 @@ def get_config(name: str) -> Any: # If it was passed by kwargs, ensure it's also set in config new_cls.__config__.table = config_table for k, v in new_cls.__fields__.items(): - col = get_column_from_field(v) - setattr(new_cls, k, col) + if k in sqlalchemy_constructs: + continue + setattr(new_cls, k, get_column_from_field(v)) # Set a config flag to tell FastAPI that this should be read with a field # in orm_mode instead of preemptively converting it to a dict. # This could be done by reading new_cls.__config__.table in FastAPI, but @@ -326,6 +336,8 @@ def __init__( if getattr(cls.__config__, "table", False) and not base_is_table: dict_used = dict_.copy() for field_name, field_value in cls.__fields__.items(): + if field_name in cls.__sqlalchemy_constructs__: + continue dict_used[field_name] = get_column_from_field(field_value) for rel_name, rel_info in cls.__sqlmodel_relationships__.items(): if rel_info.sa_relationship: diff --git a/tests/conftest.py b/tests/conftest.py index cd66420c88..1701b2e032 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import pytest from pydantic import BaseModel -from sqlmodel import SQLModel +from sqlmodel import SQLModel, create_engine from sqlmodel.main import default_registry top_level_path = Path(__file__).resolve().parent.parent @@ -23,6 +23,13 @@ def clear_sqlmodel(): default_registry.dispose() +@pytest.fixture() +def in_memory_engine(clear_sqlmodel): + engine = create_engine("sqlite:///memory") + yield engine + SQLModel.metadata.drop_all(engine, checkfirst=True) + + @pytest.fixture() def cov_tmp_path(tmp_path: Path): yield tmp_path diff --git a/tests/test_sqlalchemy_properties.py b/tests/test_sqlalchemy_properties.py new file mode 100644 index 0000000000..5eada00c4c --- /dev/null +++ b/tests/test_sqlalchemy_properties.py @@ -0,0 +1,41 @@ +from typing import Optional + +from sqlalchemy import func +from sqlalchemy.ext.hybrid import hybrid_property +from sqlmodel import Field, Session, SQLModel, select + + +def test_hybrid_property(in_memory_engine): + class Interval(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + length: float + + @hybrid_property + def radius(self) -> float: + return abs(self.length) / 2 + + @radius.expression + def radius(cls) -> float: + return func.abs(cls.length) / 2 + + class Config: + arbitrary_types_allowed = True + + SQLModel.metadata.create_all(in_memory_engine) + session = Session(in_memory_engine) + + interval = Interval(length=-2) + assert interval.radius == 1 + + session.add(interval) + session.commit() + interval_2 = session.exec(select(Interval)).all()[0] + assert interval_2.radius == 1 + + interval_3 = session.exec(select(Interval).where(Interval.radius == 1)).all()[0] + assert interval_3.radius == 1 + + intervals = session.exec(select(Interval).where(Interval.radius > 1)).all() + assert len(intervals) == 0 + + assert session.exec(select(Interval.radius + 1)).all()[0] == 2.0