Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Add array API-compatible methods #3430

Closed
njzjz opened this issue Mar 7, 2024 · 0 comments · Fixed by #3922
Closed

[Feature Request] Add array API-compatible methods #3430

njzjz opened this issue Mar 7, 2024 · 0 comments · Fixed by #3922

Comments

@njzjz
Copy link
Member

njzjz commented Mar 7, 2024

Summary

Add a new module that multiple backends could use to produce methods fully compatible with array API.

Detailed Description

Here is an example I took from their paper.

def some_function(x):
    # Retrieve a standard-compliant namespace
    xp = x.__array_namespace__()

    # Allocate a new array on the same device as x
    y = xp.linspace(0, 2*xp.pi, 100, device=x.device)

    # Perform computation (on device)
    return xp.sin(y) * x

I think this design is only used for functions, not classes. The backends still need their own classes. Sharing the same classes needs too many hack things and is not easy to maintain (I have tried it).

The current limitation is that almost all libraries don't support the latest array APIs. For example, NumPy will support 2022.12 in v2.0. JAX will support it in the following months. As a workaround, array-api-compat will be useful but still has limitations (for example, JAX is not supported in this library). For certain methods, we may need to wait.

Another limitation is that some APIs are not in the standard API. For example, np.take_along_axis.

There are several array-consumer libraries that have supported it, like scipy and scikit-learn.

Further Information, Files, and Links

cc @wanghan-iapcm

njzjz added a commit to njzjz/deepmd-kit that referenced this issue Jun 28, 2024
Fix deepmodeling#3430.
This PR sets up the basic support for the array API, and make an example function (`compute_smooth_weight`) to support the array API.
I believe NumPy and JAX have supported it (or through `array-api-compat`), so we don't need to write things twice for NumPy and JAX (although we can write them using the ChatGPT, it's still better to maintain only one thing). There are some challeging to use it in the TorchScript, so I give it up.
Supporting more function can be implemented in the following PRs.

Signed-off-by: Jinzhe Zeng <[email protected]>
@njzjz njzjz linked a pull request Jun 28, 2024 that will close this issue
github-merge-queue bot pushed a commit that referenced this issue Jun 28, 2024
Fix #3430.
This PR sets up the basic support for the array API, and make an example
function (`compute_smooth_weight`) to support the array API. I believe
NumPy and JAX have supported it (or through `array-api-compat`), so we
don't need to write things twice for NumPy and JAX (although we can
write them using the ChatGPT, it's still better to maintain only one
thing). There are some challeging to use it in the TorchScript, so I
give it up. Supporting more function can be implemented in the following
PRs.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced testing for `compute_smooth_weight` function using
`array_api_strict` for enhanced array operations.

- **Chores**
- Updated dependencies to include `'array-api-compat'` and
`'array-api-strict>=2'` for improved compatibility and testing
capabilities.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
@njzjz njzjz closed this as completed Jun 28, 2024
mtaillefumier pushed a commit to mtaillefumier/deepmd-kit that referenced this issue Sep 18, 2024
Fix deepmodeling#3430.
This PR sets up the basic support for the array API, and make an example
function (`compute_smooth_weight`) to support the array API. I believe
NumPy and JAX have supported it (or through `array-api-compat`), so we
don't need to write things twice for NumPy and JAX (although we can
write them using the ChatGPT, it's still better to maintain only one
thing). There are some challeging to use it in the TorchScript, so I
give it up. Supporting more function can be implemented in the following
PRs.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced testing for `compute_smooth_weight` function using
`array_api_strict` for enhanced array operations.

- **Chores**
- Updated dependencies to include `'array-api-compat'` and
`'array-api-strict>=2'` for improved compatibility and testing
capabilities.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
No open projects
Status: Done
Development

Successfully merging a pull request may close this issue.

1 participant