Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

complex dtypes support for activation functions #21805

Merged
merged 24 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
89d6de3
Added softplus activation function support for complex dtypes
Aug 13, 2023
7501ae5
errors fixed
Aug 14, 2023
47aa912
Merge branch 'unifyai:main' into main
mohame54 Aug 14, 2023
28723e4
Merge branch 'main' of https://github.com/mohame54/ivy
Aug 14, 2023
62e95a6
Merge branch 'unifyai:main' into main
mohame54 Aug 15, 2023
b0a6e4f
Merge branch 'unifyai:main' into main
mohame54 Aug 15, 2023
facfc54
Merge branch 'unifyai:main' into main
mohame54 Aug 15, 2023
a7ab849
Merge branch 'unifyai:main' into main
mohame54 Aug 16, 2023
258d4c5
Merge branch 'unifyai:main' into main
mohame54 Aug 16, 2023
87c63a2
Merge branch 'unifyai:main' into main
mohame54 Aug 17, 2023
5a8dc2d
Merge branch 'unifyai:main' into main
mohame54 Aug 17, 2023
0bf323f
Made softplus activation function support complex dtype
Aug 18, 2023
10c1f1a
PR refactored
Aug 18, 2023
be21aba
refactored the some activation functions doc string like relu, leaky_…
Aug 18, 2023
b742d30
conflicts resolved
Aug 18, 2023
6060619
lint error fixed
Aug 18, 2023
25d5606
Merge branch 'unifyai:main' into main
mohame54 Aug 18, 2023
b6a72ab
Merge branch 'main' of https://github.com/mohame54/ivy
Aug 18, 2023
dfc9613
Merge branch 'unifyai:main' into main
mohame54 Aug 19, 2023
bbad132
Merge branch 'unifyai:main' into main
mohame54 Aug 20, 2023
b81f2da
Quick fix to the docstring of `ivy.softplus`
jshepherd01 Aug 21, 2023
eaeaa13
refactored threshold parameter and it's doc string
Aug 21, 2023
fe5a626
Merge branch 'unifyai:main' into main
mohame54 Aug 22, 2023
a5283e3
chore: add complex_mode and fix supported dtypes in backends
jshepherd01 Aug 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions ivy/data_classes/array/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def relu(
optional output array, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types.
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand Down Expand Up @@ -69,7 +70,8 @@ def leaky_relu(
optional output array, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types.
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand All @@ -93,6 +95,7 @@ def gelu(
*,
approximate: bool = False,
out: Optional[ivy.Array] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Array:
"""
ivy.Array instance method variant of ivy.gelu. This method simply wraps the
Expand All @@ -108,6 +111,9 @@ def gelu(
out
optional output array, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand Down Expand Up @@ -196,6 +202,7 @@ def softplus(
beta: Optional[Union[int, float]] = None,
threshold: Optional[Union[int, float]] = None,
out: Optional[ivy.Array] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Array:
"""
ivy.Array instance method variant of ivy.softplus. This method simply wraps the
Expand All @@ -212,6 +219,9 @@ def softplus(
the threshold parameter of the softplus function.
out
optional output array, for writing the result to. It must have a shape
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand All @@ -235,7 +245,13 @@ def softplus(
>>> print(x)
ivy.array([1.55, 2.13, 2.13])
"""
return ivy.softplus(self._data, beta=beta, threshold=threshold, out=out)
return ivy.softplus(
self._data,
beta=beta,
threshold=threshold,
out=out,
complex_mode=complex_mode,
)

def log_softmax(
self: ivy.Array,
Expand Down
28 changes: 22 additions & 6 deletions ivy/data_classes/container/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def _static_relu(
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types.
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand Down Expand Up @@ -109,7 +110,8 @@ def relu(
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types.
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand Down Expand Up @@ -176,7 +178,8 @@ def _static_leaky_relu(
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types.
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand Down Expand Up @@ -243,7 +246,8 @@ def leaky_relu(
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types.
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand Down Expand Up @@ -310,7 +314,8 @@ def _static_gelu(
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types.
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand Down Expand Up @@ -376,7 +381,8 @@ def gelu(
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types.
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand Down Expand Up @@ -660,6 +666,7 @@ def _static_softplus(
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.softplus. This method simply wraps
Expand Down Expand Up @@ -688,6 +695,9 @@ def _static_softplus(
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand Down Expand Up @@ -721,6 +731,7 @@ def _static_softplus(
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
complex_mode=complex_mode,
)

def softplus(
Expand All @@ -734,6 +745,7 @@ def softplus(
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Container:
"""
ivy.Container instance method variant of ivy.softplus. This method simply wraps
Expand Down Expand Up @@ -762,6 +774,9 @@ def softplus(
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand Down Expand Up @@ -793,6 +808,7 @@ def softplus(
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
complex_mode=complex_mode,
)

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions ivy/functional/backends/jax/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def softplus(
beta: Optional[Union[int, float]] = None,
threshold: Optional[Union[int, float]] = None,
out: Optional[JaxArray] = None,
complex_mode="jax",
) -> JaxArray:
if beta is not None and beta != 1:
x_beta = x * beta
Expand Down
1 change: 1 addition & 0 deletions ivy/functional/backends/mxnet/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def softplus(
beta: Optional[Union[(int, float)]] = None,
threshold: Optional[Union[(int, float)]] = None,
out: Optional[None] = None,
complex_mode="jax",
) -> None:
if beta is not None and beta != 1:
x_beta = x * beta
Expand Down
1 change: 1 addition & 0 deletions ivy/functional/backends/numpy/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def softplus(
beta: Optional[Union[int, float]] = None,
threshold: Optional[Union[int, float]] = None,
out: Optional[np.ndarray] = None,
complex_mode="jax",
) -> np.ndarray:
if beta is not None and beta != 1:
x_beta = x * beta
Expand Down
1 change: 1 addition & 0 deletions ivy/functional/backends/paddle/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def softplus(
beta: Optional[Union[int, float]] = None,
threshold: Optional[Union[int, float]] = None,
out: Optional[paddle.Tensor] = None,
complex_mode="jax",
) -> paddle.Tensor:
if beta is not None and beta != 1:
x_beta = x * beta
Expand Down
3 changes: 3 additions & 0 deletions ivy/functional/backends/paddle/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,9 @@ def square(
return paddle_backend.pow(x, 2).astype(x.dtype)


@with_unsupported_device_and_dtypes(
{"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version
)
def pow(
x1: Union[float, paddle.Tensor],
x2: Union[float, paddle.Tensor],
Expand Down
13 changes: 12 additions & 1 deletion ivy/functional/backends/tensorflow/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,17 @@ def softmax(


@with_supported_dtypes(
{"2.13.0 and below": ("float16", "bfloat16", "float32", "float64")}, backend_version
{
"2.13.0 and below": (
"float16",
"bfloat16",
"float32",
"float64",
"complex64",
"complex128",
)
},
backend_version,
)
def softplus(
x: Tensor,
Expand All @@ -56,6 +66,7 @@ def softplus(
beta: Optional[Union[int, float]] = None,
threshold: Optional[Union[int, float]] = None,
out: Optional[Tensor] = None,
complex_mode="jax",
) -> Tensor:
if beta is not None and beta != 1:
x_beta = x * beta
Expand Down
5 changes: 2 additions & 3 deletions ivy/functional/backends/torch/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,15 @@ def softmax(
return torch.nn.functional.softmax(x, axis)


@with_unsupported_dtypes(
{"2.0.1 and below": ("complex", "float16", "bfloat16")}, backend_version
)
@with_unsupported_dtypes({"2.0.1 and below": ("float16", "bfloat16")}, backend_version)
def softplus(
x: torch.Tensor,
/,
*,
beta: Optional[Union[int, float]] = None,
threshold: Optional[Union[int, float]] = None,
out: Optional[torch.Tensor] = None,
complex_mode="jax",
) -> torch.Tensor:
kwargs = {
k: v for k, v in {"beta": beta, "threshold": threshold}.items() if v is not None
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/frontends/jax/nn/non_linear_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def softmax(x, axis=-1, where=None, initial=None):
@to_ivy_arrays_and_back
def softplus(x):
x = _type_conversion(x)
return ivy.softplus(x).astype(x.dtype)
return ivy.softplus(x, complex_mode="jax").astype(x.dtype)


@to_ivy_arrays_and_back
Expand Down
57 changes: 56 additions & 1 deletion ivy/functional/ivy/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,46 @@ def softmax(
return current_backend(x).softmax(x, axis=axis, out=out)


def _wrap_between(y, a):
"""Wrap y between [-a, a]"""
a = ivy.array(a, dtype=y.dtype)
a2 = ivy.array(2 * a, dtype=y.dtype)
zero = ivy.array(0, dtype=y.dtype)
rem = ivy.remainder(ivy.add(y, a), a2)
rem = ivy.where(rem < zero, rem + a2, rem) - a
return rem


def _softplus_jax_like(
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
fn_original=None,
beta: Optional[Union[int, float]] = None,
threshold: Optional[Union[int, float]] = None,
out: Optional[ivy.Array] = None,
):
if beta is not None:
x_beta = ivy.multiply(x, ivy.array(beta, dtype=x.dtype))
else:
x_beta = x
amax = ivy.relu(x_beta)
res = ivy.subtract(x_beta, ivy.multiply(amax, ivy.array(2, dtype=x.dtype)))
res = ivy.add(amax, ivy.log(ivy.add(1, ivy.exp(res))))
res = ivy.real(res) + _wrap_between(ivy.imag(res), ivy.pi).astype(
x.dtype
) * ivy.astype(1j, x.dtype)
if beta is not None:
res = ivy.divide(res, ivy.array(beta, dtype=x.dtype))
if threshold is not None:
res = ivy.where(
ivy.real(x_beta) < threshold,
res,
x,
).astype(x.dtype)
return res
Comment on lines +501 to +538
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not sure whether this should be here or in the backends so I'm going to request a review from @Ishticode for a second opinion



@handle_exceptions
@handle_backend_invalid
@handle_nestable
Expand All @@ -506,28 +546,40 @@ def softmax(
@to_native_arrays_and_back
@handle_array_function
@handle_device_shifting
@handle_complex_input
def softplus(
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
beta: Optional[Union[int, float]] = None,
threshold: Optional[Union[int, float]] = None,
out: Optional[ivy.Array] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Array:
"""
Apply the softplus function element-wise.

If the input is complex, then by default we apply the softplus operation
`log(1+ exp(x))` to each element
If threshold is set we check if either its real part is strictly negative or
if its real part is zero and its imaginary part is negative then we apply
`input×β > threshold`.

Comment on lines +562 to +567
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this docstring quite confusing, could you reword it a bit? It still seems to reference the number being negative.

Although, I suppose instead you could remove this and edit the description of the threshold parameter (since this just describes the standard way softplus works, and then describes special behaviour of threshold with complex inputs)

Parameters
----------
x
input array.
beta
The beta value for the softplus formation. Default: ``None``.
threshold
values above this revert to a linear function. Default: ``None``.
values above this revert to a linear function
If the input is complex, only its real part is considered. Default: ``None``
out
optional output array, for writing the result to. It must have a shape that the
inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Returns
-------
Expand Down Expand Up @@ -557,6 +609,9 @@ def softplus(
return current_backend(x).softplus(x, beta=beta, threshold=threshold, out=out)


softplus.jax_like = _softplus_jax_like
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since softplus is well defined on the complex numbers, I'd suggest that unless you're including threshold it would be better handled with changes to the backend functions rather than with a jax_like function here



@handle_exceptions
@handle_backend_invalid
@handle_nestable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,10 @@ def test_jax_softmax(
@handle_frontend_test(
fn_tree="jax.nn.softplus",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float_and_integer"),
large_abs_safety_factor=2,
small_abs_safety_factor=2,
safety_factor_scale="linear",
available_dtypes=helpers.get_dtypes("numeric"),
large_abs_safety_factor=4,
small_abs_safety_factor=4,
safety_factor_scale="log",
),
test_with_out=st.just(False),
)
Expand Down
Loading
Loading