-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added support for sigmoid activation fucntion #21424
added support for sigmoid activation fucntion #21424
Conversation
Thanks for contributing to Ivy! 😊👏 |
If you are working on an open task, please edit the PR description to link to the issue you've created. For more information, please check ToDo List Issues Guide. Thank you 🤗 |
I think you need to add the |
Sure I will add that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's some fiddly parts of the testing which will need changing related to the dtypes. I also think you've missed out relu6
in the jax frontend.
About testing: it seems like there's currently a lot of test failures, mostly for paddle and torch, and also complex values never seem to be generated with the tensorflow backend (and therefore are never generated for the functional API tests, since tensorflow is the ground truth for those)
ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py
Outdated
Show resolved
Hide resolved
ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py
Outdated
Show resolved
Hide resolved
ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py
Outdated
Show resolved
Hide resolved
ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py
Outdated
Show resolved
Hide resolved
ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py
Outdated
Show resolved
Hide resolved
ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py
Outdated
Show resolved
Hide resolved
ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py
Outdated
Show resolved
Hide resolved
Looks like tests are still failing. If you run the tests for the jax frontend, you'll see that some backends aren't handling complex values correctly, and if you print the Another thing is that as of yesterday (specifically as of #21902) the stateful API classes no longer need to take parameters like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've done some testing of each of the functions you've modified and attached the feedback to each one separately to make it easier to read through
ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py
Outdated
Show resolved
Hide resolved
@@ -376,8 +376,13 @@ def relu( | |||
@to_native_arrays_and_back | |||
@handle_array_function | |||
@handle_device_shifting | |||
@handle_complex_input | |||
def sigmoid( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Complex numbers are marked as unsupported for sigmoid
in the tf
and paddle
backends
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To fix this, I have to add a jax-like function for sigmoid. I just did in the latest commit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure that's needed? I just tested on the main branch and ivy.sigmoid(3+4j)
works just fine with both paddle and tensorflow backends
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Running the test without the sigmoid jax-like function gives errors in tensorflow and paddle.
[cpu-paddle-False-False] Failed: [undefined]AssertionError: the results from backend paddle and ground truth framework jax do not match (0.2689414322376251+0j)!=(0.21795842936625076-0.20194822765801285j)
[cpu-tensorflow-False-False] Failed: [undefined]AssertionError: the results from backend tensorflow and ground truth framework jax do not match (1+0j)!=-0j
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you still have the test cases that returned those errors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For Paddle:
`Falsifying example: test_jax_sigmoid(
on_device='cpu',
frontend='jax',
backend_fw='paddle',
dtype_and_x=(['complex128'], [array(-1.-1.j)]),
test_flags=FrontendFunctionTestFlags(
num_positional_args=0,
with_out=False,
inplace=False,
as_variable=[False],
native_arrays=[False],
test_compile=False,
generate_frontend_arrays=False,
),
fn_tree='ivy.functional.frontends.jax.nn.sigmoid',
)
You can reproduce this example by temporarily adding @reproduce_failure('6.82.5', b'AXicY2QAAkYGCGBEYzMwAAAAfgAG') as a decorator on your test case`
For Tensorflow:
`Falsifying example: test_jax_sigmoid(
on_device='cpu',
frontend='jax',
backend_fw='tensorflow',
dtype_and_x=(['complex128'], [array(-709.-1.j)]),
test_flags=FrontendFunctionTestFlags(
num_positional_args=0,
with_out=False,
inplace=False,
as_variable=[False],
native_arrays=[False],
test_compile=False,
generate_frontend_arrays=False,
),
fn_tree='ivy.functional.frontends.jax.nn.sigmoid',
)
You can reproduce this example by temporarily adding @reproduce_failure('6.82.5', b'AXicY2QAAkYGMGA6imBDaQAObADM') as a decorator on your test case
@handle_frontend_test(`
@@ -350,8 +410,13 @@ def selu( | |||
@to_native_arrays_and_back | |||
@handle_array_function | |||
@handle_device_shifting | |||
@handle_complex_input | |||
def silu( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Complex numbers are marked as unsupported for silu
in the tf
, torch
, and paddle
backends. The implementation in the paddle backend throws an error:
@with_unsupported_device_and_dtypes(
{"2.5.1 and below": {"cpu": ("bfloat16",)}},
backend_version,
)
def silu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor:
if x.dtype in [paddle.float32, paddle.float64]:
return F.silu(x)
if paddle.is_complex(x):
> return x * (1.0 / (1.0 + paddle_backend.exp(-x)))
E ValueError: (InvalidArgument) __mul__(): argument (position 1) must be int, float, bool or Tensor, but got Array (at /paddle/paddle/fluid/pybind/eager_utils.cc:1435)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this should be done using a jax_like
function, it would be better to modify the paddle
backend implementation to work properly with it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without the silu
jax-like function, there are erros in tf
and paddle
during testing
./ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py::test_jax_silu[cpu-tensorflow-False-False] Failed: [undefined]AssertionError: the results from backend tensorflow and ground truth framework jax do not match (-88-1j)!=(-0+0j)
./ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py::test_jax_silu[cpu-paddle-False-False] Failed: [undefined]AssertionError: the results from backend paddle and ground truth framework jax do not match (-0.2689414322376251-0.2689414322376251j)!=(-0.4199066758155823-0.016010195016860962j)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have the test cases for these handy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Falsifying example: test_jax_silu(
on_device='cpu',
frontend='jax',
backend_fw='tensorflow',
dtype_and_x=(['complex64'], [array(-88.-1.j, dtype=complex64)]),
fn_tree='ivy.functional.frontends.jax.nn.silu',
test_flags=FrontendFunctionTestFlags(
num_positional_args=0,
with_out=False,
inplace=False,
as_variable=[False],
native_arrays=[False],
test_compile=False,
generate_frontend_arrays=False,
),
)
You can reproduce this example by temporarily adding @reproduce_failure('6.82.5', b'AXicY2AAAkYGCIhAYkNoAAZ9AFw=') as a decorator on your test case
Falsifying example: test_jax_silu(
on_device='cpu',
frontend='jax',
backend_fw='paddle',
dtype_and_x=(['complex64'], [array(-1.-1.j, dtype=complex64)]),
fn_tree='ivy.functional.frontends.jax.nn.silu',
test_flags=FrontendFunctionTestFlags(
num_positional_args=0,
with_out=False,
inplace=False,
as_variable=[False],
native_arrays=[False],
test_compile=False,
generate_frontend_arrays=False,
),
)
You can reproduce this example by temporarily adding @reproduce_failure('6.82.5', b'AXicY2AAAkYGCGBEYzMwAAAAXwAF') as a decorator on your test case
I've added more comments to some of the threads I started on the last review. I'll do a proper code review early next week, but if you could address those in the meantime that would be great. Thanks :) |
…term into its own variable
And a couple other bugs that were causing things to break, such as removing duplicated atol from `test_jax_hard_swish` and adding complex dtype to some tests. These functions now pass all tests in the jax frontend, except sigmoid and silu which fail for tf backend due to a known issue
Update supported dtypes for certain functions, and change the safety factor for logsigmoid. All tests now pass locally, except problems caused by tensorflow's sigmoid and a problem with the test for elu (which seems to feed in the wrong values).
I started experimenting with fixing the errors and ended up just writing up a fix myself so I thought I might as well push it. There remain some test failures caused by an issue with tensorflow's |
This PR has been labelled as stale because it has been inactive for more than 7 days. If you would like to continue working on this PR, then please add another comment or this PR will be closed in 7 days. |
Thank you for this PR, here is the CI results: This pull request does not result in any additional test failures. Congratulations! |
The PR is not fully ready yet, as I still need to implement for 7 more functions.
As at now, I have only done for the sigmoid activation function