From b76d4cae333fc4e015d47eb737e10551daf30334 Mon Sep 17 00:00:00 2001 From: Tomasz Czeszun Date: Mon, 26 Jun 2023 14:06:31 -0700 Subject: [PATCH] cpu: x64: gemm: disable po for unsupported data types & ISAs --- src/cpu/gemm_inner_product_utils.cpp | 7 +++---- src/cpu/x64/jit_gemm_inner_product_utils.hpp | 12 ++++++++++-- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/cpu/gemm_inner_product_utils.cpp b/src/cpu/gemm_inner_product_utils.cpp index cce5ff87d63..22948d6f8c3 100644 --- a/src/cpu/gemm_inner_product_utils.cpp +++ b/src/cpu/gemm_inner_product_utils.cpp @@ -203,8 +203,8 @@ pp_kernel_t *pp_kernel_t::create(size_t OC, size_t MB, dim_t dst_mb_stride, bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d, const bcast_set_t &enabled_bcast_strategy) { #if DNNL_X64 - static constexpr auto isa_supported - = x64::inner_product_utils::jit_pp_kernel_supported_isa(); + const auto isa_supported + = x64::inner_product_utils::get_max_jit_pp_kernel_supported_isa(); using namespace cpu::x64; if (mayiuse(isa_supported)) { using namespace x64::injector; @@ -231,9 +231,8 @@ bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d, is_binary_po_per_mb_w_bcast, utils::one_of(ndims, 3, 4)) && IMPLICATION( is_binary_po_per_w_bcast, utils::one_of(ndims, 3, 4)); - const cpu_isa_t isa = get_max_cpu_isa(); return supported_binary_bcast - && injector::post_ops_ok(post_ops_ok_args_t(isa, + && injector::post_ops_ok(post_ops_ok_args_t(isa_supported, {binary, eltwise, sum}, post_ops, dst_d, sum_at_pos_0_only, sum_requires_scale_one, sum_requires_zp_zero, sum_requires_same_params, diff --git a/src/cpu/x64/jit_gemm_inner_product_utils.hpp b/src/cpu/x64/jit_gemm_inner_product_utils.hpp index 6f77e0d4f8a..f3400dd7d77 100644 --- a/src/cpu/x64/jit_gemm_inner_product_utils.hpp +++ b/src/cpu/x64/jit_gemm_inner_product_utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2021 Intel Corporation +* Copyright 2020-2023 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,7 +30,15 @@ cpu::inner_product_utils::pp_kernel_t *jit_pp_kernel_create(size_t OC, data_type_t bias_dt, data_type_t acc_dt, const memory_desc_t *dst_md, bool skip_sum); -constexpr cpu_isa_t jit_pp_kernel_supported_isa() { +inline cpu_isa_t get_max_jit_pp_kernel_supported_isa() { +#define CASE(isa) \ + do { \ + if (mayiuse(isa)) return isa; \ + } while (false) + CASE(avx512_core_bf16); + CASE(avx512_core); + CASE(avx2); +#undef CASE return sse41; }