Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to define required fields in schema constructor and meta #1670

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ def schema(self):
context=context,
load_only=self._nested_normalized_option("load_only"),
dump_only=self._nested_normalized_option("dump_only"),
required=self._nested_normalized_option("required"),
)
return self._schema

Expand Down
9 changes: 8 additions & 1 deletion src/marshmallow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def __init__(self, meta, ordered: bool = False):
self.include = getattr(meta, "include", {})
self.load_only = getattr(meta, "load_only", ())
self.dump_only = getattr(meta, "dump_only", ())
self.required = getattr(meta, "required", ())
self.unknown = getattr(meta, "unknown", RAISE)
self.register = getattr(meta, "register", True)

Expand Down Expand Up @@ -273,6 +274,7 @@ class AlbumSchema(Schema):
:class:`fields.Function` fields.
:param load_only: Fields to skip during serialization (write-only fields)
:param dump_only: Fields to skip during deserialization (read-only fields)
:param required: Fields to be considered required.
:param partial: Whether to ignore missing fields and not require
any fields declared. Propagates down to ``Nested`` fields as well. If
its value is an iterable, only missing fields listed in that iterable
Expand Down Expand Up @@ -354,6 +356,7 @@ class Meta:
of invalid items in a collection.
- ``load_only``: Tuple or list of fields to exclude from serialized results.
- ``dump_only``: Tuple or list of fields to exclude from deserialization
- ``required``: Tuple or list of fields to be considered required.
- ``unknown``: Whether to exclude, include, or raise an error for unknown
fields in the data. Use `EXCLUDE`, `INCLUDE` or `RAISE`.
- ``register``: Whether to register the `Schema` with marshmallow's internal
Expand All @@ -371,6 +374,7 @@ def __init__(
context: typing.Dict = None,
load_only: types.StrSequenceOrSet = (),
dump_only: types.StrSequenceOrSet = (),
required: types.StrSequenceOrSet = (),
partial: typing.Union[bool, types.StrSequenceOrSet] = False,
unknown: str = None
):
Expand All @@ -387,6 +391,7 @@ def __init__(
self.ordered = self.opts.ordered
self.load_only = set(load_only) or set(self.opts.load_only)
self.dump_only = set(dump_only) or set(self.opts.dump_only)
self.required = set(required) or set(self.opts.required)
self.partial = partial
self.unknown = unknown or self.opts.unknown
self.context = context or {}
Expand Down Expand Up @@ -1028,13 +1033,15 @@ def _bind_field(self, field_name: str, field_obj: ma_fields.Field) -> None:
"""Bind field to the schema, setting any necessary attributes on the
field (e.g. parent and name).

Also set field load_only and dump_only values if field_name was
Also set field load_only, dump_only and required values if field_name was
specified in ``class Meta``.
"""
if field_name in self.load_only:
field_obj.load_only = True
if field_name in self.dump_only:
field_obj.dump_only = True
if field_name in self.required:
field_obj.required = True
try:
field_obj._bind_to_schema(field_name, self)
except TypeError as error:
Expand Down
108 changes: 108 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,61 @@ def test_dump_only(self, schema, data):
assert "str_regular" in grand_child


def test_deeply_nested_required():
class GrandChildSchema(Schema):
str_required = fields.String()
str_regular = fields.String()

class ChildSchema(Schema):
str_required = fields.String()
str_regular = fields.String()
grand_child = fields.Nested(GrandChildSchema)

class ParentSchema(Schema):
str_required = fields.String()
str_regular = fields.String()
child = fields.Nested(ChildSchema)

schema = ParentSchema(
required=(
"str_required",
"child.str_required",
"child.grand_child.str_required",
),
)

valid_data = {
"str_required": "Required",
"child": {
"str_required": "Required",
"grand_child": {
"str_required": "Required",
},
},
}

# Assert no exception
schema.load(valid_data)

data = valid_data.copy()
del data["str_required"]
with pytest.raises(ValidationError) as excinfo:
schema.load(data)
assert "str_required" in excinfo.value.messages

data = valid_data.copy()
del data["child"]["str_required"]
with pytest.raises(ValidationError) as excinfo:
schema.load(data)
assert "str_required" in excinfo.value.messages["child"]

data = valid_data.copy()
del data["child"]["grand_child"]["str_required"]
with pytest.raises(ValidationError) as excinfo:
schema.load(data)
assert "str_required" in excinfo.value.messages["child"]["grand_child"]


class TestDeeplyNestedListLoadOnly:
@pytest.fixture()
def schema(self):
Expand Down Expand Up @@ -1449,6 +1504,38 @@ def test_dump_only(self, schema, data):
assert "str_regular" in child


def test_deeply_nested_list_required():
class ChildSchema(Schema):
str_required = fields.String()
str_regular = fields.String()

class ParentSchema(Schema):
str_required = fields.String()
str_regular = fields.String()
child = fields.List(fields.Nested(ChildSchema))

schema = ParentSchema(
required=("str_required", "child.str_required"),
)

valid_data = {"str_required": "Required", "child": [{"str_required": "Required"}]}

# Assert no exception
schema.load(valid_data)

data = valid_data.copy()
del data["str_required"]
with pytest.raises(ValidationError) as excinfo:
schema.load(data)
assert "str_required" in excinfo.value.messages

data = valid_data.copy()
del data["child"][0]["str_required"]
with pytest.raises(ValidationError) as excinfo:
schema.load(data)
assert "str_required" in excinfo.value.messages["child"][0]


def test_nested_constructor_only_and_exclude():
class GrandChildSchema(Schema):
goo = fields.Field()
Expand Down Expand Up @@ -2858,6 +2945,27 @@ class NoTldTestSchema(Schema):
assert result == data_with_no_top_level_domain


def test_required_in_meta():
class MySchema(Schema):
class Meta:
required = "str_required"

str_required = fields.String()
str_regular = fields.String()

data = {
"str_required": None,
"str_regular": "Regular String",
}

schema = MySchema()

with pytest.raises(ValidationError) as excinfo:
schema.load(data)

assert "str_required" in excinfo.value.messages


class TestFromDict:
def test_generates_schema(self):
MySchema = Schema.from_dict({"foo": fields.Str()})
Expand Down