-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
complex dtypes support for activation functions #21805
Changes from 12 commits
89d6de3
7501ae5
47aa912
28723e4
62e95a6
b0a6e4f
facfc54
a7ab849
258d4c5
87c63a2
5a8dc2d
0bf323f
10c1f1a
be21aba
b742d30
6060619
25d5606
b6a72ab
dfc9613
bbad132
b81f2da
eaeaa13
fe5a626
a5283e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -660,6 +660,7 @@ def _static_softplus( | |||||||||||
prune_unapplied: Union[bool, ivy.Container] = False, | ||||||||||||
map_sequences: Union[bool, ivy.Container] = False, | ||||||||||||
out: Optional[ivy.Container] = None, | ||||||||||||
complex_mode: Literal["split", "magnitude", "jax"] = "jax", | ||||||||||||
) -> ivy.Container: | ||||||||||||
""" | ||||||||||||
ivy.Container static method variant of ivy.softplus. This method simply wraps | ||||||||||||
|
@@ -688,6 +689,8 @@ def _static_softplus( | |||||||||||
out | ||||||||||||
optional output container, for writing the result to. It must have a shape | ||||||||||||
that the inputs broadcast to. | ||||||||||||
complex_mode | ||||||||||||
optional specifier for how to handle complex data types. | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
|
||||||||||||
Returns | ||||||||||||
------- | ||||||||||||
|
@@ -721,6 +724,7 @@ def _static_softplus( | |||||||||||
prune_unapplied=prune_unapplied, | ||||||||||||
map_sequences=map_sequences, | ||||||||||||
out=out, | ||||||||||||
complex_mode=complex_mode, | ||||||||||||
) | ||||||||||||
|
||||||||||||
def softplus( | ||||||||||||
|
@@ -734,6 +738,7 @@ def softplus( | |||||||||||
prune_unapplied: Union[bool, ivy.Container] = False, | ||||||||||||
map_sequences: Union[bool, ivy.Container] = False, | ||||||||||||
out: Optional[ivy.Container] = None, | ||||||||||||
complex_mode: Literal["split", "magnitude", "jax"] = "jax", | ||||||||||||
) -> ivy.Container: | ||||||||||||
""" | ||||||||||||
ivy.Container instance method variant of ivy.softplus. This method simply wraps | ||||||||||||
|
@@ -762,6 +767,8 @@ def softplus( | |||||||||||
out | ||||||||||||
optional output container, for writing the result to. It must have a shape | ||||||||||||
that the inputs broadcast to. | ||||||||||||
complex_mode | ||||||||||||
optional specifier for how to handle complex data types. | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
|
||||||||||||
Returns | ||||||||||||
------- | ||||||||||||
|
@@ -793,6 +800,7 @@ def softplus( | |||||||||||
prune_unapplied=prune_unapplied, | ||||||||||||
map_sequences=map_sequences, | ||||||||||||
out=out, | ||||||||||||
complex_mode=complex_mode, | ||||||||||||
) | ||||||||||||
|
||||||||||||
@staticmethod | ||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -498,6 +498,47 @@ def softmax( | |||||
return current_backend(x).softmax(x, axis=axis, out=out) | ||||||
|
||||||
|
||||||
def _wrap_between(y, a): | ||||||
"""wrap y between [-a, a]""" | ||||||
a = ivy.array(a, dtype=y.dtype) | ||||||
a2 = ivy.array(2 * a, dtype=y.dtype) | ||||||
zero = ivy.array(0, dtype=y.dtype) | ||||||
rem = ivy.remainder(ivy.add(y, a), a2) | ||||||
rem = ivy.where(rem < zero, rem + a2, rem) - a | ||||||
return rem | ||||||
|
||||||
|
||||||
def _softplus_jax_like( | ||||||
x: Union[ivy.Array, ivy.NativeArray], | ||||||
/, | ||||||
*, | ||||||
fn_original=None, | ||||||
beta: Optional[Union[int, float]] = None, | ||||||
threshold: Optional[Union[int, float]] = None, | ||||||
out: Optional[ivy.Array] = None, | ||||||
): | ||||||
if beta is not None: | ||||||
x_beta = ivy.multiply(x, ivy.array(beta, dtype=x.dtype)) | ||||||
else: | ||||||
x_beta = x | ||||||
amax = ivy.relu(x_beta) | ||||||
res = ivy.subtract(x_beta, ivy.multiply(amax, ivy.array(2, dtype=x.dtype))) | ||||||
res = ivy.add(amax, ivy.log(ivy.add(1, ivy.exp(res)))) | ||||||
res = ivy.real(res) + _wrap_between(ivy.imag(res), ivy.pi).astype(x.dtype) * ivy.astype(1j, x.dtype) | ||||||
if beta is not None: | ||||||
res = ivy.divide(res, ivy.array(beta, dtype=x.dtype)) | ||||||
if threshold is not None: | ||||||
res = ivy.where( | ||||||
( | ||||||
ivy.logical_or( | ||||||
ivy.real(x_beta) < 0, ivy.logical_and(ivy.real(x_beta) == 0, ivy.imag(x_beta) < 0) | ||||||
) | ||||||
), | ||||||
res, | ||||||
x, | ||||||
).astype(x.dtype) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The actual value of |
||||||
return res | ||||||
|
||||||
@handle_exceptions | ||||||
@handle_backend_invalid | ||||||
@handle_nestable | ||||||
|
@@ -506,17 +547,22 @@ def softmax( | |||||
@to_native_arrays_and_back | ||||||
@handle_array_function | ||||||
@handle_device_shifting | ||||||
@handle_complex_input | ||||||
def softplus( | ||||||
x: Union[ivy.Array, ivy.NativeArray], | ||||||
/, | ||||||
*, | ||||||
beta: Optional[Union[int, float]] = None, | ||||||
threshold: Optional[Union[int, float]] = None, | ||||||
out: Optional[ivy.Array] = None, | ||||||
complex_mode: Literal["split", "magnitude", "jax"] = "jax", | ||||||
) -> ivy.Array: | ||||||
""" | ||||||
Apply the softplus function element-wise. | ||||||
|
||||||
If the input is complex, then by default we apply the `log(1+ exp(x))` to each element0 | ||||||
This behaviour can be changed by specifying a different `complex_mode`. | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
This behaviour is the same as with real numbers, so you could probably drop this part from the docstring - I don't think it needs to be explicitly spelled out, and we already have a description of what the |
||||||
Parameters | ||||||
---------- | ||||||
x | ||||||
|
@@ -528,6 +574,9 @@ def softplus( | |||||
out | ||||||
optional output array, for writing the result to. It must have a shape that the | ||||||
inputs broadcast to. | ||||||
complex_mode | ||||||
optional specifier for how to handle complex data types. See | ||||||
`ivy.func_wrapper.handle_complex_input` for more detail. | ||||||
jshepherd01 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
Returns | ||||||
------- | ||||||
|
@@ -557,6 +606,9 @@ def softplus( | |||||
return current_backend(x).softplus(x, beta=beta, threshold=threshold, out=out) | ||||||
|
||||||
|
||||||
softplus.jax_like = _softplus_jax_like | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since softplus is well defined on the complex numbers, I'd suggest that unless you're including |
||||||
|
||||||
|
||||||
@handle_exceptions | ||||||
@handle_backend_invalid | ||||||
@handle_nestable | ||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -185,11 +185,22 @@ def _forward(self, x, *, axis=None): | |||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
class Softplus(Module): | ||||||||||||||||||||||
def __init__(self): | ||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||
self, | ||||||||||||||||||||||
complex_mode: Literal["split", "magnitude", "jax"] = "jax", | ||||||||||||||||||||||
): | ||||||||||||||||||||||
"""Apply the SOFTPLUS activation function.""" | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Parameters should be included in docstrings There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, |
||||||||||||||||||||||
Module.__init__(self) | ||||||||||||||||||||||
self._complex_mode = complex_mode | ||||||||||||||||||||||
|
||||||||||||||||||||||
def _forward(self, x, *, beta=None, threshold=None): | ||||||||||||||||||||||
def _forward( | ||||||||||||||||||||||
self, | ||||||||||||||||||||||
x, | ||||||||||||||||||||||
*, | ||||||||||||||||||||||
beta=None, | ||||||||||||||||||||||
threshold=None, | ||||||||||||||||||||||
complex_mode=None, | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||
): | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
|
||||||||||||||||||||||
Parameters | ||||||||||||||||||||||
|
@@ -201,14 +212,21 @@ def _forward(self, x, *, beta=None, threshold=None): | |||||||||||||||||||||
|
||||||||||||||||||||||
threshold | ||||||||||||||||||||||
values above this revert to a linear function. Default: ``None``. | ||||||||||||||||||||||
complex_mode | ||||||||||||||||||||||
optional specifier for how to handle complex data types. | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
We no longer use optional arguments in the forward method for stateful components, only in |
||||||||||||||||||||||
|
||||||||||||||||||||||
Returns | ||||||||||||||||||||||
------- | ||||||||||||||||||||||
ret | ||||||||||||||||||||||
The outputs following the SOFTPLUS activation *[batch_shape, d]* | ||||||||||||||||||||||
|
||||||||||||||||||||||
""" | ||||||||||||||||||||||
return ivy.softplus(x, beta=beta, threshold=threshold) | ||||||||||||||||||||||
return ivy.softplus( | ||||||||||||||||||||||
x, | ||||||||||||||||||||||
beta=beta, | ||||||||||||||||||||||
threshold=threshold, | ||||||||||||||||||||||
complex_mode=ivy.default(complex_mode, self._complex_mode) | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
class Mish(Module): | ||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -327,10 +327,10 @@ def test_jax_softmax( | |||||
@handle_frontend_test( | ||||||
fn_tree="jax.nn.softplus", | ||||||
dtype_and_x=helpers.dtype_and_values( | ||||||
available_dtypes=helpers.get_dtypes("float_and_integer"), | ||||||
large_abs_safety_factor=2, | ||||||
small_abs_safety_factor=2, | ||||||
safety_factor_scale="linear", | ||||||
available_dtypes=helpers.get_dtypes("float_and_complex"), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
large_abs_safety_factor=4, | ||||||
small_abs_safety_factor=4, | ||||||
safety_factor_scale="log", | ||||||
), | ||||||
test_with_out=st.just(False), | ||||||
) | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.