Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

complex dtypes support for activation functions #21805

Merged
merged 24 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
89d6de3
Added softplus activation function support for complex dtypes
Aug 13, 2023
7501ae5
errors fixed
Aug 14, 2023
47aa912
Merge branch 'unifyai:main' into main
mohame54 Aug 14, 2023
28723e4
Merge branch 'main' of https://github.com/mohame54/ivy
Aug 14, 2023
62e95a6
Merge branch 'unifyai:main' into main
mohame54 Aug 15, 2023
b0a6e4f
Merge branch 'unifyai:main' into main
mohame54 Aug 15, 2023
facfc54
Merge branch 'unifyai:main' into main
mohame54 Aug 15, 2023
a7ab849
Merge branch 'unifyai:main' into main
mohame54 Aug 16, 2023
258d4c5
Merge branch 'unifyai:main' into main
mohame54 Aug 16, 2023
87c63a2
Merge branch 'unifyai:main' into main
mohame54 Aug 17, 2023
5a8dc2d
Merge branch 'unifyai:main' into main
mohame54 Aug 17, 2023
0bf323f
Made softplus activation function support complex dtype
Aug 18, 2023
10c1f1a
PR refactored
Aug 18, 2023
be21aba
refactored the some activation functions doc string like relu, leaky_…
Aug 18, 2023
b742d30
conflicts resolved
Aug 18, 2023
6060619
lint error fixed
Aug 18, 2023
25d5606
Merge branch 'unifyai:main' into main
mohame54 Aug 18, 2023
b6a72ab
Merge branch 'main' of https://github.com/mohame54/ivy
Aug 18, 2023
dfc9613
Merge branch 'unifyai:main' into main
mohame54 Aug 19, 2023
bbad132
Merge branch 'unifyai:main' into main
mohame54 Aug 20, 2023
b81f2da
Quick fix to the docstring of `ivy.softplus`
jshepherd01 Aug 21, 2023
eaeaa13
refactored threshold parameter and it's doc string
Aug 21, 2023
fe5a626
Merge branch 'unifyai:main' into main
mohame54 Aug 22, 2023
a5283e3
chore: add complex_mode and fix supported dtypes in backends
jshepherd01 Aug 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion ivy/data_classes/array/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def softplus(
beta: Optional[Union[int, float]] = None,
threshold: Optional[Union[int, float]] = None,
out: Optional[ivy.Array] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Array:
"""
ivy.Array instance method variant of ivy.softplus. This method simply wraps the
Expand All @@ -212,6 +213,8 @@ def softplus(
the threshold parameter of the softplus function.
out
optional output array, for writing the result to. It must have a shape
complex_mode
optional specifier for how to handle complex data types.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
complex_mode
optional specifier for how to handle complex data types.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.


Returns
-------
Expand All @@ -235,7 +238,7 @@ def softplus(
>>> print(x)
ivy.array([1.55, 2.13, 2.13])
"""
return ivy.softplus(self._data, beta=beta, threshold=threshold, out=out)
return ivy.softplus(self._data, beta=beta, threshold=threshold, out=out, complex_mode=complex_mode)

def log_softmax(
self: ivy.Array,
Expand Down
8 changes: 8 additions & 0 deletions ivy/data_classes/container/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ def _static_softplus(
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.softplus. This method simply wraps
Expand Down Expand Up @@ -688,6 +689,8 @@ def _static_softplus(
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
complex_mode
optional specifier for how to handle complex data types.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.


Returns
-------
Expand Down Expand Up @@ -721,6 +724,7 @@ def _static_softplus(
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
complex_mode=complex_mode,
)

def softplus(
Expand All @@ -734,6 +738,7 @@ def softplus(
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Container:
"""
ivy.Container instance method variant of ivy.softplus. This method simply wraps
Expand Down Expand Up @@ -762,6 +767,8 @@ def softplus(
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
complex_mode
optional specifier for how to handle complex data types.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.


Returns
-------
Expand Down Expand Up @@ -793,6 +800,7 @@ def softplus(
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
complex_mode=complex_mode,
)

@staticmethod
Expand Down
4 changes: 3 additions & 1 deletion ivy/functional/backends/paddle/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,9 @@ def square(
return paddle.square(x)
return paddle_backend.pow(x, 2).astype(x.dtype)


@with_unsupported_device_and_dtypes(
{"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version
)
def pow(
x1: Union[float, paddle.Tensor],
x2: Union[float, paddle.Tensor],
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/frontends/jax/nn/non_linear_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def softmax(x, axis=-1, where=None, initial=None):
@to_ivy_arrays_and_back
def softplus(x):
x = _type_conversion(x)
return ivy.softplus(x).astype(x.dtype)
return ivy.softplus(x, complex_mode="jax").astype(x.dtype)


@to_ivy_arrays_and_back
Expand Down
52 changes: 52 additions & 0 deletions ivy/functional/ivy/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,47 @@ def softmax(
return current_backend(x).softmax(x, axis=axis, out=out)


def _wrap_between(y, a):
"""wrap y between [-a, a]"""
a = ivy.array(a, dtype=y.dtype)
a2 = ivy.array(2 * a, dtype=y.dtype)
zero = ivy.array(0, dtype=y.dtype)
rem = ivy.remainder(ivy.add(y, a), a2)
rem = ivy.where(rem < zero, rem + a2, rem) - a
return rem


def _softplus_jax_like(
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
fn_original=None,
beta: Optional[Union[int, float]] = None,
threshold: Optional[Union[int, float]] = None,
out: Optional[ivy.Array] = None,
):
if beta is not None:
x_beta = ivy.multiply(x, ivy.array(beta, dtype=x.dtype))
else:
x_beta = x
amax = ivy.relu(x_beta)
res = ivy.subtract(x_beta, ivy.multiply(amax, ivy.array(2, dtype=x.dtype)))
res = ivy.add(amax, ivy.log(ivy.add(1, ivy.exp(res))))
res = ivy.real(res) + _wrap_between(ivy.imag(res), ivy.pi).astype(x.dtype) * ivy.astype(1j, x.dtype)
if beta is not None:
res = ivy.divide(res, ivy.array(beta, dtype=x.dtype))
if threshold is not None:
res = ivy.where(
(
ivy.logical_or(
ivy.real(x_beta) < 0, ivy.logical_and(ivy.real(x_beta) == 0, ivy.imag(x_beta) < 0)
)
),
res,
x,
).astype(x.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The actual value of threshold doesn't seem to be used here

return res

@handle_exceptions
@handle_backend_invalid
@handle_nestable
Expand All @@ -506,17 +547,22 @@ def softmax(
@to_native_arrays_and_back
@handle_array_function
@handle_device_shifting
@handle_complex_input
def softplus(
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
beta: Optional[Union[int, float]] = None,
threshold: Optional[Union[int, float]] = None,
out: Optional[ivy.Array] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Array:
"""
Apply the softplus function element-wise.

If the input is complex, then by default we apply the `log(1+ exp(x))` to each element0
This behaviour can be changed by specifying a different `complex_mode`.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If the input is complex, then by default we apply the `log(1+ exp(x))` to each element0
This behaviour can be changed by specifying a different `complex_mode`.

This behaviour is the same as with real numbers, so you could probably drop this part from the docstring - I don't think it needs to be explicitly spelled out, and we already have a description of what the complex_mode parameter does further down

Parameters
----------
x
Expand All @@ -528,6 +574,9 @@ def softplus(
out
optional output array, for writing the result to. It must have a shape that the
inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types. See
`ivy.func_wrapper.handle_complex_input` for more detail.
jshepherd01 marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
Expand Down Expand Up @@ -557,6 +606,9 @@ def softplus(
return current_backend(x).softplus(x, beta=beta, threshold=threshold, out=out)


softplus.jax_like = _softplus_jax_like
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since softplus is well defined on the complex numbers, I'd suggest that unless you're including threshold it would be better handled with changes to the backend functions rather than with a jax_like function here



@handle_exceptions
@handle_backend_invalid
@handle_nestable
Expand Down
24 changes: 21 additions & 3 deletions ivy/stateful/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,22 @@ def _forward(self, x, *, axis=None):


class Softplus(Module):
def __init__(self):
def __init__(
self,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
):
"""Apply the SOFTPLUS activation function."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Apply the SOFTPLUS activation function."""
"""
Apply the SOFTPLUS activation function.
Parameters
----------
complex_mode
Specifies how to handle complex input. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
"""

Parameters should be included in docstrings

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, beta and threshold should really be in the __init__ part rather than _forward as well but that's beside the point of this PR

Module.__init__(self)
self._complex_mode = complex_mode

def _forward(self, x, *, beta=None, threshold=None):
def _forward(
self,
x,
*,
beta=None,
threshold=None,
complex_mode=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
complex_mode=None,

):
"""

Parameters
Expand All @@ -201,14 +212,21 @@ def _forward(self, x, *, beta=None, threshold=None):

threshold
values above this revert to a linear function. Default: ``None``.
complex_mode
optional specifier for how to handle complex data types.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
complex_mode
optional specifier for how to handle complex data types.

We no longer use optional arguments in the forward method for stateful components, only in __init__


Returns
-------
ret
The outputs following the SOFTPLUS activation *[batch_shape, d]*

"""
return ivy.softplus(x, beta=beta, threshold=threshold)
return ivy.softplus(
x,
beta=beta,
threshold=threshold,
complex_mode=ivy.default(complex_mode, self._complex_mode)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
complex_mode=ivy.default(complex_mode, self._complex_mode)
complex_mode=self._complex_mode,

)


class Mish(Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,10 @@ def test_jax_softmax(
@handle_frontend_test(
fn_tree="jax.nn.softplus",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float_and_integer"),
large_abs_safety_factor=2,
small_abs_safety_factor=2,
safety_factor_scale="linear",
available_dtypes=helpers.get_dtypes("float_and_complex"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
available_dtypes=helpers.get_dtypes("float_and_complex"),
available_dtypes=helpers.get_dtypes("numeric"),

"float_and_complex" doesn't include integers, so we should use "numeric" (or "valid") instead

large_abs_safety_factor=4,
small_abs_safety_factor=4,
safety_factor_scale="log",
),
test_with_out=st.just(False),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ def test_softmax(*, dtype_and_x, axis, test_flags, backend_fw, fn_name, on_devic
@handle_test(
fn_tree="functional.ivy.softplus",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
available_dtypes=helpers.get_dtypes("numeric"),
min_num_dims=1,
large_abs_safety_factor=8,
small_abs_safety_factor=8,
large_abs_safety_factor=4,
small_abs_safety_factor=4,
safety_factor_scale="log",
),
beta=st.one_of(helpers.number(min_value=0.1, max_value=10), st.none()),
Expand Down