diff --git a/ivy/functional/frontends/jax/array.py b/ivy/functional/frontends/jax/array.py index 8434a3578d0a0..290f7a2fd9c76 100644 --- a/ivy/functional/frontends/jax/array.py +++ b/ivy/functional/frontends/jax/array.py @@ -163,6 +163,27 @@ def sort(self, axis=-1, order=None): order=order, ) + def sum( + self, + axis=None, + dtype=None, + out=None, + keepdims=False, + initial=None, + where=None, + promote_integers=True, + ): + return jax_frontend.numpy.sum( + self, + axis=axis, + dtype=dtype, + out=out, + keepdims=keepdims, + initial=initial, + where=where, + promote_integers=promote_integers, + ) + def argsort(self, axis=-1, kind="stable", order=None): return jax_frontend.numpy.argsort(self, axis=axis, kind=kind, order=order) diff --git a/ivy/functional/frontends/jax/numpy/statistical.py b/ivy/functional/frontends/jax/numpy/statistical.py index 26ac5fc5741ec..069e6aedba66b 100644 --- a/ivy/functional/frontends/jax/numpy/statistical.py +++ b/ivy/functional/frontends/jax/numpy/statistical.py @@ -95,19 +95,20 @@ def sum( where=None, promote_integers=True, ): - if dtype is None: - dtype = "float32" if ivy.is_int_dtype(a.dtype) else ivy.as_ivy_dtype(a.dtype) - # TODO: promote_integers is only supported from JAX v0.4.10 if dtype is None and promote_integers: - if ivy.is_bool_dtype(dtype): + if ivy.is_bool_dtype(a.dtype): dtype = ivy.default_int_dtype() - elif ivy.is_uint_dtype(dtype): - if ivy.dtype_bits(dtype) < ivy.dtype_bits(ivy.default_uint_dtype()): - dtype = ivy.default_uint_dtype() - elif ivy.is_int_dtype(dtype): - if ivy.dtype_bits(dtype) < ivy.dtype_bits(ivy.default_int_dtype()): - dtype = ivy.default_int_dtype() + elif ivy.is_uint_dtype(a.dtype): + dtype = "uint64" + a = ivy.astype(a, dtype) + elif ivy.is_int_dtype(a.dtype): + dtype = "int64" + a = ivy.astype(a, dtype) + else: + dtype = a.dtype + elif dtype is None and not promote_integers: + dtype = "float32" if ivy.is_int_dtype(a.dtype) else ivy.as_ivy_dtype(a.dtype) if initial: if axis is None: diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_array.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_array.py index 888c76c18b73b..8824b7261c030 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_array.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_array.py @@ -705,6 +705,49 @@ def test_jax_array_sort( ) +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="jax.numpy.array", + method_name="sum", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=1, + max_num_dims=5, + min_dim_size=2, + max_dim_size=10, + valid_axis=True, + force_int_axis=True, + ), +) +def test_jax_sum( + dtype_and_x, + on_device, + frontend, + frontend_method_data, + backend_fw, + init_flags, + method_flags, +): + input_dtype, x, axis = dtype_and_x + helpers.test_frontend_method( + backend_to_test=backend_fw, + init_input_dtypes=input_dtype, + init_all_as_kwargs_np={ + "object": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "axis": axis, + }, + frontend=frontend, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + on_device=on_device, + atol_=1e-04, + ) + + @handle_frontend_method( class_tree=CLASS_TREE, init_tree="jax.numpy.array",