Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added support for sigmoid activation fucntion #21424

Conversation

mosesdaudu001
Copy link
Contributor

The PR is not fully ready yet, as I still need to implement for 7 more functions.

As at now, I have only done for the sigmoid activation function

@github-actions
Copy link
Contributor

github-actions bot commented Aug 7, 2023

Thanks for contributing to Ivy! 😊👏
Here are some of the important points from our Contributing Guidelines 📝:
1. Feel free to ignore the run_tests (1), run_tests (2), … jobs, and only look at the display_test_results job. 👀 It contains the following two sections:
- Combined Test Results: This shows the results of all the ivy tests that ran on the PR. ✔️
- New Failures Introduced: This lists the tests that are passing on master, but fail on the PR Fork. Please try to make sure that there are no such tests. 💪
2. The lint / Check formatting / check-formatting tests check for the formatting of your code. 📜 If it fails, please check the exact error message in the logs and fix the same. ⚠️🔧
3. Finally, the test-docstrings / run-docstring-tests check for the changes made in docstrings of the functions. This may be skipped, as well. 📚
Happy coding! 🎉👨‍💻

@ivy-leaves ivy-leaves added Array API Conform to the Array API Standard, created by The Consortium for Python Data API Standards JAX Frontend Developing the JAX Frontend, checklist triggered by commenting add_frontend_checklist Ivy Functional API labels Aug 7, 2023
@ivy-leaves
Copy link

If you are working on an open task, please edit the PR description to link to the issue you've created.

For more information, please check ToDo List Issues Guide.

Thank you 🤗

@jshepherd01
Copy link
Contributor

I think you need to add the complex_mode to the container class's _static_hardswish and _static_sigmoid functions as well

@mosesdaudu001
Copy link
Contributor Author

I think you need to add the complex_mode to the container class's _static_hardswish and _static_sigmoid functions as well

Sure I will add that

Copy link
Contributor

@jshepherd01 jshepherd01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's some fiddly parts of the testing which will need changing related to the dtypes. I also think you've missed out relu6 in the jax frontend.

About testing: it seems like there's currently a lot of test failures, mostly for paddle and torch, and also complex values never seem to be generated with the tensorflow backend (and therefore are never generated for the functional API tests, since tensorflow is the ground truth for those)

@jshepherd01
Copy link
Contributor

Looks like tests are still failing. If you run the tests for the jax frontend, you'll see that some backends aren't handling complex values correctly, and if you print the input_dtype for each test you'll see that hypothesis doesn't generate any complex inputs when tensorflow is the chosen backend. Those will both need fixing.

Another thing is that as of yesterday (specifically as of #21902) the stateful API classes no longer need to take parameters like complex_mode in the _forward method, so they'll need to be removed here

Copy link
Contributor

@jshepherd01 jshepherd01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've done some testing of each of the functions you've modified and attached the feedback to each one separately to make it easier to read through

@@ -376,8 +376,13 @@ def relu(
@to_native_arrays_and_back
@handle_array_function
@handle_device_shifting
@handle_complex_input
def sigmoid(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Complex numbers are marked as unsupported for sigmoid in the tf and paddle backends

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To fix this, I have to add a jax-like function for sigmoid. I just did in the latest commit

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure that's needed? I just tested on the main branch and ivy.sigmoid(3+4j) works just fine with both paddle and tensorflow backends

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running the test without the sigmoid jax-like function gives errors in tensorflow and paddle.

[cpu-paddle-False-False] Failed: [undefined]AssertionError: the results from backend paddle and ground truth framework jax do not match (0.2689414322376251+0j)!=(0.21795842936625076-0.20194822765801285j)

[cpu-tensorflow-False-False] Failed: [undefined]AssertionError: the results from backend tensorflow and ground truth framework jax do not match (1+0j)!=-0j

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you still have the test cases that returned those errors?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Paddle:
`Falsifying example: test_jax_sigmoid(
on_device='cpu',
frontend='jax',
backend_fw='paddle',
dtype_and_x=(['complex128'], [array(-1.-1.j)]),
test_flags=FrontendFunctionTestFlags(
num_positional_args=0,
with_out=False,
inplace=False,
as_variable=[False],
native_arrays=[False],
test_compile=False,
generate_frontend_arrays=False,
),
fn_tree='ivy.functional.frontends.jax.nn.sigmoid',
)

You can reproduce this example by temporarily adding @reproduce_failure('6.82.5', b'AXicY2QAAkYGCGBEYzMwAAAAfgAG') as a decorator on your test case`

For Tensorflow:
`Falsifying example: test_jax_sigmoid(
on_device='cpu',
frontend='jax',
backend_fw='tensorflow',
dtype_and_x=(['complex128'], [array(-709.-1.j)]),
test_flags=FrontendFunctionTestFlags(
num_positional_args=0,
with_out=False,
inplace=False,
as_variable=[False],
native_arrays=[False],
test_compile=False,
generate_frontend_arrays=False,
),
fn_tree='ivy.functional.frontends.jax.nn.sigmoid',
)

You can reproduce this example by temporarily adding @reproduce_failure('6.82.5', b'AXicY2QAAkYGMGA6imBDaQAObADM') as a decorator on your test case
@handle_frontend_test(`

ivy/functional/ivy/activations.py Show resolved Hide resolved
ivy/functional/ivy/experimental/activations.py Outdated Show resolved Hide resolved
@@ -350,8 +410,13 @@ def selu(
@to_native_arrays_and_back
@handle_array_function
@handle_device_shifting
@handle_complex_input
def silu(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Complex numbers are marked as unsupported for silu in the tf, torch, and paddle backends. The implementation in the paddle backend throws an error:

    @with_unsupported_device_and_dtypes(
        {"2.5.1 and below": {"cpu": ("bfloat16",)}},
        backend_version,
    )
    def silu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor:
        if x.dtype in [paddle.float32, paddle.float64]:
            return F.silu(x)
        if paddle.is_complex(x):
>           return x * (1.0 / (1.0 + paddle_backend.exp(-x)))
E           ValueError: (InvalidArgument) __mul__(): argument (position 1) must be int, float, bool or Tensor, but got Array (at /paddle/paddle/fluid/pybind/eager_utils.cc:1435)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be done using a jax_like function, it would be better to modify the paddle backend implementation to work properly with it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the silu jax-like function, there are erros in tf and paddle during testing

./ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py::test_jax_silu[cpu-tensorflow-False-False] Failed: [undefined]AssertionError: the results from backend tensorflow and ground truth framework jax do not match (-88-1j)!=(-0+0j)

./ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py::test_jax_silu[cpu-paddle-False-False] Failed: [undefined]AssertionError: the results from backend paddle and ground truth framework jax do not match (-0.2689414322376251-0.2689414322376251j)!=(-0.4199066758155823-0.016010195016860962j)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have the test cases for these handy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Falsifying example: test_jax_silu(
on_device='cpu',
frontend='jax',
backend_fw='tensorflow',
dtype_and_x=(['complex64'], [array(-88.-1.j, dtype=complex64)]),
fn_tree='ivy.functional.frontends.jax.nn.silu',
test_flags=FrontendFunctionTestFlags(
num_positional_args=0,
with_out=False,
inplace=False,
as_variable=[False],
native_arrays=[False],
test_compile=False,
generate_frontend_arrays=False,
),
)

You can reproduce this example by temporarily adding @reproduce_failure('6.82.5', b'AXicY2AAAkYGCIhAYkNoAAZ9AFw=') as a decorator on your test case

Falsifying example: test_jax_silu(
on_device='cpu',
frontend='jax',
backend_fw='paddle',
dtype_and_x=(['complex64'], [array(-1.-1.j, dtype=complex64)]),
fn_tree='ivy.functional.frontends.jax.nn.silu',
test_flags=FrontendFunctionTestFlags(
num_positional_args=0,
with_out=False,
inplace=False,
as_variable=[False],
native_arrays=[False],
test_compile=False,
generate_frontend_arrays=False,
),
)

You can reproduce this example by temporarily adding @reproduce_failure('6.82.5', b'AXicY2AAAkYGCGBEYzMwAAAAXwAF') as a decorator on your test case

ivy/functional/ivy/experimental/activations.py Outdated Show resolved Hide resolved
ivy/functional/backends/tensorflow/activations.py Outdated Show resolved Hide resolved
ivy/functional/backends/tensorflow/activations.py Outdated Show resolved Hide resolved
ivy/functional/backends/torch/activations.py Outdated Show resolved Hide resolved
@jshepherd01
Copy link
Contributor

I've added more comments to some of the threads I started on the last review. I'll do a proper code review early next week, but if you could address those in the meantime that would be great. Thanks :)

mosesdaudu001 and others added 10 commits August 26, 2023 09:44
And a couple other bugs that were causing things to break, such as
removing duplicated atol from `test_jax_hard_swish` and adding complex
dtype to some tests. These functions now pass all tests in the jax
frontend, except sigmoid and silu which fail for tf backend due to a
known issue
Update supported dtypes for certain functions, and change the safety
factor for logsigmoid. All tests now pass locally, except problems
caused by tensorflow's sigmoid and a problem with the test for elu
(which seems to feed in the wrong values).
@jshepherd01
Copy link
Contributor

I started experimenting with fixing the errors and ended up just writing up a fix myself so I thought I might as well push it. There remain some test failures caused by an issue with tensorflow's sigmoid (reported here: tensorflow/tensorflow#61800) and some caused by bfloat16 problems with numpy, as well as an issue caused by an apparent bug in the test for elu which still needs to be fixed.

@jshepherd01 jshepherd01 removed their assignment Sep 19, 2023
@ivy-seed
Copy link

This PR has been labelled as stale because it has been inactive for more than 7 days. If you would like to continue working on this PR, then please add another comment or this PR will be closed in 7 days.

@ivy-seed ivy-seed added the Stale label Oct 14, 2023
@github-actions
Copy link
Contributor

Thank you for this PR, here is the CI results:


This pull request does not result in any additional test failures. Congratulations!

@NripeshN NripeshN closed this Nov 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Array API Conform to the Array API Standard, created by The Consortium for Python Data API Standards Ivy Functional API JAX Frontend Developing the JAX Frontend, checklist triggered by commenting add_frontend_checklist Stale
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants