-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
complex dtypes support for activation functions #21805
Conversation
Thanks for contributing to Ivy! 😊👏 |
If you are working on an open task, please edit the PR description to link to the issue you've created. For more information, please check ToDo List Issues Guide. Thank you 🤗 |
ivy/functional/ivy/activations.py
Outdated
amax = ivy.relu(x) | ||
print(amax) | ||
x = ivy.subtract(x, ivy.multiply(amax, ivy.array(2, dtype=x.dtype))) | ||
x = ivy.add(amax, ivy.log(ivy.add(1, ivy.exp(x)))) | ||
x = ivy.real(x) + _wrap_between(ivy.imag(x), ivy.pi).astype(x.dtype) * ivy.astype(1j, x.dtype) | ||
return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if I understand this implementation here, I thought softplus was simply ivy.log(ivy.add(1., ivy.exp(x)))
. I also can't see a beta
or threshold
in here, and I think we should implement at least beta
for supersetting, and I think you've left a print statement in by mistake
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, my bad I forgot to delete the print statement, about softplus(x) = ivy.log(ivy.add(1., ivy.exp(x)) each backend has it's own implementation of softplus also some other frameworks like doesn't support complex dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but torch is going to make all the activation functions support support complex dtype that's why i didn't need to implement a special case for softplus op in torch.
ivy/functional/ivy/activations.py
Outdated
) -> ivy.Array: | ||
""" | ||
Apply the softplus function element-wise. | ||
|
||
If the input is complex, then by default we apply the `log(1+ exp(x))` to each element | ||
and if the `threshold` is set we check whether the output is less than the threshold |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"less than" isn't well defined for complex numbers, so you might want to be more specific here (if you decide to use threshold in the complex version at all - it doesn't seem like you have)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay on it.
@@ -557,6 +593,9 @@ def softplus( | |||
return current_backend(x).softplus(x, beta=beta, threshold=threshold, out=out) | |||
|
|||
|
|||
softplus.jax_like = _softplus_jax_like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since softplus is well defined on the complex numbers, I'd suggest that unless you're including threshold
it would be better handled with changes to the backend functions rather than with a jax_like
function here
Hey @jshepherd01, since you've already started reviewing this PR, I'm removing myself from the assignees. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After Ved's review of my PR (#21539) I've changed the guidance I'd written on how to do these functions, so I'm also updating this review to include those changes as suggestions. There's also a small point about the dtypes in the test function.
complex_mode | ||
optional specifier for how to handle complex data types. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
complex_mode | |
optional specifier for how to handle complex data types. | |
complex_mode | |
optional specifier for how to handle complex data types. See | |
``ivy.func_wrapper.handle_complex_input`` for more detail. |
complex_mode | ||
optional specifier for how to handle complex data types. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
complex_mode | |
optional specifier for how to handle complex data types. | |
complex_mode | |
optional specifier for how to handle complex data types. See | |
``ivy.func_wrapper.handle_complex_input`` for more detail. |
complex_mode | ||
optional specifier for how to handle complex data types. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
complex_mode | |
optional specifier for how to handle complex data types. | |
complex_mode | |
optional specifier for how to handle complex data types. See | |
``ivy.func_wrapper.handle_complex_input`` for more detail. |
ivy/stateful/activations.py
Outdated
def __init__( | ||
self, | ||
complex_mode: Literal["split", "magnitude", "jax"] = "jax", | ||
): | ||
"""Apply the SOFTPLUS activation function.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Apply the SOFTPLUS activation function.""" | |
""" | |
Apply the SOFTPLUS activation function. | |
Parameters | |
---------- | |
complex_mode | |
Specifies how to handle complex input. See | |
``ivy.func_wrapper.handle_complex_input`` for more detail. | |
""" |
Parameters should be included in docstrings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, beta
and threshold
should really be in the __init__
part rather than _forward
as well but that's beside the point of this PR
ivy/stateful/activations.py
Outdated
complex_mode | ||
optional specifier for how to handle complex data types. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
complex_mode | |
optional specifier for how to handle complex data types. |
We no longer use optional arguments in the forward method for stateful components, only in __init__
ivy/stateful/activations.py
Outdated
*, | ||
beta=None, | ||
threshold=None, | ||
complex_mode=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
complex_mode=None, |
ivy/stateful/activations.py
Outdated
x, | ||
beta=beta, | ||
threshold=threshold, | ||
complex_mode=ivy.default(complex_mode, self._complex_mode) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
complex_mode=ivy.default(complex_mode, self._complex_mode) | |
complex_mode=self._complex_mode, |
large_abs_safety_factor=2, | ||
small_abs_safety_factor=2, | ||
safety_factor_scale="linear", | ||
available_dtypes=helpers.get_dtypes("float_and_complex"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
available_dtypes=helpers.get_dtypes("float_and_complex"), | |
available_dtypes=helpers.get_dtypes("numeric"), |
"float_and_complex"
doesn't include integers, so we should use "numeric"
(or "valid"
) instead
ivy/functional/ivy/activations.py
Outdated
If the input is complex, then by default we apply the `log(1+ exp(x))` to each element0 | ||
This behaviour can be changed by specifying a different `complex_mode`. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the input is complex, then by default we apply the `log(1+ exp(x))` to each element0 | |
This behaviour can be changed by specifying a different `complex_mode`. |
This behaviour is the same as with real numbers, so you could probably drop this part from the docstring - I don't think it needs to be explicitly spelled out, and we already have a description of what the complex_mode
parameter does further down
ivy/functional/ivy/activations.py
Outdated
res = ivy.where( | ||
( | ||
ivy.logical_or( | ||
ivy.real(x_beta) < 0, ivy.logical_and(ivy.real(x_beta) == 0, ivy.imag(x_beta) < 0) | ||
) | ||
), | ||
res, | ||
x, | ||
).astype(x.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The actual value of threshold
doesn't seem to be used here
ivy/functional/ivy/activations.py
Outdated
if threshold is not None: | ||
res = ivy.where( | ||
( | ||
ivy.logical_or( | ||
ivy.real(x_beta) < threshold, | ||
ivy.logical_and( | ||
ivy.real(x_beta) == threshold, ivy.imag(x_beta) < threshold | ||
), | ||
) | ||
), | ||
res, | ||
x, | ||
).astype(x.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if threshold is not None: | |
res = ivy.where( | |
( | |
ivy.logical_or( | |
ivy.real(x_beta) < threshold, | |
ivy.logical_and( | |
ivy.real(x_beta) == threshold, ivy.imag(x_beta) < threshold | |
), | |
) | |
), | |
res, | |
x, | |
).astype(x.dtype) | |
if threshold is not None: | |
res = ivy.where( | |
ivy.real(x_beta) < threshold, | |
res, | |
x, | |
).astype(x.dtype) |
I've investigated a bit and I think this is a better solution, given what threshold is trying to do numerically
If the input is complex, then by default we apply the softplus operation | ||
`log(1+ exp(x))` to each element | ||
If threshold is set we check if either its real part is strictly negative or | ||
if its real part is zero and its imaginary part is negative then we apply | ||
`input×β > threshold`. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find this docstring quite confusing, could you reword it a bit? It still seems to reference the number being negative.
Although, I suppose instead you could remove this and edit the description of the threshold
parameter (since this just describes the standard way softplus works, and then describes special behaviour of threshold with complex inputs)
def _wrap_between(y, a): | ||
"""Wrap y between [-a, a]""" | ||
a = ivy.array(a, dtype=y.dtype) | ||
a2 = ivy.array(2 * a, dtype=y.dtype) | ||
zero = ivy.array(0, dtype=y.dtype) | ||
rem = ivy.remainder(ivy.add(y, a), a2) | ||
rem = ivy.where(rem < zero, rem + a2, rem) - a | ||
return rem | ||
|
||
|
||
def _softplus_jax_like( | ||
x: Union[ivy.Array, ivy.NativeArray], | ||
/, | ||
*, | ||
fn_original=None, | ||
beta: Optional[Union[int, float]] = None, | ||
threshold: Optional[Union[int, float]] = None, | ||
out: Optional[ivy.Array] = None, | ||
): | ||
if beta is not None: | ||
x_beta = ivy.multiply(x, ivy.array(beta, dtype=x.dtype)) | ||
else: | ||
x_beta = x | ||
amax = ivy.relu(x_beta) | ||
res = ivy.subtract(x_beta, ivy.multiply(amax, ivy.array(2, dtype=x.dtype))) | ||
res = ivy.add(amax, ivy.log(ivy.add(1, ivy.exp(res)))) | ||
res = ivy.real(res) + _wrap_between(ivy.imag(res), ivy.pi).astype( | ||
x.dtype | ||
) * ivy.astype(1j, x.dtype) | ||
if beta is not None: | ||
res = ivy.divide(res, ivy.array(beta, dtype=x.dtype)) | ||
if threshold is not None: | ||
res = ivy.where( | ||
ivy.real(x_beta) < threshold, | ||
res, | ||
x, | ||
).astype(x.dtype) | ||
return res |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still not sure whether this should be here or in the backends so I'm going to request a review from @Ishticode for a second opinion
softplus activation function support for complex dtypes