Skip to content

Commit

Permalink
Add type validation for foreign key and one to one model consistency (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Abdeldjalil-H authored Dec 20, 2024
1 parent a5bb80f commit 3829267
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 2 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Changelog
Added
^^^^^
- Implement savepoints for transactions (#1816)
- Added type validation for foreign key fields to ensure type safety. Now raises `ValidationError` when assigning foreign key values with incorrect model types (#1792)

Fixed
^^^^^
Expand Down Expand Up @@ -1498,4 +1499,4 @@ Docs/examples:
await Tournament.filter(
events__name__in=['1', '3']
).order_by('-events__participants__name').distinct()
).order_by('-events__participants__name').distinct()
71 changes: 70 additions & 1 deletion tests/fields/test_fk.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from tests import testmodels
from tortoise.contrib import test
from tortoise.exceptions import IntegrityError, NoValuesFetched, OperationalError
from tortoise.exceptions import (
IntegrityError,
NoValuesFetched,
OperationalError,
ValidationError,
)
from tortoise.queryset import QuerySet


class TestForeignKeyField(test.TestCase):
def assertRaisesWrongTypeException(self, relation_name: str):
return self.assertRaisesRegex(
ValidationError, f"Invalid type for relationship field '{relation_name}'"
)

async def test_empty(self):
with self.assertRaises(IntegrityError):
await testmodels.MinRelation.create()
Expand Down Expand Up @@ -151,6 +161,11 @@ async def test_minimal__instantiated_create(self):
tour = await testmodels.Tournament.create(name="Team1")
await testmodels.MinRelation.create(tournament=tour)

async def test_minimal__instantiated_create_wrong_type(self):
author = await testmodels.Author.create(name="Author1")
with self.assertRaisesWrongTypeException("tournament"):
await testmodels.MinRelation.create(tournament=author)

async def test_minimal__instantiated_iterate(self):
tour = await testmodels.Tournament.create(name="Team1")
async for _ in tour.minrelations:
Expand Down Expand Up @@ -229,3 +244,57 @@ async def test_event__offset(self):
event2 = await testmodels.Event.create(name="Event2", tournament=tour)
event3 = await testmodels.Event.create(name="Event3", tournament=tour)
self.assertEqual(await tour.events.offset(1).order_by("name"), [event2, event3])

async def test_fk_correct_type_assignment(self):
tour1 = await testmodels.Tournament.create(name="Team1")
tour2 = await testmodels.Tournament.create(name="Team2")
event = await testmodels.Event(name="Event1", tournament=tour1)

event.tournament = tour2
await event.save()
self.assertEqual(event.tournament_id, tour2.id)

async def test_fk_wrong_type_assignment(self):
tour = await testmodels.Tournament.create(name="Team1")
author = await testmodels.Author.create(name="Author")
rel = await testmodels.MinRelation.create(tournament=tour)

with self.assertRaisesWrongTypeException("tournament"):
rel.tournament = author

async def test_fk_none_assignment(self):
manager = await testmodels.Employee.create(name="Manager")
employee = await testmodels.Employee.create(name="Employee", manager=manager)

employee.manager = None
await employee.save()
self.assertIsNone(employee.manager)

async def test_fk_update_wrong_type(self):
tour = await testmodels.Tournament.create(name="Team1")
rel = await testmodels.MinRelation.create(tournament=tour)
author = await testmodels.Author.create(name="Author1")

with self.assertRaisesWrongTypeException("tournament"):
await testmodels.MinRelation.filter(id=rel.id).update(tournament=author)

async def test_fk_bulk_create_wrong_type(self):
author = await testmodels.Author.create(name="Author")
with self.assertRaisesWrongTypeException("tournament"):
await testmodels.MinRelation.bulk_create(
[testmodels.MinRelation(tournament=author) for _ in range(10)]
)

async def test_fk_bulk_update_wrong_type(self):
tour = await testmodels.Tournament.create(name="Team1")
await testmodels.MinRelation.bulk_create(
[testmodels.MinRelation(tournament=tour) for _ in range(1, 10)]
)
author = await testmodels.Author.create(name="Author")

with self.assertRaisesWrongTypeException("tournament"):
relations = await testmodels.MinRelation.all()
await testmodels.MinRelation.bulk_update(
[testmodels.MinRelation(id=rel.id, tournament=author) for rel in relations],
fields=["tournament"],
)
25 changes: 25 additions & 0 deletions tortoise/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@
from tortoise.exceptions import (
ConfigurationError,
DoesNotExist,
FieldError,
IncompleteInstanceError,
IntegrityError,
ObjectDoesNotExistError,
OperationalError,
ParamsError,
ValidationError,
)
from tortoise.expressions import Expression
from tortoise.fields.base import Field
Expand Down Expand Up @@ -685,6 +687,8 @@ def __setattr__(self, key, value) -> None:
# set field value override async default function
if hasattr(self, "_await_when_save"):
self._await_when_save.pop(key, None)
if key in self._meta.fk_fields or key in self._meta.o2o_fields:
self._validate_relation_type(key, value)
super().__setattr__(key, value)

def _set_kwargs(self, kwargs: dict) -> Set[str]:
Expand Down Expand Up @@ -806,6 +810,27 @@ def _set_pk_val(self, value: Any) -> None:
Can be used as a field name when doing filtering e.g. ``.filter(pk=...)`` etc...
"""

@classmethod
def _validate_relation_type(cls, field_key: str, value: Optional["Model"]) -> None:
if value is None:
return

field = cls._meta.fields_map[field_key]
if not isinstance(field, (OneToOneFieldInstance, ForeignKeyFieldInstance)):
raise FieldError(
f"Field '{field_key}' must be a OneToOne or ForeignKey relation, "
f"got {type(field).__name__}"
)

expected_model = field.related_model
received_model = type(value)
if received_model is not expected_model:
raise ValidationError(
f"Invalid type for relationship field '{field_key}'. "
f"Expected model type '{expected_model.__name__}', but got '{received_model.__name__}'. "
"Make sure you're using the correct model class for this relationship."
)

@classmethod
async def _getbypk(cls: Type[MODEL], key: Any) -> MODEL:
try:
Expand Down
1 change: 1 addition & 0 deletions tortoise/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,6 +1204,7 @@ def _make_query(self) -> None:
if field_object.pk:
raise IntegrityError(f"Field {key} is PK and can not be updated")
if isinstance(field_object, (ForeignKeyFieldInstance, OneToOneFieldInstance)):
self.model._validate_relation_type(key, value)
fk_field: str = field_object.source_field # type: ignore
db_field = self.model._meta.fields_map[fk_field].source_field
value = executor.column_map[fk_field](
Expand Down

0 comments on commit 3829267

Please sign in to comment.