Skip to content

Commit

Permalink
allow for functions on models beside properties. #68
Browse files Browse the repository at this point in the history
  • Loading branch information
tfranzel committed May 26, 2020
1 parent 1aae554 commit ddeb5bb
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 12 deletions.
17 changes: 11 additions & 6 deletions drf_spectacular/plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,16 +256,21 @@ def _follow_field_source(model, path):
# end of traversal
if isinstance(field_or_property, property):
return field_or_property.fget
elif callable(field_or_property):
return field_or_property
else:
return get_field_from_model(model, field_or_property)
else:
if isinstance(field_or_property, property):
target_model = field_or_property.fget.__annotations__.get('return')
if isinstance(field_or_property, property) or callable(field_or_property):
if isinstance(field_or_property, property):
target_model = field_or_property.fget.__annotations__.get('return')
else:
target_model = field_or_property.__annotations__.get('return')
if not target_model:
raise UnableToProceedError(
f'could not follow field source through intermediate property "{path[0]}" '
f'on model {model}. please add a type hint on the model\'s property to '
f'enable traversal of the source path "{".".join(path)}".'
f'on model {model}. please add a type hint on the model\'s property/function'
f'to enable traversal of the source path "{".".join(path)}".'
)
return _follow_field_source(target_model, path[1:])
else:
Expand All @@ -284,11 +289,11 @@ def follow_field_source(model, path):
return _follow_field_source(model, path)
except UnableToProceedError as e:
warn(e)
except: # noqa: E722
except Exception as exc:
warn(
f'could not resolve field on model {model} with path "{".".join(path)}". '
f'this is likely a custom field that does some unknown magic. maybe '
f'consider annotating the field/property? defaulting to "string".'
f'consider annotating the field/property? defaulting to "string". (Exception: {exc})'
)

def dummy_property(obj) -> str:
Expand Down
9 changes: 9 additions & 0 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ def field_model_property_float(self) -> float:
def field_list(self):
return [1.1, 2.2, 3.3]

def model_function_basic(self) -> bool:
return True

def model_function_model(self) -> Aux:
return self.field_foreign


class AllFieldsSerializer(serializers.ModelSerializer):
field_decimal_uncoerced = serializers.DecimalField(
Expand Down Expand Up @@ -108,6 +114,9 @@ def get_field_method_object(self, obj) -> dict:
source='field_foreign.field_foreign.field_foreign.id',
allow_null=True, # force field output even if traversal fails
)
field_read_only_model_function_basic = serializers.ReadOnlyField(source='model_function_basic')
field_read_only_model_function_model = serializers.ReadOnlyField(source='model_function_model.id')

# override default writable bool field with readonly
field_bool_override = serializers.ReadOnlyField()

Expand Down
9 changes: 9 additions & 0 deletions tests/test_fields.yml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ components:
format: uuid
readOnly: true
nullable: true
field_read_only_model_function_basic:
type: boolean
readOnly: true
field_read_only_model_function_model:
type: string
format: uuid
readOnly: true
field_bool_override:
type: boolean
readOnly: true
Expand Down Expand Up @@ -259,6 +266,8 @@ components:
- field_o2o
- field_posint
- field_possmallint
- field_read_only_model_function_basic
- field_read_only_model_function_model
- field_read_only_nav_uuid
- field_read_only_nav_uuid_3steps
- field_regex
Expand Down
19 changes: 13 additions & 6 deletions tests/test_regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,17 +458,21 @@ def get(self, request):
assert component['properties']['custom_int_field']['type'] == 'integer'


def test_follow_field_source_through_intermediate_property(no_warnings):
def test_follow_field_source_through_intermediate_property_or_function(no_warnings):
class FieldSourceTraversalModel2(models.Model):
y = models.IntegerField(choices=[(1, '1'), (2, '2'), (3, '3')])

class FieldSourceTraversalModel1(models.Model):
@property
def x(self) -> FieldSourceTraversalModel2: # property is required for traversal
def prop(self) -> FieldSourceTraversalModel2: # property is required for traversal
return # pragma: no cover

def func(self) -> FieldSourceTraversalModel2: # property is required for traversal
return # pragma: no cover

class XSerializer(serializers.ModelSerializer):
x = serializers.ReadOnlyField(source='x.y')
prop = serializers.ReadOnlyField(source='prop.y')
func = serializers.ReadOnlyField(source='func.y')

class Meta:
model = FieldSourceTraversalModel1
Expand All @@ -482,9 +486,12 @@ def get(self, request):
# this checks if field type is correctly estimated AND field was initialized
# with the model parameters (choices)
schema = generate_schema('x', view=XAPIView)
assert schema['components']['schemas']['X']['properties']['x']['readOnly'] is True
assert 'enum' in schema['components']['schemas']['XEnum']
assert schema['components']['schemas']['XEnum']['type'] == 'integer'
assert schema['components']['schemas']['X']['properties']['func']['readOnly'] is True
assert schema['components']['schemas']['X']['properties']['prop']['readOnly'] is True
assert 'enum' in schema['components']['schemas']['PropEnum']
assert 'enum' in schema['components']['schemas']['FuncEnum']
assert schema['components']['schemas']['PropEnum']['type'] == 'integer'
assert schema['components']['schemas']['FuncEnum']['type'] == 'integer'


def test_viewset_list_with_envelope(no_warnings):
Expand Down

0 comments on commit ddeb5bb

Please sign in to comment.