You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
JAX currently supports dynamic shape input through padding or bucket padding. However, this approach isn't always ideal. For example, some models require complex logic within their structure to manage padding. Consider the Pixtral encoder as an example. Padding input image shapes to minimize recompilation requires additional logic for position embedding and masks, as these are dependent on the input shapes. This padding logic isn't immediately obvious to users. What's more, different models need different padding logic, that is really a large challenge to make jax/xla support such models.
This feature would allow the compiler to handle padding internally, simplifying user code and improving the overall experience. This enhancement requires work on both the JAX framework and the XLA compiler. Discussions with the XLA team have confirmed that XLA already supports this capability. In the example below, xla compiler needs additional scalar input(s) to know the actual shape of the param.
param = s32[4] parameter(0)
size = s32[] parameter(1)
param_dynamic = s32[<=4] set-dimension-size(param, size)
// ... use param_dynamic in the body of the program...
output_dynamic = s32[<=4] ....
output_size = s32[] get-dimension-size(output_dynamic)
output_static = s32[4] remove-dynamic-size(output_dynamic)
ROOT (s32[4], s32) tuple(output_static, output_size)
Enabling XLA's bounded shape capabilities within JAX would be highly beneficial. One possible solution could involve the following steps:
Augment the jax.Array class with a new attribute to store the actual shape of the array.
Introduce a method for creating a "bounded shape array" derived from a standard array.
Improve the lowering process to ensure that these bounded shape arrays can be correctly translated into HLO (param, size and set-dimension-size)
The text was updated successfully, but these errors were encountered:
JAX currently supports dynamic shape input through padding or bucket padding. However, this approach isn't always ideal. For example, some models require complex logic within their structure to manage padding. Consider the Pixtral encoder as an example. Padding input image shapes to minimize recompilation requires additional logic for position embedding and masks, as these are dependent on the input shapes. This padding logic isn't immediately obvious to users. What's more, different models need different padding logic, that is really a large challenge to make jax/xla support such models.
This feature would allow the compiler to handle padding internally, simplifying user code and improving the overall experience. This enhancement requires work on both the JAX framework and the XLA compiler. Discussions with the XLA team have confirmed that XLA already supports this capability. In the example below, xla compiler needs additional scalar input(s) to know the actual shape of the
param
.Enabling XLA's bounded shape capabilities within JAX would be highly beneficial. One possible solution could involve the following steps:
param
,size
andset-dimension-size
)The text was updated successfully, but these errors were encountered: