From 677168cffd5d17bd4d3fad9146e5d9abd95611d8 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 19 Feb 2024 19:24:53 +0100 Subject: [PATCH 1/3] remove shape length check --- src/operators/tensor/helpers.cairo | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/operators/tensor/helpers.cairo b/src/operators/tensor/helpers.cairo index 931eeb0af..e60b38b72 100644 --- a/src/operators/tensor/helpers.cairo +++ b/src/operators/tensor/helpers.cairo @@ -51,8 +51,6 @@ fn check_shape(shape: Span, data: Span) { /// # Panics /// * Panics if the shapes are not compatible for broadcasting. fn check_compatibility(mut shape_1: Span, mut shape_2: Span) { - assert(shape_1.len() == shape_2.len(), 'tensors shape must match'); - loop { match shape_1.pop_front() { Option::Some(shape_1_val) => { From 39c66d66473f1f0f70e04aa12a194a2baf4ea0d9 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Tue, 20 Feb 2024 12:17:16 +0100 Subject: [PATCH 2/3] fix check_compatibility --- src/operators/tensor/helpers.cairo | 42 +++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/src/operators/tensor/helpers.cairo b/src/operators/tensor/helpers.cairo index e60b38b72..13eb7f43a 100644 --- a/src/operators/tensor/helpers.cairo +++ b/src/operators/tensor/helpers.cairo @@ -51,19 +51,37 @@ fn check_shape(shape: Span, data: Span) { /// # Panics /// * Panics if the shapes are not compatible for broadcasting. fn check_compatibility(mut shape_1: Span, mut shape_2: Span) { - loop { - match shape_1.pop_front() { - Option::Some(shape_1_val) => { - let shape_2_val = *shape_2.pop_front().unwrap(); - - assert( - *shape_1_val == shape_2_val || *shape_1_val == 1 || shape_2_val == 1, - 'tensors shape must match' - ); - }, - Option::None => { break; } + // Start from the last dimension by getting the length of each shape + let mut iter_1 = shape_1.len(); + let mut iter_2 = shape_2.len(); + + // Iterate while there are dimensions left in either shape + while iter_1 > 0 || iter_2 > 0 { + // Get the current dimension for each shape, defaulting to 1 if we've run out of dimensions + let dim_1 = if iter_1 > 0 { + *shape_1[iter_1 - 1] + } else { + 1 }; - }; + let dim_2 = if iter_2 > 0 { + *shape_2[iter_2 - 1] + } else { + 1 + }; + + // Check the broadcasting rule for the current dimension + if dim_1 != dim_2 && dim_1 != 1 && dim_2 != 1 { + panic(array!['tensors shape must match']); + } + + // Move to the next dimension + if iter_1 > 0 { + iter_1 -= 1; + } + if iter_2 > 0 { + iter_2 -= 1; + } + } } /// Computes the index in the broadcasted tensor corresponding to the given indices and shape. From e9cf5a7c3bb209208e915d0ec0ca6c3c626e7d48 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Tue, 20 Feb 2024 20:07:56 +0100 Subject: [PATCH 3/3] fix broadcast_shape and broadcast_index_mapping --- src/operators/nn/functional/gemm.cairo | 13 +--- src/operators/tensor/core.cairo | 26 ++----- src/operators/tensor/helpers.cairo | 74 ++++++++++++++----- tests/nodes/gemm_default_vector_bias.cairo | 1 - .../gemm_default_vector_bias/input_2.cairo | 1 - 5 files changed, 64 insertions(+), 51 deletions(-) diff --git a/src/operators/nn/functional/gemm.cairo b/src/operators/nn/functional/gemm.cairo index c37bda880..e5b997731 100644 --- a/src/operators/nn/functional/gemm.cairo +++ b/src/operators/nn/functional/gemm.cairo @@ -1,4 +1,3 @@ -use alexandria_data_structures::array_ext::SpanTraitExt; use core::array::SpanTrait; use orion::numbers::NumberTrait; @@ -49,16 +48,8 @@ fn gemm< match C { Option::Some(c) => { - let broadcast_c_shape = if c.shape.len() == 1 { - array![1].span().concat(c.shape) - } else { - c.shape - }; - - let c = Tensor { shape: broadcast_c_shape, data: c.data }; - return mul_by_scalar(@A.matmul(@B), alpha) + mul_by_scalar(@c, beta); }, - Option::None => { return mul_by_scalar(@A.matmul(@B), alpha); } + Option::None(_) => { return mul_by_scalar(@A.matmul(@B), alpha); } } -} +} \ No newline at end of file diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index 222b0f423..9342bf328 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -1,3 +1,4 @@ +use alexandria_data_structures::array_ext::ArrayTraitExt; use core::array::{ArrayTrait, SpanTrait}; use core::serde::Serde; use core::option::OptionTrait; @@ -5743,33 +5744,22 @@ fn unravel_index(index: usize, mut shape: Span) -> Span { /// Cf: TensorTrait::stride docstring fn stride(mut shape: Span) -> Span { - let shape_len = shape.len(); - assert(shape_len > 0, 'shape cannot be empty'); - - let mut result: Array = ArrayTrait::new(); - let mut accumulated: usize = 1; - let mut temp_result = ArrayTrait::new(); + let mut strides = ArrayTrait::new(); + let mut stride = 1; loop { match shape.pop_back() { - Option::Some(i) => { - temp_result.append(accumulated); - accumulated *= *i; + Option::Some(size) => { + strides.append(stride); + stride *= *size; }, Option::None => { break; } }; }; - let mut temp_result = temp_result.span(); - loop { - match temp_result.pop_back() { - Option::Some(val) => { result.append(*val); }, - Option::None => { break; } - }; - }; - - return result.span(); + strides.reverse().span() } + /// Cf: TensorTrait::reshape docstring fn reshape(self: @Tensor, target_shape: Span) -> Tensor { new_tensor(target_shape, *self.data) diff --git a/src/operators/tensor/helpers.cairo b/src/operators/tensor/helpers.cairo index 13eb7f43a..a781be259 100644 --- a/src/operators/tensor/helpers.cairo +++ b/src/operators/tensor/helpers.cairo @@ -97,7 +97,15 @@ fn check_compatibility(mut shape_1: Span, mut shape_2: Span) { /// # Returns /// * A usize representing the index in the broadcasted tensor. fn broadcast_index_mapping(mut shape: Span, mut indices: Span) -> usize { - assert(shape.len() == indices.len(), 'shape/indices len must be equal'); + if shape.len() == indices.len() { + broadcast_index_mapping_equal_shape(shape, indices) + } else { + broadcast_index_mapping_non_equal_shape(shape, indices) + } +} + + +fn broadcast_index_mapping_equal_shape(mut shape: Span, mut indices: Span) -> usize { let mut result = 0_usize; let mut stride = stride(shape); @@ -117,6 +125,47 @@ fn broadcast_index_mapping(mut shape: Span, mut indices: Span) -> return result; } +fn broadcast_index_mapping_non_equal_shape( + mut shape: Span, mut indices: Span +) -> usize { + let mut result = 0_usize; + let mut stride = stride(shape.clone()); + + // Calculate the offset to align indices with the rightmost dimensions of the shape + let mut offset = if shape.len() > indices.len() { + shape.len() - indices.len() + } else { + 0 + }; + + loop { + match shape.pop_back() { + Option::Some(_) => { + let stride_val = stride + .pop_back() + .unwrap_or(@1); // Default stride for non-existent dimensions is 1 + + // Calculate the index, using 0 for dimensions beyond the length of indices + let index_val = if offset > 0 { + offset -= 1; // Decrement offset until we align indices with the shape + 0 // Use 0 for indices beyond the length of the indices span + } else { + *indices + .pop_back() + .unwrap_or(@0) // Use actual index value or 0 if indices are exhausted + }; + + let index = index_val * *stride_val; + result += index; + }, + Option::None => { break; } + }; + }; + + result +} + + /// Generates the output shape after reducing a tensor along a specified axis. /// /// # Arguments @@ -272,32 +321,17 @@ fn broadcast_shape(mut shape1: Span, mut shape2: Span) -> Span = ArrayTrait::new(); - loop { - let mut dim1 = 1; - let mut dim2 = 1; - - match shape1.pop_front() { - Option::Some(item) => { dim1 = *item; }, - Option::None => { if shape1.len() == 0 && shape2.len() == 0 { - break (); - }; } - }; - - match shape2.pop_front() { - Option::Some(item) => { dim2 = *item; }, - Option::None => { if shape1.len() == 0 && shape2.len() == 0 { - break (); - }; } - }; + while !shape1.is_empty() || !shape2.is_empty() { + let dim1 = *shape1.pop_back().unwrap_or(@1); + let dim2 = *shape2.pop_back().unwrap_or(@1); let broadcasted_dim = u32_max(dim1, dim2); result.append(broadcasted_dim); }; - return result.span(); + return result.reverse().span(); } - /// Substitute a value in a shape at a given index /// /// # Arguments diff --git a/tests/nodes/gemm_default_vector_bias.cairo b/tests/nodes/gemm_default_vector_bias.cairo index 24826f739..fbed99929 100644 --- a/tests/nodes/gemm_default_vector_bias.cairo +++ b/tests/nodes/gemm_default_vector_bias.cairo @@ -3,7 +3,6 @@ mod input_1; mod input_2; mod output_0; - use orion::operators::nn::NNTrait; use orion::numbers::FixedTrait; use orion::utils::{assert_eq, assert_seq_eq}; diff --git a/tests/nodes/gemm_default_vector_bias/input_2.cairo b/tests/nodes/gemm_default_vector_bias/input_2.cairo index f340a6ea2..e3d351aff 100644 --- a/tests/nodes/gemm_default_vector_bias/input_2.cairo +++ b/tests/nodes/gemm_default_vector_bias/input_2.cairo @@ -5,7 +5,6 @@ use orion::numbers::{FixedTrait, FP16x16}; fn input_2() -> Tensor { let mut shape = ArrayTrait::::new(); - shape.append(1); shape.append(4); let mut data = ArrayTrait::new();