Skip to content

Commit

Permalink
cpu: x64: gemm: disable po for unsupported data types & ISAs
Browse files Browse the repository at this point in the history
  • Loading branch information
tczeszun authored and vpirogov committed Jun 29, 2023
1 parent 190a9b2 commit b76d4ca
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
7 changes: 3 additions & 4 deletions src/cpu/gemm_inner_product_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ pp_kernel_t *pp_kernel_t::create(size_t OC, size_t MB, dim_t dst_mb_stride,
bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d,
const bcast_set_t &enabled_bcast_strategy) {
#if DNNL_X64
static constexpr auto isa_supported
= x64::inner_product_utils::jit_pp_kernel_supported_isa();
const auto isa_supported
= x64::inner_product_utils::get_max_jit_pp_kernel_supported_isa();
using namespace cpu::x64;
if (mayiuse(isa_supported)) {
using namespace x64::injector;
Expand All @@ -231,9 +231,8 @@ bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d,
is_binary_po_per_mb_w_bcast, utils::one_of(ndims, 3, 4))
&& IMPLICATION(
is_binary_po_per_w_bcast, utils::one_of(ndims, 3, 4));
const cpu_isa_t isa = get_max_cpu_isa();
return supported_binary_bcast
&& injector::post_ops_ok(post_ops_ok_args_t(isa,
&& injector::post_ops_ok(post_ops_ok_args_t(isa_supported,
{binary, eltwise, sum}, post_ops, dst_d,
sum_at_pos_0_only, sum_requires_scale_one,
sum_requires_zp_zero, sum_requires_same_params,
Expand Down
12 changes: 10 additions & 2 deletions src/cpu/x64/jit_gemm_inner_product_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2021 Intel Corporation
* Copyright 2020-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,7 +30,15 @@ cpu::inner_product_utils::pp_kernel_t *jit_pp_kernel_create(size_t OC,
data_type_t bias_dt, data_type_t acc_dt, const memory_desc_t *dst_md,
bool skip_sum);

constexpr cpu_isa_t jit_pp_kernel_supported_isa() {
inline cpu_isa_t get_max_jit_pp_kernel_supported_isa() {
#define CASE(isa) \
do { \
if (mayiuse(isa)) return isa; \
} while (false)
CASE(avx512_core_bf16);
CASE(avx512_core);
CASE(avx2);
#undef CASE
return sse41;
}

Expand Down

0 comments on commit b76d4ca

Please sign in to comment.