Skip to content

Commit

Permalink
Fix type hint of schemas, for example for Required('key') (#478)
Browse files Browse the repository at this point in the history
  • Loading branch information
ds-cbo authored Feb 22, 2023
1 parent ecb0cdc commit 72641b8
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions voluptuous/schema_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,6 @@ def _isnamedtuple(obj):
return isinstance(obj, tuple) and hasattr(obj, '_fields')


primitive_types = (str, unicode, bool, int, float)


class Undefined(object):
def __nonzero__(self):
return False
Expand Down Expand Up @@ -165,7 +162,15 @@ def Extra(_) -> None:
# deprecated object, so we just leave an alias here instead.
extra = Extra

Schemable = typing.Union[dict, list, type, typing.Callable]
primitive_types = (bool, bytes, int, long, str, unicode, float, complex)

Schemable = typing.Union[
Extra, 'Schema', 'Object',
_Mapping,
list, tuple, frozenset, set,
bool, bytes, int, long, str, unicode, float, complex,
type, object, dict, type(None), typing.Callable
]


class Schema(object):
Expand Down Expand Up @@ -306,8 +311,7 @@ def _compile(self, schema):
type_ = type(schema)
if inspect.isclass(schema):
type_ = schema
if type_ in (bool, bytes, int, long, str, unicode, float, complex, object,
list, dict, type(None)) or callable(schema):
if type_ in (*primitive_types, object, list, dict, type(None)) or callable(schema):
return _compile_scalar(schema)
raise er.SchemaError('unsupported schema data type %r' %
type(schema).__name__)
Expand Down Expand Up @@ -733,7 +737,7 @@ def validate_set(path, data):

return validate_set

def extend(self, schema: dict, required: typing.Optional[bool] = None, extra: typing.Optional[int] = None) -> Schema:
def extend(self, schema: Schemable, required: typing.Optional[bool] = None, extra: typing.Optional[int] = None) -> Schema:
"""Create a new `Schema` by merging this and the provided `schema`.
Neither this `Schema` nor the provided `schema` are modified. The
Expand Down Expand Up @@ -947,7 +951,7 @@ class Msg(object):
... assert isinstance(e.errors[0], er.RangeInvalid)
"""

def __init__(self, schema: dict, msg: str, cls: typing.Optional[typing.Type[Error]] = None) -> None:
def __init__(self, schema: Schemable, msg: str, cls: typing.Optional[typing.Type[Error]] = None) -> None:
if cls and not issubclass(cls, er.Invalid):
raise er.SchemaError("Msg can only use subclases of"
" Invalid as custom class")
Expand All @@ -972,7 +976,7 @@ def __repr__(self):
class Object(dict):
"""Indicate that we should work with attributes, not keys."""

def __init__(self, schema, cls: object = UNDEFINED) -> None:
def __init__(self, schema: typing.Any, cls: object = UNDEFINED) -> None:
self.cls = cls
super(Object, self).__init__(schema)

Expand All @@ -988,7 +992,7 @@ def __repr__(self):
class Marker(object):
"""Mark nodes for special treatment."""

def __init__(self, schema_: dict, msg: typing.Optional[str] = None, description: typing.Optional[str] = None) -> None:
def __init__(self, schema_: Schemable, msg: typing.Optional[str] = None, description: typing.Optional[str] = None) -> None:
self.schema = schema_
self._schema = Schema(schema_)
self.msg = msg
Expand Down Expand Up @@ -1046,7 +1050,7 @@ class Optional(Marker):
{'key2': 'value'}
"""

def __init__(self, schema: dict, msg: typing.Optional[str] = None, default=UNDEFINED, description: typing.Optional[str] = None) -> None:
def __init__(self, schema: Schemable, msg: typing.Optional[str] = None, default=UNDEFINED, description: typing.Optional[str] = None) -> None:
super(Optional, self).__init__(schema, msg=msg,
description=description)
self.default = default_factory(default)
Expand Down Expand Up @@ -1088,7 +1092,7 @@ class Exclusive(Optional):
... 'social': {'social_network': 'barfoo', 'token': 'tEMp'}})
"""

def __init__(self, schema: dict, group_of_exclusion: str, msg: typing.Optional[str] = None, description: typing.Optional[str] = None) -> None:
def __init__(self, schema: Schemable, group_of_exclusion: str, msg: typing.Optional[str] = None, description: typing.Optional[str] = None) -> None:
super(Exclusive, self).__init__(schema, msg=msg,
description=description)
self.group_of_exclusion = group_of_exclusion
Expand Down Expand Up @@ -1136,7 +1140,7 @@ class Inclusive(Optional):
True
"""

def __init__(self, schema: dict, group_of_inclusion: str,
def __init__(self, schema: Schemable, group_of_inclusion: str,
msg: typing.Optional[str] = None, description: typing.Optional[str] = None, default=UNDEFINED) -> None:
super(Inclusive, self).__init__(schema, msg=msg,
default=default,
Expand All @@ -1159,7 +1163,7 @@ class Required(Marker):
{'key': []}
"""

def __init__(self, schema: dict, msg: typing.Optional[str] = None, default=UNDEFINED, description: typing.Optional[str] = None) -> None:
def __init__(self, schema: Schemable, msg: typing.Optional[str] = None, default=UNDEFINED, description: typing.Optional[str] = None) -> None:
super(Required, self).__init__(schema, msg=msg,
description=description)
self.default = default_factory(default)
Expand All @@ -1180,8 +1184,8 @@ class Remove(Marker):
[1, 2, 3, 5, '7']
"""

def __call__(self, v: object):
super(Remove, self).__call__(v)
def __call__(self, schema: Schemable):
super(Remove, self).__call__(schema)
return self.__class__

def __repr__(self):
Expand Down

0 comments on commit 72641b8

Please sign in to comment.