Skip to content

Commit

Permalink
feat: validate int value range for IntField
Browse files Browse the repository at this point in the history
  • Loading branch information
waketzheng committed Jan 23, 2025
1 parent 948ccdb commit af59d3c
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 13 deletions.
43 changes: 31 additions & 12 deletions tests/fields/test_int.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,37 @@
from decimal import Decimal
from typing import ClassVar, Type

from tests import testmodels
from tortoise import Model
from tortoise.contrib import test
from tortoise.exceptions import IntegrityError
from tortoise.exceptions import ValidationError
from tortoise.expressions import F


class TestIntFields(test.TestCase):
class TestIntNum(test.TestCase):
model: ClassVar[Type[Model]] = testmodels.IntFields

async def test_empty(self):
with self.assertRaises(IntegrityError):
await testmodels.IntFields.create()
with self.assertRaises(ValidationError):
await self.model.create()

async def test_value_range(self):
try:
field = self.model._meta.fields_map["intnum"]
except KeyError:
field = self.model._meta.fields_map["smallintnum"]
min_, max_ = field.constraints["ge"], field.constraints["le"]
with self.assertRaises(ValidationError):
await self.model.create(intnum=min_ - 1)
with self.assertRaises(ValidationError):
await self.model.create(intnum=max_ + 1)
with self.assertRaises(ValidationError):
await self.model.create(intnum=max_ + 1.1)
with self.assertRaises(ValidationError):
await self.model.create(intnum=Decimal(max_ + 1.1))


class TestIntFields(test.TestCase):
async def test_create(self):
obj0 = await testmodels.IntFields.create(intnum=2147483647)
obj = await testmodels.IntFields.get(id=obj0.id)
Expand Down Expand Up @@ -60,10 +83,8 @@ async def test_f_expression(self):
self.assertEqual(obj1.intnum, 2)


class TestSmallIntFields(test.TestCase):
async def test_empty(self):
with self.assertRaises(IntegrityError):
await testmodels.SmallIntFields.create()
class TestSmallIntFields(TestIntNum):
model = testmodels.SmallIntFields

async def test_create(self):
obj0 = await testmodels.SmallIntFields.create(smallintnum=32767)
Expand Down Expand Up @@ -102,10 +123,8 @@ async def test_f_expression(self):
self.assertEqual(obj1.smallintnum, 2)


class TestBigIntFields(test.TestCase):
async def test_empty(self):
with self.assertRaises(IntegrityError):
await testmodels.BigIntFields.create()
class TestBigIntFields(TestIntNum):
model = testmodels.BigIntFields

async def test_create(self):
obj0 = await testmodels.BigIntFields.create(intnum=9223372036854775807)
Expand Down
4 changes: 3 additions & 1 deletion tortoise/fields/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tortoise.exceptions import ConfigurationError, FieldError
from tortoise.fields.base import Field
from tortoise.timezone import get_default_timezone, get_timezone, get_use_tz, localtime
from tortoise.validators import MaxLengthValidator
from tortoise.validators import MaxLengthValidator, ValueRangeValidator

try:
from ciso8601 import parse_datetime
Expand Down Expand Up @@ -80,6 +80,8 @@ def __init__(self, primary_key: Optional[bool] = None, **kwargs: Any) -> None:
if primary_key or kwargs.get("pk"):
kwargs["generated"] = bool(kwargs.get("generated", True))
super().__init__(primary_key=primary_key, **kwargs)
min_value, max_value = self.constraints["ge"], self.constraints["le"]
self.validators.append(ValueRangeValidator(min_value, max_value))

@property
def constraints(self) -> dict:
Expand Down
19 changes: 19 additions & 0 deletions tortoise/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,25 @@ def __call__(self, value: int | float | Decimal) -> None:
raise ValidationError(f"Value should be less or equal to {self.max_value}")


class ValueRangeValidator(MinValueValidator):
"""
Value range validator for IntField, SmallIntField, BigIntField
"""

def __init__(self, min_value: int | float | Decimal, max_value: int | float | Decimal) -> None:
super().__init__(min_value)
self._validate_type(max_value)
self.max_value = max_value

def __call__(self, value: int | float | Decimal) -> None:
self._validate_type(value)
if not self.min_value <= value <= self.max_value:
raise ValidationError(
f"Value should be greater or equal to {self.min_value},"
f" and less or equal to {self.max_value}"
)


class CommaSeparatedIntegerListValidator(Validator):
"""
A validator to validate whether the given value is valid comma separated integer list or not.
Expand Down

0 comments on commit af59d3c

Please sign in to comment.