Skip to content

Commit

Permalink
✨ Add support for hybrid_property
Browse files Browse the repository at this point in the history
  • Loading branch information
van51 committed Nov 2, 2022
1 parent 75ce455 commit 32b125c
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 3 deletions.
16 changes: 14 additions & 2 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sqlalchemy import Boolean, Column, Date, DateTime
from sqlalchemy import Enum as sa_Enum
from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship
from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.decl_api import DeclarativeMeta
Expand Down Expand Up @@ -207,6 +208,7 @@ def Relationship(
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
__sqlmodel_relationships__: Dict[str, RelationshipInfo]
__sqlalchemy_constructs__: Dict[str, Any]
__config__: Type[BaseConfig]
__fields__: Dict[str, ModelField]

Expand All @@ -232,6 +234,7 @@ def __new__(
**kwargs: Any,
) -> Any:
relationships: Dict[str, RelationshipInfo] = {}
sqlalchemy_constructs = {}
dict_for_pydantic = {}
original_annotations = resolve_annotations(
class_dict.get("__annotations__", {}), class_dict.get("__module__", None)
Expand All @@ -241,6 +244,8 @@ def __new__(
for k, v in class_dict.items():
if isinstance(v, RelationshipInfo):
relationships[k] = v
elif isinstance(v, hybrid_property):
sqlalchemy_constructs[k] = v
else:
dict_for_pydantic[k] = v
for k, v in original_annotations.items():
Expand All @@ -253,6 +258,7 @@ def __new__(
"__weakref__": None,
"__sqlmodel_relationships__": relationships,
"__annotations__": pydantic_annotations,
"__sqlalchemy_constructs__": sqlalchemy_constructs,
}
# Duplicate logic from Pydantic to filter config kwargs because if they are
# passed directly including the registry Pydantic will pass them over to the
Expand All @@ -276,6 +282,9 @@ def __new__(
**new_cls.__annotations__,
}

for k, v in sqlalchemy_constructs.items():
setattr(new_cls, k, v)

def get_config(name: str) -> Any:
config_class_value = getattr(new_cls.__config__, name, Undefined)
if config_class_value is not Undefined:
Expand All @@ -290,8 +299,9 @@ def get_config(name: str) -> Any:
# If it was passed by kwargs, ensure it's also set in config
new_cls.__config__.table = config_table
for k, v in new_cls.__fields__.items():
col = get_column_from_field(v)
setattr(new_cls, k, col)
if k in sqlalchemy_constructs:
continue
setattr(new_cls, k, get_column_from_field(v))
# Set a config flag to tell FastAPI that this should be read with a field
# in orm_mode instead of preemptively converting it to a dict.
# This could be done by reading new_cls.__config__.table in FastAPI, but
Expand Down Expand Up @@ -326,6 +336,8 @@ def __init__(
if getattr(cls.__config__, "table", False) and not base_is_table:
dict_used = dict_.copy()
for field_name, field_value in cls.__fields__.items():
if field_name in cls.__sqlalchemy_constructs__:
continue
dict_used[field_name] = get_column_from_field(field_value)
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
if rel_info.sa_relationship:
Expand Down
9 changes: 8 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest
from pydantic import BaseModel
from sqlmodel import SQLModel
from sqlmodel import SQLModel, create_engine
from sqlmodel.main import default_registry

top_level_path = Path(__file__).resolve().parent.parent
Expand All @@ -23,6 +23,13 @@ def clear_sqlmodel():
default_registry.dispose()


@pytest.fixture()
def in_memory_engine(clear_sqlmodel):
engine = create_engine("sqlite:///memory")
yield engine
SQLModel.metadata.drop_all(engine, checkfirst=True)


@pytest.fixture()
def cov_tmp_path(tmp_path: Path):
yield tmp_path
Expand Down
41 changes: 41 additions & 0 deletions tests/test_sqlalchemy_properties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Optional

from sqlalchemy import func
from sqlalchemy.ext.hybrid import hybrid_property
from sqlmodel import Field, Session, SQLModel, select


def test_hybrid_property(in_memory_engine):
class Interval(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
length: float

@hybrid_property
def radius(self) -> float:
return abs(self.length) / 2

@radius.expression
def radius(cls) -> float:
return func.abs(cls.length) / 2

class Config:
arbitrary_types_allowed = True

SQLModel.metadata.create_all(in_memory_engine)
session = Session(in_memory_engine)

interval = Interval(length=-2)
assert interval.radius == 1

session.add(interval)
session.commit()
interval_2 = session.exec(select(Interval)).all()[0]
assert interval_2.radius == 1

interval_3 = session.exec(select(Interval).where(Interval.radius == 1)).all()[0]
assert interval_3.radius == 1

intervals = session.exec(select(Interval).where(Interval.radius > 1)).all()
assert len(intervals) == 0

assert session.exec(select(Interval.radius + 1)).all()[0] == 2.0

0 comments on commit 32b125c

Please sign in to comment.