Skip to content

Commit

Permalink
Fix kernel data writing to DLA in TVM order (#73)
Browse files Browse the repository at this point in the history
* Fix kernel data writing to DLA in TVM order
  • Loading branch information
vilukissa68 authored Sep 26, 2024
1 parent 0ee2fb8 commit 856d524
Showing 1 changed file with 35 additions and 8 deletions.
43 changes: 35 additions & 8 deletions examples/hpc/dla-driver/src/tensor4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,11 @@ impl<T: Clone> Tensor4<T> {

let standard_shape = [kernels, channels, height, width];
let dim_order: [usize; 4] = order.into_position();
let kernels_ordered =
standard_shape[unsafe { dim_order.iter().position(|&r| r == 0).unwrap_unchecked() }];
let channels_ordered =
standard_shape[unsafe { dim_order.iter().position(|&r| r == 1).unwrap_unchecked() }];
let height_ordered =
standard_shape[unsafe { dim_order.iter().position(|&r| r == 2).unwrap_unchecked() }];
let width_ordered =
standard_shape[unsafe { dim_order.iter().position(|&r| r == 3).unwrap_unchecked() }];

let kernels_ordered = standard_shape[dim_order[0]];
let channels_ordered = standard_shape[dim_order[1]];
let height_ordered = standard_shape[dim_order[2]];
let width_ordered = standard_shape[dim_order[3]];

let data = Array::from_shape_vec(
(
Expand Down Expand Up @@ -324,8 +321,38 @@ impl<T: Clone> Tensor4<T> {
if order == self.order {
return self.to_buffer();
}

// NOTE:(20240925 [email protected]) TVM order fix
if self.order == Order4::HWCK && order == Order4::HWKC {
return self.tvm_layout_to_headsail();
}

let mut data = self.clone();
data.permute(order);
data.to_buffer()
}

/// Convert HWIO (HWCK) order to HWOI (HWKC) for headsail
pub fn tvm_layout_to_headsail(&self) -> Vec<T> {
let data = self.to_buffer();

let mut hwoi_flat =
Vec::with_capacity(self.height() * self.width() * self.channels() * self.kernels());
// Iterate and assign values from HWIO to HWOI format
for h_idx in 0..self.height() {
for w_idx in 0..self.width() {
for o_idx in 0..self.kernels() {
for i_idx in 0..self.channels() {
let hwio_index = h_idx * self.width() * self.channels() * self.kernels()
+ w_idx * self.channels() * self.kernels()
+ i_idx * self.kernels()
+ o_idx;
hwoi_flat.push(data[hwio_index].clone());
}
}
}
}

hwoi_flat
}
}

0 comments on commit 856d524

Please sign in to comment.