Skip to content

Commit

Permalink
Replace subtile_input_channel and subtile_output_channel with single …
Browse files Browse the repository at this point in the history
…subtile_input_output_channel
  • Loading branch information
lukamac committed Dec 30, 2024
1 parent bffefb6 commit 9e12f7b
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 22 deletions.
28 changes: 11 additions & 17 deletions neureka_v2/hal/neureka_v2_task.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,6 @@ void neureka_v2_task_set_op_to_conv(neureka_v2_task_t *task,
const uint8_t depthwise) {
task->depthwise = depthwise;
task->kernel_shape = kernel_shape;
task->subtile_output_channel = depthwise
? NEUREKA_V2_SUBTILE_INPUT_CHANNEL_3x3
: NEUREKA_V2_SUBTILE_OUTPUT_CHANNEL;
task->subtile_input_channel = kernel_shape == 3
? NEUREKA_V2_SUBTILE_INPUT_CHANNEL_3x3
: NEUREKA_V2_SUBTILE_INPUT_CHANNEL_1x1;

const int flag_mode = kernel_shape == 1 ? NEUREKA_V2_FLAG_MODE_1x1
: depthwise == 1 ? NEUREKA_V2_FLAG_MODE_3x3_DW
Expand Down Expand Up @@ -181,8 +175,8 @@ void neureka_v2_task_set_strides(neureka_v2_task_t *task, const uint32_t k_in,
const uint32_t w_in_stride,
const uint32_t h_out_stride,
const uint32_t w_out_stride) {
const uint32_t num_k_in =
nnx_calculate_number_of_tiles(k_in, task->subtile_input_channel);
const uint32_t num_k_in = nnx_calculate_number_of_tiles(
k_in, NEUREKA_V2_SUBTILE_INPUT_OUTPUT_CHANNEL);

const neureka_v2_stride_t input_stride = {
.d0 = w_in_stride, .d1 = h_in_stride, .d2 = 0};
Expand All @@ -197,7 +191,7 @@ void neureka_v2_task_set_strides(neureka_v2_task_t *task, const uint32_t k_in,
task->data.cfg.weights_stride.d0 = NEUREKA_V2_WEIGHT_BANDWIDTH_BYTES;
if (task->kernel_shape == 1) { // 1x1
task->data.cfg.weights_stride.d1 =
num_k_in * task->qw * NEUREKA_V2_SUBTILE_INPUT_CHANNEL_1x1 / 8;
num_k_in * task->qw * NEUREKA_V2_SUBTILE_INPUT_OUTPUT_CHANNEL / 8;
} else if (!task->depthwise) { // 3x3
task->data.cfg.weights_stride.d1 =
NEUREKA_V2_WEIGHT_BANDWIDTH_BYTES * task->qw * num_k_in;
Expand All @@ -212,19 +206,19 @@ void neureka_v2_task_set_counters(neureka_v2_task_t *task, const uint32_t k_in,
const uint32_t k_out,
const uint8_t padding_bottom,
const uint8_t padding_right) {
const uint16_t num_Ko =
nnx_calculate_number_of_tiles(k_out, task->subtile_output_channel);
const uint16_t num_Ki =
nnx_calculate_number_of_tiles(k_in, task->subtile_input_channel);
const uint16_t num_Ko = nnx_calculate_number_of_tiles(
k_out, NEUREKA_V2_SUBTILE_INPUT_OUTPUT_CHANNEL);
const uint16_t num_Ki = nnx_calculate_number_of_tiles(
k_in, NEUREKA_V2_SUBTILE_INPUT_OUTPUT_CHANNEL);
const uint16_t num_Ho =
nnx_calculate_number_of_tiles(h_out, NEUREKA_V2_SUBTILE_OUTPUT_HEIGHT);
const uint16_t num_Wo =
nnx_calculate_number_of_tiles(w_out, NEUREKA_V2_SUBTILE_OUTPUT_WIDTH);

const uint16_t rem_Ko =
nnx_calculate_last_tile_size(k_out, task->subtile_output_channel);
const uint16_t rem_Ki =
nnx_calculate_last_tile_size(k_in, task->subtile_input_channel);
const uint16_t rem_Ko = nnx_calculate_last_tile_size(
k_out, NEUREKA_V2_SUBTILE_INPUT_OUTPUT_CHANNEL);
const uint16_t rem_Ki = nnx_calculate_last_tile_size(
k_in, NEUREKA_V2_SUBTILE_INPUT_OUTPUT_CHANNEL);
const uint16_t rem_Ho =
nnx_calculate_last_tile_size(h_out, NEUREKA_V2_SUBTILE_OUTPUT_HEIGHT);
const uint16_t rem_Wo =
Expand Down
2 changes: 0 additions & 2 deletions neureka_v2/hal/neureka_v2_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ typedef struct neureka_v2_task_data_t {
typedef struct neureka_v2_task_t {
neureka_v2_task_data_t data;
uint8_t qw;
uint8_t subtile_output_channel;
uint8_t subtile_input_channel;
uint8_t kernel_shape;
uint8_t depthwise;
uint8_t id;
Expand Down
5 changes: 2 additions & 3 deletions neureka_v2/hal/neureka_v2_task_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@

#define NEUREKA_V2_SUBTILE_INPUT_HEIGHT_1x1 (6)
#define NEUREKA_V2_SUBTILE_INPUT_WIDTH_1x1 (6)
#define NEUREKA_V2_SUBTILE_INPUT_CHANNEL_1x1 (32)

#define NEUREKA_V2_SUBTILE_INPUT_HEIGHT_3x3 (8)
#define NEUREKA_V2_SUBTILE_INPUT_WIDTH_3x3 (8)
#define NEUREKA_V2_SUBTILE_INPUT_CHANNEL_3x3 (32)

#define NEUREKA_V2_SUBTILE_OUTPUT_HEIGHT (6)
#define NEUREKA_V2_SUBTILE_OUTPUT_WIDTH (6)
#define NEUREKA_V2_SUBTILE_OUTPUT_CHANNEL (32)

#define NEUREKA_V2_SUBTILE_INPUT_OUTPUT_CHANNEL (32)

#define NEUREKA_V2_OUTPUT_BANDWIDTH_BYTES (32)
#define NEUREKA_V2_WEIGHT_BANDWIDTH_BYTES (36)
Expand Down

0 comments on commit 9e12f7b

Please sign in to comment.