Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat:mish #604

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/framework/operators/neural-network/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@ Orion supports currently these `NN` types.
| [`nn.col2im`](nn.col2im.md) | Rearranges column blocks back into a multidimensional image |
| [`nn.conv_transpose`](nn.conv\_transpose.md) | Performs the convolution transpose of the input data tensor and weight tensor. |
| [`nn.conv`](nn.conv.md) | Performs the convolution of the input data tensor and weight tensor. |
| [`nn.mish`](nn.mish.md) | A Self Regularized Non-Monotonic Neural Activation Function. |

50 changes: 50 additions & 0 deletions docs/framework/operators/neural-network/nn.mish.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# NNTrait::mish

```rust
fn mish(tensor: @Tensor<T>) -> Tensor<T>;
```

A Self Regularized Non-Monotonic Neural Activation Function.
Perform the linear unit element-wise on the input tensor X using formula:
```rust
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x}))
```

## Args

* `tensor`(`@Tensor<T>`) - The input tensor.

## Returns

* A `Tensor<T>` with the same shape as the input tensor.

## Examples

```rust
use orion::operators::tensor::{FP8x23Tensor, FP8x23TensorAdd};
use core::array::{ArrayTrait, SpanTrait};
use orion::operators::tensor::{TensorTrait, Tensor};
use orion::utils::{assert_eq, assert_seq_eq};
use orion::operators::tensor::FP8x23TensorPartialEq;
use orion::numbers::{FixedTrait, FP8x23};
use orion::operators::nn::NNTrait;
use orion::operators::nn::FP8x23NN;

fn example() -> Tensor<FP8x23> {
let mut shape = ArrayTrait::<usize>::new();
shape.append(2);
shape.append(3);

let mut data = ArrayTrait::new();
data.append(FP8x23 { mag: 29330286, sign: true });
data.append(FP8x23 { mag: 29576280, sign: false });
data.append(FP8x23 { mag: 605854, sign: false });
data.append(FP8x23 { mag: 26167402, sign: false });
data.append(FP8x23 { mag: 24733382, sign: false });
data.append(FP8x23 { mag: 5248967, sign: true });
let tensor1 = TensorTrait::new(shape.span(), data.span());

return NNTrait::mish(@tensor1);
}
>>> [FP8x23 { mag: 875391, sign: true } , FP8x23 { mag: 29527976, sign: false } , FP8x23 { mag: 377454, sign: false } , FP8x23 { mag: 26073864, sign: false } , FP8x23 { mag: 24610957, sign: false } , FP8x23 { mag: 2120704, sign: true })]
```
35 changes: 35 additions & 0 deletions nodegen/node/mish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np
from nodegen.node import RunAll
from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl, Trait


class Mish(RunAll):

@staticmethod
def fp8x23():
x = np.random.uniform(-4, 4, (2, 3)).astype(np.float64)
y = x * np.tanh(np.log1p(np.exp(x)))

x = Tensor(Dtype.FP8x23, x.shape, to_fp(
x.flatten(), FixedImpl.FP8x23))
y = Tensor(Dtype.FP8x23, y.shape, to_fp(
y.flatten(), FixedImpl.FP8x23))

name = "mish_fp8x23"
make_test([x], y, "NNTrait::mish(@input_0)",
name, Trait.NN)

@staticmethod
def fp16x16():
x = np.random.uniform(-3, 3, (3, 2, 2, 3)).astype(np.float16)
y = x * np.tanh(np.log1p(np.exp(x)))

x = Tensor(Dtype.FP16x16, x.shape, to_fp(
x.flatten(), FixedImpl.FP16x16))
y = Tensor(Dtype.FP16x16, y.shape, to_fp(
y.flatten(), FixedImpl.FP16x16))

name = "mish_fp16x16"
make_test([x], y, "NNTrait::mish(@input_0)",
name, Trait.NN)

52 changes: 52 additions & 0 deletions src/operators/nn/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use orion::operators::tensor::core::Tensor;
/// col2im - Rearranges column blocks back into a multidimensional image
/// conv_transpose - Performs the convolution transpose of the input data tensor and weight tensor.
/// conv - Performs the convolution of the input data tensor and weight tensor.
/// mish - A Self Regularized Non-Monotonic Neural Activation Function.
trait NNTrait<T> {
/// # NNTrait::relu
///
Expand Down Expand Up @@ -1304,4 +1305,55 @@ trait NNTrait<T> {
mode: Option<orion::operators::nn::functional::grid_sample::MODE>,
padding_mode: Option<orion::operators::nn::functional::grid_sample::PADDING_MODE>,
) -> Tensor<T>;
/// # NNTrait::mish
///
/// ```rust
/// fn mish(tensor: @Tensor<T>) -> Tensor<T>;
/// ```
///
/// A Self Regularized Non-Monotonic Neural Activation Function.
/// Perform the linear unit element-wise on the input tensor X using formula:
/// ```rust
/// mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x}))
/// ```
///
/// ## Args
///
/// * `tensor`(`@Tensor<T>`) - The input tensor.
///
/// ## Returns
///
/// * A `Tensor<T>` with the same shape as the input tensor.
///
/// ## Examples
///
/// ```rust
/// use orion::operators::tensor::{FP8x23Tensor, FP8x23TensorAdd};
/// use core::array::{ArrayTrait, SpanTrait};
/// use orion::operators::tensor::{TensorTrait, Tensor};
/// use orion::utils::{assert_eq, assert_seq_eq};
/// use orion::operators::tensor::FP8x23TensorPartialEq;
/// use orion::numbers::{FixedTrait, FP8x23};
/// use orion::operators::nn::NNTrait;
/// use orion::operators::nn::FP8x23NN;
///
/// fn example() -> Tensor<FP8x23> {
/// let mut shape = ArrayTrait::<usize>::new();
/// shape.append(2);
/// shape.append(3);
///
/// let mut data = ArrayTrait::new();
/// data.append(FP8x23 { mag: 29330286, sign: true });
/// data.append(FP8x23 { mag: 29576280, sign: false });
/// data.append(FP8x23 { mag: 605854, sign: false });
/// data.append(FP8x23 { mag: 26167402, sign: false });
/// data.append(FP8x23 { mag: 24733382, sign: false });
/// data.append(FP8x23 { mag: 5248967, sign: true });
/// let tensor1 = TensorTrait::new(shape.span(), data.span());
///
/// return NNTrait::mish(@tensor1);
/// }
/// >>> [FP8x23 { mag: 875391, sign: true } , FP8x23 { mag: 29527976, sign: false } , FP8x23 { mag: 377454, sign: false } , FP8x23 { mag: 26073864, sign: false } , FP8x23 { mag: 24610957, sign: false } , FP8x23 { mag: 2120704, sign: true })]
/// ```
fn mish(tensor: @Tensor<T>) -> Tensor<T>;
}
1 change: 1 addition & 0 deletions src/operators/nn/functional.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ mod conv_transpose;
mod depth_to_space;
mod space_to_depth;
mod conv;
mod mish;
47 changes: 47 additions & 0 deletions src/operators/nn/functional/mish.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use orion::numbers::fixed_point::core::FixedTrait;
use orion::numbers::NumberTrait;
use orion::operators::tensor::core::{Tensor, TensorTrait};
use orion::operators::tensor::helpers::{reduce_output_shape, len_from_shape, combine_indices};
use orion::operators::tensor::math::{reduce_sum::accumulate_sum, arithmetic::div_downcast};

fn mish<
T,
MAG,
impl TTensor: TensorTrait<T>,
impl TNumber: NumberTrait<T, MAG>,
impl TAdd: Add<T>,
impl TSub: Sub<T>,
impl TMul: Mul<T>,
impl TDiv: Div<T>,
impl TTensorAdd: Add<Tensor<T>>,
impl TPartialOrd: PartialOrd<T>,
impl TAddEq: AddEq<T>,
impl TCopy: Copy<T>,
impl TDrop: Drop<T>,
>(
tensor: Tensor<T>
) -> Tensor<T> {
let exp = tensor.exp();
let len = (tensor.data).len();
let mut arr1: Array<T> = array![];
let mut i: usize = 0;
while i != len {
let v = *(exp.data).at(i);
let r = v + NumberTrait::one();
arr1.append(r);
i += 1;
};
let log1p = TensorTrait::<T>::new(tensor.shape, arr1.span()).log();
let tanh = log1p.tanh();

let mut arr2: Array<T> = array![];
i = 0;
while i != len {
let v1 = *(tensor.data).at(i);
let v2 = *(tanh.data).at(i);
let r = v1 * v2;
arr2.append(r);
i += 1;
};
TensorTrait::<T>::new(tensor.shape, arr2.span())
}
4 changes: 4 additions & 0 deletions src/operators/nn/implementations/nn_fp16x16.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,8 @@ impl FP16x16NN of NNTrait<FP16x16> {
) -> Tensor<FP16x16> {
functional::conv::conv(X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides)
}

fn mish(tensor: @Tensor<FP16x16>) -> Tensor<FP16x16> {
functional::mish::mish(*tensor)
}
}
4 changes: 4 additions & 0 deletions src/operators/nn/implementations/nn_fp32x32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,8 @@ impl FP32x32NN of NNTrait<FP32x32> {
) -> Tensor<FP32x32> {
functional::conv::conv(X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides)
}

fn mish(tensor: @Tensor<FP32x32>) -> Tensor<FP32x32> {
functional::mish::mish(*tensor)
}
}
4 changes: 4 additions & 0 deletions src/operators/nn/implementations/nn_fp64x64.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,8 @@ impl FP64x64NN of NNTrait<FP64x64> {
) -> Tensor<FP64x64> {
functional::conv::conv(X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides)
}

fn mish(tensor: @Tensor<FP64x64>) -> Tensor<FP64x64> {
functional::mish::mish(*tensor)
}
}
4 changes: 4 additions & 0 deletions src/operators/nn/implementations/nn_fp8x23.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,8 @@ impl FP8x23NN of NNTrait<FP8x23> {
) -> Tensor<FP8x23> {
functional::conv::conv(X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides)
}

fn mish(tensor: @Tensor<FP8x23>) -> Tensor<FP8x23> {
functional::mish::mish(*tensor)
}
}
4 changes: 4 additions & 0 deletions src/operators/nn/implementations/nn_i32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,8 @@ impl I32NN of NNTrait<i32> {
) -> Tensor<i32> {
functional::conv::conv(X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides)
}

fn mish(tensor: @Tensor<i32>) -> Tensor<i32> {
panic(array!['not supported!'])
}
}
4 changes: 4 additions & 0 deletions src/operators/nn/implementations/nn_i8.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,8 @@ impl I8NN of NNTrait<i8> {
) -> Tensor<i8> {
functional::conv::conv(X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides)
}

fn mish(tensor: @Tensor<i8>) -> Tensor<i8> {
panic(array!['not supported!'])
}
}
4 changes: 4 additions & 0 deletions src/operators/nn/implementations/nn_u32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,8 @@ impl U32NN of NNTrait<u32> {
) -> Tensor<u32> {
functional::conv::conv(X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides)
}

fn mish(tensor: @Tensor<u32>) -> Tensor<u32> {
panic(array!['not supported!'])
}
}
2 changes: 2 additions & 0 deletions tests/nodes.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -1047,3 +1047,5 @@ mod label_encoder_fp8x23_default;
mod label_encoder_i8_default;
mod label_encoder_i32_default;
mod label_encoder_u32_default;
mod mish_fp8x23;
mod mish_fp16x16;
20 changes: 20 additions & 0 deletions tests/nodes/mish_fp16x16.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
mod input_0;
mod output_0;


use orion::operators::nn::NNTrait;
use orion::utils::{assert_eq, assert_seq_eq};
use orion::numbers::FixedTrait;
use orion::operators::tensor::FP16x16TensorPartialEq;
use orion::operators::nn::FP16x16NN;

#[test]
#[available_gas(2000000000)]
fn test_mish_fp16x16() {
let input_0 = input_0::input_0();
let z_0 = output_0::output_0();

let y_0 = NNTrait::mish(@input_0);

assert_eq(y_0, z_0);
}
51 changes: 51 additions & 0 deletions tests/nodes/mish_fp16x16/input_0.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use core::array::{ArrayTrait, SpanTrait};
use orion::operators::tensor::{TensorTrait, Tensor};
use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd};
use orion::numbers::{FixedTrait, FP16x16};

fn input_0() -> Tensor<FP16x16> {
let mut shape = ArrayTrait::<usize>::new();
shape.append(3);
shape.append(2);
shape.append(2);
shape.append(3);

let mut data = ArrayTrait::new();
data.append(FP16x16 { mag: 184064, sign: true });
data.append(FP16x16 { mag: 130304, sign: false });
data.append(FP16x16 { mag: 117248, sign: true });
data.append(FP16x16 { mag: 94016, sign: false });
data.append(FP16x16 { mag: 147072, sign: false });
data.append(FP16x16 { mag: 100800, sign: true });
data.append(FP16x16 { mag: 84544, sign: false });
data.append(FP16x16 { mag: 74496, sign: false });
data.append(FP16x16 { mag: 71616, sign: false });
data.append(FP16x16 { mag: 69760, sign: false });
data.append(FP16x16 { mag: 112832, sign: false });
data.append(FP16x16 { mag: 52928, sign: false });
data.append(FP16x16 { mag: 175488, sign: true });
data.append(FP16x16 { mag: 6936, sign: true });
data.append(FP16x16 { mag: 193664, sign: true });
data.append(FP16x16 { mag: 39648, sign: false });
data.append(FP16x16 { mag: 166528, sign: true });
data.append(FP16x16 { mag: 180096, sign: false });
data.append(FP16x16 { mag: 130944, sign: true });
data.append(FP16x16 { mag: 105792, sign: false });
data.append(FP16x16 { mag: 51392, sign: true });
data.append(FP16x16 { mag: 45408, sign: true });
data.append(FP16x16 { mag: 169344, sign: true });
data.append(FP16x16 { mag: 151936, sign: true });
data.append(FP16x16 { mag: 90112, sign: true });
data.append(FP16x16 { mag: 9808, sign: true });
data.append(FP16x16 { mag: 98368, sign: true });
data.append(FP16x16 { mag: 179584, sign: true });
data.append(FP16x16 { mag: 122048, sign: false });
data.append(FP16x16 { mag: 19856, sign: false });
data.append(FP16x16 { mag: 38944, sign: true });
data.append(FP16x16 { mag: 65792, sign: true });
data.append(FP16x16 { mag: 187136, sign: false });
data.append(FP16x16 { mag: 190336, sign: false });
data.append(FP16x16 { mag: 119744, sign: true });
data.append(FP16x16 { mag: 128832, sign: true });
TensorTrait::new(shape.span(), data.span())
}
Loading
Loading