Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement tolerant deserialization to allow unknown fields #595

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ class Nested(Field):
value will be returned as output instead of a dictionary.
This parameter takes precedence over ``exclude``.
:param bool many: Whether the field is a collection of objects.
:param bool tolerant: Whether to include unknown fields in the result.
:param kwargs: The same keyword arguments that :class:`Field` receives.
"""

Expand All @@ -384,6 +385,7 @@ def __init__(self, nested, default=missing_, exclude=tuple(), only=None, **kwarg
self.only = only
self.exclude = exclude
self.many = kwargs.get('many', False)
self.tolerant = kwargs.get('tolerant', False)
self.__schema = None # Cached Schema instance
self.__updated_fields = False
super(Nested, self).__init__(default=default, **kwargs)
Expand Down Expand Up @@ -467,7 +469,7 @@ def _deserialize(self, value, attr, data):
value = [{self.only: v} for v in value]
else:
value = {self.only: value}
data, errors = self.schema.load(value)
data, errors = self.schema.load(value, tolerant=self.tolerant)
if errors:
raise ValidationError(errors, data=data)
return data
Expand Down
13 changes: 10 additions & 3 deletions marshmallow/marshalling.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def run_validator(self, validator_func, output,
errors.setdefault(field_name, []).append(text_type(err))

def deserialize(self, data, fields_dict, many=False, partial=False,
dict_class=dict, index_errors=True, index=None):
tolerant=False, dict_class=dict, index_errors=True, index=None):
"""Deserialize ``data`` based on the schema defined by ``fields_dict``.

:param dict data: The data to deserialize.
Expand All @@ -221,6 +221,7 @@ def deserialize(self, data, fields_dict, many=False, partial=False,
:param bool|tuple partial: Whether to ignore missing fields. If its
value is an iterable, only missing fields listed in that iterable
will be ignored.
:param bool tolerant: Whether to include unknown fields in the result.
:param type dict_class: Dictionary class used to construct the output.
:param bool index_errors: Whether to store the index of invalid items in
``self.errors`` when ``many=True``.
Expand All @@ -234,8 +235,9 @@ def deserialize(self, data, fields_dict, many=False, partial=False,
if many and data is not None:
self._pending = True
ret = [self.deserialize(d, fields_dict, many=False,
partial=partial, dict_class=dict_class,
index=idx, index_errors=index_errors)
partial=partial, tolerant=tolerant,
dict_class=dict_class, index=idx,
index_errors=index_errors)
for idx, d in enumerate(data)]

self._pending = False
Expand Down Expand Up @@ -297,6 +299,11 @@ def deserialize(self, data, fields_dict, many=False, partial=False,
if value is not missing:
key = fields_dict[attr_name].attribute or attr_name
items.append((key, value))

if tolerant:
for field in set(data) - set(fields_dict):
items.append((field, data[field]))

ret = dict_class(items)
else:
ret = None
Expand Down
23 changes: 18 additions & 5 deletions marshmallow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ class Meta:
:param bool|tuple partial: Whether to ignore missing fields. If its value
is an iterable, only missing fields listed in that iterable will be
ignored.
:param bool tolerant: Whether to include unknown fields in the loaded data.

.. versionchanged:: 2.0.0
`__validators__`, `__preprocessors__`, and `__data_handlers__` are removed in favor of
Expand Down Expand Up @@ -319,7 +320,7 @@ class Meta:

def __init__(self, only=(), exclude=(), prefix='', strict=None,
many=False, context=None, load_only=(), dump_only=(),
partial=False):
partial=False, tolerant=False):
# copy declared fields from metaclass
self.declared_fields = copy.deepcopy(self._declared_fields)
self.many = many
Expand All @@ -331,6 +332,7 @@ def __init__(self, only=(), exclude=(), prefix='', strict=None,
self.load_only = set(load_only) or set(self.opts.load_only)
self.dump_only = set(dump_only) or set(self.opts.dump_only)
self.partial = partial
self.tolerant = tolerant
#: Dictionary mapping field_names -> :class:`Field` objects
self.fields = self.dict_class()
#: Callable marshalling object
Expand Down Expand Up @@ -485,7 +487,7 @@ def dumps(self, obj, many=None, update_fields=True, *args, **kwargs):
ret = self.opts.json_module.dumps(deserialized, *args, **kwargs)
return MarshalResult(ret, errors)

def load(self, data, many=None, partial=None):
def load(self, data, many=None, partial=None, tolerant=None):
"""Deserialize a data structure to an object defined by this Schema's
fields and :meth:`make_object`.

Expand All @@ -495,12 +497,15 @@ def load(self, data, many=None, partial=None):
:param bool|tuple partial: Whether to ignore missing fields. If `None`,
the value for `self.partial` is used. If its value is an iterable,
only missing fields listed in that iterable will be ignored.
:param bool tolerant: Whether to include unknown fields in the result.
If `None`, the value for `self.tolerant` is used.
:return: A tuple of the form (``data``, ``errors``)
:rtype: `UnmarshalResult`, a `collections.namedtuple`

.. versionadded:: 1.0.0
"""
result, errors = self._do_load(data, many, partial=partial, postprocess=True)
result, errors = self._do_load(data, many, partial=partial,
tolerant=tolerant, postprocess=True)
return UnmarshalResult(data=result, errors=errors)

def loads(self, json_data, many=None, *args, **kwargs):
Expand All @@ -512,6 +517,8 @@ def loads(self, json_data, many=None, *args, **kwargs):
:param bool|tuple partial: Whether to ignore missing fields. If `None`,
the value for `self.partial` is used. If its value is an iterable,
only missing fields listed in that iterable will be ignored.
:param bool tolerant: Whether to include unknown fields in the result.
If `None`, the value for `self.tolerant` is used.
:return: A tuple of the form (``data``, ``errors``)
:rtype: `UnmarshalResult`, a `collections.namedtuple`

Expand All @@ -521,9 +528,10 @@ def loads(self, json_data, many=None, *args, **kwargs):
# passing in positional args after `many` for use by `json.loads`, but
# ideally we shouldn't have to do this.
partial = kwargs.pop('partial', None)
tolerant = kwargs.pop('tolerant', None)

data = self.opts.json_module.loads(json_data, *args, **kwargs)
return self.load(data, many=many, partial=partial)
return self.load(data, many=many, partial=partial, tolerant=tolerant)

def validate(self, data, many=None, partial=None):
"""Validate `data` against the schema, returning a dictionary of
Expand All @@ -545,7 +553,8 @@ def validate(self, data, many=None, partial=None):

##### Private Helpers #####

def _do_load(self, data, many=None, partial=None, postprocess=True):
def _do_load(self, data, many=None, partial=None, tolerant=None,
postprocess=True):
"""Deserialize `data`, returning the deserialized result and a dictonary of
validation errors.

Expand All @@ -556,11 +565,14 @@ def _do_load(self, data, many=None, partial=None, postprocess=True):
only fields listed in that iterable will be ignored will be allowed missing.
If `True`, all fields will be allowed missing.
If `None`, the value for `self.partial` is used.
:param bool tolerant: Whether to include unknown fields in the result.
If `None`, the value for `self.tolerant` is used.
:param bool postprocess: Whether to run post_load methods..
:return: A tuple of the form (`data`, `errors`)
"""
errors = {}
many = self.many if many is None else bool(many)
tolerant = self.tolerant if tolerant is None else bool(tolerant)
if partial is None:
partial = self.partial
try:
Expand All @@ -579,6 +591,7 @@ def _do_load(self, data, many=None, partial=None, postprocess=True):
self.fields,
many=many,
partial=partial,
tolerant=tolerant,
dict_class=self.dict_class,
index_errors=self.opts.index_errors,
)
Expand Down
34 changes: 34 additions & 0 deletions tests/test_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,6 +1171,40 @@ class MySchema(Schema):
errors = MySchema(partial=True).validate({'foo': 3}, partial=('bar', 'baz'))
assert not errors

def test_tolerant_fields_deserialization(self):
class MySchema(Schema):
foo = fields.Integer()

data, errors = MySchema().load({'foo': 3, 'bar': 5})
assert data['foo'] == 3
assert 'bar' not in data
assert not errors

data, errors = MySchema(tolerant=True).load({'foo': 3, 'bar': 5}, tolerant=False)
assert data['foo'] == 3
assert 'bar' not in data
assert not errors

data, errors = MySchema().load({'foo': 3, 'bar': 5}, tolerant=True)
assert data['foo'] == 3
assert data['bar']
assert not errors

data, errors = MySchema(tolerant=True).load({'foo': 3, 'bar': 5})
assert data['foo'] == 3
assert data['bar']
assert not errors

data, errors = MySchema(tolerant=True).load({'foo': "asd", 'bar': 5})
assert 'foo' in errors
assert data['bar']

schema = MySchema(tolerant=True, many=True)
data, errors = schema.load([{'foo': 3, 'bar': 5}])
assert 'foo' in data[0]
assert 'bar' in data[0]
assert not errors

validators_gen = (func for func in [lambda x: x <= 24, lambda x: 18 <= x])

validators_gen_float = (func for func in
Expand Down