diff --git a/ivy/data_classes/array/activations.py b/ivy/data_classes/array/activations.py index 33fa7e1844cb1..0552357ba295a 100644 --- a/ivy/data_classes/array/activations.py +++ b/ivy/data_classes/array/activations.py @@ -30,7 +30,8 @@ def relu( optional output array, for writing the result to. It must have a shape that the inputs broadcast to. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -69,7 +70,8 @@ def leaky_relu( optional output array, for writing the result to. It must have a shape that the inputs broadcast to. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -93,6 +95,7 @@ def gelu( *, approximate: bool = False, out: Optional[ivy.Array] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ ivy.Array instance method variant of ivy.gelu. This method simply wraps the @@ -108,6 +111,9 @@ def gelu( out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -196,6 +202,7 @@ def softplus( beta: Optional[Union[int, float]] = None, threshold: Optional[Union[int, float]] = None, out: Optional[ivy.Array] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ ivy.Array instance method variant of ivy.softplus. This method simply wraps the @@ -212,6 +219,9 @@ def softplus( the threshold parameter of the softplus function. out optional output array, for writing the result to. It must have a shape + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -235,7 +245,13 @@ def softplus( >>> print(x) ivy.array([1.55, 2.13, 2.13]) """ - return ivy.softplus(self._data, beta=beta, threshold=threshold, out=out) + return ivy.softplus( + self._data, + beta=beta, + threshold=threshold, + out=out, + complex_mode=complex_mode, + ) def log_softmax( self: ivy.Array, diff --git a/ivy/data_classes/container/activations.py b/ivy/data_classes/container/activations.py index 104ae63ea31af..542f4c5065741 100644 --- a/ivy/data_classes/container/activations.py +++ b/ivy/data_classes/container/activations.py @@ -45,7 +45,8 @@ def _static_relu( optional output container, for writing the result to. It must have a shape that the inputs broadcast to. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -109,7 +110,8 @@ def relu( optional output container, for writing the result to. It must have a shape that the inputs broadcast to. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -176,7 +178,8 @@ def _static_leaky_relu( optional output container, for writing the result to. It must have a shape that the inputs broadcast to. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -243,7 +246,8 @@ def leaky_relu( optional output container, for writing the result to. It must have a shape that the inputs broadcast to. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -310,7 +314,8 @@ def _static_gelu( optional output container, for writing the result to. It must have a shape that the inputs broadcast to. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -376,7 +381,8 @@ def gelu( optional output container, for writing the result to. It must have a shape that the inputs broadcast to. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -660,6 +666,7 @@ def _static_softplus( prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, out: Optional[ivy.Container] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ ivy.Container static method variant of ivy.softplus. This method simply wraps @@ -688,6 +695,9 @@ def _static_softplus( out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -721,6 +731,7 @@ def _static_softplus( prune_unapplied=prune_unapplied, map_sequences=map_sequences, out=out, + complex_mode=complex_mode, ) def softplus( @@ -734,6 +745,7 @@ def softplus( prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, out: Optional[ivy.Container] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ ivy.Container instance method variant of ivy.softplus. This method simply wraps @@ -762,6 +774,9 @@ def softplus( out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -793,6 +808,7 @@ def softplus( prune_unapplied=prune_unapplied, map_sequences=map_sequences, out=out, + complex_mode=complex_mode, ) @staticmethod diff --git a/ivy/functional/backends/jax/activations.py b/ivy/functional/backends/jax/activations.py index d7705c9847d14..5de229f228dcc 100644 --- a/ivy/functional/backends/jax/activations.py +++ b/ivy/functional/backends/jax/activations.py @@ -51,6 +51,7 @@ def softplus( beta: Optional[Union[int, float]] = None, threshold: Optional[Union[int, float]] = None, out: Optional[JaxArray] = None, + complex_mode="jax", ) -> JaxArray: if beta is not None and beta != 1: x_beta = x * beta diff --git a/ivy/functional/backends/mxnet/activations.py b/ivy/functional/backends/mxnet/activations.py index 9b6cfdcc9467f..40136c7d346b2 100644 --- a/ivy/functional/backends/mxnet/activations.py +++ b/ivy/functional/backends/mxnet/activations.py @@ -44,6 +44,7 @@ def softplus( beta: Optional[Union[(int, float)]] = None, threshold: Optional[Union[(int, float)]] = None, out: Optional[None] = None, + complex_mode="jax", ) -> None: if beta is not None and beta != 1: x_beta = x * beta diff --git a/ivy/functional/backends/numpy/activations.py b/ivy/functional/backends/numpy/activations.py index d2ee4c0675de3..73ceca88e969b 100644 --- a/ivy/functional/backends/numpy/activations.py +++ b/ivy/functional/backends/numpy/activations.py @@ -59,6 +59,7 @@ def softplus( beta: Optional[Union[int, float]] = None, threshold: Optional[Union[int, float]] = None, out: Optional[np.ndarray] = None, + complex_mode="jax", ) -> np.ndarray: if beta is not None and beta != 1: x_beta = x * beta diff --git a/ivy/functional/backends/paddle/activations.py b/ivy/functional/backends/paddle/activations.py index 185218223ce33..9f60af9cddb31 100644 --- a/ivy/functional/backends/paddle/activations.py +++ b/ivy/functional/backends/paddle/activations.py @@ -121,6 +121,7 @@ def softplus( beta: Optional[Union[int, float]] = None, threshold: Optional[Union[int, float]] = None, out: Optional[paddle.Tensor] = None, + complex_mode="jax", ) -> paddle.Tensor: if beta is not None and beta != 1: x_beta = x * beta diff --git a/ivy/functional/backends/paddle/elementwise.py b/ivy/functional/backends/paddle/elementwise.py index 722ccfeda78c7..0face1e720bad 100644 --- a/ivy/functional/backends/paddle/elementwise.py +++ b/ivy/functional/backends/paddle/elementwise.py @@ -733,6 +733,9 @@ def square( return paddle_backend.pow(x, 2).astype(x.dtype) +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version +) def pow( x1: Union[float, paddle.Tensor], x2: Union[float, paddle.Tensor], diff --git a/ivy/functional/backends/tensorflow/activations.py b/ivy/functional/backends/tensorflow/activations.py index 5ba82d96269f6..1db0da2a5fce3 100644 --- a/ivy/functional/backends/tensorflow/activations.py +++ b/ivy/functional/backends/tensorflow/activations.py @@ -47,7 +47,17 @@ def softmax( @with_supported_dtypes( - {"2.13.0 and below": ("float16", "bfloat16", "float32", "float64")}, backend_version + { + "2.13.0 and below": ( + "float16", + "bfloat16", + "float32", + "float64", + "complex64", + "complex128", + ) + }, + backend_version, ) def softplus( x: Tensor, @@ -56,6 +66,7 @@ def softplus( beta: Optional[Union[int, float]] = None, threshold: Optional[Union[int, float]] = None, out: Optional[Tensor] = None, + complex_mode="jax", ) -> Tensor: if beta is not None and beta != 1: x_beta = x * beta diff --git a/ivy/functional/backends/torch/activations.py b/ivy/functional/backends/torch/activations.py index f94890adb1424..d0277edf8355b 100644 --- a/ivy/functional/backends/torch/activations.py +++ b/ivy/functional/backends/torch/activations.py @@ -71,9 +71,7 @@ def softmax( return torch.nn.functional.softmax(x, axis) -@with_unsupported_dtypes( - {"2.0.1 and below": ("complex", "float16", "bfloat16")}, backend_version -) +@with_unsupported_dtypes({"2.0.1 and below": ("float16", "bfloat16")}, backend_version) def softplus( x: torch.Tensor, /, @@ -81,6 +79,7 @@ def softplus( beta: Optional[Union[int, float]] = None, threshold: Optional[Union[int, float]] = None, out: Optional[torch.Tensor] = None, + complex_mode="jax", ) -> torch.Tensor: kwargs = { k: v for k, v in {"beta": beta, "threshold": threshold}.items() if v is not None diff --git a/ivy/functional/frontends/jax/nn/non_linear_activations.py b/ivy/functional/frontends/jax/nn/non_linear_activations.py index c9340407ab663..b11ff04ca7142 100644 --- a/ivy/functional/frontends/jax/nn/non_linear_activations.py +++ b/ivy/functional/frontends/jax/nn/non_linear_activations.py @@ -291,7 +291,7 @@ def softmax(x, axis=-1, where=None, initial=None): @to_ivy_arrays_and_back def softplus(x): x = _type_conversion(x) - return ivy.softplus(x).astype(x.dtype) + return ivy.softplus(x, complex_mode="jax").astype(x.dtype) @to_ivy_arrays_and_back diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index 67d2dcf838666..f8ff348d0ae71 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -498,6 +498,46 @@ def softmax( return current_backend(x).softmax(x, axis=axis, out=out) +def _wrap_between(y, a): + """Wrap y between [-a, a]""" + a = ivy.array(a, dtype=y.dtype) + a2 = ivy.array(2 * a, dtype=y.dtype) + zero = ivy.array(0, dtype=y.dtype) + rem = ivy.remainder(ivy.add(y, a), a2) + rem = ivy.where(rem < zero, rem + a2, rem) - a + return rem + + +def _softplus_jax_like( + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + fn_original=None, + beta: Optional[Union[int, float]] = None, + threshold: Optional[Union[int, float]] = None, + out: Optional[ivy.Array] = None, +): + if beta is not None: + x_beta = ivy.multiply(x, ivy.array(beta, dtype=x.dtype)) + else: + x_beta = x + amax = ivy.relu(x_beta) + res = ivy.subtract(x_beta, ivy.multiply(amax, ivy.array(2, dtype=x.dtype))) + res = ivy.add(amax, ivy.log(ivy.add(1, ivy.exp(res)))) + res = ivy.real(res) + _wrap_between(ivy.imag(res), ivy.pi).astype( + x.dtype + ) * ivy.astype(1j, x.dtype) + if beta is not None: + res = ivy.divide(res, ivy.array(beta, dtype=x.dtype)) + if threshold is not None: + res = ivy.where( + ivy.real(x_beta) < threshold, + res, + x, + ).astype(x.dtype) + return res + + @handle_exceptions @handle_backend_invalid @handle_nestable @@ -506,6 +546,7 @@ def softmax( @to_native_arrays_and_back @handle_array_function @handle_device_shifting +@handle_complex_input def softplus( x: Union[ivy.Array, ivy.NativeArray], /, @@ -513,10 +554,17 @@ def softplus( beta: Optional[Union[int, float]] = None, threshold: Optional[Union[int, float]] = None, out: Optional[ivy.Array] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ Apply the softplus function element-wise. + If the input is complex, then by default we apply the softplus operation + `log(1+ exp(x))` to each element + If threshold is set we check if either its real part is strictly negative or + if its real part is zero and its imaginary part is negative then we apply + `input×β > threshold`. + Parameters ---------- x @@ -524,10 +572,14 @@ def softplus( beta The beta value for the softplus formation. Default: ``None``. threshold - values above this revert to a linear function. Default: ``None``. + values above this revert to a linear function + If the input is complex, only its real part is considered. Default: ``None`` out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -557,6 +609,9 @@ def softplus( return current_backend(x).softplus(x, beta=beta, threshold=threshold, out=out) +softplus.jax_like = _softplus_jax_like + + @handle_exceptions @handle_backend_invalid @handle_nestable diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index 34b77168b11bb..c7d69200a4be9 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -327,10 +327,10 @@ def test_jax_softmax( @handle_frontend_test( fn_tree="jax.nn.softplus", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_integer"), - large_abs_safety_factor=2, - small_abs_safety_factor=2, - safety_factor_scale="linear", + available_dtypes=helpers.get_dtypes("numeric"), + large_abs_safety_factor=4, + small_abs_safety_factor=4, + safety_factor_scale="log", ), test_with_out=st.just(False), ) diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py index 0cc9653edeeaa..07e0b6dab339f 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py @@ -146,10 +146,10 @@ def test_softmax(*, dtype_and_x, axis, test_flags, backend_fw, fn_name, on_devic @handle_test( fn_tree="functional.ivy.softplus", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("numeric"), min_num_dims=1, - large_abs_safety_factor=8, - small_abs_safety_factor=8, + large_abs_safety_factor=4, + small_abs_safety_factor=4, safety_factor_scale="log", ), beta=st.one_of(helpers.number(min_value=0.1, max_value=10), st.none()),