From 5085f21f7e34f7b7e4bc193b1bf7ea746ca075b7 Mon Sep 17 00:00:00 2001 From: marwan Date: Wed, 2 Aug 2023 16:25:58 +0300 Subject: [PATCH] foo --- .../frontends/jax/numpy/statistical.py | 13 +++++- .../test_jax/test_numpy/test_statistical.py | 40 +++++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/ivy/functional/frontends/jax/numpy/statistical.py b/ivy/functional/frontends/jax/numpy/statistical.py index 4be5c7af4a301..1a304e4c9dd42 100644 --- a/ivy/functional/frontends/jax/numpy/statistical.py +++ b/ivy/functional/frontends/jax/numpy/statistical.py @@ -1,5 +1,5 @@ # local - +from typing import Optional, Union, Tuple import ivy from ivy.func_wrapper import with_unsupported_dtypes from ivy.functional.frontends.jax.func_wrapper import ( @@ -486,3 +486,14 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=N return ivy.cov( m, y, rowVar=rowvar, bias=bias, ddof=ddof, fweights=fweights, aweights=aweights ) + + +@to_ivy_arrays_and_back +def histogram( + a: Union[ivy.Array, ivy.NativeArray], + bins: Optional[Union[int, ivy.Array, ivy.NativeArray]] = 10, + range: Optional[Tuple[float]] = None, + weights: Optional[Union[ivy.Array, ivy.NativeArray]] = None, + density: Optional[bool] = False, +) -> ivy.Array: + return ivy.histogram(a, bins=bins, range=range, weights=weights, density=density) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_statistical.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_statistical.py index 1f212f5d435f9..e567da4068ef9 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_statistical.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_statistical.py @@ -13,6 +13,9 @@ _statistical_dtype_values, _get_castable_dtype, ) +from ivy_tests.test_ivy.test_functional.test_experimental.test_core.test_statistical import ( + _histogram_helper, +) from ivy import inf @@ -1185,3 +1188,40 @@ def test_jax_cov( fweights=fweights, aweights=aweights, ) + + +@handle_frontend_test(fn_tree="jax.numpy.histogram", values=_histogram_helper()) +def test_jax_histogram( + *, + values, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + ( + a, + bins, + axis, + extend_lower_interval, + extend_upper_interval, + dtype, + range, + weights, + density, + dtype_input, + ) = values + helpers.test_frontend_function( + a=a, + bins=bins, + range=range, + weights=weights, + density=density, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input_dtypes=[dtype_input], + )