From 6b90a8a477832ace543971ef14f3b236db44a444 Mon Sep 17 00:00:00 2001 From: mosesdaudu001 Date: Mon, 7 Aug 2023 17:09:52 +0100 Subject: [PATCH 01/38] added support for sigmoid activation fucntion --- ivy/data_classes/array/activations.py | 10 ++++++++-- ivy/data_classes/container/activations.py | 4 ++++ .../jax/nn/non_linear_activations.py | 2 +- ivy/functional/ivy/activations.py | 15 +++++++++++++- ivy/stateful/activations.py | 20 +++++++++++++++---- .../test_nn/test_non_linear_activations.py | 2 +- .../test_nn/test_activations.py | 2 +- .../test_stateful/test_activations.py | 2 +- 8 files changed, 46 insertions(+), 11 deletions(-) diff --git a/ivy/data_classes/array/activations.py b/ivy/data_classes/array/activations.py index d7ffc3efb24d8..46e677abb81a6 100644 --- a/ivy/data_classes/array/activations.py +++ b/ivy/data_classes/array/activations.py @@ -115,7 +115,13 @@ def gelu( """ return ivy.gelu(self._data, approximate=approximate, out=out) - def sigmoid(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: + def sigmoid( + self: ivy.Array, + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, + ) -> ivy.Array: """ ivy.Array instance method variant of ivy.sigmoid. @@ -143,7 +149,7 @@ def sigmoid(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array >>> print(y) ivy.array([0.269, 0.731, 0.881]) """ - return ivy.sigmoid(self._data, out=out) + return ivy.sigmoid(self._data, complex_mode=complex_mode, out=out) def softmax( self: ivy.Array, diff --git a/ivy/data_classes/container/activations.py b/ivy/data_classes/container/activations.py index dabd3be451f8e..cf54e76d11ad5 100644 --- a/ivy/data_classes/container/activations.py +++ b/ivy/data_classes/container/activations.py @@ -456,6 +456,7 @@ def sigmoid( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -478,6 +479,8 @@ def sigmoid( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. @@ -503,6 +506,7 @@ def sigmoid( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, out=out, ) diff --git a/ivy/functional/frontends/jax/nn/non_linear_activations.py b/ivy/functional/frontends/jax/nn/non_linear_activations.py index c94e6a69be361..305ad0c2d3acd 100644 --- a/ivy/functional/frontends/jax/nn/non_linear_activations.py +++ b/ivy/functional/frontends/jax/nn/non_linear_activations.py @@ -262,7 +262,7 @@ def relu6(x): @to_ivy_arrays_and_back def sigmoid(x): x = _type_conversion(x) - ret = ivy.sigmoid(x) + ret = ivy.sigmoid(x, complex_mode="jax") return ivy.astype(ret, x.dtype) diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index 0bfe93a8d5538..fec386717045a 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -312,16 +312,29 @@ def relu( @to_native_arrays_and_back @handle_array_function @handle_device_shifting +@handle_complex_input def sigmoid( - x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply the sigmoid function element-wise. + If the input is complex, then by default each element is scaled by `alpha` if + either its real part is strictly negative or if its real part is zero and its + imaginary part is negative. This behaviour can be changed by specifying a different + `complex_mode`. + Parameters ---------- x input array. + complex_mode + optional specifier for how to handle complex data types. See + `ivy.func_wrapper.handle_complex_input` for more detail. out optional output array, for writing the result to. It must have a shape that the input broadcast to. diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index 0958770452912..04078408c9587 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -233,24 +233,36 @@ def _forward(self, x): class Sigmoid(Module): - def __init__(self): - """Apply the SIGMOID activation function.""" + def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): + """ + Apply the SIGMOID activation function. + + Parameter + ---------- + complex_mode + Specifies how to handle complex input. + """ + self._complex_mode = complex_mode Module.__init__(self) - def _forward(self, x): + def _forward(self, x, complex_mode=None): """ Parameters ---------- x Inputs to process *[batch_shape, d]*. + complex_mode + Specifies how to handle complex input. Returns ------- ret The outputs following the SIGMOID activation *[batch_shape, d]* """ - return ivy.sigmoid(x) + return ivy.sigmoid( + x, complex_mode=ivy.default(complex_mode, self._complex_mode) + ) class Tanh(Module): diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index 0ac715d8cc786..595d2387ffa91 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -205,7 +205,7 @@ def test_jax_gelu( @handle_frontend_test( fn_tree="jax.nn.sigmoid", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=2, small_abs_safety_factor=2, safety_factor_scale="linear", diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py index 4a9f8c69b6249..27c4b895cce22 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py @@ -88,7 +88,7 @@ def test_gelu(*, dtype_and_x, approximate, test_flags, backend_fw, fn_name, on_d @handle_test( fn_tree="functional.ivy.sigmoid", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", diff --git a/ivy_tests/test_ivy/test_stateful/test_activations.py b/ivy_tests/test_ivy/test_stateful/test_activations.py index 1e7e1c7a49ead..040295e87a74b 100644 --- a/ivy_tests/test_ivy/test_stateful/test_activations.py +++ b/ivy_tests/test_ivy/test_stateful/test_activations.py @@ -416,7 +416,7 @@ def test_silu( @handle_method( method_tree="stateful.activations.Sigmoid.__call__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", From a9879a1a9482770d05163e1a0ec94684fb6ada38 Mon Sep 17 00:00:00 2001 From: mosesdaudu001 Date: Tue, 8 Aug 2023 12:38:58 +0100 Subject: [PATCH 02/38] applied complex_input decorator to hardswish activation fucntion --- ivy/data_classes/array/activations.py | 14 +++++++++++-- ivy/data_classes/container/activations.py | 4 ++++ .../jax/nn/non_linear_activations.py | 2 +- ivy/functional/ivy/activations.py | 15 +++++++++++++- ivy/stateful/activations.py | 20 +++++++++++++++---- .../test_nn/test_non_linear_activations.py | 2 +- .../test_nn/test_activations.py | 2 +- .../test_stateful/test_activations.py | 2 +- 8 files changed, 50 insertions(+), 11 deletions(-) diff --git a/ivy/data_classes/array/activations.py b/ivy/data_classes/array/activations.py index 46e677abb81a6..443ce143d3b8b 100644 --- a/ivy/data_classes/array/activations.py +++ b/ivy/data_classes/array/activations.py @@ -132,6 +132,8 @@ def sigmoid( ---------- self Input array + complex_mode + optional specifier for how to handle complex data types. out optional output array for writing the result to. It must have the same shape the input broadcast to default: None @@ -298,7 +300,13 @@ def mish(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: """ return ivy.mish(self._data, out=out) - def hardswish(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: + def hardswish( + self: ivy.Array, + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, + ) -> ivy.Array: """ Apply the hardswish activation function element-wise. @@ -306,6 +314,8 @@ def hardswish(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Arr ---------- x input array + complex_mode + optional specifier for how to handle complex data types. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. @@ -334,4 +344,4 @@ def hardswish(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Arr b: ivy.array([0., 5.]) } """ - return ivy.hardswish(self._data, out=out) + return ivy.hardswish(self._data, complex_mode=complex_mode, out=out) diff --git a/ivy/data_classes/container/activations.py b/ivy/data_classes/container/activations.py index cf54e76d11ad5..840e5b33e35db 100644 --- a/ivy/data_classes/container/activations.py +++ b/ivy/data_classes/container/activations.py @@ -1112,6 +1112,7 @@ def hardswish( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -1134,6 +1135,8 @@ def hardswish( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. @@ -1160,5 +1163,6 @@ def hardswish( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, out=out, ) diff --git a/ivy/functional/frontends/jax/nn/non_linear_activations.py b/ivy/functional/frontends/jax/nn/non_linear_activations.py index 305ad0c2d3acd..aadc3196d7cdf 100644 --- a/ivy/functional/frontends/jax/nn/non_linear_activations.py +++ b/ivy/functional/frontends/jax/nn/non_linear_activations.py @@ -144,7 +144,7 @@ def glu(x, axis=-1): @to_ivy_arrays_and_back def hard_swish(x): - res = (x * ivy.minimum(ivy.maximum(x + 3, 0.0), 6.0)) / 6 + res = ivy.hardswish(x, complex_mode="jax") return ivy.asarray(res, dtype=x.dtype) diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index fec386717045a..3f2faaf1e731e 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -565,16 +565,29 @@ def mish( @handle_out_argument @to_native_arrays_and_back @handle_array_function +@handle_complex_input def hardswish( - x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply the hardswish activation function element-wise. + If the input is complex, then by default each element is scaled by `alpha` if + either its real part is strictly negative or if its real part is zero and its + imaginary part is negative. This behaviour can be changed by specifying a different + `complex_mode`. + Parameters ---------- x input array + complex_mode + optional specifier for how to handle complex data types. See + `ivy.func_wrapper.handle_complex_input` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index 04078408c9587..356393943d98a 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -308,24 +308,36 @@ def _forward(self, x): class Hardswish(Module): - def __init__(self): - """Apply the HARDSWISH activation function.""" + def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): + """ + Apply the HARDSWISH activation function. + + Parameter + ---------- + complex_mode + Specifies how to handle complex input. + """ + self._complex_mode = complex_mode Module.__init__(self) - def _forward(self, x): + def _forward(self, x, complex_mode=None): """ Parameters ---------- x Inputs to process *[batch_shape, d]*. + complex_mode + Specifies how to handle complex input. Returns ------- ret The outputs following the HARDSWISH activation *[batch_shape, d]* """ - return ivy.hardswish(x) + return ivy.hardswish( + x, complex_mode=ivy.default(complex_mode, self._complex_mode) + ) class Logit(Module): diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index 595d2387ffa91..ee5d5dc5b2311 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -694,7 +694,7 @@ def test_jax_swish( @handle_frontend_test( fn_tree="jax.nn.hard_swish", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), min_value=-10, max_value=10, safety_factor_scale="linear", diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py index 27c4b895cce22..c431c956ab7a1 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py @@ -226,7 +226,7 @@ def test_mish(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.hardswish", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", diff --git a/ivy_tests/test_ivy/test_stateful/test_activations.py b/ivy_tests/test_ivy/test_stateful/test_activations.py index 040295e87a74b..0368cdc9e3698 100644 --- a/ivy_tests/test_ivy/test_stateful/test_activations.py +++ b/ivy_tests/test_ivy/test_stateful/test_activations.py @@ -542,7 +542,7 @@ def test_relu6( @handle_method( method_tree="stateful.activations.Hardswish.__call__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", From 16c42d48052fd321893e8f0e10ebcef91c723d7a Mon Sep 17 00:00:00 2001 From: mosesdaudu001 Date: Tue, 8 Aug 2023 16:15:11 +0100 Subject: [PATCH 03/38] added argument to static sigmoid and hardswish container funcs --- ivy/data_classes/container/activations.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ivy/data_classes/container/activations.py b/ivy/data_classes/container/activations.py index 840e5b33e35db..3d7b791c741b0 100644 --- a/ivy/data_classes/container/activations.py +++ b/ivy/data_classes/container/activations.py @@ -397,6 +397,7 @@ def _static_sigmoid( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -419,6 +420,8 @@ def _static_sigmoid( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. @@ -445,6 +448,7 @@ def _static_sigmoid( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, out=out, ) @@ -1052,6 +1056,7 @@ def _static_hardswish( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -1074,6 +1079,8 @@ def _static_hardswish( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. @@ -1101,6 +1108,7 @@ def _static_hardswish( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, out=out, ) From a9a9a532a9be17a5b00ab1a267f54e447ff9ec64 Mon Sep 17 00:00:00 2001 From: mosesdaudu001 Date: Tue, 8 Aug 2023 17:50:42 +0100 Subject: [PATCH 04/38] applied complex_input decorator to silu activation fucntion --- .../array/experimental/activations.py | 14 ++++++++-- .../container/experimental/activations.py | 10 ++++++- .../jax/nn/non_linear_activations.py | 2 +- .../ivy/experimental/activations.py | 28 +++++++++++++++---- ivy/stateful/activations.py | 18 +++++++++--- .../test_nn/test_non_linear_activations.py | 2 +- .../test_nn/test_activations.py | 2 +- .../test_stateful/test_activations.py | 2 +- 8 files changed, 60 insertions(+), 18 deletions(-) diff --git a/ivy/data_classes/array/experimental/activations.py b/ivy/data_classes/array/experimental/activations.py index 6079e3718211d..35830cff405fd 100644 --- a/ivy/data_classes/array/experimental/activations.py +++ b/ivy/data_classes/array/experimental/activations.py @@ -1,6 +1,6 @@ # global import abc -from typing import Optional, Union +from typing import Optional, Union, Literal # local import ivy @@ -231,7 +231,13 @@ def selu(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: """ return ivy.selu(self._data, out=out) - def silu(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: + def silu( + self: ivy.Array, + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, + ) -> ivy.Array: """ ivy.Array instance method variant of ivy.silu. This method simply wraps the function, and so the docstring for ivy.silu also applies to this method with @@ -241,6 +247,8 @@ def silu(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: ---------- self input array. + complex_mode + optional specifier for how to handle complex data types. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. @@ -252,7 +260,7 @@ def silu(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: >>> print(y) ivy.array([-0.26894143, 0. , 0.73105854]) """ - return ivy.silu(self._data, out=out) + return ivy.silu(self._data, complex_mode=complex_mode, out=out) def elu( self, diff --git a/ivy/data_classes/container/experimental/activations.py b/ivy/data_classes/container/experimental/activations.py index e2a410e0d7d10..eeea95ca3312f 100644 --- a/ivy/data_classes/container/experimental/activations.py +++ b/ivy/data_classes/container/experimental/activations.py @@ -1,5 +1,5 @@ # global -from typing import Union, Optional, List, Dict +from typing import Union, Optional, List, Dict, Literal # local import ivy @@ -672,6 +672,7 @@ def _static_silu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -694,6 +695,8 @@ def _static_silu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. @@ -721,6 +724,7 @@ def _static_silu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, out=out, ) @@ -732,6 +736,7 @@ def silu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -754,6 +759,8 @@ def silu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. @@ -780,6 +787,7 @@ def silu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, out=out, ) diff --git a/ivy/functional/frontends/jax/nn/non_linear_activations.py b/ivy/functional/frontends/jax/nn/non_linear_activations.py index aadc3196d7cdf..6634ea994db81 100644 --- a/ivy/functional/frontends/jax/nn/non_linear_activations.py +++ b/ivy/functional/frontends/jax/nn/non_linear_activations.py @@ -273,7 +273,7 @@ def sigmoid(x): @to_ivy_arrays_and_back def silu(x): x = _type_conversion(x) - return ivy.multiply(x, ivy.sigmoid(x)) + return ivy.silu(x, complex_mode="jax") @to_ivy_arrays_and_back diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index 0e0e20f676c59..ddf839b7f168b 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -1,5 +1,5 @@ # global -from typing import Union, Optional +from typing import Union, Optional, Literal # local import ivy @@ -13,6 +13,7 @@ handle_out_argument, inputs_to_ivy_arrays, handle_device_shifting, + handle_complex_input, ) @@ -357,16 +358,29 @@ def selu( @to_native_arrays_and_back @handle_array_function @handle_device_shifting +@handle_complex_input def silu( - x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply the silu function element-wise. + If the input is complex, then by default each element is scaled by `alpha` if + either its real part is strictly negative or if its real part is zero and its + imaginary part is negative. This behaviour can be changed by specifying a different + `complex_mode`. + Parameters ---------- x input array. + complex_mode + optional specifier for how to handle complex data types. See + `ivy.func_wrapper.handle_complex_input` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. @@ -464,14 +478,16 @@ def sequence_length( x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None ) -> ivy.int64: """ - Produces a scalar (tensor of empty shape) containing the number of tensors in the - ivy array input. + Produce a scalar (tensor of empty shape) containing the number of tensors in the ivy + array input. Parameters ---------- x - Can be a sequence of any tensor type: bool, complex128, complex64, double, float, - float16, int16, int32, int64, int8, string, uint16, uint32, uint64, uint8 + Can be a sequence of any tensor type: + bool, complex128, complex64, double, float, + float16, int16, int32, int64, int8, string, + uint16, uint32, uint64, uint8 Returns ------- diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index 356393943d98a..467c81c9fbe74 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -212,24 +212,34 @@ def _forward(self, x): class SiLU(Module): - def __init__(self): - """Apply the SiLU activation function.""" + def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): + """ + Apply the SiLU activation function. + + Parameter + ---------- + complex_mode + Specifies how to handle complex input. + """ + self._complex_mode = complex_mode Module.__init__(self) - def _forward(self, x): + def _forward(self, x, complex_mode=None): """ Parameters ---------- x Inputs to process *[batch_shape, d]*. + complex_mode + Specifies how to handle complex input. Returns ------- ret The outputs following the SiLU activation *[batch_shape, d]* """ - return ivy.silu(x) + return ivy.silu(x, complex_mode=ivy.default(complex_mode, self._complex_mode)) class Sigmoid(Module): diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index ee5d5dc5b2311..ef31f054038b7 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -727,7 +727,7 @@ def test_jax_hard_swish( @handle_frontend_test( fn_tree="jax.nn.hard_silu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_integer"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=2, small_abs_safety_factor=2, ), diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py index 1d334a33689ca..83900e7c8f534 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py @@ -157,7 +157,7 @@ def test_selu(*, dtype_and_input, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.experimental.silu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", diff --git a/ivy_tests/test_ivy/test_stateful/test_activations.py b/ivy_tests/test_ivy/test_stateful/test_activations.py index 0368cdc9e3698..33e7fbbcd05a3 100644 --- a/ivy_tests/test_ivy/test_stateful/test_activations.py +++ b/ivy_tests/test_ivy/test_stateful/test_activations.py @@ -373,7 +373,7 @@ def test_mish( @handle_method( method_tree="stateful.activations.SiLU.__call__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", From 04e0002b4c582752b871e7a672c4e682d81393fb Mon Sep 17 00:00:00 2001 From: mosesdaudu001 Date: Wed, 9 Aug 2023 11:22:36 +0100 Subject: [PATCH 05/38] resolved conflicts --- ivy/functional/ivy/experimental/activations.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index ddf839b7f168b..3ce6801fb4a13 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -13,11 +13,13 @@ handle_out_argument, inputs_to_ivy_arrays, handle_device_shifting, + handle_backend_invalid, handle_complex_input, ) @handle_exceptions +@handle_backend_invalid @handle_nestable @handle_array_like_without_promotion @handle_out_argument @@ -125,6 +127,7 @@ def prelu( @handle_exceptions +@handle_backend_invalid @handle_nestable @handle_array_like_without_promotion @handle_out_argument @@ -185,6 +188,7 @@ def thresholded_relu( @handle_exceptions +@handle_backend_invalid @handle_nestable @handle_array_like_without_promotion @handle_out_argument @@ -243,6 +247,7 @@ def relu6( @handle_exceptions +@handle_backend_invalid @handle_nestable @handle_array_like_without_promotion @handle_out_argument @@ -293,6 +298,7 @@ def logsigmoid( @handle_exceptions +@handle_backend_invalid @handle_nestable @handle_array_like_without_promotion @handle_out_argument @@ -352,6 +358,7 @@ def selu( @handle_exceptions +@handle_backend_invalid @handle_nestable @handle_array_like_without_promotion @handle_out_argument @@ -414,6 +421,7 @@ def silu( @handle_exceptions +@handle_backend_invalid @handle_nestable @handle_array_like_without_promotion @handle_out_argument @@ -484,10 +492,9 @@ def sequence_length( Parameters ---------- x - Can be a sequence of any tensor type: - bool, complex128, complex64, double, float, - float16, int16, int32, int64, int8, string, - uint16, uint32, uint64, uint8 + Can be a sequence of any tensor type: bool, complex128, + complex64, double, float, float16, int16, int32, int64, + int8, string, uint16, uint32, uint64, uint8 Returns ------- From 63eeb4ccdb4bccefc6647d45488bd4a73ec56d40 Mon Sep 17 00:00:00 2001 From: mosesdaudu001 Date: Wed, 9 Aug 2023 12:32:08 +0100 Subject: [PATCH 06/38] applied complex_input decorator to selu activation fucntion --- .../array/experimental/activations.py | 12 ++++++++++-- .../container/experimental/activations.py | 8 ++++++++ .../frontends/jax/nn/non_linear_activations.py | 2 +- ivy/functional/ivy/experimental/activations.py | 15 ++++++++++++++- ivy/stateful/activations.py | 18 ++++++++++++++---- .../test_nn/test_non_linear_activations.py | 6 +++--- .../test_nn/test_activations.py | 2 +- .../test_ivy/test_stateful/test_activations.py | 2 +- 8 files changed, 52 insertions(+), 13 deletions(-) diff --git a/ivy/data_classes/array/experimental/activations.py b/ivy/data_classes/array/experimental/activations.py index 35830cff405fd..d052a4b49d3c2 100644 --- a/ivy/data_classes/array/experimental/activations.py +++ b/ivy/data_classes/array/experimental/activations.py @@ -194,7 +194,13 @@ def logsigmoid( """ return ivy.logsigmoid(self._data) - def selu(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: + def selu( + self, + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, + ) -> ivy.Array: """ Apply the scaled exponential linear unit function element-wise. @@ -202,6 +208,8 @@ def selu(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: ---------- self input array + complex_mode + optional specifier for how to handle complex data types. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. @@ -229,7 +237,7 @@ def selu(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: ivy.array([-1.11133075, 0., 1.05070102, 2.10140204, 3.15210295, 4.20280409, 5.25350523, 6.30420589, 7.35490704]) """ - return ivy.selu(self._data, out=out) + return ivy.selu(self._data, complex_mode=complex_mode, out=out) def silu( self: ivy.Array, diff --git a/ivy/data_classes/container/experimental/activations.py b/ivy/data_classes/container/experimental/activations.py index eeea95ca3312f..d1292a9c099a9 100644 --- a/ivy/data_classes/container/experimental/activations.py +++ b/ivy/data_classes/container/experimental/activations.py @@ -552,6 +552,7 @@ def static_selu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -574,6 +575,8 @@ def static_selu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. @@ -601,6 +604,7 @@ def static_selu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, out=out, ) @@ -612,6 +616,7 @@ def selu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -634,6 +639,8 @@ def selu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. @@ -660,6 +667,7 @@ def selu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, out=out, ) diff --git a/ivy/functional/frontends/jax/nn/non_linear_activations.py b/ivy/functional/frontends/jax/nn/non_linear_activations.py index 534166ed4cd6f..44babeb66c778 100644 --- a/ivy/functional/frontends/jax/nn/non_linear_activations.py +++ b/ivy/functional/frontends/jax/nn/non_linear_activations.py @@ -297,7 +297,7 @@ def softplus(x): @to_ivy_arrays_and_back def selu(x): x = _type_conversion_64(x) - return ivy.selu(x) + return ivy.selu(x, complex_mode="jax") @to_ivy_arrays_and_back diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index 3ce6801fb4a13..20044d97e2a9f 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -305,16 +305,29 @@ def logsigmoid( @to_native_arrays_and_back @handle_array_function @handle_device_shifting +@handle_complex_input def selu( - x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply the scaled exponential linear unit function element-wise. + If the input is complex, then by default each element is scaled by `alpha` if + either its real part is strictly negative or if its real part is zero and its + imaginary part is negative. This behaviour can be changed by specifying a different + `complex_mode`. + Parameters ---------- x input array + complex_mode + optional specifier for how to handle complex data types. See + `ivy.func_wrapper.handle_complex_input` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index 674e14da552be..510a52e7472ba 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -414,24 +414,34 @@ def _forward(self, x, slope): class SeLU(Module): - def __init__(self): - """Apply the SELU activation function.""" + def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): + """ + Apply the SELU activation function. + + Parameter + ---------- + complex_mode + Specifies how to handle complex input. + """ + self._complex_mode = complex_mode Module.__init__(self) - def _forward(self, x): + def _forward(self, x, complex_mode=None): """ Parameters ---------- x Inputs to process *[batch_shape, d]*. + complex_mode + Specifies how to handle complex input. Returns ------- ret The outputs following the SELU activation *[batch_shape, d]* """ - return ivy.selu(x) + return ivy.selu(x, complex_mode=ivy.default(complex_mode, self._complex_mode)) class ELU(Module): diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index adbda62108229..f0658d60db45b 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -102,7 +102,7 @@ def test_jax_soft_sign( @handle_frontend_test( fn_tree="jax.nn.silu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_integer"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=2, small_abs_safety_factor=2, safety_factor_scale="linear", @@ -732,7 +732,7 @@ def test_jax_hard_swish( @handle_frontend_test( fn_tree="jax.nn.hard_silu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), + available_dtypes=helpers.get_dtypes("float_and_integer"), large_abs_safety_factor=2, small_abs_safety_factor=2, ), @@ -792,7 +792,7 @@ def test_jax_hard_sigmoid( @handle_frontend_test( fn_tree="jax.nn.selu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_integer"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=2, small_abs_safety_factor=2, safety_factor_scale="log", diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py index 83900e7c8f534..c2ee9c460b72f 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py @@ -133,7 +133,7 @@ def test_logsigmoid(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.experimental.selu", dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float_and_complex"), safety_factor_scale="log", small_abs_safety_factor=20, ), diff --git a/ivy_tests/test_ivy/test_stateful/test_activations.py b/ivy_tests/test_ivy/test_stateful/test_activations.py index f0947c1b6c2c6..37f53b39e8cdf 100644 --- a/ivy_tests/test_ivy/test_stateful/test_activations.py +++ b/ivy_tests/test_ivy/test_stateful/test_activations.py @@ -674,7 +674,7 @@ def test_prelu( @handle_method( method_tree="stateful.activations.SeLU.__call__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", From 0636cfbf9a980722a91fc098f6fac8aaa84f2f92 Mon Sep 17 00:00:00 2001 From: mosesdaudu001 Date: Wed, 9 Aug 2023 12:57:55 +0100 Subject: [PATCH 07/38] applied complex_input decorator to logsigmoid activation function --- .../array/experimental/activations.py | 5 ++++- .../container/experimental/activations.py | 8 ++++++++ .../jax/nn/non_linear_activations.py | 2 +- .../ivy/experimental/activations.py | 10 +++++++++- ivy/stateful/activations.py | 20 +++++++++++++++---- .../test_nn/test_non_linear_activations.py | 2 +- .../test_nn/test_activations.py | 2 +- .../test_stateful/test_activations.py | 2 +- 8 files changed, 41 insertions(+), 10 deletions(-) diff --git a/ivy/data_classes/array/experimental/activations.py b/ivy/data_classes/array/experimental/activations.py index d052a4b49d3c2..24d22422ca694 100644 --- a/ivy/data_classes/array/experimental/activations.py +++ b/ivy/data_classes/array/experimental/activations.py @@ -165,6 +165,7 @@ def relu6(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: def logsigmoid( self: ivy.Array, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ ivy.Array instance method variant of ivy.logsigmoid. This method simply wraps @@ -175,6 +176,8 @@ def logsigmoid( ---------- self Input array. + complex_mode + optional specifier for how to handle complex data types. Returns ------- @@ -192,7 +195,7 @@ def logsigmoid( >>> print(z) ivy.array([-2.57888985, -0.31326169, -0.69314718, -0.01104775]) """ - return ivy.logsigmoid(self._data) + return ivy.logsigmoid(self._data, complex_mode=complex_mode) def selu( self, diff --git a/ivy/data_classes/container/experimental/activations.py b/ivy/data_classes/container/experimental/activations.py index d1292a9c099a9..03ec34e2d6e98 100644 --- a/ivy/data_classes/container/experimental/activations.py +++ b/ivy/data_classes/container/experimental/activations.py @@ -446,6 +446,7 @@ def static_logsigmoid( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ ivy.Container static method variant of ivy.logsigmoid. This method simply wraps @@ -467,6 +468,8 @@ def static_logsigmoid( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. Returns ------- @@ -501,6 +504,7 @@ def static_logsigmoid( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, ) def logsigmoid( @@ -511,6 +515,7 @@ def logsigmoid( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ Apply element-wise Log-sigmoid of x i.e. log(1 / (1 + exp(-x)). @@ -519,6 +524,8 @@ def logsigmoid( ---------- self Input container. + complex_mode + optional specifier for how to handle complex data types. Returns ------- @@ -541,6 +548,7 @@ def logsigmoid( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, ) @staticmethod diff --git a/ivy/functional/frontends/jax/nn/non_linear_activations.py b/ivy/functional/frontends/jax/nn/non_linear_activations.py index 44babeb66c778..3b05d290b202b 100644 --- a/ivy/functional/frontends/jax/nn/non_linear_activations.py +++ b/ivy/functional/frontends/jax/nn/non_linear_activations.py @@ -168,7 +168,7 @@ def leaky_relu(x, negative_slope=0.01): @to_ivy_arrays_and_back def log_sigmoid(x): x = _type_conversion(x) - return ivy.negative(ivy.softplus(ivy.negative(x))).astype(x.dtype) + return ivy.logsigmoid(x, complex_mode="jax").astype(x.dtype) @to_ivy_arrays_and_back diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index 20044d97e2a9f..6356fdf991de2 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -253,8 +253,13 @@ def relu6( @handle_out_argument @to_native_arrays_and_back @handle_device_shifting +@handle_complex_input def logsigmoid( - input: Union[ivy.NativeArray, ivy.Array], /, *, out: Optional[ivy.Array] = None + input: Union[ivy.NativeArray, ivy.Array], + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply element-wise Log-sigmoid of x. @@ -265,6 +270,9 @@ def logsigmoid( ---------- input Input array. + complex_mode + optional specifier for how to handle complex data types. See + `ivy.func_wrapper.handle_complex_input` for more detail. Returns ------- diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index 510a52e7472ba..6f4abaf512884 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -466,21 +466,33 @@ def _forward(self, x, alpha=1.0): class LogSigmoid(Module): - def __init__(self): - """Apply the LogSigmoid activation function.""" + def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): + """ + Apply the LogSigmoid activation function. + + Parameter + ---------- + complex_mode + Specifies how to handle complex input. + """ + self._complex_mode = complex_mode Module.__init__(self) - def _forward(self, x): + def _forward(self, x, complex_mode=None): """ Parameters ---------- x Inputs to process *[batch_shape, d]*. + complex_mode + Specifies how to handle complex input. Returns ------- ret The outputs following the LogSigmoid activation *[batch_shape, d]* """ - return ivy.logsigmoid(x) + return ivy.logsigmoid( + x, complex_mode=ivy.default(complex_mode, self._complex_mode) + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index f0658d60db45b..2778031757982 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -358,7 +358,7 @@ def test_jax_softplus( @handle_frontend_test( fn_tree="jax.nn.log_sigmoid", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), min_value=-100, max_value=100, large_abs_safety_factor=8, diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py index c2ee9c460b72f..5f6769fa71657 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py @@ -110,7 +110,7 @@ def test_relu6(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.experimental.logsigmoid", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("float_and_complex"), safety_factor_scale="log", large_abs_safety_factor=120, ), diff --git a/ivy_tests/test_ivy/test_stateful/test_activations.py b/ivy_tests/test_ivy/test_stateful/test_activations.py index 37f53b39e8cdf..b0d2a593538d3 100644 --- a/ivy_tests/test_ivy/test_stateful/test_activations.py +++ b/ivy_tests/test_ivy/test_stateful/test_activations.py @@ -762,7 +762,7 @@ def test_elu( @handle_method( method_tree="stateful.activations.LogSigmoid.__call__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", From 01c3e414901befb8b78b8bfdb49331261d7c4185 Mon Sep 17 00:00:00 2001 From: mosesdaudu001 Date: Mon, 14 Aug 2023 17:19:06 +0100 Subject: [PATCH 08/38] applied complex_input decorator to elu activation function --- .../array/experimental/activations.py | 5 ++++- .../container/experimental/activations.py | 8 +++++++ .../jax/nn/non_linear_activations.py | 4 +--- ivy/functional/ivy/activations.py | 10 --------- .../ivy/experimental/activations.py | 15 +++++-------- ivy/stateful/activations.py | 21 +++++++++++++++---- .../test_nn/test_non_linear_activations.py | 2 +- .../test_nn/test_activations.py | 2 +- 8 files changed, 37 insertions(+), 30 deletions(-) diff --git a/ivy/data_classes/array/experimental/activations.py b/ivy/data_classes/array/experimental/activations.py index 24d22422ca694..b52edc35cc4ab 100644 --- a/ivy/data_classes/array/experimental/activations.py +++ b/ivy/data_classes/array/experimental/activations.py @@ -278,6 +278,7 @@ def elu( /, *, alpha: float = 1.0, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Array] = None, ) -> ivy.Array: """ @@ -291,6 +292,8 @@ def elu( input array. alpha scaler for controlling the slope of the function for x <= 0 Default: 1.0 + complex_mode + optional specifier for how to handle complex data types. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. @@ -307,4 +310,4 @@ def elu( >>> print(y) ivy.array([ 0.39, -0.57]) """ - return ivy.elu(self._data, alpha=alpha, out=out) + return ivy.elu(self._data, alpha=alpha, complex_mode=complex_mode, out=out) diff --git a/ivy/data_classes/container/experimental/activations.py b/ivy/data_classes/container/experimental/activations.py index 03ec34e2d6e98..0429103d855ce 100644 --- a/ivy/data_classes/container/experimental/activations.py +++ b/ivy/data_classes/container/experimental/activations.py @@ -817,6 +817,7 @@ def _static_elu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -841,6 +842,8 @@ def _static_elu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. @@ -868,6 +871,7 @@ def _static_elu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, out=out, ) @@ -880,6 +884,7 @@ def elu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -904,6 +909,8 @@ def elu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. @@ -930,5 +937,6 @@ def elu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, out=out, ) diff --git a/ivy/functional/frontends/jax/nn/non_linear_activations.py b/ivy/functional/frontends/jax/nn/non_linear_activations.py index 3b05d290b202b..448e5b8f1e742 100644 --- a/ivy/functional/frontends/jax/nn/non_linear_activations.py +++ b/ivy/functional/frontends/jax/nn/non_linear_activations.py @@ -122,9 +122,7 @@ def celu(x, alpha=1.0): @to_ivy_arrays_and_back def elu(x, alpha=1.0): - ret = ivy.where(x > 0, x, alpha * ivy.expm1(x)) - dtype = _batch_promotion(x, alpha, default_dtype="float64") - return ivy.asarray(ret, dtype=dtype) + return ivy.elu(x, alpha=alpha, complex_mode="jax") @to_ivy_arrays_and_back diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index ec3aed0d5fb27..d4b96aefeab36 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -353,11 +353,6 @@ def sigmoid( """ Apply the sigmoid function element-wise. - If the input is complex, then by default each element is scaled by `alpha` if - either its real part is strictly negative or if its real part is zero and its - imaginary part is negative. This behaviour can be changed by specifying a different - `complex_mode`. - Parameters ---------- x @@ -610,11 +605,6 @@ def hardswish( """ Apply the hardswish activation function element-wise. - If the input is complex, then by default each element is scaled by `alpha` if - either its real part is strictly negative or if its real part is zero and its - imaginary part is negative. This behaviour can be changed by specifying a different - `complex_mode`. - Parameters ---------- x diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index 6356fdf991de2..35dd34732e08d 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -324,11 +324,6 @@ def selu( """ Apply the scaled exponential linear unit function element-wise. - If the input is complex, then by default each element is scaled by `alpha` if - either its real part is strictly negative or if its real part is zero and its - imaginary part is negative. This behaviour can be changed by specifying a different - `complex_mode`. - Parameters ---------- x @@ -397,11 +392,6 @@ def silu( """ Apply the silu function element-wise. - If the input is complex, then by default each element is scaled by `alpha` if - either its real part is strictly negative or if its real part is zero and its - imaginary part is negative. This behaviour can be changed by specifying a different - `complex_mode`. - Parameters ---------- x @@ -448,11 +438,13 @@ def silu( @handle_out_argument @to_native_arrays_and_back @handle_array_function +@handle_complex_input def elu( x: Union[ivy.Array, ivy.NativeArray], /, *, alpha: float = 1.0, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Array] = None, ) -> ivy.Array: """ @@ -464,6 +456,9 @@ def elu( Input array. alpha scaler for controlling the slope of the function for x <= 0 Default: 1.0 + complex_mode + optional specifier for how to handle complex data types. See + `ivy.func_wrapper.handle_complex_input` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index 6f4abaf512884..1d2979d9436e5 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -445,11 +445,19 @@ def _forward(self, x, complex_mode=None): class ELU(Module): - def __init__(self): - """Apply the ELU activation function.""" + def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): + """ + Apply the ELU activation function. + + Parameter + ---------- + complex_mode + Specifies how to handle complex input. + """ + self._complex_mode = complex_mode Module.__init__(self) - def _forward(self, x, alpha=1.0): + def _forward(self, x, alpha=1.0, complex_mode=None): """ Parameters ---------- @@ -457,12 +465,17 @@ def _forward(self, x, alpha=1.0): Inputs to process *[batch_shape, d]*. alpha scaler for controlling the slope of the function for x <= 0 Default: 1.0 + complex_mode + Specifies how to handle complex input. + Returns ------- ret The outputs following the ELU activation *[batch_shape, d]* """ - return ivy.elu(x, alpha=alpha) + return ivy.elu( + x, alpha=alpha, complex_mode=ivy.default(complex_mode, self._complex_mode) + ) class LogSigmoid(Module): diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index 2778031757982..8039a69f0516e 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -589,7 +589,7 @@ def test_jax_celu( @handle_frontend_test( fn_tree="jax.nn.elu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_integer"), + available_dtypes=helpers.get_dtypes("float_and_complex"), min_value=-5, max_value=5, safety_factor_scale="linear", diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py index 5f6769fa71657..0560be85cbdd9 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py @@ -181,7 +181,7 @@ def test_silu(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.experimental.elu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", From 83a94aab272cf1875f7ef057e721e898c9efb198 Mon Sep 17 00:00:00 2001 From: mosesdaudu001 Date: Mon, 14 Aug 2023 17:34:23 +0100 Subject: [PATCH 09/38] applied complex_input decorator to relu6 activation function --- .../array/experimental/activations.py | 12 ++++++++++-- .../container/experimental/activations.py | 8 ++++++++ ivy/functional/ivy/experimental/activations.py | 10 +++++++++- ivy/stateful/activations.py | 18 ++++++++++++++---- .../test_nn/test_non_linear_activations.py | 2 +- .../test_nn/test_activations.py | 2 +- 6 files changed, 43 insertions(+), 9 deletions(-) diff --git a/ivy/data_classes/array/experimental/activations.py b/ivy/data_classes/array/experimental/activations.py index b52edc35cc4ab..f431112eaf8b3 100644 --- a/ivy/data_classes/array/experimental/activations.py +++ b/ivy/data_classes/array/experimental/activations.py @@ -115,7 +115,13 @@ def prelu( """ return ivy.prelu(self._data, slope, out=out) - def relu6(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: + def relu6( + self, + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, + ) -> ivy.Array: """ Apply the rectified linear unit 6 function element-wise. @@ -123,6 +129,8 @@ def relu6(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: ---------- self input array + complex_mode + optional specifier for how to handle complex data types. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. @@ -161,7 +169,7 @@ def relu6(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: b: ivy.array([1., 2., 3., 4., 5., 6., 6., 6., 6.]) } """ - return ivy.relu6(self._data, out=out) + return ivy.relu6(self._data, complex_mode=complex_mode, out=out) def logsigmoid( self: ivy.Array, diff --git a/ivy/data_classes/container/experimental/activations.py b/ivy/data_classes/container/experimental/activations.py index 0429103d855ce..c46d06c7328ab 100644 --- a/ivy/data_classes/container/experimental/activations.py +++ b/ivy/data_classes/container/experimental/activations.py @@ -320,6 +320,7 @@ def static_relu6( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -342,6 +343,8 @@ def static_relu6( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. @@ -372,6 +375,7 @@ def static_relu6( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, out=out, ) @@ -383,6 +387,7 @@ def relu6( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -405,6 +410,8 @@ def relu6( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. + complex_mode + optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. @@ -434,6 +441,7 @@ def relu6( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, + complex_mode=complex_mode, out=out, ) diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index 35dd34732e08d..ffd7f160c3222 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -195,8 +195,13 @@ def thresholded_relu( @to_native_arrays_and_back @handle_array_function @handle_device_shifting +@handle_complex_input def relu6( - x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply the rectified linear unit 6 function element-wise. @@ -205,6 +210,9 @@ def relu6( ---------- x input array + complex_mode + optional specifier for how to handle complex data types. See + `ivy.func_wrapper.handle_complex_input` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index 1d2979d9436e5..0d7497173a29d 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -314,24 +314,34 @@ def _forward(self, x): class ReLU6(Module): - def __init__(self): - """Apply the RELU6 activation function.""" + def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): + """ + Apply the RELU6 activation function. + + Parameter + ---------- + complex_mode + Specifies how to handle complex input. + """ + self._complex_mode = complex_mode Module.__init__(self) - def _forward(self, x): + def _forward(self, x, complex_mode=None): """ Parameters ---------- x Inputs to process *[batch_shape, d]*. + complex_mode + Specifies how to handle complex input. Returns ------- ret The outputs following the RELU6 activation *[batch_shape, d]* """ - return ivy.relu6(x) + return ivy.relu6(x, complex_mode=ivy.default(complex_mode, self._complex_mode)) class Hardswish(Module): diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index 8039a69f0516e..01b457a5c7be1 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -40,7 +40,7 @@ def test_jax_relu( @handle_frontend_test( fn_tree="jax.nn.relu6", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_integer"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=2, small_abs_safety_factor=2, safety_factor_scale="linear", diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py index 0560be85cbdd9..3632969cd3208 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py @@ -88,7 +88,7 @@ def test_prelu(*, dtype_and_x, slope, test_flags, backend_fw, fn_name, on_device @handle_test( fn_tree="functional.ivy.experimental.relu6", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=2, small_abs_safety_factor=2, safety_factor_scale="log", From 54eb859c6fb2a876b0b7d7cc5c3455f0f81f284e Mon Sep 17 00:00:00 2001 From: mosesdaudu001 Date: Mon, 14 Aug 2023 17:43:20 +0100 Subject: [PATCH 10/38] added float_complex to stateful test --- ivy_tests/test_ivy/test_stateful/test_activations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy_tests/test_ivy/test_stateful/test_activations.py b/ivy_tests/test_ivy/test_stateful/test_activations.py index b0d2a593538d3..8747f0c384bce 100644 --- a/ivy_tests/test_ivy/test_stateful/test_activations.py +++ b/ivy_tests/test_ivy/test_stateful/test_activations.py @@ -500,7 +500,7 @@ def test_tanh( @handle_method( method_tree="stateful.activations.ReLU6.__call__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", @@ -630,7 +630,7 @@ def test_logit( @handle_method( method_tree="stateful.activations.PReLU.__call__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), num_arrays=2, shared_dtype=True, min_num_dims=2, From cfe6cdcc6264526cb72a2a508a32bc87e36e3709 Mon Sep 17 00:00:00 2001 From: mosesdaudu001 Date: Tue, 15 Aug 2023 15:12:35 +0100 Subject: [PATCH 11/38] made changes as suggested by joe --- .../frontends/jax/nn/non_linear_activations.py | 2 +- .../test_nn/test_non_linear_activations.py | 14 +++++++------- .../test_experimental/test_nn/test_activations.py | 6 +++--- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/ivy/functional/frontends/jax/nn/non_linear_activations.py b/ivy/functional/frontends/jax/nn/non_linear_activations.py index 448e5b8f1e742..eba550245f703 100644 --- a/ivy/functional/frontends/jax/nn/non_linear_activations.py +++ b/ivy/functional/frontends/jax/nn/non_linear_activations.py @@ -253,7 +253,7 @@ def relu(x): @to_ivy_arrays_and_back def relu6(x): - res = ivy.relu6(x) + res = ivy.relu6(x, complex_mode="jax") return _type_conversion_64(res) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index 01b457a5c7be1..19cd6e8724912 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -40,7 +40,7 @@ def test_jax_relu( @handle_frontend_test( fn_tree="jax.nn.relu6", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), + available_dtypes=helpers.get_dtypes("numeric"), large_abs_safety_factor=2, small_abs_safety_factor=2, safety_factor_scale="linear", @@ -102,7 +102,7 @@ def test_jax_soft_sign( @handle_frontend_test( fn_tree="jax.nn.silu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), + available_dtypes=helpers.get_dtypes("numeric"), large_abs_safety_factor=2, small_abs_safety_factor=2, safety_factor_scale="linear", @@ -358,7 +358,7 @@ def test_jax_softplus( @handle_frontend_test( fn_tree="jax.nn.log_sigmoid", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), + available_dtypes=helpers.get_dtypes("numeric"), min_value=-100, max_value=100, large_abs_safety_factor=8, @@ -589,7 +589,7 @@ def test_jax_celu( @handle_frontend_test( fn_tree="jax.nn.elu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), + available_dtypes=helpers.get_dtypes("numeric"), min_value=-5, max_value=5, safety_factor_scale="linear", @@ -699,7 +699,7 @@ def test_jax_swish( @handle_frontend_test( fn_tree="jax.nn.hard_swish", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), + available_dtypes=helpers.get_dtypes("numeric"), min_value=-10, max_value=10, safety_factor_scale="linear", @@ -762,7 +762,7 @@ def test_jax_hard_silu( @handle_frontend_test( fn_tree="jax.nn.hard_sigmoid", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_integer"), + available_dtypes=helpers.get_dtypes("numeric"), large_abs_safety_factor=2, small_abs_safety_factor=2, ), @@ -792,7 +792,7 @@ def test_jax_hard_sigmoid( @handle_frontend_test( fn_tree="jax.nn.selu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), + available_dtypes=helpers.get_dtypes("numeric"), large_abs_safety_factor=2, small_abs_safety_factor=2, safety_factor_scale="log", diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py index 3632969cd3208..a136b343f2ca8 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py @@ -88,7 +88,7 @@ def test_prelu(*, dtype_and_x, slope, test_flags, backend_fw, fn_name, on_device @handle_test( fn_tree="functional.ivy.experimental.relu6", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), + available_dtypes=helpers.get_dtypes("numeric"), large_abs_safety_factor=2, small_abs_safety_factor=2, safety_factor_scale="log", @@ -110,7 +110,7 @@ def test_relu6(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.experimental.logsigmoid", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), + available_dtypes=helpers.get_dtypes("valid"), safety_factor_scale="log", large_abs_safety_factor=120, ), @@ -133,7 +133,7 @@ def test_logsigmoid(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.experimental.selu", dtype_and_input=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), + available_dtypes=helpers.get_dtypes("valid"), safety_factor_scale="log", small_abs_safety_factor=20, ), From 1c00385ad7eb99b70385fa8b05a591932bd6f827 Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Fri, 18 Aug 2023 09:43:17 +0000 Subject: [PATCH 12/38] fixed testing issues for sigmoid --- .../backends/numpy/experimental/activations.py | 9 +++++++++ ivy/functional/backends/paddle/activations.py | 4 ++++ ivy/functional/backends/tensorflow/activations.py | 1 + ivy/stateful/activations.py | 8 ++------ 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/ivy/functional/backends/numpy/experimental/activations.py b/ivy/functional/backends/numpy/experimental/activations.py index d3f8282a22de2..8b9383e0630be 100644 --- a/ivy/functional/backends/numpy/experimental/activations.py +++ b/ivy/functional/backends/numpy/experimental/activations.py @@ -83,6 +83,15 @@ def silu(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: silu.support_native_out = True +@with_unsupported_dtypes( + { + "1.25.2 and below": ( + "complex64", + "complex128", + ) + }, + backend_version, +) @_scalar_output_to_0d_array def elu( x: np.ndarray, /, *, alpha: float = 1.0, out: Optional[np.ndarray] = None diff --git a/ivy/functional/backends/paddle/activations.py b/ivy/functional/backends/paddle/activations.py index 185218223ce33..7840bd8eec25b 100644 --- a/ivy/functional/backends/paddle/activations.py +++ b/ivy/functional/backends/paddle/activations.py @@ -84,6 +84,10 @@ def gelu( return F.gelu(x, approximate=approximate) +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("bfloat16", "complex128", "complex64")}}, + backend_version, +) def sigmoid( x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None ) -> paddle.Tensor: diff --git a/ivy/functional/backends/tensorflow/activations.py b/ivy/functional/backends/tensorflow/activations.py index 1a37c65541f11..fb6190c5829ec 100644 --- a/ivy/functional/backends/tensorflow/activations.py +++ b/ivy/functional/backends/tensorflow/activations.py @@ -34,6 +34,7 @@ def relu(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: return tf.nn.relu(x) +@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) def sigmoid(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: if not ivy.is_array(x): x = float(x) diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index 0d7497173a29d..5deca2487f7d3 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -272,24 +272,20 @@ def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): self._complex_mode = complex_mode Module.__init__(self) - def _forward(self, x, complex_mode=None): + def _forward(self, x): """ Parameters ---------- x Inputs to process *[batch_shape, d]*. - complex_mode - Specifies how to handle complex input. Returns ------- ret The outputs following the SIGMOID activation *[batch_shape, d]* """ - return ivy.sigmoid( - x, complex_mode=ivy.default(complex_mode, self._complex_mode) - ) + return ivy.sigmoid(x) class Tanh(Module): From d4d70b1f2c6db8446e59a923bd70c50360b9cecd Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Fri, 18 Aug 2023 11:55:37 +0100 Subject: [PATCH 13/38] Bring docstrings and call signatures in line with standards --- ivy/data_classes/array/activations.py | 18 +++-- .../array/experimental/activations.py | 39 ++++++---- ivy/data_classes/container/activations.py | 36 +++++---- .../container/experimental/activations.py | 78 +++++++++++-------- ivy/functional/backends/jax/activations.py | 8 +- .../backends/jax/experimental/activations.py | 23 ++++-- ivy/functional/backends/mxnet/activations.py | 2 +- .../mxnet/experimental/activations.py | 8 +- ivy/functional/backends/numpy/activations.py | 8 +- .../numpy/experimental/activations.py | 23 ++++-- ivy/functional/backends/paddle/activations.py | 4 +- .../paddle/experimental/activations.py | 21 +++-- .../backends/tensorflow/activations.py | 8 +- .../tensorflow/experimental/activations.py | 24 +++--- ivy/functional/backends/torch/activations.py | 6 +- .../torch/experimental/activations.py | 21 +++-- ivy/functional/ivy/activations.py | 16 ++-- .../ivy/experimental/activations.py | 34 ++++---- ivy/stateful/activations.py | 63 +++++++-------- 19 files changed, 258 insertions(+), 182 deletions(-) diff --git a/ivy/data_classes/array/activations.py b/ivy/data_classes/array/activations.py index 443ce143d3b8b..d0c4fc4090556 100644 --- a/ivy/data_classes/array/activations.py +++ b/ivy/data_classes/array/activations.py @@ -119,8 +119,8 @@ def sigmoid( self: ivy.Array, /, *, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Array] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ ivy.Array instance method variant of ivy.sigmoid. @@ -132,11 +132,12 @@ def sigmoid( ---------- self Input array - complex_mode - optional specifier for how to handle complex data types. out optional output array for writing the result to. It must have the same shape the input broadcast to default: None + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -151,7 +152,7 @@ def sigmoid( >>> print(y) ivy.array([0.269, 0.731, 0.881]) """ - return ivy.sigmoid(self._data, complex_mode=complex_mode, out=out) + return ivy.sigmoid(self._data, out=out, complex_mode=complex_mode) def softmax( self: ivy.Array, @@ -304,8 +305,8 @@ def hardswish( self: ivy.Array, /, *, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Array] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ Apply the hardswish activation function element-wise. @@ -314,11 +315,12 @@ def hardswish( ---------- x input array - complex_mode - optional specifier for how to handle complex data types. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -344,4 +346,4 @@ def hardswish( b: ivy.array([0., 5.]) } """ - return ivy.hardswish(self._data, complex_mode=complex_mode, out=out) + return ivy.hardswish(self._data, out=out, complex_mode=complex_mode) diff --git a/ivy/data_classes/array/experimental/activations.py b/ivy/data_classes/array/experimental/activations.py index f431112eaf8b3..da88e5c1d48fb 100644 --- a/ivy/data_classes/array/experimental/activations.py +++ b/ivy/data_classes/array/experimental/activations.py @@ -119,8 +119,8 @@ def relu6( self, /, *, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Array] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ Apply the rectified linear unit 6 function element-wise. @@ -129,11 +129,12 @@ def relu6( ---------- self input array - complex_mode - optional specifier for how to handle complex data types. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -169,7 +170,7 @@ def relu6( b: ivy.array([1., 2., 3., 4., 5., 6., 6., 6., 6.]) } """ - return ivy.relu6(self._data, complex_mode=complex_mode, out=out) + return ivy.relu6(self._data, out=out, complex_mode=complex_mode) def logsigmoid( self: ivy.Array, @@ -185,7 +186,8 @@ def logsigmoid( self Input array. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -209,8 +211,8 @@ def selu( self, /, *, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Array] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ Apply the scaled exponential linear unit function element-wise. @@ -219,11 +221,12 @@ def selu( ---------- self input array - complex_mode - optional specifier for how to handle complex data types. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -248,14 +251,14 @@ def selu( ivy.array([-1.11133075, 0., 1.05070102, 2.10140204, 3.15210295, 4.20280409, 5.25350523, 6.30420589, 7.35490704]) """ - return ivy.selu(self._data, complex_mode=complex_mode, out=out) + return ivy.selu(self._data, out=out, complex_mode=complex_mode) def silu( self: ivy.Array, /, *, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Array] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ ivy.Array instance method variant of ivy.silu. This method simply wraps the @@ -266,11 +269,12 @@ def silu( ---------- self input array. - complex_mode - optional specifier for how to handle complex data types. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Examples -------- @@ -279,15 +283,15 @@ def silu( >>> print(y) ivy.array([-0.26894143, 0. , 0.73105854]) """ - return ivy.silu(self._data, complex_mode=complex_mode, out=out) + return ivy.silu(self._data, out=out, complex_mode=complex_mode) def elu( self, /, *, alpha: float = 1.0, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Array] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ Ivy.Array instance method variant of ivy.elu. This method simply wraps the @@ -300,11 +304,12 @@ def elu( input array. alpha scaler for controlling the slope of the function for x <= 0 Default: 1.0 - complex_mode - optional specifier for how to handle complex data types. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -318,4 +323,4 @@ def elu( >>> print(y) ivy.array([ 0.39, -0.57]) """ - return ivy.elu(self._data, alpha=alpha, complex_mode=complex_mode, out=out) + return ivy.elu(self._data, alpha=alpha, out=out, complex_mode=complex_mode) diff --git a/ivy/data_classes/container/activations.py b/ivy/data_classes/container/activations.py index 8a3f7dc683924..31bcd37b2b5da 100644 --- a/ivy/data_classes/container/activations.py +++ b/ivy/data_classes/container/activations.py @@ -405,8 +405,8 @@ def _static_sigmoid( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ ivy.Container static method variant of ivy.sigmoid. This method simply wraps the @@ -428,11 +428,12 @@ def _static_sigmoid( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - complex_mode - optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -456,8 +457,8 @@ def _static_sigmoid( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - complex_mode=complex_mode, out=out, + complex_mode=complex_mode, ) def sigmoid( @@ -468,8 +469,8 @@ def sigmoid( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ ivy.Container instance method variant of ivy.sigmoid. This method simply wraps @@ -491,11 +492,12 @@ def sigmoid( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - complex_mode - optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -518,8 +520,8 @@ def sigmoid( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - complex_mode=complex_mode, out=out, + complex_mode=complex_mode, ) @staticmethod @@ -1064,8 +1066,8 @@ def _static_hardswish( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ ivy.Container static method variant of ivy.hardswish. This method simply wraps @@ -1087,11 +1089,12 @@ def _static_hardswish( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - complex_mode - optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -1116,8 +1119,8 @@ def _static_hardswish( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - complex_mode=complex_mode, out=out, + complex_mode=complex_mode, ) def hardswish( @@ -1128,8 +1131,8 @@ def hardswish( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ ivy.Container instance method variant of ivy.hardswish. This method simply wraps @@ -1151,11 +1154,12 @@ def hardswish( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - complex_mode - optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -1179,6 +1183,6 @@ def hardswish( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - complex_mode=complex_mode, out=out, + complex_mode=complex_mode, ) diff --git a/ivy/data_classes/container/experimental/activations.py b/ivy/data_classes/container/experimental/activations.py index c46d06c7328ab..dbec1d544ba85 100644 --- a/ivy/data_classes/container/experimental/activations.py +++ b/ivy/data_classes/container/experimental/activations.py @@ -320,8 +320,8 @@ def static_relu6( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ ivy.Container static method variant of ivy.relu6. This method simply wraps the @@ -343,11 +343,12 @@ def static_relu6( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - complex_mode - optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -375,8 +376,8 @@ def static_relu6( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - complex_mode=complex_mode, out=out, + complex_mode=complex_mode, ) def relu6( @@ -387,8 +388,8 @@ def relu6( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ ivy.Container instance method variant of ivy.relu6. This method simply wraps the @@ -410,11 +411,12 @@ def relu6( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - complex_mode - optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -441,8 +443,8 @@ def relu6( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - complex_mode=complex_mode, out=out, + complex_mode=complex_mode, ) @staticmethod @@ -477,7 +479,8 @@ def static_logsigmoid( Whether to also map method to sequences (lists, tuples). Default is ``False``. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -533,7 +536,8 @@ def logsigmoid( self Input container. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -568,8 +572,8 @@ def static_selu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ ivy.Container static method variant of ivy.selu. This method simply wraps the @@ -591,11 +595,12 @@ def static_selu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - complex_mode - optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -620,8 +625,8 @@ def static_selu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - complex_mode=complex_mode, out=out, + complex_mode=complex_mode, ) def selu( @@ -632,8 +637,8 @@ def selu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ ivy.Container instance method variant of ivy.selu. This method simply wraps the @@ -655,11 +660,12 @@ def selu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - complex_mode - optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -683,8 +689,8 @@ def selu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - complex_mode=complex_mode, out=out, + complex_mode=complex_mode, ) @staticmethod @@ -696,8 +702,8 @@ def _static_silu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ ivy.Container static method variant of ivy.silu. This method simply wraps the @@ -719,11 +725,12 @@ def _static_silu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - complex_mode - optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -748,8 +755,8 @@ def _static_silu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - complex_mode=complex_mode, out=out, + complex_mode=complex_mode, ) def silu( @@ -760,8 +767,8 @@ def silu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ ivy.Container instance method variant of ivy.silu. This method simply wraps the @@ -783,11 +790,12 @@ def silu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - complex_mode - optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -811,8 +819,8 @@ def silu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - complex_mode=complex_mode, out=out, + complex_mode=complex_mode, ) @staticmethod @@ -825,8 +833,8 @@ def _static_elu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ ivy.Container static method variant of ivy.elu. This method simply wraps the @@ -850,11 +858,12 @@ def _static_elu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - complex_mode - optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -879,8 +888,8 @@ def _static_elu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - complex_mode=complex_mode, out=out, + complex_mode=complex_mode, ) def elu( @@ -892,8 +901,8 @@ def elu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ ivy.Container instance method variant of ivy.elu. This method simply wraps the @@ -917,11 +926,12 @@ def elu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - complex_mode - optional specifier for how to handle complex data types. out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -945,6 +955,6 @@ def elu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - complex_mode=complex_mode, out=out, + complex_mode=complex_mode, ) diff --git a/ivy/functional/backends/jax/activations.py b/ivy/functional/backends/jax/activations.py index d7705c9847d14..af2a269623303 100644 --- a/ivy/functional/backends/jax/activations.py +++ b/ivy/functional/backends/jax/activations.py @@ -32,7 +32,9 @@ def relu(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: return jnp.maximum(x, 0) -def sigmoid(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: +def sigmoid( + x: JaxArray, /, *, out: Optional[JaxArray] = None, complex_mode="jax" +) -> JaxArray: return 1 / (1 + jnp.exp(-x)) @@ -83,5 +85,7 @@ def mish(x: JaxArray, /, *, out: Optional[JaxArray] = None): return x * jnp.tanh(jax.nn.softplus(x)) -def hardswish(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: +def hardswish( + x: JaxArray, /, *, out: Optional[JaxArray] = None, complex_mode="jax" +) -> JaxArray: return jax.nn.hard_swish(x) diff --git a/ivy/functional/backends/jax/experimental/activations.py b/ivy/functional/backends/jax/experimental/activations.py index 788126f4b44cf..6494831629f5f 100644 --- a/ivy/functional/backends/jax/experimental/activations.py +++ b/ivy/functional/backends/jax/experimental/activations.py @@ -22,7 +22,9 @@ def logit( return jnp.log(x / (1 - x)) -def relu6(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: +def relu6( + x: JaxArray, /, *, out: Optional[JaxArray] = None, complex_mode="jax" +) -> JaxArray: relu6_func = jax.nn.relu6 # sets gradient at 0 and 6 to 0 instead of 0.5 @@ -48,18 +50,24 @@ def thresholded_relu( return jnp.where(x > threshold, x, 0).astype(x.dtype) -def logsigmoid(input: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: +def logsigmoid( + input: JaxArray, /, *, out: Optional[JaxArray] = None, complex_mode="jax" +) -> JaxArray: return jax.nn.log_sigmoid(input) -def selu(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: +def selu( + x: JaxArray, /, *, out: Optional[JaxArray] = None, complex_mode="jax" +) -> JaxArray: ret = jax.nn.selu(x).astype(x.dtype) if ivy.exists(out): return ivy.inplace_update(out, ret).astype(x.dtype) return ret -def silu(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: +def silu( + x: JaxArray, /, *, out: Optional[JaxArray] = None, complex_mode="jax" +) -> JaxArray: ret = jax.nn.silu(x) if ivy.exists(out): return ivy.inplace_update(out, ret).astype(x.dtype) @@ -67,7 +75,12 @@ def silu(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: def elu( - x: JaxArray, /, *, alpha: float = 1.0, out: Optional[JaxArray] = None + x: JaxArray, + /, + *, + alpha: float = 1.0, + out: Optional[JaxArray] = None, + complex_mode="jax", ) -> JaxArray: ret = jax.nn.elu(x, alpha) if ivy.exists(out): diff --git a/ivy/functional/backends/mxnet/activations.py b/ivy/functional/backends/mxnet/activations.py index 9b6cfdcc9467f..68b634c45f097 100644 --- a/ivy/functional/backends/mxnet/activations.py +++ b/ivy/functional/backends/mxnet/activations.py @@ -27,7 +27,7 @@ def relu(x: None, /, *, out: Optional[None] = None) -> None: return mx.nd.relu(x) -def sigmoid(x: None, /, *, out: Optional[None] = None) -> None: +def sigmoid(x: None, /, *, out: Optional[None] = None, complex_mode="jax") -> None: return mx.nd.sigmoid(x) diff --git a/ivy/functional/backends/mxnet/experimental/activations.py b/ivy/functional/backends/mxnet/experimental/activations.py index 2ab8f7443e0af..ef52d70dbccd6 100644 --- a/ivy/functional/backends/mxnet/experimental/activations.py +++ b/ivy/functional/backends/mxnet/experimental/activations.py @@ -20,17 +20,17 @@ def thresholded_relu( raise IvyNotImplementedException() -def relu6(x: None, /, *, out: Optional[None] = None) -> None: +def relu6(x: None, /, *, out: Optional[None] = None, complex_mode="jax") -> None: raise IvyNotImplementedException() -def logsigmoid(input: None) -> None: +def logsigmoid(input: None, complex_mode="jax") -> None: raise IvyNotImplementedException() -def selu(x: None, /, *, out: Optional[None] = None) -> None: +def selu(x: None, /, *, out: Optional[None] = None, complex_mode="jax") -> None: raise IvyNotImplementedException() -def silu(x: None, /, *, out: Optional[None] = None) -> None: +def silu(x: None, /, *, out: Optional[None] = None, complex_mode="jax") -> None: raise IvyNotImplementedException() diff --git a/ivy/functional/backends/numpy/activations.py b/ivy/functional/backends/numpy/activations.py index d2ee4c0675de3..1a485367de893 100644 --- a/ivy/functional/backends/numpy/activations.py +++ b/ivy/functional/backends/numpy/activations.py @@ -34,7 +34,9 @@ def gelu( return ivy.astype(ret, x.dtype, copy=False) -def sigmoid(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: +def sigmoid( + x: np.ndarray, /, *, out: Optional[np.ndarray] = None, complex_mode="jax" +) -> np.ndarray: if not ivy.is_array(x): return np.asarray(1 / (1 + np.exp(-x))) return np.asarray(1 / (1 + np.exp(-x))).astype(x.dtype) @@ -117,7 +119,9 @@ def mish(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: @_scalar_output_to_0d_array -def hardswish(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: +def hardswish( + x: np.ndarray, /, *, out: Optional[np.ndarray] = None, complex_mode="jax" +) -> np.ndarray: max_x_3 = np.maximum(x + 3, 0, dtype=x.dtype) return (x * np.minimum(max_x_3, 6, out=out, dtype=x.dtype) / 6).astype(x.dtype) diff --git a/ivy/functional/backends/numpy/experimental/activations.py b/ivy/functional/backends/numpy/experimental/activations.py index 8b9383e0630be..9cc9002d190f2 100644 --- a/ivy/functional/backends/numpy/experimental/activations.py +++ b/ivy/functional/backends/numpy/experimental/activations.py @@ -43,7 +43,9 @@ def thresholded_relu( @_scalar_output_to_0d_array -def relu6(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: +def relu6( + x: np.ndarray, /, *, out: Optional[np.ndarray] = None, complex_mode="jax" +) -> np.ndarray: return np.minimum(np.maximum(x, 0, dtype=x.dtype), 6, out=out, dtype=x.dtype) @@ -52,12 +54,16 @@ def relu6(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: @with_unsupported_dtypes({"1.25.2 and below": ("bool",)}, backend_version) @_scalar_output_to_0d_array -def logsigmoid(input: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: +def logsigmoid( + input: np.ndarray, /, *, out: Optional[np.ndarray] = None, complex_mode="jax" +) -> np.ndarray: return -(np.log1p(np.exp(-(input)))) @_scalar_output_to_0d_array -def selu(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: +def selu( + x: np.ndarray, /, *, out: Optional[np.ndarray] = None, complex_mode="jax" +) -> np.ndarray: alpha = 1.6732632423543772848170429916717 scale = 1.0507009873554804934193349852946 ret = (scale * np.where(x > 0, x, alpha * np.expm1(x))).astype(x.dtype) @@ -70,7 +76,9 @@ def selu(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: @_scalar_output_to_0d_array -def silu(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: +def silu( + x: np.ndarray, /, *, out: Optional[np.ndarray] = None, complex_mode="jax" +) -> np.ndarray: ret = np.asarray(x * (1 / (1 + np.exp(-x)))) if ivy.exists(out): return ivy.inplace_update(out, ret).astype(x.dtype) @@ -94,7 +102,12 @@ def silu(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: ) @_scalar_output_to_0d_array def elu( - x: np.ndarray, /, *, alpha: float = 1.0, out: Optional[np.ndarray] = None + x: np.ndarray, + /, + *, + alpha: float = 1.0, + out: Optional[np.ndarray] = None, + complex_mode="jax", ) -> np.ndarray: # exp = np.expm1(x) ret = np.where(x > 0, x, np.multiply(alpha, np.expm1(x))).astype(x.dtype) diff --git a/ivy/functional/backends/paddle/activations.py b/ivy/functional/backends/paddle/activations.py index 7840bd8eec25b..6c75d1db46304 100644 --- a/ivy/functional/backends/paddle/activations.py +++ b/ivy/functional/backends/paddle/activations.py @@ -89,7 +89,7 @@ def gelu( backend_version, ) def sigmoid( - x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None + x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None, complex_mode="jax" ) -> paddle.Tensor: if x.dtype in unsupported_dtypes: if paddle.is_complex(x): @@ -183,6 +183,6 @@ def mish(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. {"2.5.1 and below": {"cpu": ("float16",)}}, backend_version ) def hardswish( - x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None + x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None, complex_mode="jax" ) -> paddle.Tensor: return F.hardswish(x) diff --git a/ivy/functional/backends/paddle/experimental/activations.py b/ivy/functional/backends/paddle/experimental/activations.py index 81bc6bdf25cb7..ce99e64e78ca6 100644 --- a/ivy/functional/backends/paddle/experimental/activations.py +++ b/ivy/functional/backends/paddle/experimental/activations.py @@ -47,7 +47,9 @@ def thresholded_relu( ) -def relu6(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: +def relu6( + x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None, complex_mode="jax" +) -> paddle.Tensor: if x.dtype in [paddle.float32, paddle.float64]: return F.relu6(x) if paddle.is_complex(x): @@ -56,7 +58,7 @@ def relu6(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle def logsigmoid( - input: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None + input: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None, complex_mode="jax" ) -> paddle.Tensor: if input.dtype in [paddle.float32, paddle.float64]: return F.log_sigmoid(input) @@ -69,7 +71,9 @@ def logsigmoid( return F.log_sigmoid(input.cast("float32")).cast(input.dtype) -def selu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: +def selu( + x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None, complex_mode="jax" +) -> paddle.Tensor: if x.dtype in [paddle.float32, paddle.float64]: return F.selu(x) if paddle.is_complex(x): @@ -87,7 +91,9 @@ def selu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. return F.selu(x.cast("float32")).cast(x.dtype) -def silu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: +def silu( + x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None, complex_mode="jax" +) -> paddle.Tensor: if x.dtype in [paddle.float32, paddle.float64]: return F.silu(x) if paddle.is_complex(x): @@ -96,7 +102,12 @@ def silu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. def elu( - x: paddle.Tensor, /, *, alpha: float = 1.0, out: Optional[paddle.Tensor] = None + x: paddle.Tensor, + /, + *, + alpha: float = 1.0, + out: Optional[paddle.Tensor] = None, + complex_mode="jax", ) -> paddle.Tensor: if x.dtype in [paddle.float32, paddle.float64]: return F.elu(x, alpha=alpha) diff --git a/ivy/functional/backends/tensorflow/activations.py b/ivy/functional/backends/tensorflow/activations.py index fb6190c5829ec..1dbc045927d71 100644 --- a/ivy/functional/backends/tensorflow/activations.py +++ b/ivy/functional/backends/tensorflow/activations.py @@ -35,7 +35,9 @@ def relu(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: @with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) -def sigmoid(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: +def sigmoid( + x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax" +) -> Tensor: if not ivy.is_array(x): x = float(x) return tf.nn.sigmoid(x) @@ -88,5 +90,7 @@ def mish( @with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) -def hardswish(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: +def hardswish( + x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax" +) -> Tensor: return x * tf.nn.relu6(x + 3) / 6 diff --git a/ivy/functional/backends/tensorflow/experimental/activations.py b/ivy/functional/backends/tensorflow/experimental/activations.py index f798d7b907d98..22ee4d6f45284 100644 --- a/ivy/functional/backends/tensorflow/experimental/activations.py +++ b/ivy/functional/backends/tensorflow/experimental/activations.py @@ -38,17 +38,19 @@ def thresholded_relu( @with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) -def relu6(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: +def relu6(x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax") -> Tensor: return tf.nn.relu6(x) @with_supported_dtypes({"2.13.0 and below": ("float",)}, backend_version) -def logsigmoid(input: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: +def logsigmoid( + input: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax" +) -> Tensor: return tf.math.log_sigmoid(input) @with_supported_dtypes({"2.13.0 and below": ("float",)}, backend_version) -def selu(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: +def selu(x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax") -> Tensor: ret = tf.nn.selu(x) if ivy.exists(out): return ivy.inplace_update(out, ret).astype(x.dtype) @@ -56,12 +58,7 @@ def selu(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: @with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) -def silu( - x: Tensor, - /, - *, - out: Optional[Tensor] = None, -) -> Tensor: +def silu(x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax") -> Tensor: ret = tf.nn.silu(x) if ivy.exists(out): return ivy.inplace_update(out, ret).astype(x.dtype) @@ -69,7 +66,14 @@ def silu( @with_supported_dtypes({"2.13.0 and below": ("float",)}, backend_version) -def elu(x: Tensor, /, *, alpha: float = 1.0, out: Optional[Tensor] = None) -> Tensor: +def elu( + x: Tensor, + /, + *, + alpha: float = 1.0, + out: Optional[Tensor] = None, + complex_mode="jax", +) -> Tensor: alpha = tf.cast(alpha, x.dtype) ret = tf.cast(tf.where(x > 0, x, tf.multiply(alpha, tf.math.expm1(x))), x.dtype) if ivy.exists(out): diff --git a/ivy/functional/backends/torch/activations.py b/ivy/functional/backends/torch/activations.py index f5f4f500974a7..ed1b798c31881 100644 --- a/ivy/functional/backends/torch/activations.py +++ b/ivy/functional/backends/torch/activations.py @@ -49,7 +49,9 @@ def gelu( @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) -def sigmoid(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: +def sigmoid( + x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, complex_mode="jax" +) -> torch.Tensor: if not ivy.is_array(x): x = torch.tensor(x) return torch.sigmoid(x, out=out) @@ -132,6 +134,6 @@ def mish(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Ten backend_version, ) def hardswish( - x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None + x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, complex_mode="jax" ) -> torch.Tensor: return torch.nn.functional.hardswish(x) diff --git a/ivy/functional/backends/torch/experimental/activations.py b/ivy/functional/backends/torch/experimental/activations.py index 98d72ac526c45..4bb944577f63a 100644 --- a/ivy/functional/backends/torch/experimental/activations.py +++ b/ivy/functional/backends/torch/experimental/activations.py @@ -33,19 +33,23 @@ def thresholded_relu( @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) -def relu6(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: +def relu6( + x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, complex_mode="jax" +) -> torch.Tensor: return torch.nn.functional.relu6(x) @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) def logsigmoid( - input: torch.Tensor, /, *, out: Optional[torch.Tensor] = None + input: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, complex_mode="jax" ) -> torch.Tensor: return torch.nn.functional.logsigmoid(input) @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) -def selu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: +def selu( + x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, complex_mode="jax" +) -> torch.Tensor: ret = torch.nn.functional.selu(x) if ivy.exists(out): return ivy.inplace_update(out, ret).astype(x.dtype) @@ -53,13 +57,20 @@ def selu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Ten @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) -def silu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: +def silu( + x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, complex_mode="jax" +) -> torch.Tensor: return torch.nn.functional.silu(x) @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) def elu( - x: torch.Tensor, /, *, alpha: float = 1.0, out: Optional[torch.Tensor] = None + x: torch.Tensor, + /, + *, + alpha: float = 1.0, + out: Optional[torch.Tensor] = None, + complex_mode="jax", ) -> torch.Tensor: ret = torch.nn.functional.elu(x, alpha) if ivy.exists(out): diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index d4b96aefeab36..3390731582993 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -347,8 +347,8 @@ def sigmoid( x: Union[ivy.Array, ivy.NativeArray], /, *, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Array] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ Apply the sigmoid function element-wise. @@ -357,13 +357,13 @@ def sigmoid( ---------- x input array. - complex_mode - optional specifier for how to handle complex data types. See - `ivy.func_wrapper.handle_complex_input` for more detail. out optional output array, for writing the result to. It must have a shape that the input broadcast to. default: None + complex_mode + optional specifier for how to handle complex data types. See + `ivy.func_wrapper.handle_complex_input` for more detail. Returns ------- @@ -599,8 +599,8 @@ def hardswish( x: Union[ivy.Array, ivy.NativeArray], /, *, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Array] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ Apply the hardswish activation function element-wise. @@ -609,12 +609,12 @@ def hardswish( ---------- x input array - complex_mode - optional specifier for how to handle complex data types. See - `ivy.func_wrapper.handle_complex_input` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + `ivy.func_wrapper.handle_complex_input` for more detail. Returns ------- diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index ffd7f160c3222..bf2fa94d1630f 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -200,8 +200,8 @@ def relu6( x: Union[ivy.Array, ivy.NativeArray], /, *, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Array] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ Apply the rectified linear unit 6 function element-wise. @@ -210,12 +210,12 @@ def relu6( ---------- x input array - complex_mode - optional specifier for how to handle complex data types. See - `ivy.func_wrapper.handle_complex_input` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + `ivy.func_wrapper.handle_complex_input` for more detail. Returns ------- @@ -266,8 +266,8 @@ def logsigmoid( input: Union[ivy.NativeArray, ivy.Array], /, *, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Array] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ Apply element-wise Log-sigmoid of x. @@ -326,8 +326,8 @@ def selu( x: Union[ivy.Array, ivy.NativeArray], /, *, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Array] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ Apply the scaled exponential linear unit function element-wise. @@ -336,12 +336,12 @@ def selu( ---------- x input array - complex_mode - optional specifier for how to handle complex data types. See - `ivy.func_wrapper.handle_complex_input` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + `ivy.func_wrapper.handle_complex_input` for more detail. Returns ------- @@ -394,8 +394,8 @@ def silu( x: Union[ivy.Array, ivy.NativeArray], /, *, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Array] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ Apply the silu function element-wise. @@ -404,12 +404,12 @@ def silu( ---------- x input array. - complex_mode - optional specifier for how to handle complex data types. See - `ivy.func_wrapper.handle_complex_input` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + `ivy.func_wrapper.handle_complex_input` for more detail. Returns ------- @@ -452,8 +452,8 @@ def elu( /, *, alpha: float = 1.0, - complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Array] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ Apply the elu unit function element-wise. @@ -464,12 +464,12 @@ def elu( Input array. alpha scaler for controlling the slope of the function for x <= 0 Default: 1.0 - complex_mode - optional specifier for how to handle complex data types. See - `ivy.func_wrapper.handle_complex_input` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + `ivy.func_wrapper.handle_complex_input` for more detail. Returns ------- diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index 5deca2487f7d3..1cca92b061a64 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -236,27 +236,26 @@ def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): Parameter ---------- complex_mode - Specifies how to handle complex input. + Specifies how to handle complex input. See + `ivy.func_wrapper.handle_complex_input` for more detail. """ self._complex_mode = complex_mode Module.__init__(self) - def _forward(self, x, complex_mode=None): + def _forward(self, x): """ Parameters ---------- x Inputs to process *[batch_shape, d]*. - complex_mode - Specifies how to handle complex input. Returns ------- ret The outputs following the SiLU activation *[batch_shape, d]* """ - return ivy.silu(x, complex_mode=ivy.default(complex_mode, self._complex_mode)) + return ivy.silu(x, complex_mode=self._complex_mode) class Sigmoid(Module): @@ -267,7 +266,8 @@ def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): Parameter ---------- complex_mode - Specifies how to handle complex input. + Specifies how to handle complex input. See + `ivy.func_wrapper.handle_complex_input` for more detail. """ self._complex_mode = complex_mode Module.__init__(self) @@ -317,27 +317,26 @@ def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): Parameter ---------- complex_mode - Specifies how to handle complex input. + Specifies how to handle complex input. See + `ivy.func_wrapper.handle_complex_input` for more detail. """ self._complex_mode = complex_mode Module.__init__(self) - def _forward(self, x, complex_mode=None): + def _forward(self, x): """ Parameters ---------- x Inputs to process *[batch_shape, d]*. - complex_mode - Specifies how to handle complex input. Returns ------- ret The outputs following the RELU6 activation *[batch_shape, d]* """ - return ivy.relu6(x, complex_mode=ivy.default(complex_mode, self._complex_mode)) + return ivy.relu6(x, complex_mode=self._complex_mode) class Hardswish(Module): @@ -348,29 +347,26 @@ def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): Parameter ---------- complex_mode - Specifies how to handle complex input. + Specifies how to handle complex input. See + `ivy.func_wrapper.handle_complex_input` for more detail. """ self._complex_mode = complex_mode Module.__init__(self) - def _forward(self, x, complex_mode=None): + def _forward(self, x): """ Parameters ---------- x Inputs to process *[batch_shape, d]*. - complex_mode - Specifies how to handle complex input. Returns ------- ret The outputs following the HARDSWISH activation *[batch_shape, d]* """ - return ivy.hardswish( - x, complex_mode=ivy.default(complex_mode, self._complex_mode) - ) + return ivy.hardswish(x, complex_mode=self._complex_mode) class Logit(Module): @@ -427,27 +423,26 @@ def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): Parameter ---------- complex_mode - Specifies how to handle complex input. + Specifies how to handle complex input. See + `ivy.func_wrapper.handle_complex_input` for more detail. """ self._complex_mode = complex_mode Module.__init__(self) - def _forward(self, x, complex_mode=None): + def _forward(self, x): """ Parameters ---------- x Inputs to process *[batch_shape, d]*. - complex_mode - Specifies how to handle complex input. Returns ------- ret The outputs following the SELU activation *[batch_shape, d]* """ - return ivy.selu(x, complex_mode=ivy.default(complex_mode, self._complex_mode)) + return ivy.selu(x, complex_mode=self._complex_mode) class ELU(Module): @@ -458,12 +453,13 @@ def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): Parameter ---------- complex_mode - Specifies how to handle complex input. + Specifies how to handle complex input. See + `ivy.func_wrapper.handle_complex_input` for more detail. """ self._complex_mode = complex_mode Module.__init__(self) - def _forward(self, x, alpha=1.0, complex_mode=None): + def _forward(self, x, alpha=1.0): """ Parameters ---------- @@ -471,17 +467,13 @@ def _forward(self, x, alpha=1.0, complex_mode=None): Inputs to process *[batch_shape, d]*. alpha scaler for controlling the slope of the function for x <= 0 Default: 1.0 - complex_mode - Specifies how to handle complex input. Returns ------- ret The outputs following the ELU activation *[batch_shape, d]* """ - return ivy.elu( - x, alpha=alpha, complex_mode=ivy.default(complex_mode, self._complex_mode) - ) + return ivy.elu(x, alpha=alpha, complex_mode=self._complex_mode) class LogSigmoid(Module): @@ -492,26 +484,23 @@ def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): Parameter ---------- complex_mode - Specifies how to handle complex input. + Specifies how to handle complex input. See + `ivy.func_wrapper.handle_complex_input` for more detail. """ self._complex_mode = complex_mode Module.__init__(self) - def _forward(self, x, complex_mode=None): + def _forward(self, x): """ Parameters ---------- x Inputs to process *[batch_shape, d]*. - complex_mode - Specifies how to handle complex input. Returns ------- ret The outputs following the LogSigmoid activation *[batch_shape, d]* """ - return ivy.logsigmoid( - x, complex_mode=ivy.default(complex_mode, self._complex_mode) - ) + return ivy.logsigmoid(x, complex_mode=self._complex_mode) From 5a0956e9f0336d4d3cbe804e11f9155bb9e6a467 Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Fri, 18 Aug 2023 11:17:37 +0000 Subject: [PATCH 14/38] fixed testing issues for silu --- ivy/functional/backends/paddle/experimental/activations.py | 4 ++++ ivy/stateful/activations.py | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/ivy/functional/backends/paddle/experimental/activations.py b/ivy/functional/backends/paddle/experimental/activations.py index 81bc6bdf25cb7..5fe379aa508b9 100644 --- a/ivy/functional/backends/paddle/experimental/activations.py +++ b/ivy/functional/backends/paddle/experimental/activations.py @@ -87,6 +87,10 @@ def selu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. return F.selu(x.cast("float32")).cast(x.dtype) +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("bfloat16", "complex64", "complex128")}}, + backend_version, +) def silu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: if x.dtype in [paddle.float32, paddle.float64]: return F.silu(x) diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index 5deca2487f7d3..592751efba9b2 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -241,7 +241,7 @@ def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): self._complex_mode = complex_mode Module.__init__(self) - def _forward(self, x, complex_mode=None): + def _forward(self, x): """ Parameters @@ -256,7 +256,7 @@ def _forward(self, x, complex_mode=None): ret The outputs following the SiLU activation *[batch_shape, d]* """ - return ivy.silu(x, complex_mode=ivy.default(complex_mode, self._complex_mode)) + return ivy.silu(x, complex_mode=self._complex_mode) class Sigmoid(Module): @@ -285,7 +285,7 @@ def _forward(self, x): ret The outputs following the SIGMOID activation *[batch_shape, d]* """ - return ivy.sigmoid(x) + return ivy.sigmoid(x, complex_mode=self._complex_mode) class Tanh(Module): From 5bb2dab7aef93844dd8207159d6c8c2587d0c1dc Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Fri, 18 Aug 2023 11:27:47 +0000 Subject: [PATCH 15/38] fixed testing issues for selu --- ivy/functional/backends/paddle/experimental/activations.py | 4 ++++ ivy/functional/backends/torch/experimental/activations.py | 4 ++-- ivy/stateful/activations.py | 4 ++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/ivy/functional/backends/paddle/experimental/activations.py b/ivy/functional/backends/paddle/experimental/activations.py index 5fe379aa508b9..252b5be7d6449 100644 --- a/ivy/functional/backends/paddle/experimental/activations.py +++ b/ivy/functional/backends/paddle/experimental/activations.py @@ -69,6 +69,10 @@ def logsigmoid( return F.log_sigmoid(input.cast("float32")).cast(input.dtype) +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("bfloat16", "complex64", "complex128")}}, + backend_version, +) def selu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: if x.dtype in [paddle.float32, paddle.float64]: return F.selu(x) diff --git a/ivy/functional/backends/torch/experimental/activations.py b/ivy/functional/backends/torch/experimental/activations.py index 98d72ac526c45..cfdca26d6629b 100644 --- a/ivy/functional/backends/torch/experimental/activations.py +++ b/ivy/functional/backends/torch/experimental/activations.py @@ -44,7 +44,7 @@ def logsigmoid( return torch.nn.functional.logsigmoid(input) -@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) +@with_unsupported_dtypes({"2.0.1 and below": ("float16", "complex")}, backend_version) def selu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: ret = torch.nn.functional.selu(x) if ivy.exists(out): @@ -52,7 +52,7 @@ def selu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Ten return ivy.astype(ret, x.dtype) -@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) +@with_unsupported_dtypes({"2.0.1 and below": ("float16", "complex")}, backend_version) def silu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: return torch.nn.functional.silu(x) diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index 592751efba9b2..c1bb33a1d02bf 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -432,7 +432,7 @@ def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): self._complex_mode = complex_mode Module.__init__(self) - def _forward(self, x, complex_mode=None): + def _forward(self, x): """ Parameters @@ -447,7 +447,7 @@ def _forward(self, x, complex_mode=None): ret The outputs following the SELU activation *[batch_shape, d]* """ - return ivy.selu(x, complex_mode=ivy.default(complex_mode, self._complex_mode)) + return ivy.selu(x, complex_mode=self._complex_mode) class ELU(Module): From 12bd06602899412fa3ad65e33c1eb29d0c5b9375 Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Fri, 18 Aug 2023 15:41:03 +0000 Subject: [PATCH 16/38] fixed testing issues for log-sigmoid --- ivy/functional/backends/paddle/experimental/activations.py | 5 ++++- ivy/functional/backends/torch/experimental/activations.py | 4 +++- .../test_jax/test_nn/test_non_linear_activations.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/ivy/functional/backends/paddle/experimental/activations.py b/ivy/functional/backends/paddle/experimental/activations.py index e87e26d1b4f36..2185f91bb1900 100644 --- a/ivy/functional/backends/paddle/experimental/activations.py +++ b/ivy/functional/backends/paddle/experimental/activations.py @@ -57,6 +57,9 @@ def relu6( return F.relu6(x.cast("float32")).cast(x.dtype) +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version +) def logsigmoid( input: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None, complex_mode="jax" ) -> paddle.Tensor: @@ -65,7 +68,7 @@ def logsigmoid( if paddle.is_complex(input): return paddle_backend.log( paddle_backend.divide( - 1.0, (paddle_backend.add(1.0, paddle_backend.exp(input))) + 1.0, (paddle_backend.add(1.0, paddle_backend.exp(-input))) ) ) return F.log_sigmoid(input.cast("float32")).cast(input.dtype) diff --git a/ivy/functional/backends/torch/experimental/activations.py b/ivy/functional/backends/torch/experimental/activations.py index 5366bc4ac8934..691d41f2a1834 100644 --- a/ivy/functional/backends/torch/experimental/activations.py +++ b/ivy/functional/backends/torch/experimental/activations.py @@ -41,8 +41,10 @@ def relu6( @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) def logsigmoid( - input: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, complex_mode="jax" + input: torch.Tensor, /, *, out: Optional[torch.Tensor] = None ) -> torch.Tensor: + if torch.is_complex(input): + return torch.log(torch.sigmoid(input)) return torch.nn.functional.logsigmoid(input) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index 19cd6e8724912..8b77073c70a6d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -358,7 +358,7 @@ def test_jax_softplus( @handle_frontend_test( fn_tree="jax.nn.log_sigmoid", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float_and_complex"), min_value=-100, max_value=100, large_abs_safety_factor=8, From 38192186574bbf68b38984542a66ac0d30b28b14 Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Fri, 18 Aug 2023 21:08:16 +0000 Subject: [PATCH 17/38] added jax-like function for elu --- .../numpy/experimental/activations.py | 9 --------- .../paddle/experimental/activations.py | 4 ++++ .../torch/experimental/activations.py | 2 +- .../ivy/experimental/activations.py | 19 +++++++++++++++++++ .../test_nn/test_non_linear_activations.py | 8 +++++--- 5 files changed, 29 insertions(+), 13 deletions(-) diff --git a/ivy/functional/backends/numpy/experimental/activations.py b/ivy/functional/backends/numpy/experimental/activations.py index 9cc9002d190f2..b222a84a75167 100644 --- a/ivy/functional/backends/numpy/experimental/activations.py +++ b/ivy/functional/backends/numpy/experimental/activations.py @@ -91,15 +91,6 @@ def silu( silu.support_native_out = True -@with_unsupported_dtypes( - { - "1.25.2 and below": ( - "complex64", - "complex128", - ) - }, - backend_version, -) @_scalar_output_to_0d_array def elu( x: np.ndarray, diff --git a/ivy/functional/backends/paddle/experimental/activations.py b/ivy/functional/backends/paddle/experimental/activations.py index 2185f91bb1900..ccd33672a192b 100644 --- a/ivy/functional/backends/paddle/experimental/activations.py +++ b/ivy/functional/backends/paddle/experimental/activations.py @@ -108,6 +108,10 @@ def silu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. return F.silu(x.cast("float32")).cast(x.dtype) +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("float16", "bfloat16", "complex64", "complex128")}}, + backend_version, +) def elu( x: paddle.Tensor, /, diff --git a/ivy/functional/backends/torch/experimental/activations.py b/ivy/functional/backends/torch/experimental/activations.py index 691d41f2a1834..59a80fd1f7603 100644 --- a/ivy/functional/backends/torch/experimental/activations.py +++ b/ivy/functional/backends/torch/experimental/activations.py @@ -61,7 +61,7 @@ def silu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Ten return torch.nn.functional.silu(x) -@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) +@with_unsupported_dtypes({"2.0.1 and below": ("float16", "complex")}, backend_version) def elu( x: torch.Tensor, /, diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index bf2fa94d1630f..ebc7922cd8393 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -439,6 +439,22 @@ def silu( return current_backend(x).silu(x, out=out) +def _elu_jax_like( + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + fn_original=None, + alpha: float = 1.0, + out: Optional[ivy.Array] = None, +) -> ivy.Array: + safe_x = ivy.where( + (x > 0), + 0.0, + x, + ) + return ivy.where((x > 0), x, ivy.astype(alpha * ivy.expm1(safe_x), x.dtype)) + + @handle_exceptions @handle_backend_invalid @handle_nestable @@ -506,6 +522,9 @@ def elu( return current_backend(x).elu(x, alpha=alpha, out=out) +elu.jax_like = _elu_jax_like + + def sequence_length( x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None ) -> ivy.int64: diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index 8b77073c70a6d..83bab11929a21 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -589,18 +589,20 @@ def test_jax_celu( @handle_frontend_test( fn_tree="jax.nn.elu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float_and_complex"), min_value=-5, max_value=5, safety_factor_scale="linear", - num_arrays=2, + num_arrays=1, shared_dtype=True, ), + alpha=st.floats(min_value=0, max_value=1, allow_infinity=False), test_with_out=st.just(False), ) def test_jax_elu( *, dtype_and_x, + alpha, test_flags, on_device, fn_tree, @@ -616,7 +618,7 @@ def test_jax_elu( fn_tree=fn_tree, on_device=on_device, x=xs[0], - alpha=xs[1], + alpha=alpha, rtol=1e-03, atol=1e-03, ) From 22ed2a639764a0c7d921095c2b076c2db4d94971 Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Sat, 19 Aug 2023 07:24:19 +0000 Subject: [PATCH 18/38] modified alpha within the tests --- .../test_jax/test_nn/test_non_linear_activations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index 83bab11929a21..6d8174bd91530 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -596,7 +596,7 @@ def test_jax_celu( num_arrays=1, shared_dtype=True, ), - alpha=st.floats(min_value=0, max_value=1, allow_infinity=False), + alpha=helpers.floats(min_value=0.1, max_value=1), test_with_out=st.just(False), ) def test_jax_elu( From 154f1f45e553f4f5f5d0229f04b47d40de0293d0 Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Sat, 19 Aug 2023 07:33:51 +0000 Subject: [PATCH 19/38] added complex_mode argument to all backends as suggested by Joe --- ivy/functional/backends/torch/experimental/activations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/backends/torch/experimental/activations.py b/ivy/functional/backends/torch/experimental/activations.py index 59a80fd1f7603..0a8de884db013 100644 --- a/ivy/functional/backends/torch/experimental/activations.py +++ b/ivy/functional/backends/torch/experimental/activations.py @@ -41,7 +41,7 @@ def relu6( @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) def logsigmoid( - input: torch.Tensor, /, *, out: Optional[torch.Tensor] = None + input: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, complex_mode="jax" ) -> torch.Tensor: if torch.is_complex(input): return torch.log(torch.sigmoid(input)) From 1317aebd452fc1c18da49305d834046c6e430582 Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Sat, 19 Aug 2023 14:17:25 +0000 Subject: [PATCH 20/38] added jax-like function for selu --- ivy/functional/ivy/experimental/activations.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index ebc7922cd8393..defefec88cbe5 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -313,6 +313,18 @@ def logsigmoid( return ivy.current_backend(input).logsigmoid(input, out=out) +def _selu_jax_like( + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + fn_original=None, + out: Optional[ivy.Array] = None, +) -> ivy.Array: + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + return scale * elu(x, alpha=alpha) + + @handle_exceptions @handle_backend_invalid @handle_nestable @@ -381,6 +393,9 @@ def selu( return current_backend(x).selu(x, out=out) +selu.jax_like = _selu_jax_like + + @handle_exceptions @handle_backend_invalid @handle_nestable From 55ef1c4d34fa0f0494019c9aa699ce00a9ab8372 Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Sat, 19 Aug 2023 17:38:02 +0000 Subject: [PATCH 21/38] added jax-like fucntion for relu6 --- .../paddle/experimental/activations.py | 3 +++ .../torch/experimental/activations.py | 2 +- .../ivy/experimental/activations.py | 21 +++++++++++++++++++ .../test_nn/test_non_linear_activations.py | 4 ++-- 4 files changed, 27 insertions(+), 3 deletions(-) diff --git a/ivy/functional/backends/paddle/experimental/activations.py b/ivy/functional/backends/paddle/experimental/activations.py index ccd33672a192b..f290199bfb145 100644 --- a/ivy/functional/backends/paddle/experimental/activations.py +++ b/ivy/functional/backends/paddle/experimental/activations.py @@ -47,6 +47,9 @@ def thresholded_relu( ) +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version +) def relu6( x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None, complex_mode="jax" ) -> paddle.Tensor: diff --git a/ivy/functional/backends/torch/experimental/activations.py b/ivy/functional/backends/torch/experimental/activations.py index 0a8de884db013..6ae2cec7619ff 100644 --- a/ivy/functional/backends/torch/experimental/activations.py +++ b/ivy/functional/backends/torch/experimental/activations.py @@ -32,7 +32,7 @@ def thresholded_relu( return torch.threshold(x, threshold=threshold, value=0) -@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) +@with_unsupported_dtypes({"2.0.1 and below": ("float16", "complex")}, backend_version) def relu6( x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, complex_mode="jax" ) -> torch.Tensor: diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index 6bea99007196b..e9d2da7e335cd 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -187,6 +187,24 @@ def thresholded_relu( return current_backend(x).thresholded_relu(x, threshold=threshold, out=out) +def _relu6_jax_like( + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + fn_original=None, + out: Optional[ivy.Array] = None, +) -> ivy.Array: + return ivy.where( + ( + ivy.logical_or( + ivy.real(x) < 0, ivy.logical_and(ivy.real(x) == 0, ivy.imag(x) < 0) + ) + ), + ivy.minimum(ivy.array(0.0, dtype=x.dtype), 6), + x, + ) + + @handle_exceptions @handle_backend_invalid @handle_nestable @@ -241,6 +259,9 @@ def relu6( return current_backend(x).relu6(x, out=out) +relu6.jax_like = _relu6_jax_like + + @handle_exceptions @handle_backend_invalid @handle_nestable diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index 56740069edf6c..28ac4195171e0 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -41,8 +41,8 @@ def test_jax_relu( fn_tree="jax.nn.relu6", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), - large_abs_safety_factor=2, - small_abs_safety_factor=2, + large_abs_safety_factor=3, + small_abs_safety_factor=3, safety_factor_scale="linear", ), test_with_out=st.just(False), From 0d785a780991a5182fa188308b762a22861a963b Mon Sep 17 00:00:00 2001 From: mosesdaudu001 Date: Thu, 24 Aug 2023 10:53:59 +0100 Subject: [PATCH 22/38] fixed relu6 jax-like function --- .../ivy/experimental/activations.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index e9d2da7e335cd..c0afc92c7cbef 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -194,16 +194,15 @@ def _relu6_jax_like( fn_original=None, out: Optional[ivy.Array] = None, ) -> ivy.Array: - return ivy.where( - ( - ivy.logical_or( - ivy.real(x) < 0, ivy.logical_and(ivy.real(x) == 0, ivy.imag(x) < 0) - ) - ), - ivy.minimum(ivy.array(0.0, dtype=x.dtype), 6), - x, - ) - + return ivy.where( + ivy.logical_or(ivy.real(x) < 0, ivy.logical_and(ivy.real(x) == 0, ivy.imag(x) < 0)), + ivy.array(0, dtype=x.dtype), + ivy.where( + ivy.logical_or(ivy.real(x) > 6, ivy.logical_and(ivy.real(x) == 6, ivy.imag(x) > 0)), + ivy.array(6, dtype=x.dtype), + x, + ), + ) @handle_exceptions @handle_backend_invalid From 32ce6765f5c483053994ee23861b935dbd0f357d Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Fri, 25 Aug 2023 10:00:52 +0000 Subject: [PATCH 23/38] added jax-like function for hardswish --- ivy/functional/backends/paddle/activations.py | 3 ++- ivy/functional/ivy/activations.py | 24 +++++++++++++++++++ .../test_nn/test_non_linear_activations.py | 2 +- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/ivy/functional/backends/paddle/activations.py b/ivy/functional/backends/paddle/activations.py index 6c75d1db46304..628263266da22 100644 --- a/ivy/functional/backends/paddle/activations.py +++ b/ivy/functional/backends/paddle/activations.py @@ -180,7 +180,8 @@ def mish(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("float16",)}}, backend_version + {"2.5.1 and below": {"cpu": ("float16", "bfloat16", "complex64", "complex128")}}, + backend_version, ) def hardswish( x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None, complex_mode="jax" diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index 8d3fef4486a23..e79c05eba3b8c 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -621,6 +621,27 @@ def mish( return current_backend(x).mish(x, out=out) +def _hardswish_jax_like( + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + fn_original=None, + out: Optional[ivy.Array] = None, +) -> ivy.Array: + def hard_sigmoid(x): + return ivy.relu6(x + 3.0) / 6 + + return ivy.where( + ( + ivy.logical_or( + ivy.real(x) < 0, ivy.logical_and(ivy.real(x) == 0, ivy.imag(x) < 0) + ) + ), + x * hard_sigmoid(x), + x, + ) + + @handle_exceptions @handle_backend_invalid @handle_nestable @@ -675,3 +696,6 @@ def hardswish( } """ return current_backend(x).hardswish(x, out=out) + + +hardswish.jax_like = _hardswish_jax_like diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index 28ac4195171e0..846ccbedb7e10 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -701,7 +701,7 @@ def test_jax_swish( @handle_frontend_test( fn_tree="jax.nn.hard_swish", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), + available_dtypes=helpers.get_dtypes("float_and_complex"), min_value=-10, max_value=10, safety_factor_scale="linear", From 86e2a88cbb025aefcc2945d8295deac51aa9b43a Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Fri, 25 Aug 2023 10:44:09 +0000 Subject: [PATCH 24/38] added support for complex dtpes in hardwish tf and torch --- .../backends/tensorflow/activations.py | 13 +++++++++++-- ivy/functional/backends/torch/activations.py | 19 ++++++++++++------- ivy/functional/ivy/activations.py | 10 +--------- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/ivy/functional/backends/tensorflow/activations.py b/ivy/functional/backends/tensorflow/activations.py index 0d31d42e341b3..b383fad5a0088 100644 --- a/ivy/functional/backends/tensorflow/activations.py +++ b/ivy/functional/backends/tensorflow/activations.py @@ -88,8 +88,17 @@ def mish( return x * tf.math.tanh(tf.math.softplus(x)) -@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) def hardswish( x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax" ) -> Tensor: - return x * tf.nn.relu6(x + 3) / 6 + if x.dtype.is_complex: + real_part = tf.real(x) + imag_part = tf.imag(x) + + real_result = real_part * tf.nn.relu6(real_part + 3) / 6 + imag_result = imag_part * tf.nn.relu6(imag_part + 3) / 6 + + result = tf.complex(real_result, imag_result) + else: + result = x * tf.nn.relu6(x + 3) / 6 + return result diff --git a/ivy/functional/backends/torch/activations.py b/ivy/functional/backends/torch/activations.py index baee95b38ee2f..c916289e42542 100644 --- a/ivy/functional/backends/torch/activations.py +++ b/ivy/functional/backends/torch/activations.py @@ -125,15 +125,20 @@ def mish(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Ten @with_unsupported_dtypes( - { - "2.0.1 and below": ( - "complex", - "float16", - ) - }, + {"2.0.1 and below": ("float16",)}, backend_version, ) def hardswish( x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, complex_mode="jax" ) -> torch.Tensor: - return torch.nn.functional.hardswish(x) + if x.dtype.is_complex: + real_part = x.real + imag_part = x.imag + + real_result = real_part * torch.nn.functional.relu6(real_part + 3) / 6 + imag_result = imag_part * torch.nn.functional.relu6(imag_part + 3) / 6 + + result = torch.complex(real_result, imag_result) + else: + result = torch.nn.functional.hardswish(x) + return result diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index e79c05eba3b8c..cbdd0faed5281 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -631,15 +631,7 @@ def _hardswish_jax_like( def hard_sigmoid(x): return ivy.relu6(x + 3.0) / 6 - return ivy.where( - ( - ivy.logical_or( - ivy.real(x) < 0, ivy.logical_and(ivy.real(x) == 0, ivy.imag(x) < 0) - ) - ), - x * hard_sigmoid(x), - x, - ) + return x * hard_sigmoid(x) @handle_exceptions From 6a92760d313cba9f87e280f2c21bdd675068324d Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Fri, 25 Aug 2023 11:36:33 +0000 Subject: [PATCH 25/38] added sigmoid jax-like function --- ivy/functional/backends/paddle/activations.py | 3 +-- .../backends/tensorflow/activations.py | 16 +++++++++++----- ivy/functional/ivy/activations.py | 13 +++++++++++++ .../test_nn/test_non_linear_activations.py | 2 ++ 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/ivy/functional/backends/paddle/activations.py b/ivy/functional/backends/paddle/activations.py index 628263266da22..316d010b3f21c 100644 --- a/ivy/functional/backends/paddle/activations.py +++ b/ivy/functional/backends/paddle/activations.py @@ -85,8 +85,7 @@ def gelu( @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("bfloat16", "complex128", "complex64")}}, - backend_version, + {"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version ) def sigmoid( x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None, complex_mode="jax" diff --git a/ivy/functional/backends/tensorflow/activations.py b/ivy/functional/backends/tensorflow/activations.py index b383fad5a0088..65f6035b73d5e 100644 --- a/ivy/functional/backends/tensorflow/activations.py +++ b/ivy/functional/backends/tensorflow/activations.py @@ -12,7 +12,6 @@ from tensorflow.python.types.core import Tensor # local -import ivy from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes from . import backend_version @@ -33,13 +32,20 @@ def relu(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: return tf.nn.relu(x) -@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) def sigmoid( x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax" ) -> Tensor: - if not ivy.is_array(x): - x = float(x) - return tf.nn.sigmoid(x) + if x.dtype.is_complex: + real_part = tf.math.real(x) + imag_part = tf.math.imag(x) + + real_result = 1 / (1 + tf.exp(-real_part)) + imag_result = 1 / (1 + tf.exp(-imag_part)) + + result = tf.complex(real_result, imag_result) + else: + result = tf.nn.sigmoid(x) + return result @with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index cbdd0faed5281..7b2ce4f10753f 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -368,6 +368,16 @@ def relu( relu.jax_like = _relu_jax_like +def _sigmoid_jax_like( + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + fn_original=None, + out: Optional[ivy.Array] = None, +) -> ivy.Array: + return 1 / (1 + ivy.exp(-x)) + + @handle_exceptions @handle_backend_invalid @handle_nestable @@ -454,6 +464,9 @@ def sigmoid( return current_backend(x).sigmoid(x, out=out) +sigmoid.jax_like = _sigmoid_jax_like + + @handle_exceptions @handle_backend_invalid @handle_nestable diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index 846ccbedb7e10..5fa21404fdd66 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -234,6 +234,8 @@ def test_jax_sigmoid( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + rtol=1e-02, + atol=1e-02, x=x[0], ) From c0a70ba58d72e6cd12854f672c0c777b378419f4 Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Fri, 25 Aug 2023 12:01:12 +0000 Subject: [PATCH 26/38] removed complex from unsupported for sigmoid --- ivy/functional/backends/paddle/experimental/activations.py | 2 -- ivy/functional/backends/tensorflow/experimental/activations.py | 1 - ivy/functional/backends/torch/experimental/activations.py | 2 +- 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/ivy/functional/backends/paddle/experimental/activations.py b/ivy/functional/backends/paddle/experimental/activations.py index f290199bfb145..3e1232cca9661 100644 --- a/ivy/functional/backends/paddle/experimental/activations.py +++ b/ivy/functional/backends/paddle/experimental/activations.py @@ -55,8 +55,6 @@ def relu6( ) -> paddle.Tensor: if x.dtype in [paddle.float32, paddle.float64]: return F.relu6(x) - if paddle.is_complex(x): - return paddle.complex(F.relu6(x.real()), F.relu6(x.imag())) return F.relu6(x.cast("float32")).cast(x.dtype) diff --git a/ivy/functional/backends/tensorflow/experimental/activations.py b/ivy/functional/backends/tensorflow/experimental/activations.py index 22ee4d6f45284..1abb806c1e812 100644 --- a/ivy/functional/backends/tensorflow/experimental/activations.py +++ b/ivy/functional/backends/tensorflow/experimental/activations.py @@ -37,7 +37,6 @@ def thresholded_relu( return tf.cast(tf.where(x > threshold, x, 0), x.dtype) -@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) def relu6(x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax") -> Tensor: return tf.nn.relu6(x) diff --git a/ivy/functional/backends/torch/experimental/activations.py b/ivy/functional/backends/torch/experimental/activations.py index 6ae2cec7619ff..0a8de884db013 100644 --- a/ivy/functional/backends/torch/experimental/activations.py +++ b/ivy/functional/backends/torch/experimental/activations.py @@ -32,7 +32,7 @@ def thresholded_relu( return torch.threshold(x, threshold=threshold, value=0) -@with_unsupported_dtypes({"2.0.1 and below": ("float16", "complex")}, backend_version) +@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) def relu6( x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, complex_mode="jax" ) -> torch.Tensor: From e1163ea553fbbe4a98dbab9a44e731a6fb080e44 Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Fri, 25 Aug 2023 14:49:06 +0000 Subject: [PATCH 27/38] removed complex from unsupported for selu and updated jax-like function --- .../paddle/experimental/activations.py | 14 +------ ivy/functional/backends/torch/elementwise.py | 4 +- .../torch/experimental/activations.py | 2 +- .../ivy/experimental/activations.py | 38 ++++++++++++++----- .../test_nn/test_non_linear_activations.py | 4 +- 5 files changed, 35 insertions(+), 27 deletions(-) diff --git a/ivy/functional/backends/paddle/experimental/activations.py b/ivy/functional/backends/paddle/experimental/activations.py index 3e1232cca9661..39df230a5c9e8 100644 --- a/ivy/functional/backends/paddle/experimental/activations.py +++ b/ivy/functional/backends/paddle/experimental/activations.py @@ -76,24 +76,12 @@ def logsigmoid( @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("bfloat16", "complex64", "complex128")}}, + {"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version, ) def selu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: if x.dtype in [paddle.float32, paddle.float64]: return F.selu(x) - if paddle.is_complex(x): - alpha = 1.6732632423543772848170429916717 - scale = 1.0507009873554804934193349852946 - ret = paddle_backend.multiply( - scale, - paddle_backend.where( - paddle_backend.greater(x, 0), - x, - paddle_backend.multiply(alpha, paddle_backend.expm1(x)), - ), - ) - return ret return F.selu(x.cast("float32")).cast(x.dtype) diff --git a/ivy/functional/backends/torch/elementwise.py b/ivy/functional/backends/torch/elementwise.py index 784452d922e1f..3fde2560511a3 100644 --- a/ivy/functional/backends/torch/elementwise.py +++ b/ivy/functional/backends/torch/elementwise.py @@ -68,10 +68,12 @@ def imag( imag.support_native_out = False -@with_unsupported_dtypes({"2.0.1 and below": ("float16", "complex")}, backend_version) +@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) @handle_numpy_arrays_in_specific_backend def expm1(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: x = _cast_for_unary_op(x) + if x.is_complex(): + return torch.exp(x) - 1 return torch.expm1(x, out=out) diff --git a/ivy/functional/backends/torch/experimental/activations.py b/ivy/functional/backends/torch/experimental/activations.py index 0a8de884db013..c0b6296986f83 100644 --- a/ivy/functional/backends/torch/experimental/activations.py +++ b/ivy/functional/backends/torch/experimental/activations.py @@ -48,7 +48,7 @@ def logsigmoid( return torch.nn.functional.logsigmoid(input) -@with_unsupported_dtypes({"2.0.1 and below": ("float16", "complex")}, backend_version) +@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) def selu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: ret = torch.nn.functional.selu(x) if ivy.exists(out): diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index c0afc92c7cbef..1bda09e5d689a 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -194,15 +194,20 @@ def _relu6_jax_like( fn_original=None, out: Optional[ivy.Array] = None, ) -> ivy.Array: - return ivy.where( - ivy.logical_or(ivy.real(x) < 0, ivy.logical_and(ivy.real(x) == 0, ivy.imag(x) < 0)), - ivy.array(0, dtype=x.dtype), - ivy.where( - ivy.logical_or(ivy.real(x) > 6, ivy.logical_and(ivy.real(x) == 6, ivy.imag(x) > 0)), + return ivy.where( + ivy.logical_or( + ivy.real(x) < 0, ivy.logical_and(ivy.real(x) == 0, ivy.imag(x) < 0) + ), + ivy.array(0, dtype=x.dtype), + ivy.where( + ivy.logical_or( + ivy.real(x) > 6, ivy.logical_and(ivy.real(x) == 6, ivy.imag(x) > 0) + ), ivy.array(6, dtype=x.dtype), x, - ), - ) + ), + ) + @handle_exceptions @handle_backend_invalid @@ -327,9 +332,14 @@ def _selu_jax_like( fn_original=None, out: Optional[ivy.Array] = None, ) -> ivy.Array: + """ + Apply jax definition of selu to function + source: + [https://jax.readthedocs.io/en/latest/_modules/jax/_src/nn/functions.html#selu] + """ alpha = 1.6732632423543772848170429916717 scale = 1.0507009873554804934193349852946 - return scale * elu(x, alpha=alpha) + return ivy.multiply(scale, ivy.elu(x, alpha=alpha)) @handle_exceptions @@ -469,11 +479,19 @@ def _elu_jax_like( out: Optional[ivy.Array] = None, ) -> ivy.Array: safe_x = ivy.where( - (x > 0), + ivy.logical_or( + ivy.real(x) < 0, ivy.logical_and(ivy.real(x) == 0, ivy.imag(x) < 0) + ), + x, 0.0, + ) + return ivy.where( + ivy.logical_or( + ivy.real(x) < 0, ivy.logical_and(ivy.real(x) == 0, ivy.imag(x) < 0) + ), + ivy.astype(ivy.multiply(alpha, ivy.expm1(safe_x)), x.dtype), x, ) - return ivy.where((x > 0), x, ivy.astype(alpha * ivy.expm1(safe_x), x.dtype)) @handle_exceptions diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index 5fa21404fdd66..8c948b4094beb 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -621,8 +621,8 @@ def test_jax_elu( on_device=on_device, x=xs[0], alpha=alpha, - rtol=1e-03, - atol=1e-03, + rtol=1e-02, + atol=1e-02, ) From 47d023169d83927bb6de97ee38b7519cc875abf5 Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Fri, 25 Aug 2023 15:04:09 +0000 Subject: [PATCH 28/38] removed complex from unsupported for silu --- .../backends/paddle/experimental/activations.py | 4 +--- .../backends/tensorflow/experimental/activations.py | 1 - .../backends/torch/experimental/activations.py | 2 +- ivy/functional/ivy/experimental/activations.py | 13 +++++++++++++ 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/ivy/functional/backends/paddle/experimental/activations.py b/ivy/functional/backends/paddle/experimental/activations.py index 39df230a5c9e8..5be966f9593af 100644 --- a/ivy/functional/backends/paddle/experimental/activations.py +++ b/ivy/functional/backends/paddle/experimental/activations.py @@ -86,14 +86,12 @@ def selu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("bfloat16", "complex64", "complex128")}}, + {"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version, ) def silu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: if x.dtype in [paddle.float32, paddle.float64]: return F.silu(x) - if paddle.is_complex(x): - return x * (1.0 / (1.0 + paddle_backend.exp(-x))) return F.silu(x.cast("float32")).cast(x.dtype) diff --git a/ivy/functional/backends/tensorflow/experimental/activations.py b/ivy/functional/backends/tensorflow/experimental/activations.py index 1abb806c1e812..fed2929525025 100644 --- a/ivy/functional/backends/tensorflow/experimental/activations.py +++ b/ivy/functional/backends/tensorflow/experimental/activations.py @@ -56,7 +56,6 @@ def selu(x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax") -> T return ivy.astype(ret, x.dtype) -@with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) def silu(x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax") -> Tensor: ret = tf.nn.silu(x) if ivy.exists(out): diff --git a/ivy/functional/backends/torch/experimental/activations.py b/ivy/functional/backends/torch/experimental/activations.py index c0b6296986f83..cc401b4a41cff 100644 --- a/ivy/functional/backends/torch/experimental/activations.py +++ b/ivy/functional/backends/torch/experimental/activations.py @@ -56,7 +56,7 @@ def selu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Ten return ivy.astype(ret, x.dtype) -@with_unsupported_dtypes({"2.0.1 and below": ("float16", "complex")}, backend_version) +@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) def silu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: return torch.nn.functional.silu(x) diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index 1bda09e5d689a..b0bfb3f15d612 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -412,6 +412,16 @@ def selu( selu.jax_like = _selu_jax_like +def _silu_jax_like( + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + fn_original=None, + out: Optional[ivy.Array] = None, +) -> ivy.Array: + return ivy.multiply(x, ivy.sigmoid(x)) + + @handle_exceptions @handle_backend_invalid @handle_nestable @@ -470,6 +480,9 @@ def silu( return current_backend(x).silu(x, out=out) +silu.jax_like = _silu_jax_like + + def _elu_jax_like( x: Union[ivy.Array, ivy.NativeArray], /, From f47257515658e9db8eae8d6e6d9f35dda77e6651 Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Fri, 25 Aug 2023 15:08:08 +0000 Subject: [PATCH 29/38] removed complex from unsupported for elu --- .../backends/paddle/experimental/activations.py | 9 ++++++++- .../backends/torch/experimental/activations.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/ivy/functional/backends/paddle/experimental/activations.py b/ivy/functional/backends/paddle/experimental/activations.py index 5be966f9593af..b16119cb3d149 100644 --- a/ivy/functional/backends/paddle/experimental/activations.py +++ b/ivy/functional/backends/paddle/experimental/activations.py @@ -96,7 +96,14 @@ def silu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("float16", "bfloat16", "complex64", "complex128")}}, + { + "2.5.1 and below": { + "cpu": ( + "float16", + "bfloat16", + ) + } + }, backend_version, ) def elu( diff --git a/ivy/functional/backends/torch/experimental/activations.py b/ivy/functional/backends/torch/experimental/activations.py index cc401b4a41cff..1bd7e8d8fb039 100644 --- a/ivy/functional/backends/torch/experimental/activations.py +++ b/ivy/functional/backends/torch/experimental/activations.py @@ -61,7 +61,7 @@ def silu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Ten return torch.nn.functional.silu(x) -@with_unsupported_dtypes({"2.0.1 and below": ("float16", "complex")}, backend_version) +@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) def elu( x: torch.Tensor, /, From c445b97b6bc063d02178fd15a56453edab37f1f3 Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Fri, 25 Aug 2023 15:19:17 +0000 Subject: [PATCH 30/38] updated some backends as requested by Joe --- ivy/functional/backends/paddle/activations.py | 2 -- .../backends/tensorflow/activations.py | 24 ++----------------- ivy/functional/backends/torch/activations.py | 12 +--------- .../ivy/experimental/activations.py | 4 ++-- 4 files changed, 5 insertions(+), 37 deletions(-) diff --git a/ivy/functional/backends/paddle/activations.py b/ivy/functional/backends/paddle/activations.py index 316d010b3f21c..02f355f0c49cc 100644 --- a/ivy/functional/backends/paddle/activations.py +++ b/ivy/functional/backends/paddle/activations.py @@ -91,8 +91,6 @@ def sigmoid( x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None, complex_mode="jax" ) -> paddle.Tensor: if x.dtype in unsupported_dtypes: - if paddle.is_complex(x): - return 1 / (1 + paddle_backend.exp(-x)) return F.sigmoid(x.cast("float32")).cast(x.dtype) return F.sigmoid(x) diff --git a/ivy/functional/backends/tensorflow/activations.py b/ivy/functional/backends/tensorflow/activations.py index 65f6035b73d5e..d8ad605d31d9e 100644 --- a/ivy/functional/backends/tensorflow/activations.py +++ b/ivy/functional/backends/tensorflow/activations.py @@ -35,17 +35,7 @@ def relu(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: def sigmoid( x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax" ) -> Tensor: - if x.dtype.is_complex: - real_part = tf.math.real(x) - imag_part = tf.math.imag(x) - - real_result = 1 / (1 + tf.exp(-real_part)) - imag_result = 1 / (1 + tf.exp(-imag_part)) - - result = tf.complex(real_result, imag_result) - else: - result = tf.nn.sigmoid(x) - return result + return tf.nn.sigmoid(x) @with_unsupported_dtypes({"2.13.0 and below": ("complex",)}, backend_version) @@ -97,14 +87,4 @@ def mish( def hardswish( x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax" ) -> Tensor: - if x.dtype.is_complex: - real_part = tf.real(x) - imag_part = tf.imag(x) - - real_result = real_part * tf.nn.relu6(real_part + 3) / 6 - imag_result = imag_part * tf.nn.relu6(imag_part + 3) / 6 - - result = tf.complex(real_result, imag_result) - else: - result = x * tf.nn.relu6(x + 3) / 6 - return result + return x * tf.nn.relu6(x + 3) / 6 diff --git a/ivy/functional/backends/torch/activations.py b/ivy/functional/backends/torch/activations.py index c916289e42542..673e5c33b503d 100644 --- a/ivy/functional/backends/torch/activations.py +++ b/ivy/functional/backends/torch/activations.py @@ -131,14 +131,4 @@ def mish(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Ten def hardswish( x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, complex_mode="jax" ) -> torch.Tensor: - if x.dtype.is_complex: - real_part = x.real - imag_part = x.imag - - real_result = real_part * torch.nn.functional.relu6(real_part + 3) / 6 - imag_result = imag_part * torch.nn.functional.relu6(imag_part + 3) / 6 - - result = torch.complex(real_result, imag_result) - else: - result = torch.nn.functional.hardswish(x) - return result + return torch.nn.functional.hardswish(x) diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index b0bfb3f15d612..c81ee9bb4ea6c 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -333,8 +333,8 @@ def _selu_jax_like( out: Optional[ivy.Array] = None, ) -> ivy.Array: """ - Apply jax definition of selu to function - source: + Alpha and scale are taken from jax's implementation of selu + Source: [https://jax.readthedocs.io/en/latest/_modules/jax/_src/nn/functions.html#selu] """ alpha = 1.6732632423543772848170429916717 From 674783123e621dec0a3e481121fa2fa99c7ba3ef Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Sat, 26 Aug 2023 09:42:53 +0000 Subject: [PATCH 31/38] created custom implementation for tensorflow logsigmoid --- ivy/functional/backends/tensorflow/experimental/activations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ivy/functional/backends/tensorflow/experimental/activations.py b/ivy/functional/backends/tensorflow/experimental/activations.py index fed2929525025..39bf62f8f2fda 100644 --- a/ivy/functional/backends/tensorflow/experimental/activations.py +++ b/ivy/functional/backends/tensorflow/experimental/activations.py @@ -41,10 +41,11 @@ def relu6(x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax") -> return tf.nn.relu6(x) -@with_supported_dtypes({"2.13.0 and below": ("float",)}, backend_version) def logsigmoid( input: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax" ) -> Tensor: + if input.dtype in [tf.complex64, tf.complex128]: + return tf.math.log(tf.nn.sigmoid(input)) return tf.math.log_sigmoid(input) From e35a406b359d964d74ad2fc5087dc0ea04f8f411 Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Sat, 26 Aug 2023 10:12:56 +0000 Subject: [PATCH 32/38] added support for complex in tensorflow elu and selu --- .../backends/tensorflow/experimental/activations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy/functional/backends/tensorflow/experimental/activations.py b/ivy/functional/backends/tensorflow/experimental/activations.py index 39bf62f8f2fda..ffce16b1ccc3a 100644 --- a/ivy/functional/backends/tensorflow/experimental/activations.py +++ b/ivy/functional/backends/tensorflow/experimental/activations.py @@ -49,7 +49,7 @@ def logsigmoid( return tf.math.log_sigmoid(input) -@with_supported_dtypes({"2.13.0 and below": ("float",)}, backend_version) +@with_supported_dtypes({"2.13.0 and below": ("float", "complex")}, backend_version) def selu(x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax") -> Tensor: ret = tf.nn.selu(x) if ivy.exists(out): @@ -64,7 +64,7 @@ def silu(x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax") -> T return ivy.astype(ret, x.dtype) -@with_supported_dtypes({"2.13.0 and below": ("float",)}, backend_version) +@with_supported_dtypes({"2.13.0 and below": ("float", "complex")}, backend_version) def elu( x: Tensor, /, From 43e1013b14a76237685161203b41cd73e791efd7 Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Tue, 29 Aug 2023 11:30:53 +0000 Subject: [PATCH 33/38] made changes to elu-jax-like function so as to extract the condition term into its own variable --- ivy/functional/ivy/experimental/activations.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index c81ee9bb4ea6c..1186b40be8df7 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -491,17 +491,12 @@ def _elu_jax_like( alpha: float = 1.0, out: Optional[ivy.Array] = None, ) -> ivy.Array: - safe_x = ivy.where( - ivy.logical_or( - ivy.real(x) < 0, ivy.logical_and(ivy.real(x) == 0, ivy.imag(x) < 0) - ), - x, - 0.0, + cond = ivy.logical_or( + ivy.real(x) < 0, ivy.logical_and(ivy.real(x) == 0, ivy.imag(x) < 0) ) + safe_x = ivy.where(cond, x, 0.0) return ivy.where( - ivy.logical_or( - ivy.real(x) < 0, ivy.logical_and(ivy.real(x) == 0, ivy.imag(x) < 0) - ), + cond, ivy.astype(ivy.multiply(alpha, ivy.expm1(safe_x)), x.dtype), x, ) From a4df5a9d5cfba54181fa380bad0d02b63c8d8c4d Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Fri, 1 Sep 2023 09:55:57 +0000 Subject: [PATCH 34/38] updated files --- .../test_jax/test_nn/test_non_linear_activations.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index 894fb416150b9..31ebf897a4253 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -609,7 +609,6 @@ def test_jax_relu6( ) - @handle_frontend_test( fn_tree="jax.nn.selu", dtype_and_x=helpers.dtype_and_values( From b5fd99584adbdafcdd43b68f2b362ce396debef4 Mon Sep 17 00:00:00 2001 From: Moses Daudu Date: Fri, 1 Sep 2023 10:00:26 +0000 Subject: [PATCH 35/38] removed duplicated file --- ivy/functional/frontends/jax/nn/non_linear_activations.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/ivy/functional/frontends/jax/nn/non_linear_activations.py b/ivy/functional/frontends/jax/nn/non_linear_activations.py index 5f726edf149fd..9dd47cfbc2bdc 100644 --- a/ivy/functional/frontends/jax/nn/non_linear_activations.py +++ b/ivy/functional/frontends/jax/nn/non_linear_activations.py @@ -278,7 +278,7 @@ def relu6(x): @to_ivy_arrays_and_back def selu(x): x = _type_conversion_64(x) - return ivy.selu(x) + return ivy.selu(x, complex_mode="jax") @to_ivy_arrays_and_back @@ -316,12 +316,6 @@ def softplus(x): return ivy.softplus(x, complex_mode="jax").astype(x.dtype) -@to_ivy_arrays_and_back -def selu(x): - x = _type_conversion_64(x) - return ivy.selu(x, complex_mode="jax") - - @to_ivy_arrays_and_back def swish(x): ret = x / (1 + ivy.exp(-x)) From e59e784b111144ae50183c546e0d343ae14ee6a4 Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Wed, 6 Sep 2023 13:20:41 +0100 Subject: [PATCH 36/38] fix: silu and sigmoid now work in paddle, remove jax_like And a couple other bugs that were causing things to break, such as removing duplicated atol from `test_jax_hard_swish` and adding complex dtype to some tests. These functions now pass all tests in the jax frontend, except sigmoid and silu which fail for tf backend due to a known issue --- ivy/functional/backends/paddle/activations.py | 4 +++- .../backends/paddle/experimental/activations.py | 2 ++ ivy/functional/ivy/activations.py | 15 +-------------- ivy/functional/ivy/experimental/activations.py | 13 ------------- .../test_nn/test_non_linear_activations.py | 6 ++---- 5 files changed, 8 insertions(+), 32 deletions(-) diff --git a/ivy/functional/backends/paddle/activations.py b/ivy/functional/backends/paddle/activations.py index 8c879603db38e..96a1726b86b19 100644 --- a/ivy/functional/backends/paddle/activations.py +++ b/ivy/functional/backends/paddle/activations.py @@ -92,6 +92,8 @@ def gelu( def sigmoid( x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None, complex_mode="jax" ) -> paddle.Tensor: + if paddle.is_complex(x): + return 1.0 / (1.0 + paddle_backend.exp(-x)) if x.dtype in unsupported_dtypes: return F.sigmoid(x.cast("float32")).cast(x.dtype) return F.sigmoid(x) @@ -195,7 +197,7 @@ def mish(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("float16", "bfloat16", "complex64", "complex128")}}, + {"2.5.1 and below": {"cpu": ("float16", "bfloat16")}}, backend_version, ) def hardswish( diff --git a/ivy/functional/backends/paddle/experimental/activations.py b/ivy/functional/backends/paddle/experimental/activations.py index b16119cb3d149..816bd492d53fc 100644 --- a/ivy/functional/backends/paddle/experimental/activations.py +++ b/ivy/functional/backends/paddle/experimental/activations.py @@ -90,6 +90,8 @@ def selu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. backend_version, ) def silu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: + if paddle.is_complex(x): + return paddle.multiply(x, paddle_backend.sigmoid(x)) if x.dtype in [paddle.float32, paddle.float64]: return F.silu(x) return F.silu(x.cast("float32")).cast(x.dtype) diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index 6fd1ab2d987b5..c5e10018c9bb6 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -366,16 +366,6 @@ def relu( relu.jax_like = _relu_jax_like -def _sigmoid_jax_like( - x: Union[ivy.Array, ivy.NativeArray], - /, - *, - fn_original=None, - out: Optional[ivy.Array] = None, -) -> ivy.Array: - return 1 / (1 + ivy.exp(-x)) - - @handle_exceptions @handle_backend_invalid @handle_nestable @@ -462,9 +452,6 @@ def sigmoid( return current_backend(x).sigmoid(x, out=out) -sigmoid.jax_like = _sigmoid_jax_like - - @handle_exceptions @handle_backend_invalid @handle_nestable @@ -737,7 +724,7 @@ def _hardswish_jax_like( def hard_sigmoid(x): return ivy.relu6(x + 3.0) / 6 - return x * hard_sigmoid(x) + return ivy.multiply(x, hard_sigmoid(x)) @handle_exceptions diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index 1186b40be8df7..bb76215c2ccb0 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -412,16 +412,6 @@ def selu( selu.jax_like = _selu_jax_like -def _silu_jax_like( - x: Union[ivy.Array, ivy.NativeArray], - /, - *, - fn_original=None, - out: Optional[ivy.Array] = None, -) -> ivy.Array: - return ivy.multiply(x, ivy.sigmoid(x)) - - @handle_exceptions @handle_backend_invalid @handle_nestable @@ -480,9 +470,6 @@ def silu( return current_backend(x).silu(x, out=out) -silu.jax_like = _silu_jax_like - - def _elu_jax_like( x: Union[ivy.Array, ivy.NativeArray], /, diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index 31ebf897a4253..634f5161d2b12 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -271,8 +271,6 @@ def test_jax_hard_swish( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-02, - atol=1e-02, x=x[0], ) @@ -612,7 +610,7 @@ def test_jax_relu6( @handle_frontend_test( fn_tree="jax.nn.selu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_integer"), + available_dtypes=helpers.get_dtypes("numeric"), large_abs_safety_factor=2, small_abs_safety_factor=2, safety_factor_scale="log", @@ -645,7 +643,7 @@ def test_jax_selu( @handle_frontend_test( fn_tree="jax.nn.sigmoid", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=2, small_abs_safety_factor=2, safety_factor_scale="linear", From 4bb0291596608d6c756e7b25d627502f63fca198 Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Wed, 6 Sep 2023 14:08:20 +0100 Subject: [PATCH 37/38] fix: update supported dtypes Update supported dtypes for certain functions, and change the safety factor for logsigmoid. All tests now pass locally, except problems caused by tensorflow's sigmoid and a problem with the test for elu (which seems to feed in the wrong values). --- .../backends/tensorflow/experimental/activations.py | 1 + ivy/functional/backends/torch/experimental/activations.py | 4 +++- ivy/functional/ivy/activations.py | 2 +- .../test_experimental/test_nn/test_activations.py | 5 +++-- .../test_ivy/test_functional/test_nn/test_activations.py | 4 ++-- 5 files changed, 10 insertions(+), 6 deletions(-) diff --git a/ivy/functional/backends/tensorflow/experimental/activations.py b/ivy/functional/backends/tensorflow/experimental/activations.py index ffce16b1ccc3a..81b81f37049b5 100644 --- a/ivy/functional/backends/tensorflow/experimental/activations.py +++ b/ivy/functional/backends/tensorflow/experimental/activations.py @@ -41,6 +41,7 @@ def relu6(x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax") -> return tf.nn.relu6(x) +@with_supported_dtypes({"2.13.0 and below": ("float", "complex")}, backend_version) def logsigmoid( input: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax" ) -> Tensor: diff --git a/ivy/functional/backends/torch/experimental/activations.py b/ivy/functional/backends/torch/experimental/activations.py index 1bd7e8d8fb039..39cc9eb7bd7d8 100644 --- a/ivy/functional/backends/torch/experimental/activations.py +++ b/ivy/functional/backends/torch/experimental/activations.py @@ -39,7 +39,9 @@ def relu6( return torch.nn.functional.relu6(x) -@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) +@with_unsupported_dtypes( + {"2.0.1 and below": ("float16", "int16", "int32", "int64", "bool")}, backend_version +) def logsigmoid( input: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, complex_mode="jax" ) -> torch.Tensor: diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index c5e10018c9bb6..b2ebb11cf8f7a 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -724,7 +724,7 @@ def _hardswish_jax_like( def hard_sigmoid(x): return ivy.relu6(x + 3.0) / 6 - return ivy.multiply(x, hard_sigmoid(x)) + return ivy.multiply(x, hard_sigmoid(x).astype(x.dtype)) @handle_exceptions diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py index 82495ae40aba1..90b67f391b841 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py @@ -10,7 +10,7 @@ @handle_test( fn_tree="functional.ivy.experimental.elu", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", @@ -67,8 +67,9 @@ def test_logit(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): fn_tree="functional.ivy.experimental.logsigmoid", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("valid"), - safety_factor_scale="log", + small_abs_safety_factor=2, large_abs_safety_factor=120, + safety_factor_scale="log", ), test_with_out=st.just(False), ) diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py index 2399daf623170..6f60cf5168b16 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py @@ -48,7 +48,7 @@ def test_gelu( @handle_test( fn_tree="functional.ivy.hardswish", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", @@ -184,7 +184,7 @@ def test_relu(*, dtype_and_x, complex_mode, test_flags, backend_fw, fn_name, on_ @handle_test( fn_tree="functional.ivy.sigmoid", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", From b7a0daf712301cfdfe87ee7e4f488be8b984b043 Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Wed, 6 Sep 2023 14:25:58 +0100 Subject: [PATCH 38/38] style: update docstrings and function signatures for complex_mode --- ivy/data_classes/array/activations.py | 20 ++--- .../array/experimental/activations.py | 40 +++++----- ivy/data_classes/container/activations.py | 40 +++++----- .../container/experimental/activations.py | 80 +++++++++---------- ivy/functional/backends/jax/activations.py | 4 +- .../backends/jax/experimental/activations.py | 10 +-- ivy/functional/backends/mxnet/activations.py | 2 +- .../mxnet/experimental/activations.py | 6 +- ivy/functional/backends/numpy/activations.py | 4 +- .../numpy/experimental/activations.py | 10 +-- ivy/functional/backends/paddle/activations.py | 4 +- .../paddle/experimental/activations.py | 14 ++-- .../backends/tensorflow/activations.py | 4 +- .../tensorflow/experimental/activations.py | 10 +-- ivy/functional/backends/torch/activations.py | 4 +- .../torch/experimental/activations.py | 14 ++-- ivy/functional/ivy/activations.py | 16 ++-- .../ivy/experimental/activations.py | 36 ++++----- ivy/stateful/activations.py | 12 +-- 19 files changed, 169 insertions(+), 161 deletions(-) diff --git a/ivy/data_classes/array/activations.py b/ivy/data_classes/array/activations.py index da4df0b304130..e14adaf5c7f68 100644 --- a/ivy/data_classes/array/activations.py +++ b/ivy/data_classes/array/activations.py @@ -135,8 +135,8 @@ def sigmoid( self: ivy.Array, /, *, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ ivy.Array instance method variant of ivy.sigmoid. @@ -148,12 +148,12 @@ def sigmoid( ---------- self Input array - out - optional output array for writing the result to. It must have the same shape - the input broadcast to default: None complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output array for writing the result to. It must have the same shape + the input broadcast to default: None Returns ------- @@ -168,7 +168,7 @@ def sigmoid( >>> print(y) ivy.array([0.269, 0.731, 0.881]) """ - return ivy.sigmoid(self._data, out=out, complex_mode=complex_mode) + return ivy.sigmoid(self._data, complex_mode=complex_mode, out=out) def softmax( self: ivy.Array, @@ -331,8 +331,8 @@ def hardswish( self: ivy.Array, /, *, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply the hardswish activation function element-wise. @@ -341,12 +341,12 @@ def hardswish( ---------- x input array - out - optional output array, for writing the result to. It must have - a shape that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output array, for writing the result to. It must have + a shape that the inputs broadcast to. Returns ------- @@ -372,4 +372,4 @@ def hardswish( b: ivy.array([0., 5.]) } """ - return ivy.hardswish(self._data, out=out, complex_mode=complex_mode) + return ivy.hardswish(self._data, complex_mode=complex_mode, out=out) diff --git a/ivy/data_classes/array/experimental/activations.py b/ivy/data_classes/array/experimental/activations.py index b93aec38b0e52..4c1dcf5725d18 100644 --- a/ivy/data_classes/array/experimental/activations.py +++ b/ivy/data_classes/array/experimental/activations.py @@ -119,8 +119,8 @@ def relu6( self, /, *, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply the rectified linear unit 6 function element-wise. @@ -129,12 +129,12 @@ def relu6( ---------- self input array - out - optional output array, for writing the result to. - It must have a shape that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output array, for writing the result to. + It must have a shape that the inputs broadcast to. Returns ------- @@ -157,7 +157,7 @@ def relu6( >>> print(y) ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.]) """ - return ivy.relu6(self._data, out=out, complex_mode=complex_mode) + return ivy.relu6(self._data, complex_mode=complex_mode, out=out) def logsigmoid( self: ivy.Array, @@ -198,8 +198,8 @@ def selu( self, /, *, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply the scaled exponential linear unit function element-wise. @@ -208,12 +208,12 @@ def selu( ---------- self input array - out - optional output array, for writing the result to. - It must have a shape that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output array, for writing the result to. + It must have a shape that the inputs broadcast to. Returns ------- @@ -238,14 +238,14 @@ def selu( ivy.array([-1.11133075, 0., 1.05070102, 2.10140204, 3.15210295, 4.20280409, 5.25350523, 6.30420589, 7.35490704]) """ - return ivy.selu(self._data, out=out, complex_mode=complex_mode) + return ivy.selu(self._data, complex_mode=complex_mode, out=out) def silu( self: ivy.Array, /, *, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ ivy.Array instance method variant of ivy.silu. This method simply wraps the @@ -256,12 +256,12 @@ def silu( ---------- self input array. - out - optional output array, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output array, for writing the result to. It must have a shape + that the inputs broadcast to. Examples -------- @@ -270,15 +270,15 @@ def silu( >>> print(y) ivy.array([-0.26894143, 0. , 0.73105854]) """ - return ivy.silu(self._data, out=out, complex_mode=complex_mode) + return ivy.silu(self._data, complex_mode=complex_mode, out=out) def elu( self, /, *, alpha: float = 1.0, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Ivy.Array instance method variant of ivy.elu. This method simply wraps the @@ -291,12 +291,12 @@ def elu( input array. alpha scaler for controlling the slope of the function for x <= 0 Default: 1.0 - out - optional output array, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output array, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -310,4 +310,4 @@ def elu( >>> print(y) ivy.array([ 0.39, -0.57]) """ - return ivy.elu(self._data, alpha=alpha, out=out, complex_mode=complex_mode) + return ivy.elu(self._data, alpha=alpha, complex_mode=complex_mode, out=out) diff --git a/ivy/data_classes/container/activations.py b/ivy/data_classes/container/activations.py index 714dcd3eb2d28..97d95156b68be 100644 --- a/ivy/data_classes/container/activations.py +++ b/ivy/data_classes/container/activations.py @@ -419,8 +419,8 @@ def _static_sigmoid( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container static method variant of ivy.sigmoid. This method simply wraps the @@ -442,12 +442,12 @@ def _static_sigmoid( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -471,8 +471,8 @@ def _static_sigmoid( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) def sigmoid( @@ -483,8 +483,8 @@ def sigmoid( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container instance method variant of ivy.sigmoid. This method simply wraps @@ -506,12 +506,12 @@ def sigmoid( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -534,8 +534,8 @@ def sigmoid( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) @staticmethod @@ -1090,8 +1090,8 @@ def _static_hardswish( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container static method variant of ivy.hardswish. This method simply wraps @@ -1113,12 +1113,12 @@ def _static_hardswish( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -1143,8 +1143,8 @@ def _static_hardswish( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) def hardswish( @@ -1155,8 +1155,8 @@ def hardswish( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container instance method variant of ivy.hardswish. This method simply wraps @@ -1178,12 +1178,12 @@ def hardswish( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -1207,6 +1207,6 @@ def hardswish( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) diff --git a/ivy/data_classes/container/experimental/activations.py b/ivy/data_classes/container/experimental/activations.py index 975cd84c2737b..e96a05a066487 100644 --- a/ivy/data_classes/container/experimental/activations.py +++ b/ivy/data_classes/container/experimental/activations.py @@ -320,8 +320,8 @@ def static_relu6( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container static method variant of ivy.relu6. This method simply wraps the @@ -343,12 +343,12 @@ def static_relu6( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -374,8 +374,8 @@ def static_relu6( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) def relu6( @@ -386,8 +386,8 @@ def relu6( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container instance method variant of ivy.relu6. This method simply wraps the @@ -409,12 +409,12 @@ def relu6( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -439,8 +439,8 @@ def relu6( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) @staticmethod @@ -568,8 +568,8 @@ def static_selu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container static method variant of ivy.selu. This method simply wraps the @@ -591,12 +591,12 @@ def static_selu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -621,8 +621,8 @@ def static_selu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) def selu( @@ -633,8 +633,8 @@ def selu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container instance method variant of ivy.selu. This method simply wraps the @@ -656,12 +656,12 @@ def selu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -685,8 +685,8 @@ def selu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) @staticmethod @@ -698,8 +698,8 @@ def _static_silu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container static method variant of ivy.silu. This method simply wraps the @@ -721,12 +721,12 @@ def _static_silu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -751,8 +751,8 @@ def _static_silu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) def silu( @@ -763,8 +763,8 @@ def silu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container instance method variant of ivy.silu. This method simply wraps the @@ -786,12 +786,12 @@ def silu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -815,8 +815,8 @@ def silu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) @staticmethod @@ -829,8 +829,8 @@ def _static_elu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container static method variant of ivy.elu. This method simply wraps the @@ -854,12 +854,12 @@ def _static_elu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -884,8 +884,8 @@ def _static_elu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) def elu( @@ -897,8 +897,8 @@ def elu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container instance method variant of ivy.elu. This method simply wraps the @@ -922,12 +922,12 @@ def elu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -951,6 +951,6 @@ def elu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) diff --git a/ivy/functional/backends/jax/activations.py b/ivy/functional/backends/jax/activations.py index 27e5a1a5fd9da..d787751e0fa74 100644 --- a/ivy/functional/backends/jax/activations.py +++ b/ivy/functional/backends/jax/activations.py @@ -41,7 +41,7 @@ def relu( def sigmoid( - x: JaxArray, /, *, out: Optional[JaxArray] = None, complex_mode="jax" + x: JaxArray, /, *, complex_mode="jax", out: Optional[JaxArray] = None ) -> JaxArray: return 1 / (1 + jnp.exp(-x)) @@ -100,6 +100,6 @@ def mish(x: JaxArray, /, *, out: Optional[JaxArray] = None): def hardswish( - x: JaxArray, /, *, out: Optional[JaxArray] = None, complex_mode="jax" + x: JaxArray, /, *, complex_mode="jax", out: Optional[JaxArray] = None ) -> JaxArray: return jax.nn.hard_swish(x) diff --git a/ivy/functional/backends/jax/experimental/activations.py b/ivy/functional/backends/jax/experimental/activations.py index 6494831629f5f..879fa2864ccf8 100644 --- a/ivy/functional/backends/jax/experimental/activations.py +++ b/ivy/functional/backends/jax/experimental/activations.py @@ -23,7 +23,7 @@ def logit( def relu6( - x: JaxArray, /, *, out: Optional[JaxArray] = None, complex_mode="jax" + x: JaxArray, /, *, complex_mode="jax", out: Optional[JaxArray] = None ) -> JaxArray: relu6_func = jax.nn.relu6 @@ -51,13 +51,13 @@ def thresholded_relu( def logsigmoid( - input: JaxArray, /, *, out: Optional[JaxArray] = None, complex_mode="jax" + input: JaxArray, /, *, complex_mode="jax", out: Optional[JaxArray] = None ) -> JaxArray: return jax.nn.log_sigmoid(input) def selu( - x: JaxArray, /, *, out: Optional[JaxArray] = None, complex_mode="jax" + x: JaxArray, /, *, complex_mode="jax", out: Optional[JaxArray] = None ) -> JaxArray: ret = jax.nn.selu(x).astype(x.dtype) if ivy.exists(out): @@ -66,7 +66,7 @@ def selu( def silu( - x: JaxArray, /, *, out: Optional[JaxArray] = None, complex_mode="jax" + x: JaxArray, /, *, complex_mode="jax", out: Optional[JaxArray] = None ) -> JaxArray: ret = jax.nn.silu(x) if ivy.exists(out): @@ -79,8 +79,8 @@ def elu( /, *, alpha: float = 1.0, - out: Optional[JaxArray] = None, complex_mode="jax", + out: Optional[JaxArray] = None, ) -> JaxArray: ret = jax.nn.elu(x, alpha) if ivy.exists(out): diff --git a/ivy/functional/backends/mxnet/activations.py b/ivy/functional/backends/mxnet/activations.py index 2348bf5cdced4..31cf6fe9a648d 100644 --- a/ivy/functional/backends/mxnet/activations.py +++ b/ivy/functional/backends/mxnet/activations.py @@ -36,7 +36,7 @@ def relu(x: None, /, *, complex_mode="jax", out: Optional[None] = None) -> None: return mx.nd.relu(x) -def sigmoid(x: None, /, *, out: Optional[None] = None, complex_mode="jax") -> None: +def sigmoid(x: None, /, *, complex_mode="jax", out: Optional[None] = None) -> None: return mx.nd.sigmoid(x) diff --git a/ivy/functional/backends/mxnet/experimental/activations.py b/ivy/functional/backends/mxnet/experimental/activations.py index ef52d70dbccd6..c9ae4ee756ddb 100644 --- a/ivy/functional/backends/mxnet/experimental/activations.py +++ b/ivy/functional/backends/mxnet/experimental/activations.py @@ -20,7 +20,7 @@ def thresholded_relu( raise IvyNotImplementedException() -def relu6(x: None, /, *, out: Optional[None] = None, complex_mode="jax") -> None: +def relu6(x: None, /, *, complex_mode="jax", out: Optional[None] = None) -> None: raise IvyNotImplementedException() @@ -28,9 +28,9 @@ def logsigmoid(input: None, complex_mode="jax") -> None: raise IvyNotImplementedException() -def selu(x: None, /, *, out: Optional[None] = None, complex_mode="jax") -> None: +def selu(x: None, /, *, complex_mode="jax", out: Optional[None] = None) -> None: raise IvyNotImplementedException() -def silu(x: None, /, *, out: Optional[None] = None, complex_mode="jax") -> None: +def silu(x: None, /, *, complex_mode="jax", out: Optional[None] = None) -> None: raise IvyNotImplementedException() diff --git a/ivy/functional/backends/numpy/activations.py b/ivy/functional/backends/numpy/activations.py index f85c5a2cabbb7..a5804cedc97ad 100644 --- a/ivy/functional/backends/numpy/activations.py +++ b/ivy/functional/backends/numpy/activations.py @@ -47,7 +47,7 @@ def gelu( def sigmoid( - x: np.ndarray, /, *, out: Optional[np.ndarray] = None, complex_mode="jax" + x: np.ndarray, /, *, complex_mode="jax", out: Optional[np.ndarray] = None ) -> np.ndarray: if not ivy.is_array(x): return np.asarray(1 / (1 + np.exp(-x))) @@ -142,7 +142,7 @@ def mish(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: @_scalar_output_to_0d_array def hardswish( - x: np.ndarray, /, *, out: Optional[np.ndarray] = None, complex_mode="jax" + x: np.ndarray, /, *, complex_mode="jax", out: Optional[np.ndarray] = None ) -> np.ndarray: max_x_3 = np.maximum(x + 3, 0, dtype=x.dtype) return (x * np.minimum(max_x_3, 6, out=out, dtype=x.dtype) / 6).astype(x.dtype) diff --git a/ivy/functional/backends/numpy/experimental/activations.py b/ivy/functional/backends/numpy/experimental/activations.py index b222a84a75167..0b93fe8bafb91 100644 --- a/ivy/functional/backends/numpy/experimental/activations.py +++ b/ivy/functional/backends/numpy/experimental/activations.py @@ -44,7 +44,7 @@ def thresholded_relu( @_scalar_output_to_0d_array def relu6( - x: np.ndarray, /, *, out: Optional[np.ndarray] = None, complex_mode="jax" + x: np.ndarray, /, *, complex_mode="jax", out: Optional[np.ndarray] = None ) -> np.ndarray: return np.minimum(np.maximum(x, 0, dtype=x.dtype), 6, out=out, dtype=x.dtype) @@ -55,14 +55,14 @@ def relu6( @with_unsupported_dtypes({"1.25.2 and below": ("bool",)}, backend_version) @_scalar_output_to_0d_array def logsigmoid( - input: np.ndarray, /, *, out: Optional[np.ndarray] = None, complex_mode="jax" + input: np.ndarray, /, *, complex_mode="jax", out: Optional[np.ndarray] = None ) -> np.ndarray: return -(np.log1p(np.exp(-(input)))) @_scalar_output_to_0d_array def selu( - x: np.ndarray, /, *, out: Optional[np.ndarray] = None, complex_mode="jax" + x: np.ndarray, /, *, complex_mode="jax", out: Optional[np.ndarray] = None ) -> np.ndarray: alpha = 1.6732632423543772848170429916717 scale = 1.0507009873554804934193349852946 @@ -77,7 +77,7 @@ def selu( @_scalar_output_to_0d_array def silu( - x: np.ndarray, /, *, out: Optional[np.ndarray] = None, complex_mode="jax" + x: np.ndarray, /, *, complex_mode="jax", out: Optional[np.ndarray] = None ) -> np.ndarray: ret = np.asarray(x * (1 / (1 + np.exp(-x)))) if ivy.exists(out): @@ -97,8 +97,8 @@ def elu( /, *, alpha: float = 1.0, - out: Optional[np.ndarray] = None, complex_mode="jax", + out: Optional[np.ndarray] = None, ) -> np.ndarray: # exp = np.expm1(x) ret = np.where(x > 0, x, np.multiply(alpha, np.expm1(x))).astype(x.dtype) diff --git a/ivy/functional/backends/paddle/activations.py b/ivy/functional/backends/paddle/activations.py index 96a1726b86b19..ce6e6dcf055a5 100644 --- a/ivy/functional/backends/paddle/activations.py +++ b/ivy/functional/backends/paddle/activations.py @@ -90,7 +90,7 @@ def gelu( {"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version ) def sigmoid( - x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None, complex_mode="jax" + x: paddle.Tensor, /, *, complex_mode="jax", out: Optional[paddle.Tensor] = None ) -> paddle.Tensor: if paddle.is_complex(x): return 1.0 / (1.0 + paddle_backend.exp(-x)) @@ -201,6 +201,6 @@ def mish(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. backend_version, ) def hardswish( - x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None, complex_mode="jax" + x: paddle.Tensor, /, *, complex_mode="jax", out: Optional[paddle.Tensor] = None ) -> paddle.Tensor: return F.hardswish(x) diff --git a/ivy/functional/backends/paddle/experimental/activations.py b/ivy/functional/backends/paddle/experimental/activations.py index 816bd492d53fc..ffb77df42c8eb 100644 --- a/ivy/functional/backends/paddle/experimental/activations.py +++ b/ivy/functional/backends/paddle/experimental/activations.py @@ -51,7 +51,7 @@ def thresholded_relu( {"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version ) def relu6( - x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None, complex_mode="jax" + x: paddle.Tensor, /, *, complex_mode="jax", out: Optional[paddle.Tensor] = None ) -> paddle.Tensor: if x.dtype in [paddle.float32, paddle.float64]: return F.relu6(x) @@ -62,7 +62,7 @@ def relu6( {"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version ) def logsigmoid( - input: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None, complex_mode="jax" + input: paddle.Tensor, /, *, complex_mode="jax", out: Optional[paddle.Tensor] = None ) -> paddle.Tensor: if input.dtype in [paddle.float32, paddle.float64]: return F.log_sigmoid(input) @@ -79,7 +79,9 @@ def logsigmoid( {"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version, ) -def selu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: +def selu( + x: paddle.Tensor, /, *, complex_mode="jax", out: Optional[paddle.Tensor] = None +) -> paddle.Tensor: if x.dtype in [paddle.float32, paddle.float64]: return F.selu(x) return F.selu(x.cast("float32")).cast(x.dtype) @@ -89,7 +91,9 @@ def selu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. {"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version, ) -def silu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: +def silu( + x: paddle.Tensor, /, *, complex_mode="jax", out: Optional[paddle.Tensor] = None +) -> paddle.Tensor: if paddle.is_complex(x): return paddle.multiply(x, paddle_backend.sigmoid(x)) if x.dtype in [paddle.float32, paddle.float64]: @@ -113,8 +117,8 @@ def elu( /, *, alpha: float = 1.0, - out: Optional[paddle.Tensor] = None, complex_mode="jax", + out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: if x.dtype in [paddle.float32, paddle.float64]: return F.elu(x, alpha=alpha) diff --git a/ivy/functional/backends/tensorflow/activations.py b/ivy/functional/backends/tensorflow/activations.py index 316869c3a5657..8bc44e26dad1f 100644 --- a/ivy/functional/backends/tensorflow/activations.py +++ b/ivy/functional/backends/tensorflow/activations.py @@ -45,7 +45,7 @@ def relu(x: Tensor, /, *, complex_mode="jax", out: Optional[Tensor] = None) -> T def sigmoid( - x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax" + x: Tensor, /, *, complex_mode="jax", out: Optional[Tensor] = None ) -> Tensor: return tf.nn.sigmoid(x) @@ -126,6 +126,6 @@ def mish( def hardswish( - x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax" + x: Tensor, /, *, complex_mode="jax", out: Optional[Tensor] = None ) -> Tensor: return x * tf.nn.relu6(x + 3) / 6 diff --git a/ivy/functional/backends/tensorflow/experimental/activations.py b/ivy/functional/backends/tensorflow/experimental/activations.py index 81b81f37049b5..3ea74a6276125 100644 --- a/ivy/functional/backends/tensorflow/experimental/activations.py +++ b/ivy/functional/backends/tensorflow/experimental/activations.py @@ -37,13 +37,13 @@ def thresholded_relu( return tf.cast(tf.where(x > threshold, x, 0), x.dtype) -def relu6(x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax") -> Tensor: +def relu6(x: Tensor, /, *, complex_mode="jax", out: Optional[Tensor] = None) -> Tensor: return tf.nn.relu6(x) @with_supported_dtypes({"2.13.0 and below": ("float", "complex")}, backend_version) def logsigmoid( - input: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax" + input: Tensor, /, *, complex_mode="jax", out: Optional[Tensor] = None ) -> Tensor: if input.dtype in [tf.complex64, tf.complex128]: return tf.math.log(tf.nn.sigmoid(input)) @@ -51,14 +51,14 @@ def logsigmoid( @with_supported_dtypes({"2.13.0 and below": ("float", "complex")}, backend_version) -def selu(x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax") -> Tensor: +def selu(x: Tensor, /, *, complex_mode="jax", out: Optional[Tensor] = None) -> Tensor: ret = tf.nn.selu(x) if ivy.exists(out): return ivy.inplace_update(out, ret).astype(x.dtype) return ivy.astype(ret, x.dtype) -def silu(x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax") -> Tensor: +def silu(x: Tensor, /, *, complex_mode="jax", out: Optional[Tensor] = None) -> Tensor: ret = tf.nn.silu(x) if ivy.exists(out): return ivy.inplace_update(out, ret).astype(x.dtype) @@ -71,8 +71,8 @@ def elu( /, *, alpha: float = 1.0, - out: Optional[Tensor] = None, complex_mode="jax", + out: Optional[Tensor] = None, ) -> Tensor: alpha = tf.cast(alpha, x.dtype) ret = tf.cast(tf.where(x > 0, x, tf.multiply(alpha, tf.math.expm1(x))), x.dtype) diff --git a/ivy/functional/backends/torch/activations.py b/ivy/functional/backends/torch/activations.py index c5c94bd9b00ce..8bc859692e448 100644 --- a/ivy/functional/backends/torch/activations.py +++ b/ivy/functional/backends/torch/activations.py @@ -54,7 +54,7 @@ def gelu( @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) def sigmoid( - x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, complex_mode="jax" + x: torch.Tensor, /, *, complex_mode="jax", out: Optional[torch.Tensor] = None ) -> torch.Tensor: if not ivy.is_array(x): x = torch.tensor(x) @@ -142,6 +142,6 @@ def mish(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Ten backend_version, ) def hardswish( - x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, complex_mode="jax" + x: torch.Tensor, /, *, complex_mode="jax", out: Optional[torch.Tensor] = None ) -> torch.Tensor: return torch.nn.functional.hardswish(x) diff --git a/ivy/functional/backends/torch/experimental/activations.py b/ivy/functional/backends/torch/experimental/activations.py index 39cc9eb7bd7d8..d0a250819933e 100644 --- a/ivy/functional/backends/torch/experimental/activations.py +++ b/ivy/functional/backends/torch/experimental/activations.py @@ -34,7 +34,7 @@ def thresholded_relu( @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) def relu6( - x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, complex_mode="jax" + x: torch.Tensor, /, *, complex_mode="jax", out: Optional[torch.Tensor] = None ) -> torch.Tensor: return torch.nn.functional.relu6(x) @@ -43,7 +43,7 @@ def relu6( {"2.0.1 and below": ("float16", "int16", "int32", "int64", "bool")}, backend_version ) def logsigmoid( - input: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, complex_mode="jax" + input: torch.Tensor, /, *, complex_mode="jax", out: Optional[torch.Tensor] = None ) -> torch.Tensor: if torch.is_complex(input): return torch.log(torch.sigmoid(input)) @@ -51,7 +51,9 @@ def logsigmoid( @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) -def selu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: +def selu( + x: torch.Tensor, /, *, complex_mode="jax", out: Optional[torch.Tensor] = None +) -> torch.Tensor: ret = torch.nn.functional.selu(x) if ivy.exists(out): return ivy.inplace_update(out, ret).astype(x.dtype) @@ -59,7 +61,9 @@ def selu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Ten @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) -def silu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: +def silu( + x: torch.Tensor, /, *, complex_mode="jax", out: Optional[torch.Tensor] = None +) -> torch.Tensor: return torch.nn.functional.silu(x) @@ -69,8 +73,8 @@ def elu( /, *, alpha: float = 1.0, - out: Optional[torch.Tensor] = None, complex_mode="jax", + out: Optional[torch.Tensor] = None, ) -> torch.Tensor: ret = torch.nn.functional.elu(x, alpha) if ivy.exists(out): diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index b2ebb11cf8f7a..ba904e70c09d7 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -379,8 +379,8 @@ def sigmoid( x: Union[ivy.Array, ivy.NativeArray], /, *, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply the sigmoid function element-wise. @@ -389,13 +389,13 @@ def sigmoid( ---------- x input array. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output array, for writing the result to. It must have a shape that the input broadcast to. default: None - complex_mode - optional specifier for how to handle complex data types. See - `ivy.func_wrapper.handle_complex_input` for more detail. Returns ------- @@ -739,8 +739,8 @@ def hardswish( x: Union[ivy.Array, ivy.NativeArray], /, *, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply the hardswish activation function element-wise. @@ -749,12 +749,12 @@ def hardswish( ---------- x input array + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. - complex_mode - optional specifier for how to handle complex data types. See - `ivy.func_wrapper.handle_complex_input` for more detail. Returns ------- diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index bb76215c2ccb0..07a561ca50bca 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -222,8 +222,8 @@ def relu6( x: Union[ivy.Array, ivy.NativeArray], /, *, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply the rectified linear unit 6 function element-wise. @@ -232,12 +232,12 @@ def relu6( ---------- x input array + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. - complex_mode - optional specifier for how to handle complex data types. See - `ivy.func_wrapper.handle_complex_input` for more detail. Returns ------- @@ -278,8 +278,8 @@ def logsigmoid( input: Union[ivy.NativeArray, ivy.Array], /, *, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply element-wise Log-sigmoid of x. @@ -292,7 +292,7 @@ def logsigmoid( Input array. complex_mode optional specifier for how to handle complex data types. See - `ivy.func_wrapper.handle_complex_input` for more detail. + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -355,8 +355,8 @@ def selu( x: Union[ivy.Array, ivy.NativeArray], /, *, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply the scaled exponential linear unit function element-wise. @@ -365,12 +365,12 @@ def selu( ---------- x input array + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. - complex_mode - optional specifier for how to handle complex data types. See - `ivy.func_wrapper.handle_complex_input` for more detail. Returns ------- @@ -425,8 +425,8 @@ def silu( x: Union[ivy.Array, ivy.NativeArray], /, *, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply the silu function element-wise. @@ -435,12 +435,12 @@ def silu( ---------- x input array. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. - complex_mode - optional specifier for how to handle complex data types. See - `ivy.func_wrapper.handle_complex_input` for more detail. Returns ------- @@ -502,8 +502,8 @@ def elu( /, *, alpha: float = 1.0, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply the elu unit function element-wise. @@ -514,12 +514,12 @@ def elu( Input array. alpha scaler for controlling the slope of the function for x <= 0 Default: 1.0 + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. - complex_mode - optional specifier for how to handle complex data types. See - `ivy.func_wrapper.handle_complex_input` for more detail. Returns ------- diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index 66b9155079d12..517f69532cf34 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -253,7 +253,7 @@ def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): ---------- complex_mode Specifies how to handle complex input. See - `ivy.func_wrapper.handle_complex_input` for more detail. + ``ivy.func_wrapper.handle_complex_input`` for more detail. """ self._complex_mode = complex_mode Module.__init__(self) @@ -283,7 +283,7 @@ def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): ---------- complex_mode Specifies how to handle complex input. See - `ivy.func_wrapper.handle_complex_input` for more detail. + ``ivy.func_wrapper.handle_complex_input`` for more detail. """ self._complex_mode = complex_mode Module.__init__(self) @@ -343,7 +343,7 @@ def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): ---------- complex_mode Specifies how to handle complex input. See - `ivy.func_wrapper.handle_complex_input` for more detail. + ``ivy.func_wrapper.handle_complex_input`` for more detail. """ self._complex_mode = complex_mode Module.__init__(self) @@ -373,7 +373,7 @@ def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): ---------- complex_mode Specifies how to handle complex input. See - `ivy.func_wrapper.handle_complex_input` for more detail. + ``ivy.func_wrapper.handle_complex_input`` for more detail. """ self._complex_mode = complex_mode Module.__init__(self) @@ -451,7 +451,7 @@ def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): ---------- complex_mode Specifies how to handle complex input. See - `ivy.func_wrapper.handle_complex_input` for more detail. + ``ivy.func_wrapper.handle_complex_input`` for more detail. """ self._complex_mode = complex_mode Module.__init__(self) @@ -503,7 +503,7 @@ def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): ---------- complex_mode Specifies how to handle complex input. See - `ivy.func_wrapper.handle_complex_input` for more detail. + ``ivy.func_wrapper.handle_complex_input`` for more detail. """ self._complex_mode = complex_mode Module.__init__(self)