Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Bounded Shape Input #26265

Open
yaochengji opened this issue Feb 2, 2025 · 0 comments
Open

Support Bounded Shape Input #26265

yaochengji opened this issue Feb 2, 2025 · 0 comments
Assignees
Labels
enhancement New feature or request

Comments

@yaochengji
Copy link

yaochengji commented Feb 2, 2025

JAX currently supports dynamic shape input through padding or bucket padding. However, this approach isn't always ideal. For example, some models require complex logic within their structure to manage padding. Consider the Pixtral encoder as an example. Padding input image shapes to minimize recompilation requires additional logic for position embedding and masks, as these are dependent on the input shapes. This padding logic isn't immediately obvious to users. What's more, different models need different padding logic, that is really a large challenge to make jax/xla support such models.

This feature would allow the compiler to handle padding internally, simplifying user code and improving the overall experience. This enhancement requires work on both the JAX framework and the XLA compiler. Discussions with the XLA team have confirmed that XLA already supports this capability. In the example below, xla compiler needs additional scalar input(s) to know the actual shape of the param.

  param = s32[4] parameter(0)
  size = s32[] parameter(1)
  param_dynamic = s32[<=4] set-dimension-size(param, size)
  // ... use param_dynamic in the body of the program...
  output_dynamic = s32[<=4] ....
  output_size = s32[] get-dimension-size(output_dynamic)
  output_static = s32[4] remove-dynamic-size(output_dynamic)
  ROOT (s32[4], s32) tuple(output_static, output_size)

Enabling XLA's bounded shape capabilities within JAX would be highly beneficial. One possible solution could involve the following steps:

  1. Augment the jax.Array class with a new attribute to store the actual shape of the array.
  2. Introduce a method for creating a "bounded shape array" derived from a standard array.
  3. Improve the lowering process to ensure that these bounded shape arrays can be correctly translated into HLO (param, size and set-dimension-size)
@yaochengji yaochengji added the enhancement New feature or request label Feb 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants