Skip to content

Commit

Permalink
Merge pull request #468 from hakymulla/gather_elements
Browse files Browse the repository at this point in the history
Gather elements Operator
  • Loading branch information
raphaelDkhn authored Nov 30, 2023
2 parents 0b92daa + 3888e39 commit 16234c1
Show file tree
Hide file tree
Showing 80 changed files with 4,506 additions and 7 deletions.
5 changes: 5 additions & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased] - 2023-11-20

## Added
- Gather Elements Operator.

## [Unreleased] - 2023-11-06

## Added
Expand Down
1 change: 1 addition & 0 deletions docs/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
* [tensor.reduce\_sum\_square](framework/operators/tensor/tensor.reduce\_sum\_square.md)
* [tensor.reduce\_l2](framework/operators/tensor/tensor.reduce\_l2.md)
* [tensor.reduce\_l1](framework/operators/tensor/tensor.reduce\_l1.md)
* [tensor.gather\_elements](framework/operators/tensor/tensor.gather\_elements.md)
* [tensor.sequence\_length](framework/operators/tensor/tensor.sequence\_length.md)
* [tensor.sequence\_at](framework/operators/tensor/tensor.sequence\_at.md)
* [tensor.reduce\_min](framework/operators/tensor/tensor.reduce\_min.md)
Expand Down
1 change: 1 addition & 0 deletions docs/framework/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ You can see below the list of current supported ONNX Operators:
| [ConstantOfShape](operators/tensor/tensor.constant_of_shape.md) | :white\_check\_mark: |
| [ReduceL1](operators/tensor/tensor.reduce\_l1.md) | :white\_check\_mark: |
| [ReduceL2](operators/tensor/tensor.reduce\_l2.md) | :white\_check\_mark: |
| [GatherElements](operators/tensor/tensor.gather/_elements.md) | :white\_check\_mark: |
| [SequenceLength](operators/tensor/tensor.sequence\_length.md) | :white\_check\_mark: |
| [SequenceAt](operators/tensor/tensor.sequence\_at.md) | :white\_check\_mark: |
| [SequenceConstruct](operators/tensor/tensor.sequence\_construct.md) | :white\_check\_mark: |
Expand Down
3 changes: 3 additions & 0 deletions docs/framework/operators/tensor/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ use orion::operators::tensor::TensorTrait;
| [`tensor.scatter`](tensor.scatter.md) | Produces a copy of input data, and updates value to values specified by updates at specific index positions specified by indices. |
| [`tensor.reduce_sum_square`](tensor.reduce\_sum\_square.md) | Computes the sum square of the input tensor's elements along the provided axes. |
| [`tensor.reduce_l2`](tensor.reduce\_l2.md) | Computes the L2 norm of the input tensor's elements along the provided axes. |
| [`tensor.gather_elements`](tensor.gather\_elements.md) | GatherElements is an indexing operation that produces its output by indexing into the input data tensor at index positions determined by elements of the indices tensor. |
| [`tensor.reduce_min`](tensor.reduce\_min.md) | Computes the min of the input tensor's elements along the provided axes. |
| [`tensor.sequence_construct`](tensor.sequence\_construct.md) | Constructs a tensor sequence containing the input tensors. |
| [`tensor.sequence_length`](tensor.sequence\_length.md) | Returns the length of the input sequence. |
| [`tensor.shrink`](tensor.shrink.md) | Shrinks the input tensor element-wise to the output tensor with the same datatype and shape based on a defined formula. |
| [`tensor.sequence_empty`](tensor.sequence\_empty.md) | Returns an empty tensor sequence. |
Expand Down
47 changes: 47 additions & 0 deletions docs/framework/operators/tensor/tensor.gather_elements.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# tensor.gather_elements

```rust
fn gather_elements(self: @Tensor<T>, indices: Tensor<T>, axis: Option<usize>) -> Tensor<T>;
```

GatherElements is an indexing operation that produces its output by indexing into the input data tensor at index positions determined by elements of the indices tensor.

## Args

* `self`(`@Tensor<T>`) - The input tensor.
* `indices`(`Tensor<T>`) - Tensor of indices.
* `axis`(`Option<usize>`) - Axis to gather_elements on. Default: axis=0.

## Panics

* Panics if index values are not within bounds [-s, s-1] along axis of size s.

## Returns

A new `Tensor<T>` .

## Example

```rust
use array::{ArrayTrait, SpanTrait};

use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor};

fn gather_elements_example() -> Tensor<u32> {
let tensor = TensorTrait::<u32>::new(
shape: array![3, 3].span(),
data: array![[ 1, 2, 3],[4, 5, 6], [7, 8, 9]].span(),
);
let indices = TensorTrait::<u32>::new(
shape: array![1, 2, 0].span(),
data: array![2, 0, 0].span(),
);

return tensor.gather_elements(
indices: indices,
axis: Option::None(()),
);
}
>>> [[4. 8. 3.]
[7. 2. 3.]]
```
268 changes: 268 additions & 0 deletions nodegen/node/gather_elements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
import numpy as np
from nodegen.node import RunAll
from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl, Trait

def gather_elements(data, indices, axis=0): # type: ignore
data_swaped = np.swapaxes(data, 0, axis)
index_swaped = np.swapaxes(indices, 0, axis)
gathered = np.choose(index_swaped, data_swaped, mode="wrap")
y = np.swapaxes(gathered, 0, axis)
return y

class Gather_elements(RunAll):

@staticmethod
def gather_elements_fp16x16():
def gather_elements_3D():
def default():
x1 = np.arange(0,27).reshape(3,3,3).astype(np.int64)
x2 = np.random.randint(low = 0,high=2, size=(3,3,3)).astype(np.uint32)
y = gather_elements(x1, x2, axis=0)

x1 = Tensor(Dtype.FP16x16, x1.shape, to_fp(x1.flatten(), FixedImpl.FP16x16))
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.FP16x16, y.shape, to_fp(
y.flatten(), FixedImpl.FP16x16))

name = "gather_elements_fp16x16_3d_default"
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather_elements(indices:input_1, axis:Option::Some(0))",
name= name)

def axis1():
x1 = np.arange(0,27).reshape(3,3,3).astype(np.int64)
x2 = np.random.randint(low = 0,high=3, size=(3,3,3)).astype(np.uint32)
y = gather_elements(x1, x2, axis=1)

x1 = Tensor(Dtype.FP16x16, x1.shape, to_fp(x1.flatten(), FixedImpl.FP16x16))
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.FP16x16, y.shape, to_fp(
y.flatten(), FixedImpl.FP16x16))

name = "gather_elements_fp16x16_3d_axis1"
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather_elements(indices:input_1, axis:Option::Some(1))",
name= name)

def axis2():
x1 = np.arange(0,27).reshape(3,3,3).astype(np.int64)
x2 = np.random.randint(low = 0,high=3, size=(3,3,3)).astype(np.uint32)
y = gather_elements(x1, x2, axis=2)

x1 = Tensor(Dtype.FP16x16, x1.shape, to_fp(x1.flatten(), FixedImpl.FP16x16))
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.FP16x16, y.shape, to_fp(
y.flatten(), FixedImpl.FP16x16))

name = "gather_elements_fp16x16_3d_axis2"
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather_elements(indices:input_1, axis:Option::Some(2))",
name= name)

default()
axis1()
axis2()
gather_elements_3D()


@staticmethod
def gather_elements_fp8x23():
def gather_elements_3D():
def default():
x1 = np.arange(0,27).reshape(3,3,3).astype(np.int64)
x2 = np.random.randint(low = 0,high=2, size=(3,3,3)).astype(np.int64)
y = gather_elements(x1, x2, axis=0)

x1 = Tensor(Dtype.FP8x23, x1.shape, to_fp(x1.flatten(), FixedImpl.FP8x23))
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.FP8x23, y.shape, to_fp(y.flatten(), FixedImpl.FP8x23))

name = "gather_elements_fp8x23_3d_default"
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather_elements(indices:input_1, axis:Option::Some(0))",
name= name)

def axis1():
x1 = np.arange(0,27).reshape(3,3,3).astype(np.int64)
x2 = np.random.randint(low = 0,high=3, size=(3,3,3)).astype(np.int64)
y = gather_elements(x1, x2, axis=1)

x1 = Tensor(Dtype.FP8x23, x1.shape, to_fp(x1.flatten(), FixedImpl.FP8x23))
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.FP8x23, y.shape, to_fp(y.flatten(), FixedImpl.FP8x23))

name = "gather_elements_fp8x23_3d_axis1"
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather_elements(indices:input_1, axis:Option::Some(1))",
name= name)

def axis2():
x1 = np.arange(0,27).reshape(3,3,3).astype(np.int64)
x2 = np.random.randint(low = 0,high=3, size=(3,3,3)).astype(np.int64)
y = gather_elements(x1, x2, axis=2)

x1 = Tensor(Dtype.FP8x23, x1.shape, to_fp(x1.flatten(), FixedImpl.FP8x23))
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.FP8x23, y.shape, to_fp(y.flatten(), FixedImpl.FP8x23))

name = "gather_elements_fp8x23_3d_axis2"
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather_elements(indices:input_1, axis:Option::Some(2))",
name= name)

default()
axis1()
axis2()
gather_elements_3D()


@staticmethod
def gather_elements_i8():
def gather_elements_3D():
def default():
x1 = np.arange(0,9).reshape(3,3).astype(np.int8)
x2 = np.random.randint(low = 0,high=2, size=(3,3)).astype(np.int8)
y = gather_elements(x1, x2, axis=0)

x1 = Tensor(Dtype.I8, x1.shape, x1.flatten())
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.I8, y.shape, y.flatten())

name = "gather_elements_i8_3d_default"
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather_elements(indices:input_1, axis:Option::Some(0))",
name= name)

def axis1():
x1 = np.arange(0,9).reshape(3,3).astype(np.int8)
x2 = np.random.randint(low = 0,high=2, size=(3,3)).astype(np.int8)
y = gather_elements(x1, x2, axis=1)

x1 = Tensor(Dtype.I8, x1.shape, x1.flatten())
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.I8, y.shape, y.flatten())

name = "gather_elements_i8_3d_axis1"
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather_elements(indices:input_1, axis:Option::Some(1))",
name= name)

default()
axis1()
gather_elements_3D()


@staticmethod
def gather_elements_i32():
def gather_elements_3D():
def default():
x1 = np.arange(0,24).reshape(4,2,3).astype(np.int32)
x2 = np.random.randint(low = 0,high=2, size=(5,2,3)).astype(np.int32)
y = gather_elements(x1, x2, axis=0)

x1 = Tensor(Dtype.I32, x1.shape, x1.flatten())
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.I32, y.shape, y.flatten())

name = "gather_elements_i32_3d_default"
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather_elements(indices:input_1, axis:Option::Some(0))",
name= name)

def axis1():
x1 = np.arange(0,24).reshape(4,2,3).astype(np.int32)
x2 = np.random.randint(low = 0,high=2, size=(4,3,3)).astype(np.int32)
y = gather_elements(x1, x2, axis=1)

x1 = Tensor(Dtype.I32, x1.shape, x1.flatten())
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.I32, y.shape, y.flatten())

name = "gather_elements_i32_3d_axis1"
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather_elements(indices:input_1, axis:Option::Some(1))",
name= name)

def axis2():
x1 = np.arange(0,24).reshape(4,2,3).astype(np.int32)
x2 = np.random.randint(low = 0,high=2, size=(4,2,4)).astype(np.int32)
y = gather_elements(x1, x2, axis=2)

x1 = Tensor(Dtype.I32, x1.shape, x1.flatten())
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.I32, y.shape, y.flatten())

name = "gather_elements_i32_3d_axis2"
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather_elements(indices:input_1, axis:Option::Some(2))",
name= name)

default()
axis1()
axis2()
gather_elements_3D()

@staticmethod
def gather_elements_u32():
def gather_elements_3D():
def default():
x1 = np.arange(0,108).reshape(3,3,4,3).astype(np.int32)
x2 = np.random.randint(low = 0,high=3, size=(10,3,4,3)).astype(np.int32)
y = gather_elements(x1, x2, axis=0)

x1 = Tensor(Dtype.U32, x1.shape, x1.flatten())
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.U32, y.shape, y.flatten())

name = "gather_elements_u32_default"
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather_elements(indices:input_1, axis:Option::Some(0))",
name= name)

def axis1():
x1 = np.arange(0,108).reshape(3,3,4,3).astype(np.int32)
x2 = np.random.randint(low = 0,high=3, size=(3,5,4,3)).astype(np.int32)
y = gather_elements(x1, x2, axis=1)

x1 = Tensor(Dtype.U32, x1.shape, x1.flatten())
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.U32, y.shape, y.flatten())

name = "gather_elements_u32_axis1"
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather_elements(indices:input_1, axis:Option::Some(1))",
name= name)

def axis2():
x1 = np.arange(0,108).reshape(3,3,4,3).astype(np.int32)
x2 = np.random.randint(low = 0,high=3, size=(3,3,4,3)).astype(np.int32)
y = gather_elements(x1, x2, axis=2)

x1 = Tensor(Dtype.U32, x1.shape, x1.flatten())
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.U32, y.shape, y.flatten())

name = "gather_elements_u32_axis2"
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather_elements(indices:input_1, axis:Option::Some(2))",
name= name)

def axis3():
x1 = np.arange(0,108).reshape(3,3,4,3).astype(np.int32)
x2 = np.random.randint(low = 0,high=3, size=(3,3,4,6)).astype(np.int32)
y = gather_elements(x1, x2, axis=3)

x1 = Tensor(Dtype.U32, x1.shape, x1.flatten())
x2 = Tensor(Dtype.U32, x2.shape, x2.flatten())
y = Tensor(Dtype.U32, y.shape, y.flatten())

name = "gather_elements_u32_axis3"
make_test(
inputs = [x1, x2], output = y, func_sig = "input_0.gather_elements(indices:input_1, axis:Option::Some(3))",
name= name)

default()
axis1()
axis2()
axis3()
gather_elements_3D()
Loading

0 comments on commit 16234c1

Please sign in to comment.