-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
complex dtypes support for activation functions #21805
Changes from 21 commits
89d6de3
7501ae5
47aa912
28723e4
62e95a6
b0a6e4f
facfc54
a7ab849
258d4c5
87c63a2
5a8dc2d
0bf323f
10c1f1a
be21aba
b742d30
6060619
25d5606
b6a72ab
dfc9613
bbad132
b81f2da
eaeaa13
fe5a626
a5283e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -498,6 +498,53 @@ def softmax( | |
return current_backend(x).softmax(x, axis=axis, out=out) | ||
|
||
|
||
def _wrap_between(y, a): | ||
"""Wrap y between [-a, a]""" | ||
a = ivy.array(a, dtype=y.dtype) | ||
a2 = ivy.array(2 * a, dtype=y.dtype) | ||
zero = ivy.array(0, dtype=y.dtype) | ||
rem = ivy.remainder(ivy.add(y, a), a2) | ||
rem = ivy.where(rem < zero, rem + a2, rem) - a | ||
return rem | ||
|
||
|
||
def _softplus_jax_like( | ||
x: Union[ivy.Array, ivy.NativeArray], | ||
/, | ||
*, | ||
fn_original=None, | ||
beta: Optional[Union[int, float]] = None, | ||
threshold: Optional[Union[int, float]] = None, | ||
out: Optional[ivy.Array] = None, | ||
): | ||
if beta is not None: | ||
x_beta = ivy.multiply(x, ivy.array(beta, dtype=x.dtype)) | ||
else: | ||
x_beta = x | ||
amax = ivy.relu(x_beta) | ||
res = ivy.subtract(x_beta, ivy.multiply(amax, ivy.array(2, dtype=x.dtype))) | ||
res = ivy.add(amax, ivy.log(ivy.add(1, ivy.exp(res)))) | ||
res = ivy.real(res) + _wrap_between(ivy.imag(res), ivy.pi).astype( | ||
x.dtype | ||
) * ivy.astype(1j, x.dtype) | ||
if beta is not None: | ||
res = ivy.divide(res, ivy.array(beta, dtype=x.dtype)) | ||
if threshold is not None: | ||
res = ivy.where( | ||
( | ||
ivy.logical_or( | ||
ivy.real(x_beta) < threshold, | ||
ivy.logical_and( | ||
ivy.real(x_beta) == threshold, ivy.imag(x_beta) < threshold | ||
), | ||
) | ||
), | ||
res, | ||
x, | ||
).astype(x.dtype) | ||
return res | ||
|
||
|
||
@handle_exceptions | ||
@handle_backend_invalid | ||
@handle_nestable | ||
|
@@ -506,17 +553,25 @@ def softmax( | |
@to_native_arrays_and_back | ||
@handle_array_function | ||
@handle_device_shifting | ||
@handle_complex_input | ||
def softplus( | ||
x: Union[ivy.Array, ivy.NativeArray], | ||
/, | ||
*, | ||
beta: Optional[Union[int, float]] = None, | ||
threshold: Optional[Union[int, float]] = None, | ||
out: Optional[ivy.Array] = None, | ||
complex_mode: Literal["split", "magnitude", "jax"] = "jax", | ||
) -> ivy.Array: | ||
""" | ||
Apply the softplus function element-wise. | ||
|
||
If the input is complex, then by default we apply the softplus operation | ||
`log(1+ exp(x))` to each element | ||
If threshold is set we check if either its real part is strictly negative or | ||
if its real part is zero and its imaginary part is negative then we apply | ||
`input×β > threshold`. | ||
|
||
Comment on lines
+562
to
+567
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find this docstring quite confusing, could you reword it a bit? It still seems to reference the number being negative. Although, I suppose instead you could remove this and edit the description of the |
||
Parameters | ||
---------- | ||
x | ||
|
@@ -528,6 +583,9 @@ def softplus( | |
out | ||
optional output array, for writing the result to. It must have a shape that the | ||
inputs broadcast to. | ||
complex_mode | ||
optional specifier for how to handle complex data types. See | ||
``ivy.func_wrapper.handle_complex_input`` for more detail. | ||
|
||
Returns | ||
------- | ||
|
@@ -557,6 +615,9 @@ def softplus( | |
return current_backend(x).softplus(x, beta=beta, threshold=threshold, out=out) | ||
|
||
|
||
softplus.jax_like = _softplus_jax_like | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since softplus is well defined on the complex numbers, I'd suggest that unless you're including |
||
|
||
|
||
@handle_exceptions | ||
@handle_backend_invalid | ||
@handle_nestable | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've investigated a bit and I think this is a better solution, given what threshold is trying to do numerically