Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

complex dtypes support for activation functions #21805

Merged
merged 24 commits into from
Aug 23, 2023
Merged

Conversation

mohame54
Copy link
Contributor

softplus activation function support for complex dtypes

@github-actions
Copy link
Contributor

Thanks for contributing to Ivy! 😊👏
Here are some of the important points from our Contributing Guidelines 📝:
1. Feel free to ignore the run_tests (1), run_tests (2), … jobs, and only look at the display_test_results job. 👀 It contains the following two sections:
- Combined Test Results: This shows the results of all the ivy tests that ran on the PR. ✔️
- New Failures Introduced: This lists the tests that are passing on master, but fail on the PR Fork. Please try to make sure that there are no such tests. 💪
2. The lint / Check formatting / check-formatting tests check for the formatting of your code. 📜 If it fails, please check the exact error message in the logs and fix the same. ⚠️🔧
3. Finally, the test-docstrings / run-docstring-tests check for the changes made in docstrings of the functions. This may be skipped, as well. 📚
Happy coding! 🎉👨‍💻

@ivy-leaves ivy-leaves added Array API Conform to the Array API Standard, created by The Consortium for Python Data API Standards JAX Frontend Developing the JAX Frontend, checklist triggered by commenting add_frontend_checklist Ivy Functional API labels Aug 13, 2023
@ivy-leaves
Copy link

If you are working on an open task, please edit the PR description to link to the issue you've created.

For more information, please check ToDo List Issues Guide.

Thank you 🤗

Comment on lines 520 to 525
amax = ivy.relu(x)
print(amax)
x = ivy.subtract(x, ivy.multiply(amax, ivy.array(2, dtype=x.dtype)))
x = ivy.add(amax, ivy.log(ivy.add(1, ivy.exp(x))))
x = ivy.real(x) + _wrap_between(ivy.imag(x), ivy.pi).astype(x.dtype) * ivy.astype(1j, x.dtype)
return x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I understand this implementation here, I thought softplus was simply ivy.log(ivy.add(1., ivy.exp(x))). I also can't see a beta or threshold in here, and I think we should implement at least beta for supersetting, and I think you've left a print statement in by mistake

Copy link
Contributor Author

@mohame54 mohame54 Aug 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, my bad I forgot to delete the print statement, about softplus(x) = ivy.log(ivy.add(1., ivy.exp(x)) each backend has it's own implementation of softplus also some other frameworks like doesn't support complex dtype

Copy link
Contributor Author

@mohame54 mohame54 Aug 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but torch is going to make all the activation functions support support complex dtype that's why i didn't need to implement a special case for softplus op in torch.

) -> ivy.Array:
"""
Apply the softplus function element-wise.

If the input is complex, then by default we apply the `log(1+ exp(x))` to each element
and if the `threshold` is set we check whether the output is less than the threshold
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"less than" isn't well defined for complex numbers, so you might want to be more specific here (if you decide to use threshold in the complex version at all - it doesn't seem like you have)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay on it.

@@ -557,6 +593,9 @@ def softplus(
return current_backend(x).softplus(x, beta=beta, threshold=threshold, out=out)


softplus.jax_like = _softplus_jax_like
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since softplus is well defined on the complex numbers, I'd suggest that unless you're including threshold it would be better handled with changes to the backend functions rather than with a jax_like function here

@mohame54 mohame54 requested review from jshepherd01 and removed request for Ishticode August 14, 2023 18:00
@rishabgit
Copy link
Contributor

rishabgit commented Aug 17, 2023

Hey @jshepherd01, since you've already started reviewing this PR, I'm removing myself from the assignees. Thanks!
Quick note - doesn't seem to link the issue. @mohame54 hope this helps - https://unify.ai/docs/ivy/overview/contributing/the_basics.html#todo-list-issues 😄

@rishabgit rishabgit assigned jshepherd01 and unassigned rishabgit Aug 17, 2023
Copy link
Contributor

@jshepherd01 jshepherd01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After Ved's review of my PR (#21539) I've changed the guidance I'd written on how to do these functions, so I'm also updating this review to include those changes as suggestions. There's also a small point about the dtypes in the test function.

Comment on lines 216 to 217
complex_mode
optional specifier for how to handle complex data types.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
complex_mode
optional specifier for how to handle complex data types.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Comment on lines 692 to 693
complex_mode
optional specifier for how to handle complex data types.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
complex_mode
optional specifier for how to handle complex data types.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

Comment on lines 770 to 771
complex_mode
optional specifier for how to handle complex data types.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
complex_mode
optional specifier for how to handle complex data types.
complex_mode
optional specifier for how to handle complex data types. See
``ivy.func_wrapper.handle_complex_input`` for more detail.

ivy/functional/ivy/activations.py Outdated Show resolved Hide resolved
def __init__(
self,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
):
"""Apply the SOFTPLUS activation function."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Apply the SOFTPLUS activation function."""
"""
Apply the SOFTPLUS activation function.
Parameters
----------
complex_mode
Specifies how to handle complex input. See
``ivy.func_wrapper.handle_complex_input`` for more detail.
"""

Parameters should be included in docstrings

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, beta and threshold should really be in the __init__ part rather than _forward as well but that's beside the point of this PR

Comment on lines 215 to 216
complex_mode
optional specifier for how to handle complex data types.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
complex_mode
optional specifier for how to handle complex data types.

We no longer use optional arguments in the forward method for stateful components, only in __init__

*,
beta=None,
threshold=None,
complex_mode=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
complex_mode=None,

x,
beta=beta,
threshold=threshold,
complex_mode=ivy.default(complex_mode, self._complex_mode)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
complex_mode=ivy.default(complex_mode, self._complex_mode)
complex_mode=self._complex_mode,

large_abs_safety_factor=2,
small_abs_safety_factor=2,
safety_factor_scale="linear",
available_dtypes=helpers.get_dtypes("float_and_complex"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
available_dtypes=helpers.get_dtypes("float_and_complex"),
available_dtypes=helpers.get_dtypes("numeric"),

"float_and_complex" doesn't include integers, so we should use "numeric" (or "valid") instead

Comment on lines 551 to 553
If the input is complex, then by default we apply the `log(1+ exp(x))` to each element0
This behaviour can be changed by specifying a different `complex_mode`.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If the input is complex, then by default we apply the `log(1+ exp(x))` to each element0
This behaviour can be changed by specifying a different `complex_mode`.

This behaviour is the same as with real numbers, so you could probably drop this part from the docstring - I don't think it needs to be explicitly spelled out, and we already have a description of what the complex_mode parameter does further down

Comment on lines 531 to 539
res = ivy.where(
(
ivy.logical_or(
ivy.real(x_beta) < 0, ivy.logical_and(ivy.real(x_beta) == 0, ivy.imag(x_beta) < 0)
)
),
res,
x,
).astype(x.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The actual value of threshold doesn't seem to be used here

Comment on lines 532 to 544
if threshold is not None:
res = ivy.where(
(
ivy.logical_or(
ivy.real(x_beta) < threshold,
ivy.logical_and(
ivy.real(x_beta) == threshold, ivy.imag(x_beta) < threshold
),
)
),
res,
x,
).astype(x.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if threshold is not None:
res = ivy.where(
(
ivy.logical_or(
ivy.real(x_beta) < threshold,
ivy.logical_and(
ivy.real(x_beta) == threshold, ivy.imag(x_beta) < threshold
),
)
),
res,
x,
).astype(x.dtype)
if threshold is not None:
res = ivy.where(
ivy.real(x_beta) < threshold,
res,
x,
).astype(x.dtype)

I've investigated a bit and I think this is a better solution, given what threshold is trying to do numerically

Comment on lines +569 to +574
If the input is complex, then by default we apply the softplus operation
`log(1+ exp(x))` to each element
If threshold is set we check if either its real part is strictly negative or
if its real part is zero and its imaginary part is negative then we apply
`input×β > threshold`.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this docstring quite confusing, could you reword it a bit? It still seems to reference the number being negative.

Although, I suppose instead you could remove this and edit the description of the threshold parameter (since this just describes the standard way softplus works, and then describes special behaviour of threshold with complex inputs)

Comment on lines +501 to +538
def _wrap_between(y, a):
"""Wrap y between [-a, a]"""
a = ivy.array(a, dtype=y.dtype)
a2 = ivy.array(2 * a, dtype=y.dtype)
zero = ivy.array(0, dtype=y.dtype)
rem = ivy.remainder(ivy.add(y, a), a2)
rem = ivy.where(rem < zero, rem + a2, rem) - a
return rem


def _softplus_jax_like(
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
fn_original=None,
beta: Optional[Union[int, float]] = None,
threshold: Optional[Union[int, float]] = None,
out: Optional[ivy.Array] = None,
):
if beta is not None:
x_beta = ivy.multiply(x, ivy.array(beta, dtype=x.dtype))
else:
x_beta = x
amax = ivy.relu(x_beta)
res = ivy.subtract(x_beta, ivy.multiply(amax, ivy.array(2, dtype=x.dtype)))
res = ivy.add(amax, ivy.log(ivy.add(1, ivy.exp(res))))
res = ivy.real(res) + _wrap_between(ivy.imag(res), ivy.pi).astype(
x.dtype
) * ivy.astype(1j, x.dtype)
if beta is not None:
res = ivy.divide(res, ivy.array(beta, dtype=x.dtype))
if threshold is not None:
res = ivy.where(
ivy.real(x_beta) < threshold,
res,
x,
).astype(x.dtype)
return res
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not sure whether this should be here or in the backends so I'm going to request a review from @Ishticode for a second opinion

@mohame54 mohame54 requested review from jshepherd01 and Ishticode and removed request for Ishticode August 21, 2023 12:05
@jshepherd01 jshepherd01 merged commit 023cf84 into ivy-llc:main Aug 23, 2023
98 of 132 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Array API Conform to the Array API Standard, created by The Consortium for Python Data API Standards Ivy Functional API JAX Frontend Developing the JAX Frontend, checklist triggered by commenting add_frontend_checklist
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants