From 72641b83317cf91a1fc264c2b08ee31519c22b34 Mon Sep 17 00:00:00 2001 From: DS/Charlie <82801887+ds-cbo@users.noreply.github.com> Date: Wed, 22 Feb 2023 20:29:06 +0100 Subject: [PATCH] Fix type hint of schemas, for example for Required('key') (#478) --- voluptuous/schema_builder.py | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/voluptuous/schema_builder.py b/voluptuous/schema_builder.py index e7e8d35..946cefd 100644 --- a/voluptuous/schema_builder.py +++ b/voluptuous/schema_builder.py @@ -118,9 +118,6 @@ def _isnamedtuple(obj): return isinstance(obj, tuple) and hasattr(obj, '_fields') -primitive_types = (str, unicode, bool, int, float) - - class Undefined(object): def __nonzero__(self): return False @@ -165,7 +162,15 @@ def Extra(_) -> None: # deprecated object, so we just leave an alias here instead. extra = Extra -Schemable = typing.Union[dict, list, type, typing.Callable] +primitive_types = (bool, bytes, int, long, str, unicode, float, complex) + +Schemable = typing.Union[ + Extra, 'Schema', 'Object', + _Mapping, + list, tuple, frozenset, set, + bool, bytes, int, long, str, unicode, float, complex, + type, object, dict, type(None), typing.Callable +] class Schema(object): @@ -306,8 +311,7 @@ def _compile(self, schema): type_ = type(schema) if inspect.isclass(schema): type_ = schema - if type_ in (bool, bytes, int, long, str, unicode, float, complex, object, - list, dict, type(None)) or callable(schema): + if type_ in (*primitive_types, object, list, dict, type(None)) or callable(schema): return _compile_scalar(schema) raise er.SchemaError('unsupported schema data type %r' % type(schema).__name__) @@ -733,7 +737,7 @@ def validate_set(path, data): return validate_set - def extend(self, schema: dict, required: typing.Optional[bool] = None, extra: typing.Optional[int] = None) -> Schema: + def extend(self, schema: Schemable, required: typing.Optional[bool] = None, extra: typing.Optional[int] = None) -> Schema: """Create a new `Schema` by merging this and the provided `schema`. Neither this `Schema` nor the provided `schema` are modified. The @@ -947,7 +951,7 @@ class Msg(object): ... assert isinstance(e.errors[0], er.RangeInvalid) """ - def __init__(self, schema: dict, msg: str, cls: typing.Optional[typing.Type[Error]] = None) -> None: + def __init__(self, schema: Schemable, msg: str, cls: typing.Optional[typing.Type[Error]] = None) -> None: if cls and not issubclass(cls, er.Invalid): raise er.SchemaError("Msg can only use subclases of" " Invalid as custom class") @@ -972,7 +976,7 @@ def __repr__(self): class Object(dict): """Indicate that we should work with attributes, not keys.""" - def __init__(self, schema, cls: object = UNDEFINED) -> None: + def __init__(self, schema: typing.Any, cls: object = UNDEFINED) -> None: self.cls = cls super(Object, self).__init__(schema) @@ -988,7 +992,7 @@ def __repr__(self): class Marker(object): """Mark nodes for special treatment.""" - def __init__(self, schema_: dict, msg: typing.Optional[str] = None, description: typing.Optional[str] = None) -> None: + def __init__(self, schema_: Schemable, msg: typing.Optional[str] = None, description: typing.Optional[str] = None) -> None: self.schema = schema_ self._schema = Schema(schema_) self.msg = msg @@ -1046,7 +1050,7 @@ class Optional(Marker): {'key2': 'value'} """ - def __init__(self, schema: dict, msg: typing.Optional[str] = None, default=UNDEFINED, description: typing.Optional[str] = None) -> None: + def __init__(self, schema: Schemable, msg: typing.Optional[str] = None, default=UNDEFINED, description: typing.Optional[str] = None) -> None: super(Optional, self).__init__(schema, msg=msg, description=description) self.default = default_factory(default) @@ -1088,7 +1092,7 @@ class Exclusive(Optional): ... 'social': {'social_network': 'barfoo', 'token': 'tEMp'}}) """ - def __init__(self, schema: dict, group_of_exclusion: str, msg: typing.Optional[str] = None, description: typing.Optional[str] = None) -> None: + def __init__(self, schema: Schemable, group_of_exclusion: str, msg: typing.Optional[str] = None, description: typing.Optional[str] = None) -> None: super(Exclusive, self).__init__(schema, msg=msg, description=description) self.group_of_exclusion = group_of_exclusion @@ -1136,7 +1140,7 @@ class Inclusive(Optional): True """ - def __init__(self, schema: dict, group_of_inclusion: str, + def __init__(self, schema: Schemable, group_of_inclusion: str, msg: typing.Optional[str] = None, description: typing.Optional[str] = None, default=UNDEFINED) -> None: super(Inclusive, self).__init__(schema, msg=msg, default=default, @@ -1159,7 +1163,7 @@ class Required(Marker): {'key': []} """ - def __init__(self, schema: dict, msg: typing.Optional[str] = None, default=UNDEFINED, description: typing.Optional[str] = None) -> None: + def __init__(self, schema: Schemable, msg: typing.Optional[str] = None, default=UNDEFINED, description: typing.Optional[str] = None) -> None: super(Required, self).__init__(schema, msg=msg, description=description) self.default = default_factory(default) @@ -1180,8 +1184,8 @@ class Remove(Marker): [1, 2, 3, 5, '7'] """ - def __call__(self, v: object): - super(Remove, self).__call__(v) + def __call__(self, schema: Schemable): + super(Remove, self).__call__(schema) return self.__class__ def __repr__(self):