Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

complex dtypes support for activation functions #21805

Merged
merged 24 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
89d6de3
Added softplus activation function support for complex dtypes
Aug 13, 2023
7501ae5
errors fixed
Aug 14, 2023
47aa912
Merge branch 'unifyai:main' into main
mohame54 Aug 14, 2023
28723e4
Merge branch 'main' of https://github.com/mohame54/ivy
Aug 14, 2023
62e95a6
Merge branch 'unifyai:main' into main
mohame54 Aug 15, 2023
b0a6e4f
Merge branch 'unifyai:main' into main
mohame54 Aug 15, 2023
facfc54
Merge branch 'unifyai:main' into main
mohame54 Aug 15, 2023
a7ab849
Merge branch 'unifyai:main' into main
mohame54 Aug 16, 2023
258d4c5
Merge branch 'unifyai:main' into main
mohame54 Aug 16, 2023
87c63a2
Merge branch 'unifyai:main' into main
mohame54 Aug 17, 2023
5a8dc2d
Merge branch 'unifyai:main' into main
mohame54 Aug 17, 2023
0bf323f
Made softplus activation function support complex dtype
Aug 18, 2023
10c1f1a
PR refactored
Aug 18, 2023
be21aba
refactored the some activation functions doc string like relu, leaky_…
Aug 18, 2023
b742d30
conflicts resolved
Aug 18, 2023
6060619
lint error fixed
Aug 18, 2023
25d5606
Merge branch 'unifyai:main' into main
mohame54 Aug 18, 2023
b6a72ab
Merge branch 'main' of https://github.com/mohame54/ivy
Aug 18, 2023
dfc9613
Merge branch 'unifyai:main' into main
mohame54 Aug 19, 2023
bbad132
Merge branch 'unifyai:main' into main
mohame54 Aug 20, 2023
b81f2da
Quick fix to the docstring of `ivy.softplus`
jshepherd01 Aug 21, 2023
eaeaa13
refactored threshold parameter and it's doc string
Aug 21, 2023
fe5a626
Merge branch 'unifyai:main' into main
mohame54 Aug 22, 2023
a5283e3
chore: add complex_mode and fix supported dtypes in backends
jshepherd01 Aug 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion ivy/data_classes/array/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def softplus(
beta: Optional[Union[int, float]] = None,
threshold: Optional[Union[int, float]] = None,
out: Optional[ivy.Array] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Array:
"""
ivy.Array instance method variant of ivy.softplus. This method simply wraps the
Expand All @@ -212,6 +213,8 @@ def softplus(
the threshold parameter of the softplus function.
out
optional output array, for writing the result to. It must have a shape
complex_mode
optional specifier for how to handle complex data types.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
complex_mode
optional specifier for how to handle complex data types.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.


Returns
-------
Expand All @@ -235,7 +238,7 @@ def softplus(
>>> print(x)
ivy.array([1.55, 2.13, 2.13])
"""
return ivy.softplus(self._data, beta=beta, threshold=threshold, out=out)
return ivy.softplus(self._data, beta=beta, threshold=threshold, out=out, complex_mode=complex_mode)

def log_softmax(
self: ivy.Array,
Expand Down
8 changes: 8 additions & 0 deletions ivy/data_classes/container/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ def _static_softplus(
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.softplus. This method simply wraps
Expand Down Expand Up @@ -688,6 +689,8 @@ def _static_softplus(
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
complex_mode
optional specifier for how to handle complex data types.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.


Returns
-------
Expand Down Expand Up @@ -721,6 +724,7 @@ def _static_softplus(
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
complex_mode=complex_mode,
)

def softplus(
Expand All @@ -734,6 +738,7 @@ def softplus(
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Container:
"""
ivy.Container instance method variant of ivy.softplus. This method simply wraps
Expand Down Expand Up @@ -762,6 +767,8 @@ def softplus(
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
complex_mode
optional specifier for how to handle complex data types.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.


Returns
-------
Expand Down Expand Up @@ -793,6 +800,7 @@ def softplus(
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
complex_mode=complex_mode,
)

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/frontends/jax/nn/non_linear_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def softmax(x, axis=-1, where=None, initial=None):
@to_ivy_arrays_and_back
def softplus(x):
x = _type_conversion(x)
return ivy.softplus(x).astype(x.dtype)
return ivy.softplus(x, complex_mode="jax").astype(x.dtype)


@to_ivy_arrays_and_back
Expand Down
39 changes: 39 additions & 0 deletions ivy/functional/ivy/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,32 @@ def softmax(
return current_backend(x).softmax(x, axis=axis, out=out)


def _wrap_between(y, a):
"""wrap y between [-a, a]"""
a = ivy.array(a, dtype=y.dtype)
a2 = ivy.array(2 * a, dtype=y.dtype)
zero = ivy.array(0, dtype=y.dtype)
rem = ivy.remainder(ivy.add(y, a), a2)
rem = ivy.where(rem < zero, rem + a2, rem) - a
return rem


def _softplus_jax_like(
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
fn_original=None,
beta: Optional[Union[int, float]] = None,
threshold: Optional[Union[int, float]] = None,
out: Optional[ivy.Array] = None,
):
amax = ivy.relu(x)
print(amax)
x = ivy.subtract(x, ivy.multiply(amax, ivy.array(2, dtype=x.dtype)))
x = ivy.add(amax, ivy.log(ivy.add(1, ivy.exp(x))))
x = ivy.real(x) + _wrap_between(ivy.imag(x), ivy.pi).astype(x.dtype) * ivy.astype(1j, x.dtype)
return x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I understand this implementation here, I thought softplus was simply ivy.log(ivy.add(1., ivy.exp(x))). I also can't see a beta or threshold in here, and I think we should implement at least beta for supersetting, and I think you've left a print statement in by mistake

Copy link
Contributor Author

@mohame54 mohame54 Aug 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, my bad I forgot to delete the print statement, about softplus(x) = ivy.log(ivy.add(1., ivy.exp(x)) each backend has it's own implementation of softplus also some other frameworks like doesn't support complex dtype

Copy link
Contributor Author

@mohame54 mohame54 Aug 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but torch is going to make all the activation functions support support complex dtype that's why i didn't need to implement a special case for softplus op in torch.


@handle_exceptions
@handle_backend_invalid
@handle_nestable
Expand All @@ -506,17 +532,24 @@ def softmax(
@to_native_arrays_and_back
@handle_array_function
@handle_device_shifting
@handle_complex_input
def softplus(
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
beta: Optional[Union[int, float]] = None,
threshold: Optional[Union[int, float]] = None,
out: Optional[ivy.Array] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Array:
"""
Apply the softplus function element-wise.

If the input is complex, then by default we apply the `log(1+ exp(x))` to each element
and if the `threshold` is set we check whether the output is less than the threshold
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"less than" isn't well defined for complex numbers, so you might want to be more specific here (if you decide to use threshold in the complex version at all - it doesn't seem like you have)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay on it.

we set it to the original value of the input if not we leave as it is.
This behaviour can be changed by specifying a different `complex_mode`.

Parameters
----------
x
Expand All @@ -528,6 +561,9 @@ def softplus(
out
optional output array, for writing the result to. It must have a shape that the
inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types. See
`ivy.func_wrapper.handle_complex_input` for more detail.
jshepherd01 marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
Expand Down Expand Up @@ -557,6 +593,9 @@ def softplus(
return current_backend(x).softplus(x, beta=beta, threshold=threshold, out=out)


softplus.jax_like = _softplus_jax_like
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since softplus is well defined on the complex numbers, I'd suggest that unless you're including threshold it would be better handled with changes to the backend functions rather than with a jax_like function here



@handle_exceptions
@handle_backend_invalid
@handle_nestable
Expand Down
24 changes: 21 additions & 3 deletions ivy/stateful/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,22 @@ def _forward(self, x, *, axis=None):


class Softplus(Module):
def __init__(self):
def __init__(
self,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
):
"""Apply the SOFTPLUS activation function."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Apply the SOFTPLUS activation function."""
"""
Apply the SOFTPLUS activation function.
Parameters
----------
complex_mode
Specifies how to handle complex input. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
"""

Parameters should be included in docstrings

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, beta and threshold should really be in the __init__ part rather than _forward as well but that's beside the point of this PR

Module.__init__(self)
self._complex_mode = complex_mode

def _forward(self, x, *, beta=None, threshold=None):
def _forward(
self,
x,
*,
beta=None,
threshold=None,
complex_mode=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
complex_mode=None,

):
"""

Parameters
Expand All @@ -201,14 +212,21 @@ def _forward(self, x, *, beta=None, threshold=None):

threshold
values above this revert to a linear function. Default: ``None``.
complex_mode
optional specifier for how to handle complex data types.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
complex_mode
optional specifier for how to handle complex data types.

We no longer use optional arguments in the forward method for stateful components, only in __init__


Returns
-------
ret
The outputs following the SOFTPLUS activation *[batch_shape, d]*

"""
return ivy.softplus(x, beta=beta, threshold=threshold)
return ivy.softplus(
x,
beta=beta,
threshold=threshold,
complex_mode=ivy.default(complex_mode, self._complex_mode)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
complex_mode=ivy.default(complex_mode, self._complex_mode)
complex_mode=self._complex_mode,

)


class Mish(Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,10 @@ def test_jax_softmax(
@handle_frontend_test(
fn_tree="jax.nn.softplus",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float_and_integer"),
large_abs_safety_factor=2,
small_abs_safety_factor=2,
safety_factor_scale="linear",
available_dtypes=helpers.get_dtypes("float_and_complex"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
available_dtypes=helpers.get_dtypes("float_and_complex"),
available_dtypes=helpers.get_dtypes("numeric"),

"float_and_complex" doesn't include integers, so we should use "numeric" (or "valid") instead

large_abs_safety_factor=4,
small_abs_safety_factor=4,
safety_factor_scale="log",
),
test_with_out=st.just(False),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ def test_softmax(*, dtype_and_x, axis, test_flags, backend_fw, fn_name, on_devic
@handle_test(
fn_tree="functional.ivy.softplus",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
available_dtypes=helpers.get_dtypes("numeric"),
min_num_dims=1,
large_abs_safety_factor=8,
small_abs_safety_factor=8,
large_abs_safety_factor=4,
small_abs_safety_factor=4,
safety_factor_scale="log",
),
beta=st.one_of(helpers.number(min_value=0.1, max_value=10), st.none()),
Expand Down
Loading