diff --git a/ivy/functional/frontends/tensorflow/general_functions.py b/ivy/functional/frontends/tensorflow/general_functions.py index cc98ef66453e8..96e9118fdf04e 100644 --- a/ivy/functional/frontends/tensorflow/general_functions.py +++ b/ivy/functional/frontends/tensorflow/general_functions.py @@ -400,6 +400,11 @@ def scan( return ivy.associative_scan(elems, fn, reverse=reverse) +@to_ivy_arrays_and_back +def scatter_nd(indices, updates, shape, name=None): + ivy.scatter_nd(indices, updates, shape) + + @to_ivy_arrays_and_back def searchsorted(sorted_sequence, values, side="left", out_type="int32"): out_type = to_ivy_dtype(out_type) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_general_functions.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_general_functions.py index bec86d46301a2..8711483761cbc 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_general_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_general_functions.py @@ -367,6 +367,83 @@ def _strided_slice_helper(draw): return dtype, x, np.array(begin), np.array(end), np.array(strides), masks +@st.composite +def _values_and_ndindices( + draw, + *, + array_dtypes, + indices_dtypes=helpers.get_dtypes("integer"), + allow_inf=False, + x_min_value=None, + x_max_value=None, + min_num_dims=2, + max_num_dims=5, + min_dim_size=1, + max_dim_size=10, +): + x_dtype, x, x_shape = draw( + helpers.dtype_and_values( + available_dtypes=array_dtypes, + allow_inf=allow_inf, + ret_shape=True, + min_value=x_min_value, + max_value=x_max_value, + min_num_dims=min_num_dims, + max_num_dims=max_num_dims, + min_dim_size=min_dim_size, + max_dim_size=max_dim_size, + ) + ) + x_dtype = x_dtype[0] if isinstance(x_dtype, (list)) else x_dtype + x = x[0] if isinstance(x, (list)) else x + # indices_dims defines how far into the array to index. + indices_dims = draw( + helpers.ints( + min_value=1, + max_value=len(x_shape) - 1, + ) + ) + + # num_ndindices defines the number of elements to generate. + num_ndindices = draw( + helpers.ints( + min_value=1, + max_value=x_shape[indices_dims], + ) + ) + + # updates_dims defines how far into the array to index. + updates_dtype, updates = draw( + helpers.dtype_and_values( + available_dtypes=array_dtypes, + allow_inf=allow_inf, + shape=x_shape[indices_dims:], + num_arrays=num_ndindices, + shared_dtype=True, + ) + ) + updates_dtype = ( + updates_dtype[0] if isinstance(updates_dtype, list) else updates_dtype + ) + updates = updates[0] if isinstance(updates, list) else updates + + indices = [] + indices_dtype = draw(st.sampled_from(indices_dtypes)) + for _ in range(num_ndindices): + nd_index = [] + for j in range(indices_dims): + axis_index = draw( + helpers.ints( + min_value=0, + max_value=max(0, x_shape[j] - 1), + ) + ) + nd_index.append(axis_index) + indices.append(nd_index) + indices = np.array(indices) + return [x_dtype, indices_dtype, updates_dtype], x, indices, updates + + @st.composite def _x_cast_dtype_shape(draw): x_dtype = draw(helpers.get_dtypes("valid", full=False)) @@ -1699,6 +1776,33 @@ def _test_fn(a, x): ) +@handle_frontend_test( + fn_tree="tensorflow.scatter_nd", + x=_values_and_ndindices( + array_dtypes=helpers.get_dtypes("numeric"), + indices_dtypes=["int32", "int64"], + x_min_value=0, + x_max_value=0, + min_num_dims=2, + allow_inf=False, + ), +) +def test_tensorflow_scatter_nd(x, frontend, backend_fw, test_flags, fn_tree, on_device): + (val_dtype, ind_dtype, update_dtype), vals, ind, updates = x + shape = vals.shape + helpers.test_frontend_function( + input_dtypes=[val_dtype, ind_dtype, update_dtype], + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + indices=np.asarray(ind, dtype=ind_dtype), + updates=updates, + shape=shape, + ) + + # searchsorted @handle_frontend_test( fn_tree="tensorflow.searchsorted",