Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add support for hybrid_property #482

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
✨ Add support for hybrid_property
van51 committed Nov 2, 2022
commit 1359098a00f0ae57c5b8166c5f07269602929e89
18 changes: 16 additions & 2 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@
from sqlalchemy import Boolean, Column, Date, DateTime
from sqlalchemy import Enum as sa_Enum
from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship
from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.decl_api import DeclarativeMeta
@@ -207,6 +208,7 @@ def Relationship(
@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta):
__sqlmodel_relationships__: Dict[str, RelationshipInfo]
__sqlalchemy_constructs__: Dict[str, Any]
__config__: Type[BaseConfig]
__fields__: Dict[str, ModelField]

@@ -232,6 +234,7 @@ def __new__(
**kwargs: Any,
) -> Any:
relationships: Dict[str, RelationshipInfo] = {}
sqlalchemy_constructs = {}
dict_for_pydantic = {}
original_annotations = resolve_annotations(
class_dict.get("__annotations__", {}), class_dict.get("__module__", None)
@@ -241,6 +244,8 @@ def __new__(
for k, v in class_dict.items():
if isinstance(v, RelationshipInfo):
relationships[k] = v
elif isinstance(v, hybrid_property):
sqlalchemy_constructs[k] = v
else:
dict_for_pydantic[k] = v
for k, v in original_annotations.items():
@@ -253,6 +258,7 @@ def __new__(
"__weakref__": None,
"__sqlmodel_relationships__": relationships,
"__annotations__": pydantic_annotations,
"__sqlalchemy_constructs__": sqlalchemy_constructs,
}
# Duplicate logic from Pydantic to filter config kwargs because if they are
# passed directly including the registry Pydantic will pass them over to the
@@ -276,6 +282,11 @@ def __new__(
**new_cls.__annotations__,
}

# We did not provide the sqlalchemy constructs to Pydantic's new function above
# so that they wouldn't be modified. Instead we set them directly to the class below:
for k, v in sqlalchemy_constructs.items():
setattr(new_cls, k, v)

def get_config(name: str) -> Any:
config_class_value = getattr(new_cls.__config__, name, Undefined)
if config_class_value is not Undefined:
@@ -290,8 +301,9 @@ def get_config(name: str) -> Any:
# If it was passed by kwargs, ensure it's also set in config
new_cls.__config__.table = config_table
for k, v in new_cls.__fields__.items():
col = get_column_from_field(v)
setattr(new_cls, k, col)
if k in sqlalchemy_constructs:
continue
setattr(new_cls, k, get_column_from_field(v))
# Set a config flag to tell FastAPI that this should be read with a field
# in orm_mode instead of preemptively converting it to a dict.
# This could be done by reading new_cls.__config__.table in FastAPI, but
@@ -326,6 +338,8 @@ def __init__(
if getattr(cls.__config__, "table", False) and not base_is_table:
dict_used = dict_.copy()
for field_name, field_value in cls.__fields__.items():
if field_name in cls.__sqlalchemy_constructs__:
continue
dict_used[field_name] = get_column_from_field(field_value)
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
if rel_info.sa_relationship:
9 changes: 8 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@

import pytest
from pydantic import BaseModel
from sqlmodel import SQLModel
from sqlmodel import SQLModel, create_engine
from sqlmodel.main import default_registry

top_level_path = Path(__file__).resolve().parent.parent
@@ -23,6 +23,13 @@ def clear_sqlmodel():
default_registry.dispose()


@pytest.fixture()
def in_memory_engine(clear_sqlmodel):
engine = create_engine("sqlite:///memory")
yield engine
SQLModel.metadata.drop_all(engine, checkfirst=True)


@pytest.fixture()
def cov_tmp_path(tmp_path: Path):
yield tmp_path
41 changes: 41 additions & 0 deletions tests/test_sqlalchemy_properties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Optional

from sqlalchemy import func
from sqlalchemy.ext.hybrid import hybrid_property
from sqlmodel import Field, Session, SQLModel, select


def test_hybrid_property(in_memory_engine):
class Interval(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
length: float

@hybrid_property
def radius(self) -> float:
return abs(self.length) / 2

@radius.expression
def radius(cls) -> float:
return func.abs(cls.length) / 2

class Config:
arbitrary_types_allowed = True

SQLModel.metadata.create_all(in_memory_engine)
session = Session(in_memory_engine)

interval = Interval(length=-2)
assert interval.radius == 1

session.add(interval)
session.commit()
interval_2 = session.exec(select(Interval)).all()[0]
assert interval_2.radius == 1

interval_3 = session.exec(select(Interval).where(Interval.radius == 1)).all()[0]
assert interval_3.radius == 1

intervals = session.exec(select(Interval).where(Interval.radius > 1)).all()
assert len(intervals) == 0

assert session.exec(select(Interval.radius + 1)).all()[0] == 2.0