diff --git a/voluptuous/schema_builder.py b/voluptuous/schema_builder.py index 5f686c9..8f93863 100644 --- a/voluptuous/schema_builder.py +++ b/voluptuous/schema_builder.py @@ -1018,15 +1018,71 @@ def wrapper(*args, **kwargs): return decorator -def validate_schema(*a, **kw): - schema = Schema(*a, **kw) +def _args_to_dict(func, args): + """Returns argument names as values as key-value pairs.""" + if sys.version_info >= (3, 0): + arg_count = func.__code__.co_argcount + arg_names = func.__code__.co_varnames[:arg_count] + else: + arg_count = func.func_code.co_argcount + arg_names = func.func_code.co_varnames[:arg_count] - def decorator(f): - @wraps(f) - def wrapper(*args, **kwargs): - result = f(*args, **kwargs) - schema(result) - return result - return wrapper + arg_value_list = list(args) + arguments = dict((arg_name, arg_value_list[i]) + for i, arg_name in enumerate(arg_names) + if i < len(arg_value_list)) + return arguments - return decorator + +def _merge_args_with_kwargs(args_dict, kwargs_dict): + """Merge args with kwargs.""" + ret = args_dict.copy() + ret.update(kwargs_dict) + return ret + + +def validate(*a, **kw): + """Decorator for validating arguments of a function against a given schema. + + Set restrictions for arguments: + + >>> @validate(arg1=int, arg2=int) + ... def foo(arg1, arg2): + ... return arg1 * arg2 + + Set restriction for returned value: + + >>> @validate(arg=int, __return__=int) + ... def foo(arg1): + ... return arg1 * 2 + + """ + RETURNS_KEY = '__return__' + + def validate_schema_decorator(func): + + returns_defined = False + returns = None + + schema_args_dict = _args_to_dict(func, a) + schema_arguments = _merge_args_with_kwargs(schema_args_dict, kw) + + if RETURNS_KEY in schema_arguments: + returns_defined = True + returns = schema_arguments[RETURNS_KEY] + del schema_arguments[RETURNS_KEY] + + input_schema = Schema(schema_arguments) if len(schema_arguments) != 0 else lambda x: x + output_schema = Schema(returns) if returns_defined else lambda x: x + + @wraps(func) + def func_wrapper(*args, **kwargs): + args_dict = _args_to_dict(func, args) + arguments = _merge_args_with_kwargs(args_dict, kwargs) + validated_arguments = input_schema(arguments) + output = func(**validated_arguments) + return output_schema(output) + + return func_wrapper + + return validate_schema_decorator diff --git a/voluptuous/tests/tests.py b/voluptuous/tests/tests.py index 79afa49..86a43ce 100644 --- a/voluptuous/tests/tests.py +++ b/voluptuous/tests/tests.py @@ -5,7 +5,7 @@ Schema, Required, Extra, Invalid, In, Remove, Literal, Url, MultipleInvalid, LiteralInvalid, NotIn, Match, Email, Replace, Range, Coerce, All, Any, Length, FqdnUrl, ALLOW_EXTRA, PREVENT_EXTRA, - validate_schema, ExactSequence, Equal, Unordered + validate, ExactSequence, Equal, Unordered ) from voluptuous.humanize import humanize_error @@ -423,14 +423,6 @@ def test_fix_157(): assert_raises(MultipleInvalid, s, ['four']) -def test_schema_decorator(): - @validate_schema(int) - def fn(arg): - return arg - - fn(1) - assert_raises(Invalid, fn, 1.0) - def test_range_exlcudes_nan(): s = Schema(Range(min=0, max=10)) @@ -485,3 +477,83 @@ def test_empty_list_as_exact(): s = Schema([]) assert_raises(Invalid, s, [1]) s([]) + + +def test_schema_decorator_match_with_args(): + @validate(int) + def fn(arg): + return arg + + fn(1) + + +def test_schema_decorator_unmatch_with_args(): + @validate(int) + def fn(arg): + return arg + + assert_raises(Invalid, fn, 1.0) + + +def test_schema_decorator_match_with_kwargs(): + @validate(arg=int) + def fn(arg): + return arg + + fn(1) + + +def test_schema_decorator_unmatch_with_kwargs(): + @validate(arg=int) + def fn(arg): + return arg + + assert_raises(Invalid, fn, 1.0) + + +def test_schema_decorator_match_return_with_args(): + @validate(int, __return__=int) + def fn(arg): + return arg + + fn(1) + + +def test_schema_decorator_unmatch_return_with_args(): + @validate(int, __return__=int) + def fn(arg): + return "hello" + + assert_raises(Invalid, fn, 1) + + +def test_schema_decorator_match_return_with_kwargs(): + @validate(arg=int, __return__=int) + def fn(arg): + return arg + + fn(1) + + +def test_schema_decorator_unmatch_return_with_kwargs(): + @validate(arg=int, __return__=int) + def fn(arg): + return "hello" + + assert_raises(Invalid, fn, 1) + + +def test_schema_decorator_return_only_match(): + @validate(__return__=int) + def fn(arg): + return arg + + fn(1) + + +def test_schema_decorator_return_only_unmatch(): + @validate(__return__=int) + def fn(arg): + return "hello" + + assert_raises(Invalid, fn, 1)