Skip to content

Commit

Permalink
exec
Browse files Browse the repository at this point in the history
  • Loading branch information
aobolensk committed Feb 14, 2025
1 parent 7976fe9 commit ac26ffc
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -306,10 +306,36 @@ void BrgemmAMXBatchedKernelExecutor::execute(const BrgemmAMXBatchedKernelExecuto
config.get_M(),
config.get_N(),
K_tail);
execute_brgemm_kernel(K_tail_kernel->brgemm_kernel, src_ptr, wei_ptr, args->C, scratch, false);
execute_brgemm(K_tail_kernel->brgemm_kernel, config.get_iter_count(), src_ptr, wei_ptr, args->C, scratch, false);
}
}

void BrgemmAMXBatchedKernelExecutor::execute_brgemm(const std::shared_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& kernel,
size_t bs,
const void* pin0,
const void* pin1,
void* dst,
void* scratch,
bool with_comp) {
cpu::x64::brgemm_kernel_params_t brgemm_p;
brgemm_batch_element_t addr_batch;
addr_batch.ptr.A = pin0;
addr_batch.ptr.B = pin1;
brgemm_p.batch = &addr_batch;
brgemm_p.ptr_A = nullptr;
brgemm_p.ptr_B = nullptr;
brgemm_p.ptr_C = dst;
brgemm_p.ptr_D = dst;
brgemm_p.ptr_buf = scratch;
brgemm_p.ptr_bias = nullptr;
brgemm_p.do_post_ops = with_comp;
brgemm_p.do_apply_comp = with_comp;
brgemm_p.skip_accm = 0;
brgemm_p.BS = bs;
OV_CPU_JIT_EMITTER_ASSERT(kernel, "has nullptr Brgemm kernel");
(*kernel)(&brgemm_p);
}

#undef INNER_K_BLK
#undef VNNI_FACTOR
#undef EQ
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ class BrgemmAMXBatchedKernelExecutor : public BrgemmBaseKernelExecutor,
const void* tr_src,
dnnl_dim_t M,
dnnl_dim_t K);

static void execute_brgemm(const std::shared_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& kernel,
size_t bs,
const void* src,
const void* wei,
void* dst,
void* scratch,
bool with_comp);
};
// #define GET_OFF_BRGEMM_AMX_ARGS(field) offsetof(BrgemmAMXBatchedKernelExecutor::call_args, field)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,7 @@ void BrgemmBatchedKernelExecutor::execute(const BrgemmBatchedKernelExecutor* exe

// Note: compensations should be applied only once, so we do it only on the first iteration, when beta == 0
const auto is_with_comp = config.get_beta() == 0 && config.is_with_comp();
for (size_t i = 0; i < config.get_iter_count(); i++) {
execute_brgemm(kernel->brgemm_kernel, 1, args->A, args->B, args->C, args->scratch, is_with_comp);
// execute_brgemm(kernel->brgemm_kernel, config.get_iter_count(), args->A, args->B, args->C, args->scratch, is_with_comp);
}
execute_brgemm(kernel->brgemm_kernel, config.get_iter_count(), args->A, args->B, args->C, args->scratch, is_with_comp);
}

void BrgemmBatchedKernelExecutor::execute_brgemm(const std::shared_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& kernel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ bool pass::BuildBrgemm::run(snippets::lowered::LinearIR& linear_ir,
for (auto expr_it = begin; expr_it != end; expr_it++) {
const auto& expr = *expr_it;
const auto gemm_node = ov::as_type_ptr<GemmCPU>(expr->get_node());
if (!gemm_node || gemm_node->is_dynamic() || with_compensations(gemm_node->get_type()) ||
with_amx(gemm_node->get_type())) {
if (!gemm_node || gemm_node->is_dynamic() || with_compensations(gemm_node->get_type())) {
continue;
}
const auto& loop_manager = linear_ir.get_loop_manager();
Expand Down

0 comments on commit ac26ffc

Please sign in to comment.