Skip to content

Commit

Permalink
DRY models
Browse files Browse the repository at this point in the history
  • Loading branch information
yedpodtrzitko committed Sep 8, 2024
1 parent deeaef7 commit 262b532
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 51 deletions.
91 changes: 41 additions & 50 deletions tagstudio/src/core/library/alchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,67 @@
from typing import Union, Any, TYPE_CHECKING

from sqlalchemy import ForeignKey, ForeignKeyConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import Mapped, mapped_column, relationship, declared_attr

from .db import Base
from .enums import FieldTypeEnum

if TYPE_CHECKING:
from .models import Entry, Tag, LibraryField

# TODO - replace with field bound to BaseField
Field = Union["TextField", "TagBoxField", "DatetimeField"]


class BooleanField(Base):
__tablename__ = "boolean_fields"
class BaseField(Base):
__abstract__ = True

id: Mapped[int] = mapped_column(primary_key=True)
type_key: Mapped[str] = mapped_column(ForeignKey("library_fields.key"))
type: Mapped[LibraryField] = relationship(foreign_keys=[type_key], lazy=False)
@declared_attr
def id(cls) -> Mapped[int]:
return mapped_column(primary_key=True, autoincrement=True)

entry_id: Mapped[int] = mapped_column(ForeignKey("entries.id"))
entry: Mapped[Entry] = relationship()
@declared_attr
def type_key(cls) -> Mapped[str]:
return mapped_column(ForeignKey("library_fields.key"))

value: Mapped[bool]
position: Mapped[int]
@declared_attr
def type(cls) -> Mapped[LibraryField]:
return relationship(foreign_keys=[cls.type_key], lazy=False) # type: ignore

def __key(self):
return (self.type, self.value)
@declared_attr
def entry_id(cls) -> Mapped[int]:
return mapped_column(ForeignKey("entries.id"))

@declared_attr
def entry(cls) -> Mapped[Entry]:
return relationship(foreign_keys=[cls.entry_id]) # type: ignore

@declared_attr
def position(cls) -> Mapped[int]:
return mapped_column()

def __hash__(self):
return hash(self.__key())

def __key(self):
raise NotImplementedError


class BooleanField(BaseField):
__tablename__ = "boolean_fields"

value: Mapped[bool]

def __key(self):
return (self.type, self.value)

def __eq__(self, value) -> bool:
if isinstance(value, BooleanField):
return self.__key() == value.__key()
raise NotImplementedError


class TextField(Base):
class TextField(BaseField):
__tablename__ = "text_fields"
# constrain for combination of: entry_id, type_key and position
__table_args__ = (
Expand All @@ -51,21 +75,10 @@ class TextField(Base):
),
)

id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
type_key: Mapped[str] = mapped_column(ForeignKey("library_fields.key"))
type: Mapped[LibraryField] = relationship(foreign_keys=[type_key], lazy=False)

entry_id: Mapped[int] = mapped_column(ForeignKey("entries.id"))
entry: Mapped[Entry] = relationship(foreign_keys=[entry_id])

value: Mapped[str | None]
position: Mapped[int]

def __key(self):
return (self.type, self.value)

def __hash__(self):
return hash(self.__key())
def __key(self) -> tuple:
return self.type, self.value

def __eq__(self, value) -> bool:
if isinstance(value, TextField):
Expand All @@ -75,18 +88,10 @@ def __eq__(self, value) -> bool:
raise NotImplementedError


class TagBoxField(Base):
class TagBoxField(BaseField):
__tablename__ = "tag_box_fields"

id: Mapped[int] = mapped_column(primary_key=True)
type_key: Mapped[str] = mapped_column(ForeignKey("library_fields.key"))
type: Mapped[LibraryField] = relationship(foreign_keys=[type_key], lazy=False)

entry_id: Mapped[int] = mapped_column(ForeignKey("entries.id"))
entry: Mapped[Entry] = relationship(foreign_keys=[entry_id])

tags: Mapped[set[Tag]] = relationship(secondary="tag_fields")
position: Mapped[int]

def __key(self):
return (
Expand All @@ -99,34 +104,20 @@ def value(self) -> None:
"""For interface compatibility with other field types."""
return None

def __hash__(self):
return hash(self.__key())

def __eq__(self, value) -> bool:
if isinstance(value, TagBoxField):
return self.__key() == value.__key()
raise NotImplementedError


class DatetimeField(Base):
class DatetimeField(BaseField):
__tablename__ = "datetime_fields"

id: Mapped[int] = mapped_column(primary_key=True)
type_key: Mapped[str] = mapped_column(ForeignKey("library_fields.key"))
type: Mapped[LibraryField] = relationship(foreign_keys=[type_key], lazy=False)

entry_id: Mapped[int] = mapped_column(ForeignKey("entries.id"))
entry: Mapped[Entry] = relationship(foreign_keys=[entry_id])

value: Mapped[str | None]
position: Mapped[int]

def __key(self):
return (self.type, self.value)

def __hash__(self):
return hash(self.__key())

def __eq__(self, value) -> bool:
if isinstance(value, DatetimeField):
return self.__key() == value.__key()
Expand Down
2 changes: 1 addition & 1 deletion tagstudio/src/core/library/alchemy/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def update_field_position(

# Reassign `order` starting from 0
for index, row in enumerate(rows):
row.position = index # type: ignore
row.position = index
session.add(row)
session.flush()
if rows:
Expand Down

0 comments on commit 262b532

Please sign in to comment.