Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added support for sigmoid activation fucntion #21424

Closed
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
6b90a8a
added support for sigmoid activation fucntion
mosesdaudu001 Aug 7, 2023
a9879a1
applied complex_input decorator to hardswish activation fucntion
mosesdaudu001 Aug 8, 2023
16c42d4
added argument to static sigmoid and hardswish container funcs
mosesdaudu001 Aug 8, 2023
a9a9a53
applied complex_input decorator to silu activation fucntion
mosesdaudu001 Aug 8, 2023
04e0002
resolved conflicts
mosesdaudu001 Aug 9, 2023
a101978
Merge branch 'master' into complex_input_activation_functions
mosesdaudu001 Aug 9, 2023
63eeb4c
applied complex_input decorator to selu activation fucntion
mosesdaudu001 Aug 9, 2023
6e97b30
Merge branch 'unifyai:master' into complex_input_activation_functions
mosesdaudu001 Aug 9, 2023
0636cfb
applied complex_input decorator to logsigmoid activation function
mosesdaudu001 Aug 9, 2023
01c3e41
applied complex_input decorator to elu activation function
mosesdaudu001 Aug 14, 2023
83a94aa
applied complex_input decorator to relu6 activation function
mosesdaudu001 Aug 14, 2023
54eb859
added float_complex to stateful test
mosesdaudu001 Aug 14, 2023
cfe6cdc
made changes as suggested by joe
mosesdaudu001 Aug 15, 2023
1c00385
fixed testing issues for sigmoid
mosesdaudu001 Aug 18, 2023
d4d70b1
Bring docstrings and call signatures in line with standards
jshepherd01 Aug 18, 2023
5a0956e
fixed testing issues for silu
mosesdaudu001 Aug 18, 2023
5bb2dab
fixed testing issues for selu
mosesdaudu001 Aug 18, 2023
8190ffd
Merge branch 'complex_input_activation_functions' of github.com:moses…
jshepherd01 Aug 18, 2023
12bd066
fixed testing issues for log-sigmoid
mosesdaudu001 Aug 18, 2023
3819218
added jax-like function for elu
mosesdaudu001 Aug 18, 2023
22ed2a6
modified alpha within the tests
mosesdaudu001 Aug 19, 2023
154f1f4
added complex_mode argument to all backends as suggested by Joe
mosesdaudu001 Aug 19, 2023
1317aeb
added jax-like function for selu
mosesdaudu001 Aug 19, 2023
0a278a9
Merge branch 'main' into complex_input_activation_functions
mosesdaudu001 Aug 19, 2023
55ef1c4
added jax-like fucntion for relu6
mosesdaudu001 Aug 19, 2023
0d785a7
fixed relu6 jax-like function
mosesdaudu001 Aug 24, 2023
32ce676
added jax-like function for hardswish
mosesdaudu001 Aug 25, 2023
86e2a88
added support for complex dtpes in hardwish tf and torch
mosesdaudu001 Aug 25, 2023
6a92760
added sigmoid jax-like function
mosesdaudu001 Aug 25, 2023
c0a70ba
removed complex from unsupported for sigmoid
mosesdaudu001 Aug 25, 2023
e1163ea
removed complex from unsupported for selu and updated jax-like function
mosesdaudu001 Aug 25, 2023
47d0231
removed complex from unsupported for silu
mosesdaudu001 Aug 25, 2023
f472575
removed complex from unsupported for elu
mosesdaudu001 Aug 25, 2023
c445b97
updated some backends as requested by Joe
mosesdaudu001 Aug 25, 2023
d48fe89
Merge branch 'main' into complex_input_activation_functions
mosesdaudu001 Aug 26, 2023
6747831
created custom implementation for tensorflow logsigmoid
mosesdaudu001 Aug 26, 2023
e35a406
added support for complex in tensorflow elu and selu
mosesdaudu001 Aug 26, 2023
43e1013
made changes to elu-jax-like function so as to extract the condition …
mosesdaudu001 Aug 29, 2023
5719e7e
Merge branch 'main' into complex_input_activation_functions
mosesdaudu001 Sep 1, 2023
a4df5a9
updated files
mosesdaudu001 Sep 1, 2023
b5fd995
removed duplicated file
mosesdaudu001 Sep 1, 2023
47abd9a
Merge branch 'master' into moses-complex
jshepherd01 Sep 6, 2023
e59e784
fix: silu and sigmoid now work in paddle, remove jax_like
jshepherd01 Sep 6, 2023
4bb0291
fix: update supported dtypes
jshepherd01 Sep 6, 2023
b7a0daf
style: update docstrings and function signatures for complex_mode
jshepherd01 Sep 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions ivy/data_classes/array/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,13 @@ def gelu(
"""
return ivy.gelu(self._data, approximate=approximate, out=out)

def sigmoid(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
def sigmoid(
self: ivy.Array,
/,
*,
out: Optional[ivy.Array] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Array:
"""
ivy.Array instance method variant of ivy.sigmoid.

Expand All @@ -137,6 +143,9 @@ def sigmoid(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array
out
optional output array for writing the result to. It must have the same shape
the input broadcast to default: None
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand All @@ -151,7 +160,7 @@ def sigmoid(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array
>>> print(y)
ivy.array([0.269, 0.731, 0.881])
"""
return ivy.sigmoid(self._data, out=out)
return ivy.sigmoid(self._data, out=out, complex_mode=complex_mode)

def softmax(
self: ivy.Array,
Expand Down Expand Up @@ -300,7 +309,13 @@ def mish(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
"""
return ivy.mish(self._data, out=out)

def hardswish(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
def hardswish(
self: ivy.Array,
/,
*,
out: Optional[ivy.Array] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Array:
"""
Apply the hardswish activation function element-wise.

Expand All @@ -311,6 +326,9 @@ def hardswish(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Arr
out
optional output array, for writing the result to. It must have
a shape that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand All @@ -336,4 +354,4 @@ def hardswish(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Arr
b: ivy.array([0., 5.])
}
"""
return ivy.hardswish(self._data, out=out)
return ivy.hardswish(self._data, out=out, complex_mode=complex_mode)
53 changes: 44 additions & 9 deletions ivy/data_classes/array/experimental/activations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# global
import abc
from typing import Optional, Union
from typing import Optional, Union, Literal

# local
import ivy
Expand Down Expand Up @@ -115,7 +115,13 @@ def prelu(
"""
return ivy.prelu(self._data, slope, out=out)

def relu6(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
def relu6(
self,
/,
*,
out: Optional[ivy.Array] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Array:
"""
Apply the rectified linear unit 6 function element-wise.

Expand All @@ -126,6 +132,9 @@ def relu6(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
out
optional output array, for writing the result to.
It must have a shape that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand All @@ -148,10 +157,11 @@ def relu6(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
>>> print(y)
ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.])
"""
return ivy.relu6(self._data, out=out)
return ivy.relu6(self._data, out=out, complex_mode=complex_mode)

def logsigmoid(
self: ivy.Array,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Array:
"""
ivy.Array instance method variant of ivy.logsigmoid. This method simply wraps
Expand All @@ -162,6 +172,9 @@ def logsigmoid(
----------
self
Input array.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand All @@ -179,9 +192,15 @@ def logsigmoid(
>>> print(z)
ivy.array([-2.57888985, -0.31326169, -0.69314718, -0.01104775])
"""
return ivy.logsigmoid(self._data)
return ivy.logsigmoid(self._data, complex_mode=complex_mode)

def selu(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
def selu(
self,
/,
*,
out: Optional[ivy.Array] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Array:
"""
Apply the scaled exponential linear unit function element-wise.

Expand All @@ -192,6 +211,9 @@ def selu(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
out
optional output array, for writing the result to.
It must have a shape that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand All @@ -216,9 +238,15 @@ def selu(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
ivy.array([-1.11133075, 0., 1.05070102, 2.10140204, 3.15210295,
4.20280409, 5.25350523, 6.30420589, 7.35490704])
"""
return ivy.selu(self._data, out=out)
return ivy.selu(self._data, out=out, complex_mode=complex_mode)

def silu(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
def silu(
self: ivy.Array,
/,
*,
out: Optional[ivy.Array] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Array:
"""
ivy.Array instance method variant of ivy.silu. This method simply wraps the
function, and so the docstring for ivy.silu also applies to this method with
Expand All @@ -231,6 +259,9 @@ def silu(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
out
optional output array, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Examples
--------
Expand All @@ -239,14 +270,15 @@ def silu(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
>>> print(y)
ivy.array([-0.26894143, 0. , 0.73105854])
"""
return ivy.silu(self._data, out=out)
return ivy.silu(self._data, out=out, complex_mode=complex_mode)

def elu(
self,
/,
*,
alpha: float = 1.0,
out: Optional[ivy.Array] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Array:
"""
Ivy.Array instance method variant of ivy.elu. This method simply wraps the
Expand All @@ -262,6 +294,9 @@ def elu(
out
optional output array, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand All @@ -275,4 +310,4 @@ def elu(
>>> print(y)
ivy.array([ 0.39, -0.57])
"""
return ivy.elu(self._data, alpha=alpha, out=out)
return ivy.elu(self._data, alpha=alpha, out=out, complex_mode=complex_mode)
20 changes: 20 additions & 0 deletions ivy/data_classes/container/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def _static_sigmoid(
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.sigmoid. This method simply wraps the
Expand All @@ -438,6 +439,9 @@ def _static_sigmoid(
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand All @@ -462,6 +466,7 @@ def _static_sigmoid(
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
complex_mode=complex_mode,
)

def sigmoid(
Expand All @@ -473,6 +478,7 @@ def sigmoid(
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Container:
"""
ivy.Container instance method variant of ivy.sigmoid. This method simply wraps
Expand All @@ -497,6 +503,9 @@ def sigmoid(
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand All @@ -520,6 +529,7 @@ def sigmoid(
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
complex_mode=complex_mode,
)

@staticmethod
Expand Down Expand Up @@ -1065,6 +1075,7 @@ def _static_hardswish(
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.hardswish. This method simply wraps
Expand All @@ -1089,6 +1100,9 @@ def _static_hardswish(
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand All @@ -1114,6 +1128,7 @@ def _static_hardswish(
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
complex_mode=complex_mode,
)

def hardswish(
Expand All @@ -1125,6 +1140,7 @@ def hardswish(
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Container:
"""
ivy.Container instance method variant of ivy.hardswish. This method simply wraps
Expand All @@ -1149,6 +1165,9 @@ def hardswish(
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand All @@ -1173,4 +1192,5 @@ def hardswish(
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
complex_mode=complex_mode,
)
Loading
Loading