Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix broadcasting when shapes have different sizez #579

Merged
merged 3 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions src/operators/nn/functional/gemm.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use alexandria_data_structures::array_ext::SpanTraitExt;
use core::array::SpanTrait;

use orion::numbers::NumberTrait;
Expand Down Expand Up @@ -49,16 +48,8 @@ fn gemm<

match C {
Option::Some(c) => {
let broadcast_c_shape = if c.shape.len() == 1 {
array![1].span().concat(c.shape)
} else {
c.shape
};

let c = Tensor { shape: broadcast_c_shape, data: c.data };

return mul_by_scalar(@A.matmul(@B), alpha) + mul_by_scalar(@c, beta);
},
Option::None => { return mul_by_scalar(@A.matmul(@B), alpha); }
Option::None(_) => { return mul_by_scalar(@A.matmul(@B), alpha); }
}
}
}
26 changes: 8 additions & 18 deletions src/operators/tensor/core.cairo
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use alexandria_data_structures::array_ext::ArrayTraitExt;
use core::array::{ArrayTrait, SpanTrait};
use core::serde::Serde;
use core::option::OptionTrait;
Expand Down Expand Up @@ -5743,33 +5744,22 @@ fn unravel_index(index: usize, mut shape: Span<usize>) -> Span<usize> {

/// Cf: TensorTrait::stride docstring
fn stride(mut shape: Span<usize>) -> Span<usize> {
let shape_len = shape.len();
assert(shape_len > 0, 'shape cannot be empty');

let mut result: Array<usize> = ArrayTrait::new();
let mut accumulated: usize = 1;
let mut temp_result = ArrayTrait::new();
let mut strides = ArrayTrait::new();
let mut stride = 1;
loop {
match shape.pop_back() {
Option::Some(i) => {
temp_result.append(accumulated);
accumulated *= *i;
Option::Some(size) => {
strides.append(stride);
stride *= *size;
},
Option::None => { break; }
};
};

let mut temp_result = temp_result.span();
loop {
match temp_result.pop_back() {
Option::Some(val) => { result.append(*val); },
Option::None => { break; }
};
};

return result.span();
strides.reverse().span()
}


/// Cf: TensorTrait::reshape docstring
fn reshape<T>(self: @Tensor<T>, target_shape: Span<usize>) -> Tensor<T> {
new_tensor(target_shape, *self.data)
Expand Down
118 changes: 84 additions & 34 deletions src/operators/tensor/helpers.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,37 @@ fn check_shape<T>(shape: Span<usize>, data: Span<T>) {
/// # Panics
/// * Panics if the shapes are not compatible for broadcasting.
fn check_compatibility(mut shape_1: Span<usize>, mut shape_2: Span<usize>) {
assert(shape_1.len() == shape_2.len(), 'tensors shape must match');

loop {
match shape_1.pop_front() {
Option::Some(shape_1_val) => {
let shape_2_val = *shape_2.pop_front().unwrap();

assert(
*shape_1_val == shape_2_val || *shape_1_val == 1 || shape_2_val == 1,
'tensors shape must match'
);
},
Option::None => { break; }
// Start from the last dimension by getting the length of each shape
let mut iter_1 = shape_1.len();
let mut iter_2 = shape_2.len();

// Iterate while there are dimensions left in either shape
while iter_1 > 0 || iter_2 > 0 {
// Get the current dimension for each shape, defaulting to 1 if we've run out of dimensions
let dim_1 = if iter_1 > 0 {
*shape_1[iter_1 - 1]
} else {
1
};
};
let dim_2 = if iter_2 > 0 {
*shape_2[iter_2 - 1]
} else {
1
};

// Check the broadcasting rule for the current dimension
if dim_1 != dim_2 && dim_1 != 1 && dim_2 != 1 {
panic(array!['tensors shape must match']);
}

// Move to the next dimension
if iter_1 > 0 {
iter_1 -= 1;
}
if iter_2 > 0 {
iter_2 -= 1;
}
}
}

/// Computes the index in the broadcasted tensor corresponding to the given indices and shape.
Expand All @@ -81,7 +97,15 @@ fn check_compatibility(mut shape_1: Span<usize>, mut shape_2: Span<usize>) {
/// # Returns
/// * A usize representing the index in the broadcasted tensor.
fn broadcast_index_mapping(mut shape: Span<usize>, mut indices: Span<usize>) -> usize {
assert(shape.len() == indices.len(), 'shape/indices len must be equal');
if shape.len() == indices.len() {
broadcast_index_mapping_equal_shape(shape, indices)
} else {
broadcast_index_mapping_non_equal_shape(shape, indices)
}
}


fn broadcast_index_mapping_equal_shape(mut shape: Span<usize>, mut indices: Span<usize>) -> usize {
let mut result = 0_usize;
let mut stride = stride(shape);

Expand All @@ -101,6 +125,47 @@ fn broadcast_index_mapping(mut shape: Span<usize>, mut indices: Span<usize>) ->
return result;
}

fn broadcast_index_mapping_non_equal_shape(
mut shape: Span<usize>, mut indices: Span<usize>
) -> usize {
let mut result = 0_usize;
let mut stride = stride(shape.clone());

// Calculate the offset to align indices with the rightmost dimensions of the shape
let mut offset = if shape.len() > indices.len() {
shape.len() - indices.len()
} else {
0
};

loop {
match shape.pop_back() {
Option::Some(_) => {
let stride_val = stride
.pop_back()
.unwrap_or(@1); // Default stride for non-existent dimensions is 1

// Calculate the index, using 0 for dimensions beyond the length of indices
let index_val = if offset > 0 {
offset -= 1; // Decrement offset until we align indices with the shape
0 // Use 0 for indices beyond the length of the indices span
} else {
*indices
.pop_back()
.unwrap_or(@0) // Use actual index value or 0 if indices are exhausted
};

let index = index_val * *stride_val;
result += index;
},
Option::None => { break; }
};
};

result
}


/// Generates the output shape after reducing a tensor along a specified axis.
///
/// # Arguments
Expand Down Expand Up @@ -256,32 +321,17 @@ fn broadcast_shape(mut shape1: Span<usize>, mut shape2: Span<usize>) -> Span<usi
check_compatibility(shape1, shape2);
let mut result: Array<usize> = ArrayTrait::new();

loop {
let mut dim1 = 1;
let mut dim2 = 1;

match shape1.pop_front() {
Option::Some(item) => { dim1 = *item; },
Option::None => { if shape1.len() == 0 && shape2.len() == 0 {
break ();
}; }
};

match shape2.pop_front() {
Option::Some(item) => { dim2 = *item; },
Option::None => { if shape1.len() == 0 && shape2.len() == 0 {
break ();
}; }
};
while !shape1.is_empty() || !shape2.is_empty() {
let dim1 = *shape1.pop_back().unwrap_or(@1);
let dim2 = *shape2.pop_back().unwrap_or(@1);

let broadcasted_dim = u32_max(dim1, dim2);
result.append(broadcasted_dim);
};

return result.span();
return result.reverse().span();
}


/// Substitute a value in a shape at a given index
///
/// # Arguments
Expand Down
1 change: 0 additions & 1 deletion tests/nodes/gemm_default_vector_bias.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ mod input_1;
mod input_2;
mod output_0;


use orion::operators::nn::NNTrait;
use orion::numbers::FixedTrait;
use orion::utils::{assert_eq, assert_seq_eq};
Expand Down
1 change: 0 additions & 1 deletion tests/nodes/gemm_default_vector_bias/input_2.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use orion::numbers::{FixedTrait, FP16x16};

fn input_2() -> Tensor<FP16x16> {
let mut shape = ArrayTrait::<usize>::new();
shape.append(1);
shape.append(4);

let mut data = ArrayTrait::new();
Expand Down
Loading