-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
complex dtypes support for activation functions #21805
Changes from 1 commit
89d6de3
7501ae5
47aa912
28723e4
62e95a6
b0a6e4f
facfc54
a7ab849
258d4c5
87c63a2
5a8dc2d
0bf323f
10c1f1a
be21aba
b742d30
6060619
25d5606
b6a72ab
dfc9613
bbad132
b81f2da
eaeaa13
fe5a626
a5283e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -660,6 +660,7 @@ def _static_softplus( | |||||||||||
prune_unapplied: Union[bool, ivy.Container] = False, | ||||||||||||
map_sequences: Union[bool, ivy.Container] = False, | ||||||||||||
out: Optional[ivy.Container] = None, | ||||||||||||
complex_mode: Literal["split", "magnitude", "jax"] = "jax", | ||||||||||||
) -> ivy.Container: | ||||||||||||
""" | ||||||||||||
ivy.Container static method variant of ivy.softplus. This method simply wraps | ||||||||||||
|
@@ -688,6 +689,8 @@ def _static_softplus( | |||||||||||
out | ||||||||||||
optional output container, for writing the result to. It must have a shape | ||||||||||||
that the inputs broadcast to. | ||||||||||||
complex_mode | ||||||||||||
optional specifier for how to handle complex data types. | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
|
||||||||||||
Returns | ||||||||||||
------- | ||||||||||||
|
@@ -721,6 +724,7 @@ def _static_softplus( | |||||||||||
prune_unapplied=prune_unapplied, | ||||||||||||
map_sequences=map_sequences, | ||||||||||||
out=out, | ||||||||||||
complex_mode=complex_mode, | ||||||||||||
) | ||||||||||||
|
||||||||||||
def softplus( | ||||||||||||
|
@@ -734,6 +738,7 @@ def softplus( | |||||||||||
prune_unapplied: Union[bool, ivy.Container] = False, | ||||||||||||
map_sequences: Union[bool, ivy.Container] = False, | ||||||||||||
out: Optional[ivy.Container] = None, | ||||||||||||
complex_mode: Literal["split", "magnitude", "jax"] = "jax", | ||||||||||||
) -> ivy.Container: | ||||||||||||
""" | ||||||||||||
ivy.Container instance method variant of ivy.softplus. This method simply wraps | ||||||||||||
|
@@ -762,6 +767,8 @@ def softplus( | |||||||||||
out | ||||||||||||
optional output container, for writing the result to. It must have a shape | ||||||||||||
that the inputs broadcast to. | ||||||||||||
complex_mode | ||||||||||||
optional specifier for how to handle complex data types. | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
|
||||||||||||
Returns | ||||||||||||
------- | ||||||||||||
|
@@ -793,6 +800,7 @@ def softplus( | |||||||||||
prune_unapplied=prune_unapplied, | ||||||||||||
map_sequences=map_sequences, | ||||||||||||
out=out, | ||||||||||||
complex_mode=complex_mode, | ||||||||||||
) | ||||||||||||
|
||||||||||||
@staticmethod | ||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -498,6 +498,32 @@ def softmax( | |
return current_backend(x).softmax(x, axis=axis, out=out) | ||
|
||
|
||
def _wrap_between(y, a): | ||
"""wrap y between [-a, a]""" | ||
a = ivy.array(a, dtype=y.dtype) | ||
a2 = ivy.array(2 * a, dtype=y.dtype) | ||
zero = ivy.array(0, dtype=y.dtype) | ||
rem = ivy.remainder(ivy.add(y, a), a2) | ||
rem = ivy.where(rem < zero, rem + a2, rem) - a | ||
return rem | ||
|
||
|
||
def _softplus_jax_like( | ||
x: Union[ivy.Array, ivy.NativeArray], | ||
/, | ||
*, | ||
fn_original=None, | ||
beta: Optional[Union[int, float]] = None, | ||
threshold: Optional[Union[int, float]] = None, | ||
out: Optional[ivy.Array] = None, | ||
): | ||
amax = ivy.relu(x) | ||
print(amax) | ||
x = ivy.subtract(x, ivy.multiply(amax, ivy.array(2, dtype=x.dtype))) | ||
x = ivy.add(amax, ivy.log(ivy.add(1, ivy.exp(x)))) | ||
x = ivy.real(x) + _wrap_between(ivy.imag(x), ivy.pi).astype(x.dtype) * ivy.astype(1j, x.dtype) | ||
return x | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if I understand this implementation here, I thought softplus was simply There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, my bad I forgot to delete the print statement, about softplus(x) = ivy.log(ivy.add(1., ivy.exp(x)) each backend has it's own implementation of softplus also some other frameworks like doesn't support complex dtype There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but torch is going to make all the activation functions support support complex dtype that's why i didn't need to implement a special case for softplus op in torch. |
||
|
||
@handle_exceptions | ||
@handle_backend_invalid | ||
@handle_nestable | ||
|
@@ -506,17 +532,24 @@ def softmax( | |
@to_native_arrays_and_back | ||
@handle_array_function | ||
@handle_device_shifting | ||
@handle_complex_input | ||
def softplus( | ||
x: Union[ivy.Array, ivy.NativeArray], | ||
/, | ||
*, | ||
beta: Optional[Union[int, float]] = None, | ||
threshold: Optional[Union[int, float]] = None, | ||
out: Optional[ivy.Array] = None, | ||
complex_mode: Literal["split", "magnitude", "jax"] = "jax", | ||
) -> ivy.Array: | ||
""" | ||
Apply the softplus function element-wise. | ||
|
||
If the input is complex, then by default we apply the `log(1+ exp(x))` to each element | ||
and if the `threshold` is set we check whether the output is less than the threshold | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "less than" isn't well defined for complex numbers, so you might want to be more specific here (if you decide to use threshold in the complex version at all - it doesn't seem like you have) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay on it. |
||
we set it to the original value of the input if not we leave as it is. | ||
This behaviour can be changed by specifying a different `complex_mode`. | ||
|
||
Parameters | ||
---------- | ||
x | ||
|
@@ -528,6 +561,9 @@ def softplus( | |
out | ||
optional output array, for writing the result to. It must have a shape that the | ||
inputs broadcast to. | ||
complex_mode | ||
optional specifier for how to handle complex data types. See | ||
`ivy.func_wrapper.handle_complex_input` for more detail. | ||
jshepherd01 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Returns | ||
------- | ||
|
@@ -557,6 +593,9 @@ def softplus( | |
return current_backend(x).softplus(x, beta=beta, threshold=threshold, out=out) | ||
|
||
|
||
softplus.jax_like = _softplus_jax_like | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since softplus is well defined on the complex numbers, I'd suggest that unless you're including |
||
|
||
|
||
@handle_exceptions | ||
@handle_backend_invalid | ||
@handle_nestable | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -185,11 +185,22 @@ def _forward(self, x, *, axis=None): | |||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
class Softplus(Module): | ||||||||||||||||||||||
def __init__(self): | ||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||
self, | ||||||||||||||||||||||
complex_mode: Literal["split", "magnitude", "jax"] = "jax", | ||||||||||||||||||||||
): | ||||||||||||||||||||||
"""Apply the SOFTPLUS activation function.""" | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Parameters should be included in docstrings There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, |
||||||||||||||||||||||
Module.__init__(self) | ||||||||||||||||||||||
self._complex_mode = complex_mode | ||||||||||||||||||||||
|
||||||||||||||||||||||
def _forward(self, x, *, beta=None, threshold=None): | ||||||||||||||||||||||
def _forward( | ||||||||||||||||||||||
self, | ||||||||||||||||||||||
x, | ||||||||||||||||||||||
*, | ||||||||||||||||||||||
beta=None, | ||||||||||||||||||||||
threshold=None, | ||||||||||||||||||||||
complex_mode=None, | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||
): | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
|
||||||||||||||||||||||
Parameters | ||||||||||||||||||||||
|
@@ -201,14 +212,21 @@ def _forward(self, x, *, beta=None, threshold=None): | |||||||||||||||||||||
|
||||||||||||||||||||||
threshold | ||||||||||||||||||||||
values above this revert to a linear function. Default: ``None``. | ||||||||||||||||||||||
complex_mode | ||||||||||||||||||||||
optional specifier for how to handle complex data types. | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
We no longer use optional arguments in the forward method for stateful components, only in |
||||||||||||||||||||||
|
||||||||||||||||||||||
Returns | ||||||||||||||||||||||
------- | ||||||||||||||||||||||
ret | ||||||||||||||||||||||
The outputs following the SOFTPLUS activation *[batch_shape, d]* | ||||||||||||||||||||||
|
||||||||||||||||||||||
""" | ||||||||||||||||||||||
return ivy.softplus(x, beta=beta, threshold=threshold) | ||||||||||||||||||||||
return ivy.softplus( | ||||||||||||||||||||||
x, | ||||||||||||||||||||||
beta=beta, | ||||||||||||||||||||||
threshold=threshold, | ||||||||||||||||||||||
complex_mode=ivy.default(complex_mode, self._complex_mode) | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
class Mish(Module): | ||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -327,10 +327,10 @@ def test_jax_softmax( | |||||
@handle_frontend_test( | ||||||
fn_tree="jax.nn.softplus", | ||||||
dtype_and_x=helpers.dtype_and_values( | ||||||
available_dtypes=helpers.get_dtypes("float_and_integer"), | ||||||
large_abs_safety_factor=2, | ||||||
small_abs_safety_factor=2, | ||||||
safety_factor_scale="linear", | ||||||
available_dtypes=helpers.get_dtypes("float_and_complex"), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
large_abs_safety_factor=4, | ||||||
small_abs_safety_factor=4, | ||||||
safety_factor_scale="log", | ||||||
), | ||||||
test_with_out=st.just(False), | ||||||
) | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.