Skip to content

Commit

Permalink
fix (JAX backends)(statistical.py): adding a check to the implementat…
Browse files Browse the repository at this point in the history
…ion for `sum` to cast the output to an integer in cases when the input is a boolean array. This is also inline with the behavior of `jnp.sum`
  • Loading branch information
YushaArif99 committed Oct 3, 2024
1 parent 595040b commit da1d258
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions ivy/functional/backends/jax/statistical.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# global
import jax
import jax.numpy as jnp
from typing import Union, Optional, Sequence

Expand Down Expand Up @@ -112,6 +113,11 @@ def sum(
if dtype != x.dtype and not ivy.is_bool_dtype(x):
x = jnp.astype(x, dtype)
axis = tuple(axis) if isinstance(axis, list) else axis
if ivy.is_bool_dtype(x):
if jax.config.jax_enable_x64:
dtype = ivy.as_native_dtype('int64')
else:
dtype = ivy.as_native_dtype('int32')
return jnp.sum(a=x, axis=axis, dtype=dtype, keepdims=keepdims)


Expand Down

0 comments on commit da1d258

Please sign in to comment.