Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve introspection of validators #533

Merged
merged 14 commits into from
Sep 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 79 additions & 44 deletions drf_spectacular/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ def _map_serializer_field(self, field, direction, bypass_extensions=False):
return append_meta(build_array_type(component.ref), meta) if component else None
else:
schema = self._map_serializer_field(field.child, direction)
self._insert_field_validators(field.child, schema)
# remove automatically attached but redundant title
if is_trivial_string_variation(field.field_name, schema.get('title')):
schema.pop('title', None)
Expand Down Expand Up @@ -636,27 +637,32 @@ def _map_serializer_field(self, field, direction, bypass_extensions=False):
content = {**build_basic_type(OpenApiTypes.STR), 'format': 'decimal'}
if field.max_whole_digits:
content['pattern'] = (
f'^\\d{{0,{field.max_whole_digits}}}'
f'(\\.\\d{{0,{field.decimal_places}}})?$'
fr'^\d{{0,{field.max_whole_digits}}}'
ngnpope marked this conversation as resolved.
Show resolved Hide resolved
fr'(?:\.\d{{0,{field.decimal_places}}})?$'
)
else:
content = build_basic_type(OpenApiTypes.DECIMAL)
if field.max_whole_digits:
content['maximum'] = int(field.max_whole_digits * '9') + 1
content['minimum'] = -content['maximum']
self._map_min_max(field, content)
value = 10 ** field.max_whole_digits
content.update({
'maximum': value,
'minimum': -value,
'exclusiveMaximum': True,
'exclusiveMinimum': True,
})
self._insert_min_max(field, content)
return append_meta(content, meta)

if isinstance(field, serializers.FloatField):
content = build_basic_type(OpenApiTypes.FLOAT)
self._map_min_max(field, content)
self._insert_min_max(field, content)
return append_meta(content, meta)

if isinstance(field, serializers.IntegerField):
content = build_basic_type(OpenApiTypes.INT)
self._map_min_max(field, content)
# 2147483647 is max for int32_size, so we use int64 for format
if int(content.get('maximum', 0)) > 2147483647 or int(content.get('minimum', 0)) > 2147483647:
self._insert_min_max(field, content)
# Use int64 for format if value outside the 32-bit signed integer range [-2,147,483,648 to 2,147,483,647].
if not all(-2147483648 <= int(content.get(key, 0)) <= 2147483647 for key in ('maximum', 'minimum')):
content['format'] = 'int64'
return append_meta(content, meta)

Expand All @@ -682,6 +688,7 @@ def _map_serializer_field(self, field, direction, bypass_extensions=False):
content = build_basic_type(OpenApiTypes.OBJECT)
if not isinstance(field.child, _UnvalidatedField):
content['additionalProperties'] = self._map_serializer_field(field.child, direction)
self._insert_field_validators(field.child, content['additionalProperties'])
return append_meta(content, meta)

if isinstance(field, serializers.CharField):
Expand Down Expand Up @@ -722,11 +729,15 @@ def _map_serializer_field(self, field, direction, bypass_extensions=False):
warn(f'could not resolve serializer field "{field}". Defaulting to "string"')
return append_meta(build_basic_type(OpenApiTypes.STR), meta)

def _map_min_max(self, field, content):
if field.max_value:
def _insert_min_max(self, field, content):
if field.max_value is not None:
ngnpope marked this conversation as resolved.
Show resolved Hide resolved
content['maximum'] = field.max_value
if field.min_value:
if 'exclusiveMaximum' in content:
del content['exclusiveMaximum']
if field.min_value is not None:
content['minimum'] = field.min_value
if 'exclusiveMinimum' in content:
del content['exclusiveMinimum']

def _map_serializer(self, serializer, direction, bypass_extensions=False):
serializer = force_instance(serializer)
Expand Down Expand Up @@ -817,7 +828,7 @@ def _map_basic_serializer(self, serializer, direction):
if add_to_required:
required.add(field.field_name)

self._map_field_validators(field, schema)
self._insert_field_validators(field, schema)

if field.field_name in get_override(serializer, 'deprecate_fields', []):
schema['deprecated'] = True
Expand All @@ -833,38 +844,62 @@ def _map_basic_serializer(self, serializer, direction):
description=get_doc(serializer.__class__),
)

def _map_field_validators(self, field, schema):
def _insert_field_validators(self, field, schema):
schema_type = schema.get('type')

def update_constraint(schema, key, function, value, *, exclusive=False):
if callable(value):
value = value()
current_value = schema.get(key)
if current_value is not None:
new_value = function(current_value, value)
else:
new_value = value
schema[key] = new_value
if key in ('maximum', 'minimum'):
exclusive_key = f'exclusive{key.title()}'
if exclusive:
if new_value != current_value:
schema[exclusive_key] = True
elif exclusive_key in schema:
del schema[exclusive_key]

for v in field.validators:
if isinstance(v, validators.EmailValidator):
schema['format'] = 'email'
elif isinstance(v, validators.URLValidator):
schema['format'] = 'uri'
elif isinstance(v, validators.RegexValidator):
pattern = v.regex.pattern.encode('ascii', 'backslashreplace').decode()
pattern = pattern.replace(r'\x', r'\u00') # unify escaping
pattern = pattern.replace(r'\Z', '$').replace(r'\A', '^') # ECMA anchors
schema['pattern'] = pattern
elif isinstance(v, validators.MaxLengthValidator):
attr_name = 'maxLength'
if isinstance(field, serializers.ListField):
attr_name = 'maxItems'
schema[attr_name] = v.limit_value() if callable(v.limit_value) else v.limit_value
elif isinstance(v, validators.MinLengthValidator):
attr_name = 'minLength'
if isinstance(field, serializers.ListField):
attr_name = 'minItems'
schema[attr_name] = v.limit_value() if callable(v.limit_value) else v.limit_value
elif isinstance(v, validators.MaxValueValidator):
schema['maximum'] = v.limit_value() if callable(v.limit_value) else v.limit_value
elif isinstance(v, validators.MinValueValidator):
schema['minimum'] = v.limit_value() if callable(v.limit_value) else v.limit_value
elif isinstance(v, validators.DecimalValidator):
if v.max_digits:
digits = v.max_digits
if v.decimal_places is not None and v.decimal_places > 0:
digits -= v.decimal_places
schema['maximum'] = int(digits * '9') + 1
schema['minimum'] = -schema['maximum']
if schema_type == 'string':
if isinstance(v, validators.EmailValidator):
schema['format'] = 'email'
elif isinstance(v, validators.URLValidator):
schema['format'] = 'uri'
elif isinstance(v, validators.RegexValidator):
pattern = v.regex.pattern.encode('ascii', 'backslashreplace').decode()
pattern = pattern.replace(r'\x', r'\u00') # unify escaping
pattern = pattern.replace(r'\Z', '$').replace(r'\A', '^') # ECMA anchors
schema['pattern'] = pattern
elif isinstance(v, validators.MaxLengthValidator):
update_constraint(schema, 'maxLength', min, v.limit_value)
elif isinstance(v, validators.MinLengthValidator):
update_constraint(schema, 'minLength', max, v.limit_value)
elif isinstance(v, validators.FileExtensionValidator) and v.allowed_extensions:
schema['pattern'] = '(?:%s)$' % '|'.join([re.escape(extn) for extn in v.allowed_extensions])
elif schema_type in ('integer', 'number'):
if isinstance(v, validators.MaxValueValidator):
update_constraint(schema, 'maximum', min, v.limit_value)
elif isinstance(v, validators.MinValueValidator):
update_constraint(schema, 'minimum', max, v.limit_value)
elif isinstance(v, validators.DecimalValidator) and v.max_digits:
value = 10 ** (v.max_digits - (v.decimal_places or 0))
update_constraint(schema, 'maximum', min, value, exclusive=True)
update_constraint(schema, 'minimum', max, -value, exclusive=True)
elif schema_type == 'array':
if isinstance(v, validators.MaxLengthValidator):
update_constraint(schema, 'maxItems', min, v.limit_value)
elif isinstance(v, validators.MinLengthValidator):
update_constraint(schema, 'minItems', max, v.limit_value)
elif schema_type == 'object':
if isinstance(v, validators.MaxLengthValidator):
update_constraint(schema, 'maxProperties', min, v.limit_value)
elif isinstance(v, validators.MinLengthValidator):
update_constraint(schema, 'minProperties', max, v.limit_value)

def _map_response_type_hint(self, method):
hint = get_override(method, 'field') or get_type_hints(method).get('return')
Expand Down
4 changes: 3 additions & 1 deletion tests/test_fields.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ components:
format: double
maximum: 1000
minimum: -1000
exclusiveMaximum: true
exclusiveMinimum: true
field_method_float:
type: number
format: float
Expand Down Expand Up @@ -242,7 +244,7 @@ components:
field_decimal:
type: string
format: decimal
pattern: ^\d{0,3}(\.\d{0,3})?$
pattern: ^\d{0,3}(?:\.\d{0,3})?$
ngnpope marked this conversation as resolved.
Show resolved Hide resolved
field_file:
type: string
format: uri
Expand Down
62 changes: 59 additions & 3 deletions tests/test_regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,7 +1347,7 @@ class XViewset(viewsets.ReadOnlyModelViewSet):
def test_manual_decimal_validator():
# manually test this validator as it is not part of the default workflow
class XSerializer(serializers.Serializer):
field = serializers.CharField(
field = serializers.FloatField(
validators=[validators.DecimalValidator(max_digits=4, decimal_places=2)]
)

Expand Down Expand Up @@ -1385,8 +1385,17 @@ def view_func(request):
pass # pragma: no cover

schema = generate_schema('/x/', view_function=view_func)
field = schema['components']['schemas']['X']['properties']['field']
assert field['minimum'] and field['maximum']
assert schema['components']['schemas']['X']['properties']['field'] == {
'type': 'number',
'format': 'double',
'maximum': Decimal('100.00'),
'minimum': Decimal('1'),
}
assert schema['components']['schemas']['X']['properties']['field_coerced'] == {
'type': 'string',
'format': 'decimal',
'pattern': r'^\d{0,3}(?:\.\d{0,2})?$',
}

schema_yml = OpenApiYamlRenderer().render(schema, renderer_context={})
assert b'maximum: 100.00\n' in schema_yml
Expand Down Expand Up @@ -2424,3 +2433,50 @@ class HexConverter(StringConverter):
assert schema['paths']['/c/{var}/']['get']['parameters'][0]['schema'] == {
'type': 'string', 'pattern': '[a-f0-9]+'
}


@pytest.mark.parametrize('kwargs,expected', [
(
{'max_value': -2147483648},
{'type': 'integer', 'maximum': -2147483648},
),
(
{'max_value': -2147483649},
{'type': 'integer', 'maximum': -2147483649, 'format': 'int64'},
),
(
{'max_value': 2147483647},
{'type': 'integer', 'maximum': 2147483647},
),
(
{'max_value': 2147483648},
{'type': 'integer', 'maximum': 2147483648, 'format': 'int64'},
),
(
{'min_value': -2147483648},
{'type': 'integer', 'minimum': -2147483648},
),
(
{'min_value': -2147483649},
{'type': 'integer', 'minimum': -2147483649, 'format': 'int64'},
),
(
{'min_value': 2147483647},
{'type': 'integer', 'minimum': 2147483647},
),
(
{'min_value': 2147483648},
{'type': 'integer', 'minimum': 2147483648, 'format': 'int64'},
),
])
def test_int64_detection(kwargs, expected, no_warnings):
class XSerializer(serializers.Serializer):
field = serializers.IntegerField(**kwargs)

@extend_schema(request=XSerializer, responses=XSerializer)
@api_view(['GET'])
def view_func(request, format=None):
pass # pragma: no cover

schema = generate_schema('x', view_function=view_func)
assert schema['components']['schemas']['X']['properties']['field'] == expected
Loading