Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Added scatter_nd tensorflow frontend function #27274

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ivy/functional/frontends/tensorflow/general_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,11 @@ def scan(
return ivy.associative_scan(elems, fn, reverse=reverse)


@to_ivy_arrays_and_back
def scatter_nd(indices, updates, shape, name=None):
ivy.scatter_nd(indices, updates, shape)


@to_ivy_arrays_and_back
def searchsorted(sorted_sequence, values, side="left", out_type="int32"):
out_type = to_ivy_dtype(out_type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,83 @@ def _strided_slice_helper(draw):
return dtype, x, np.array(begin), np.array(end), np.array(strides), masks


@st.composite
def _values_and_ndindices(
draw,
*,
array_dtypes,
indices_dtypes=helpers.get_dtypes("integer"),
allow_inf=False,
x_min_value=None,
x_max_value=None,
min_num_dims=2,
max_num_dims=5,
min_dim_size=1,
max_dim_size=10,
):
x_dtype, x, x_shape = draw(
helpers.dtype_and_values(
available_dtypes=array_dtypes,
allow_inf=allow_inf,
ret_shape=True,
min_value=x_min_value,
max_value=x_max_value,
min_num_dims=min_num_dims,
max_num_dims=max_num_dims,
min_dim_size=min_dim_size,
max_dim_size=max_dim_size,
)
)
x_dtype = x_dtype[0] if isinstance(x_dtype, (list)) else x_dtype
x = x[0] if isinstance(x, (list)) else x
# indices_dims defines how far into the array to index.
indices_dims = draw(
helpers.ints(
min_value=1,
max_value=len(x_shape) - 1,
)
)

# num_ndindices defines the number of elements to generate.
num_ndindices = draw(
helpers.ints(
min_value=1,
max_value=x_shape[indices_dims],
)
)

# updates_dims defines how far into the array to index.
updates_dtype, updates = draw(
helpers.dtype_and_values(
available_dtypes=array_dtypes,
allow_inf=allow_inf,
shape=x_shape[indices_dims:],
num_arrays=num_ndindices,
shared_dtype=True,
)
)
updates_dtype = (
updates_dtype[0] if isinstance(updates_dtype, list) else updates_dtype
)
updates = updates[0] if isinstance(updates, list) else updates

indices = []
indices_dtype = draw(st.sampled_from(indices_dtypes))
for _ in range(num_ndindices):
nd_index = []
for j in range(indices_dims):
axis_index = draw(
helpers.ints(
min_value=0,
max_value=max(0, x_shape[j] - 1),
)
)
nd_index.append(axis_index)
indices.append(nd_index)
indices = np.array(indices)
return [x_dtype, indices_dtype, updates_dtype], x, indices, updates


@st.composite
def _x_cast_dtype_shape(draw):
x_dtype = draw(helpers.get_dtypes("valid", full=False))
Expand Down Expand Up @@ -1699,6 +1776,33 @@ def _test_fn(a, x):
)


@handle_frontend_test(
fn_tree="tensorflow.scatter_nd",
x=_values_and_ndindices(
array_dtypes=helpers.get_dtypes("numeric"),
indices_dtypes=["int32", "int64"],
x_min_value=0,
x_max_value=0,
min_num_dims=2,
allow_inf=False,
),
)
def test_tensorflow_scatter_nd(x, frontend, backend_fw, test_flags, fn_tree, on_device):
(val_dtype, ind_dtype, update_dtype), vals, ind, updates = x
shape = vals.shape
helpers.test_frontend_function(
input_dtypes=[val_dtype, ind_dtype, update_dtype],
frontend=frontend,
backend_to_test=backend_fw,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
indices=np.asarray(ind, dtype=ind_dtype),
updates=updates,
shape=shape,
)


# searchsorted
@handle_frontend_test(
fn_tree="tensorflow.searchsorted",
Expand Down
Loading