From 43f5ef3d06f77390ba773db04389ce35baafa3ee Mon Sep 17 00:00:00 2001 From: Carl Simon Adorf Date: Thu, 19 Jan 2023 18:44:30 +0100 Subject: [PATCH] Refactor API decorators (#5026) Authors: - Carl Simon Adorf (https://github.com/csadorf) Approvers: - William Hicks (https://github.com/wphicks) URL: https://github.com/rapidsai/cuml/pull/5026 --- python/cuml/internals/__init__.py | 2 - python/cuml/internals/api_decorators.py | 839 +++++---------------- python/cuml/tests/test_cuml_descr_decor.py | 22 +- 3 files changed, 219 insertions(+), 644 deletions(-) diff --git a/python/cuml/internals/__init__.py b/python/cuml/internals/__init__.py index 48c081b77a..65c8620670 100644 --- a/python/cuml/internals/__init__.py +++ b/python/cuml/internals/__init__.py @@ -28,9 +28,7 @@ api_base_return_generic_skipall, api_base_return_generic, api_base_return_sparse_array, - api_ignore, api_return_any, - api_return_array_skipall, api_return_array, api_return_generic, api_return_sparse_array, diff --git a/python/cuml/internals/api_decorators.py b/python/cuml/internals/api_decorators.py index 53f73e0751..69e5f2ed26 100644 --- a/python/cuml/internals/api_decorators.py +++ b/python/cuml/internals/api_decorators.py @@ -18,13 +18,10 @@ import functools import inspect import typing -from functools import wraps import warnings -import cuml.internals.array -import cuml.internals.array_sparse -import cuml.internals.input_utils -from cuml.internals.type_utils import _DecoratorType, wraps_typed +# TODO: Try to resolve circular import that makes this necessary: +from cuml.internals import input_utils as iu from cuml.internals.api_context_managers import BaseReturnAnyCM from cuml.internals.api_context_managers import BaseReturnArrayCM from cuml.internals.api_context_managers import BaseReturnGenericCM @@ -39,646 +36,220 @@ from cuml.internals.constants import CUML_WRAPPED_FLAG from cuml.internals.global_settings import GlobalSettings from cuml.internals.memory_utils import using_output_type +from cuml.internals.type_utils import _DecoratorType, wraps_typed from cuml.internals import logger -class DecoratorMetaClass(type): - """ - This metaclass is used to prevent wrapping functions multiple times by - adding `__cuml_is_wrapped = True` to the function __dict__ - """ - def __new__(cls, classname, bases, classDict): - - if ("__call__" in classDict): +def _wrap_once(wrapped, *args, **kwargs): + """Prevent wrapping functions multiple times.""" + setattr(wrapped, CUML_WRAPPED_FLAG, True) + return functools.wraps(wrapped, *args, **kwargs) - func = classDict["__call__"] - @wraps(func) - def wrap_call(*args, **kwargs): - ret_val = func(*args, **kwargs) +def _has_self(sig): + return "self" in sig.parameters and list(sig.parameters)[0] == "self" - ret_val.__dict__[CUML_WRAPPED_FLAG] = True - return ret_val +def _find_arg(sig, arg_name, default_position): + params = list(sig.parameters) - classDict["__call__"] = wrap_call + # Check for default name in input args + if arg_name in sig.parameters: + return arg_name, params.index(arg_name) + # Otherwise use argument in list by position + elif arg_name is ...: + index = int(_has_self(sig)) + default_position + return params[index], index + else: + raise ValueError( + f"Unable to find parameter '{arg_name}'." + ) - return type.__new__(cls, classname, bases, classDict) - - -class WithArgsDecoratorMixin(object): - """ - This decorator mixin handles processing the input arguments for all api - decorators. It supplies the input_arg, target_arg properties - """ - def __init__(self, - *, - input_arg: str = ..., - target_arg: str = ..., - needs_self=True, - needs_input=False, - needs_target=False): - super().__init__() - # For input_arg and target_arg, use Ellipsis to auto detect, None to - # skip (this has different functionality on Base where it can determine - # the output type like CumlArrayDescriptor) - self.input_arg = input_arg - self.target_arg = target_arg - - self.needs_self = needs_self - self.needs_input = needs_input - self.needs_target = needs_target - - def prep_arg_to_use(self, func) -> bool: - - # Determine from the signature what processing needs to be done. This - # is executed once per function on import - sig = inspect.signature(func, follow_wrapped=True) - sig_args = list(sig.parameters.keys()) - - self.has_self = "self" in sig.parameters and sig_args.index( - "self") == 0 - - if (not self.has_self and self.needs_self): - raise Exception("No self found on function!") - - # Return early if we dont need args - if (not self.needs_input and not self.needs_target): - return - - self_offset = (1 if self.has_self else 0) - - if (self.needs_input): - input_arg_to_use = self.input_arg - input_arg_to_use_name = None - - # if input_arg is None, then set to first non self argument - if (input_arg_to_use is ...): - - # Check for "X" in input args - if ("X" in sig_args): - input_arg_to_use = "X" - else: - if (len(sig.parameters) <= self_offset): - raise Exception("No input_arg could be determined!") - - input_arg_to_use = sig_args[self_offset] - - # Now convert that to an index - if (isinstance(input_arg_to_use, str)): - input_arg_to_use_name = input_arg_to_use - input_arg_to_use = sig_args.index(input_arg_to_use) - - assert input_arg_to_use != -1 and input_arg_to_use is not None, \ - "Could not determine input_arg" - - # Save the name and argument to use later - self.input_arg_to_use = input_arg_to_use - self.input_arg_to_use_name = input_arg_to_use_name - - if (self.needs_target): - - target_arg_to_use = self.target_arg - target_arg_to_use_name = None - - # if input_arg is None, then set to first non self argument - if (target_arg_to_use is ...): - - # Check for "y" in args - if ("y" in sig_args): - target_arg_to_use = "y" - else: - if (len(sig.parameters) <= self_offset + 1): - raise Exception("No target_arg could be determined!") - - target_arg_to_use = sig_args[self_offset + 1] - - # Now convert that to an index - if (isinstance(target_arg_to_use, str)): - target_arg_to_use_name = target_arg_to_use - target_arg_to_use = sig_args.index(target_arg_to_use) - - assert target_arg_to_use != -1 and target_arg_to_use is not None, \ - "Could not determine target_arg" - - # Save the name and argument to use later - self.target_arg_to_use = target_arg_to_use - self.target_arg_to_use_name = target_arg_to_use_name - - return True - - def get_arg_values(self, *args, **kwargs): - """ - This function is called once per function invocation to get the values - of self, input and target. - - Returns - ------- - tuple - Returns a tuple of self, input, target values - - Raises - ------ - IndexError - Raises an exception if the specified input argument is not - available or called with the wrong number of arguments - """ - self_val = None - input_val = None - target_val = None - - if (self.has_self): - self_val = args[0] - - if (self.needs_input): - # Check if its set to a string - if (isinstance(self.input_arg_to_use, str)): - input_val = kwargs[self.input_arg_to_use] - - # If all arguments are set by name, then this can happen - elif (self.input_arg_to_use >= len(args)): - # Check for the name in kwargs - if (self.input_arg_to_use_name in kwargs): - input_val = kwargs[self.input_arg_to_use_name] - else: - raise IndexError( - ("Specified input_arg idx: {}, and argument name: {}, " - "were not found in args or kwargs").format( - self.input_arg_to_use, - self.input_arg_to_use_name)) - else: - # Otherwise return the index - input_val = args[self.input_arg_to_use] - - if (self.needs_target): - # Check if its set to a string - if (isinstance(self.target_arg_to_use, str)): - target_val = kwargs[self.target_arg_to_use] - - # If all arguments are set by name, then this can happen - elif (self.target_arg_to_use >= len(args)): - # Check for the name in kwargs - if (self.target_arg_to_use_name in kwargs): - target_val = kwargs[self.target_arg_to_use_name] - else: - raise IndexError(( - "Specified target_arg idx: {}, and argument name: {}, " - "were not found in args or kwargs").format( - self.target_arg_to_use, - self.target_arg_to_use_name)) +def _get_value(args, kwargs, name, index): + """Determine value for a given set of args, kwargs, name and index.""" + try: + return kwargs[name] + except KeyError: + try: + return args[index] + except IndexError: + raise IndexError( + f"Specified arg idx: {index}, and argument name: {name}, " + "were not found in args or kwargs.") + + +def _make_decorator_function( + context_manager_cls: InternalAPIContextBase, + process_return=True, + needs_self: bool = False, + ** defaults, +) -> typing.Callable[..., _DecoratorType]: + # This function generates a function to be applied as decorator to a + # wrapped function. For example: + # + # a_decorator = _make_decorator_function(...) + # + # ... + # + # @a_decorator(...) # apply decorator where appropriate + # def fit(X, y): + # ... + # + # Note: The decorator function can be partially closed by directly + # providing keyword arguments to this function to be used as defaults. + + def decorator_function( + input_arg: str = ..., + target_arg: str = ..., + get_output_type: bool = False, + set_output_type: bool = False, + get_output_dtype: bool = False, + set_output_dtype: bool = False, + set_n_features_in: bool = False, + ) -> _DecoratorType: + + def decorator_closure(func): + # This function constitutes the closed decorator that will return + # the wrapped function. It performs function introspection at + # function definition time. The code within the wrapper function is + # executed at function execution time. + + # Prepare arguments + sig = inspect.signature(func, follow_wrapped=True) + + has_self = _has_self(sig) + if needs_self and not has_self: + raise Exception("No self found on function!") + + if input_arg is not None and ( + set_output_type + or set_output_dtype + or set_n_features_in + or get_output_type + ): + input_arg_ = _find_arg(sig, input_arg or "X", 0) else: - # Otherwise return the index - target_val = args[self.target_arg_to_use] - - return self_val, input_val, target_val - - -class HasSettersDecoratorMixin(object): - """ - This mixin is responsible for handling any "set_XXX" methods used by api - decorators. Mostly used by `fit()` functions - """ - def __init__(self, - *, - set_output_type=True, - set_output_dtype=False, - set_n_features_in=True) -> None: - - super().__init__() - - self.set_output_type = set_output_type - self.set_output_dtype = set_output_dtype - self.set_n_features_in = set_n_features_in - - self.has_setters = (self.set_output_type or self.set_output_dtype - or self.set_n_features_in) - - def do_setters(self, *, self_val, input_val, target_val): - if (self.set_output_type): - assert input_val is not None, \ - "`set_output_type` is False but no input_arg detected" - self_val._set_output_type(input_val) - - if (self.set_output_dtype): - assert target_val is not None, \ - "`set_output_dtype` is True but no target_arg detected" - self_val._set_target_dtype(target_val) - - if (self.set_n_features_in): - assert input_val is not None, \ - "`set_n_features_in` is False but no input_arg detected" - if (len(input_val.shape) >= 2): - self_val._set_n_features_in(input_val) - - def has_setters_input(self): - return self.set_output_type or self.set_n_features_in - - def has_setters_target(self): - return self.set_output_dtype - - -class HasGettersDecoratorMixin(object): - """ - This mixin is responsible for handling any "get_XXX" methods used by api - decorators. Used for many functions like `predict()`, `transform()`, etc. - """ - def __init__(self, - *, - get_output_type=False, - get_output_dtype=False) -> None: - - super().__init__() - - self.get_output_type = get_output_type - self.get_output_dtype = get_output_dtype - - self.has_getters = (self.get_output_type or self.get_output_dtype) - - def do_getters_with_self_no_input(self, *, self_val): - if (self.get_output_type): - out_type = self_val.output_type - - if (out_type == "input"): - out_type = self_val._input_type - - set_api_output_type(out_type) - - if (self.get_output_dtype): - set_api_output_dtype(self_val._get_target_dtype()) - - def do_getters_with_self(self, *, self_val, input_val): - if (self.get_output_type): - out_type = self_val._get_output_type(input_val) - assert out_type is not None, \ - ("`get_output_type` is False but output_type could not " - "be determined from input_arg") - set_api_output_type(out_type) - - if (self.get_output_dtype): - set_api_output_dtype(self_val._get_target_dtype()) - - def do_getters_no_self(self, *, input_val, target_val): - if (self.get_output_type): - assert input_val is not None, \ - "`get_output_type` is False but no input_arg detected" - set_api_output_type( - cuml.internals.input_utils.determine_array_type(input_val)) - - if (self.get_output_dtype): - assert target_val is not None, \ - "`get_output_dtype` is False but no target_arg detected" - set_api_output_dtype( - cuml.internals.input_utils.determine_array_dtype(target_val)) - - def has_getters_input(self): - return self.get_output_type - - def has_getters_target(self, needs_self): - return False if needs_self else self.get_output_dtype - - -class ReturnDecorator(metaclass=DecoratorMetaClass): - def __init__(self): - super().__init__() - - self.do_autowrap = False - - def __call__(self, func: _DecoratorType) -> _DecoratorType: - raise NotImplementedError() - - def _recreate_cm(self, func, args) -> InternalAPIContextBase: - raise NotImplementedError() - - -class ReturnAnyDecorator(ReturnDecorator): - def __call__(self, func: _DecoratorType) -> _DecoratorType: - @wraps(func) - def inner(*args, **kwargs): - with self._recreate_cm(func, args): - return func(*args, **kwargs) - - return inner - - def _recreate_cm(self, func, args): - return ReturnAnyCM(func, args) - - -class BaseReturnAnyDecorator(ReturnDecorator, - HasSettersDecoratorMixin, - WithArgsDecoratorMixin): - def __init__(self, - *, - input_arg: str = ..., - target_arg: str = ..., - set_output_type=True, - set_output_dtype=False, - set_n_features_in=True) -> None: - - ReturnDecorator.__init__(self) - HasSettersDecoratorMixin.__init__(self, - set_output_type=set_output_type, - set_output_dtype=set_output_dtype, - set_n_features_in=set_n_features_in) - WithArgsDecoratorMixin.__init__(self, - input_arg=input_arg, - target_arg=target_arg, - needs_self=True, - needs_input=self.has_setters_input(), - needs_target=self.has_setters_target()) - - self.do_autowrap = self.has_setters - - def __call__(self, func: _DecoratorType) -> _DecoratorType: - - self.prep_arg_to_use(func) - - @wraps(func) - def inner_with_setters(*args, **kwargs): - - with self._recreate_cm(func, args): - - self_val, input_val, target_val = \ - self.get_arg_values(*args, **kwargs) - - self.do_setters(self_val=self_val, - input_val=input_val, - target_val=target_val) - - return func(*args, **kwargs) - - @wraps(func) - def inner(*args, **kwargs): - - with self._recreate_cm(func, args): - return func(*args, **kwargs) - - # Return the function depending on whether or not we do any automatic - # wrapping - return inner_with_setters if self.has_setters else inner - - def _recreate_cm(self, func, args): - return BaseReturnAnyCM(func, args) - - -class ReturnArrayDecorator(ReturnDecorator, - HasGettersDecoratorMixin, - WithArgsDecoratorMixin): - def __init__(self, - *, - input_arg: str = ..., - target_arg: str = ..., - get_output_type=False, - get_output_dtype=False) -> None: - - ReturnDecorator.__init__(self) - HasGettersDecoratorMixin.__init__(self, - get_output_type=get_output_type, - get_output_dtype=get_output_dtype) - WithArgsDecoratorMixin.__init__( - self, - input_arg=input_arg, - target_arg=target_arg, - needs_self=False, - needs_input=self.has_getters_input(), - needs_target=self.has_getters_target(False)) - - self.do_autowrap = self.has_getters - - def __call__(self, func: _DecoratorType) -> _DecoratorType: + input_arg_ = None - self.prep_arg_to_use(func) - - @wraps(func) - def inner_with_getters(*args, **kwargs): - with self._recreate_cm(func, args) as cm: - - # Get input/target values - _, input_val, target_val = self.get_arg_values(*args, **kwargs) - - # Now execute the getters - self.do_getters_no_self(input_val=input_val, - target_val=target_val) - - # Call the function - ret_val = func(*args, **kwargs) - - return cm.process_return(ret_val) - - @wraps(func) - def inner(*args, **kwargs): - with self._recreate_cm(func, args) as cm: - - ret_val = func(*args, **kwargs) - - return cm.process_return(ret_val) - - return inner_with_getters if self.has_getters else inner - - def _recreate_cm(self, func, args): - - return ReturnArrayCM(func, args) - - -class ReturnSparseArrayDecorator(ReturnArrayDecorator): - def _recreate_cm(self, func, args): - - return ReturnSparseArrayCM(func, args) - - -class BaseReturnArrayDecorator(ReturnDecorator, - HasSettersDecoratorMixin, - HasGettersDecoratorMixin, - WithArgsDecoratorMixin): - def __init__(self, - *, - input_arg: str = ..., - target_arg: str = ..., - get_output_type=True, - get_output_dtype=False, - set_output_type=False, - set_output_dtype=False, - set_n_features_in=False) -> None: - - ReturnDecorator.__init__(self) - HasSettersDecoratorMixin.__init__(self, - set_output_type=set_output_type, - set_output_dtype=set_output_dtype, - set_n_features_in=set_n_features_in) - HasGettersDecoratorMixin.__init__(self, - get_output_type=get_output_type, - get_output_dtype=get_output_dtype) - WithArgsDecoratorMixin.__init__( - self, - input_arg=input_arg, - target_arg=target_arg, - needs_self=True, - needs_input=(self.has_setters_input() or self.has_getters_input()) - and input_arg is not None, - needs_target=self.has_setters_target() - or self.has_getters_target(True)) - - self.do_autowrap = self.has_setters or self.has_getters - - def __call__(self, func: _DecoratorType) -> _DecoratorType: - - self.prep_arg_to_use(func) - - @wraps(func) - def inner_set_get(*args, **kwargs): - with self._recreate_cm(func, args) as cm: - - # Get input/target values - self_val, input_val, target_val = \ - self.get_arg_values(*args, **kwargs) - - # Must do the setters first - self.do_setters(self_val=self_val, - input_val=input_val, - target_val=target_val) - - # Now execute the getters - if (self.needs_input): - self.do_getters_with_self(self_val=self_val, - input_val=input_val) - else: - self.do_getters_with_self_no_input(self_val=self_val) - - # Call the function - ret_val = func(*args, **kwargs) - - return cm.process_return(ret_val) - - @wraps(func) - def inner_set(*args, **kwargs): - with self._recreate_cm(func, args) as cm: - - # Get input/target values - self_val, input_val, target_val = \ - self.get_arg_values(*args, **kwargs) - - # Must do the setters first - self.do_setters(self_val=self_val, - input_val=input_val, - target_val=target_val) - - # Call the function - ret_val = func(*args, **kwargs) - - return cm.process_return(ret_val) - - @wraps(func) - def inner_get(*args, **kwargs): - with self._recreate_cm(func, args) as cm: - - # Get input/target values - self_val, input_val, _ = self.get_arg_values(*args, **kwargs) - - # Do the getters - if (self.needs_input): - self.do_getters_with_self(self_val=self_val, - input_val=input_val) - else: - self.do_getters_with_self_no_input(self_val=self_val) - - # Call the function - ret_val = func(*args, **kwargs) - - return cm.process_return(ret_val) - - @wraps(func) - def inner(*args, **kwargs): - with self._recreate_cm(func, args) as cm: - - # Call the function - ret_val = func(*args, **kwargs) - - return cm.process_return(ret_val) - - # Return the function depending on whether or not we do any automatic - # wrapping - if (self.has_getters and self.has_setters): - return inner_set_get - elif (self.has_getters): - return inner_get - elif (self.has_setters): - return inner_set - else: - return inner - - def _recreate_cm(self, func, args): - - return BaseReturnArrayCM(func, args) - - -class BaseReturnSparseArrayDecorator(BaseReturnArrayDecorator): - def _recreate_cm(self, func, args): - - return BaseReturnSparseArrayCM(func, args) - - -class ReturnGenericDecorator(ReturnArrayDecorator): - def _recreate_cm(self, func, args): - - return ReturnGenericCM(func, args) - - -class BaseReturnGenericDecorator(BaseReturnArrayDecorator): - def _recreate_cm(self, func, args): - - return BaseReturnGenericCM(func, args) - - -class BaseReturnArrayFitTransformDecorator(BaseReturnArrayDecorator): - """ - Identical to `BaseReturnArrayDecorator`, however the defaults have been - changed to better suit `fit_transform` methods - """ - def __init__(self, - *, - input_arg: str = ..., - target_arg: str = ..., - get_output_type=True, - get_output_dtype=False, - set_output_type=True, - set_output_dtype=False, - set_n_features_in=True) -> None: - - super().__init__(input_arg=input_arg, - target_arg=target_arg, - get_output_type=get_output_type, - get_output_dtype=get_output_dtype, - set_output_type=set_output_type, - set_output_dtype=set_output_dtype, - set_n_features_in=set_n_features_in) - - -api_return_any = ReturnAnyDecorator -api_base_return_any = BaseReturnAnyDecorator -api_return_array = ReturnArrayDecorator -api_base_return_array = BaseReturnArrayDecorator -api_return_generic = ReturnGenericDecorator -api_base_return_generic = BaseReturnGenericDecorator -api_base_fit_transform = BaseReturnArrayFitTransformDecorator - -api_return_sparse_array = ReturnSparseArrayDecorator -api_base_return_sparse_array = BaseReturnSparseArrayDecorator - -api_return_array_skipall = ReturnArrayDecorator(get_output_dtype=False, - get_output_type=False) - -api_base_return_any_skipall = BaseReturnAnyDecorator(set_output_type=False, - set_n_features_in=False) -api_base_return_array_skipall = BaseReturnArrayDecorator(get_output_type=False) -api_base_return_generic_skipall = BaseReturnGenericDecorator( - get_output_type=False) - - -def api_ignore(func: _DecoratorType) -> _DecoratorType: - - func.__dict__[CUML_WRAPPED_FLAG] = True - - return func + if set_output_dtype or (get_output_dtype and not has_self): + target_arg_ = _find_arg(sig, target_arg or "y", 1) + else: + target_arg_ = None + + @_wrap_once(func) + def wrapper(*args, **kwargs): + # Wraps the decorated function, executed at runtime. + + with context_manager_cls(func, args) as cm: + + self_val = args[0] if has_self else None + + if input_arg_: + input_val = _get_value(args, kwargs, * input_arg_) + else: + input_val = None + if target_arg_: + target_val = _get_value(args, kwargs, * target_arg_) + else: + target_val = None + + if set_output_type: + assert self_val is not None + self_val._set_output_type(input_val) + if set_output_dtype: + assert self_val is not None + self_val._set_target_dtype(target_val) + if set_n_features_in and len(input_val.shape) >= 2: + assert self_val is not None + self_val._set_n_features_in(input_val) + + if get_output_type: + if self_val is None: + assert input_val is not None + out_type = iu.determine_array_type(input_val) + elif input_val is None: + out_type = self_val.output_type + if out_type == "input": + out_type = self_val._input_type + else: + out_type = self_val._get_output_type(input_val) + + set_api_output_type(out_type) + + if get_output_dtype: + if self_val is None: + assert target_val is not None + output_dtype = iu.determine_array_dtype(target_val) + else: + output_dtype = self_val._get_target_dtype() + + set_api_output_dtype(output_dtype) + + if process_return: + ret = func(*args, **kwargs) + else: + return func(*args, **kwargs) + + return cm.process_return(ret) + + return wrapper + + return decorator_closure + + return functools.partial(decorator_function, **defaults) + + +api_return_any = _make_decorator_function(ReturnAnyCM, process_return=False) +api_base_return_any = _make_decorator_function( + BaseReturnAnyCM, + needs_self=True, + set_output_type=True, + set_n_features_in=True, +) +api_return_array = _make_decorator_function(ReturnArrayCM, process_return=True) +api_base_return_array = _make_decorator_function( + BaseReturnArrayCM, + needs_self=True, + process_return=True, + get_output_type=True, +) +api_return_generic = _make_decorator_function( + ReturnGenericCM, process_return=True +) +api_base_return_generic = _make_decorator_function( + BaseReturnGenericCM, + needs_self=True, + process_return=True, + get_output_type=True, +) +api_base_fit_transform = _make_decorator_function( + # TODO: add tests for this decorator( + BaseReturnArrayCM, + needs_self=True, + process_return=True, + get_output_type=True, + set_output_type=True, + set_n_features_in=True, +) + +api_return_sparse_array = _make_decorator_function( + ReturnSparseArrayCM, process_return=True +) +api_base_return_sparse_array = _make_decorator_function( + BaseReturnSparseArrayCM, + needs_self=True, + process_return=True, + get_output_type=True, +) + +api_base_return_any_skipall = api_base_return_any( + set_output_type=False, set_n_features_in=False +) +api_base_return_array_skipall = api_base_return_array(get_output_type=False) +api_base_return_generic_skipall = api_base_return_generic( + get_output_type=False +) @contextlib.contextmanager @@ -706,7 +277,7 @@ def mirror_args( assigned=('__doc__', '__annotations__'), updated=functools.WRAPPER_UPDATES ) -> typing.Callable[[_DecoratorType], _DecoratorType]: - return wraps(wrapped=wrapped, assigned=assigned, updated=updated) + return _wrap_once(wrapped=wrapped, assigned=assigned, updated=updated) class _deprecate_pos_args: diff --git a/python/cuml/tests/test_cuml_descr_decor.py b/python/cuml/tests/test_cuml_descr_decor.py index 63f8c738a2..9007566f4e 100644 --- a/python/cuml/tests/test_cuml_descr_decor.py +++ b/python/cuml/tests/test_cuml_descr_decor.py @@ -314,14 +314,20 @@ def test_func(X, y): return X - if (input_arg == "bad" or target_arg == "bad"): - pytest.xfail("Expected error with bad arg name") - - test_func = cuml.internals.api_return_array( - input_arg=input_arg, - target_arg=target_arg, - get_output_type=get_output_type, - get_output_dtype=get_output_dtype)(test_func) + expected_to_fail = (input_arg == "bad" and get_output_type) \ + or (target_arg == "bad" and get_output_dtype) + + try: + test_func = cuml.internals.api_return_array( + input_arg=input_arg, + target_arg=target_arg, + get_output_type=get_output_type, + get_output_dtype=get_output_dtype)(test_func) + except ValueError: + assert expected_to_fail + return + else: + assert not expected_to_fail X_out = test_func(X=X_in, y=Y_in)