-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix kernel data writing to DLA in TVM order (#73)
* Fix kernel data writing to DLA in TVM order
- Loading branch information
1 parent
0ee2fb8
commit 856d524
Showing
1 changed file
with
35 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -195,14 +195,11 @@ impl<T: Clone> Tensor4<T> { | |
|
||
let standard_shape = [kernels, channels, height, width]; | ||
let dim_order: [usize; 4] = order.into_position(); | ||
let kernels_ordered = | ||
standard_shape[unsafe { dim_order.iter().position(|&r| r == 0).unwrap_unchecked() }]; | ||
let channels_ordered = | ||
standard_shape[unsafe { dim_order.iter().position(|&r| r == 1).unwrap_unchecked() }]; | ||
let height_ordered = | ||
standard_shape[unsafe { dim_order.iter().position(|&r| r == 2).unwrap_unchecked() }]; | ||
let width_ordered = | ||
standard_shape[unsafe { dim_order.iter().position(|&r| r == 3).unwrap_unchecked() }]; | ||
|
||
let kernels_ordered = standard_shape[dim_order[0]]; | ||
let channels_ordered = standard_shape[dim_order[1]]; | ||
let height_ordered = standard_shape[dim_order[2]]; | ||
let width_ordered = standard_shape[dim_order[3]]; | ||
|
||
let data = Array::from_shape_vec( | ||
( | ||
|
@@ -324,8 +321,38 @@ impl<T: Clone> Tensor4<T> { | |
if order == self.order { | ||
return self.to_buffer(); | ||
} | ||
|
||
// NOTE:(20240925 [email protected]) TVM order fix | ||
if self.order == Order4::HWCK && order == Order4::HWKC { | ||
return self.tvm_layout_to_headsail(); | ||
} | ||
|
||
let mut data = self.clone(); | ||
data.permute(order); | ||
data.to_buffer() | ||
} | ||
|
||
/// Convert HWIO (HWCK) order to HWOI (HWKC) for headsail | ||
pub fn tvm_layout_to_headsail(&self) -> Vec<T> { | ||
let data = self.to_buffer(); | ||
|
||
let mut hwoi_flat = | ||
Vec::with_capacity(self.height() * self.width() * self.channels() * self.kernels()); | ||
// Iterate and assign values from HWIO to HWOI format | ||
for h_idx in 0..self.height() { | ||
for w_idx in 0..self.width() { | ||
for o_idx in 0..self.kernels() { | ||
for i_idx in 0..self.channels() { | ||
let hwio_index = h_idx * self.width() * self.channels() * self.kernels() | ||
+ w_idx * self.channels() * self.kernels() | ||
+ i_idx * self.kernels() | ||
+ o_idx; | ||
hwoi_flat.push(data[hwio_index].clone()); | ||
} | ||
} | ||
} | ||
} | ||
|
||
hwoi_flat | ||
} | ||
} |