From 2968c8948225d18c9df19c94534ff7dc4343700c Mon Sep 17 00:00:00 2001 From: "Xuxin, Zeng" Date: Tue, 9 Apr 2024 12:41:26 -0700 Subject: [PATCH] cpu: x64: dispatch the shape that requires large/small cache to jit on avx2 --- src/cpu/x64/jit_brgemm_conv_utils.cpp | 42 +++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/src/cpu/x64/jit_brgemm_conv_utils.cpp b/src/cpu/x64/jit_brgemm_conv_utils.cpp index e61d0aee90e..71449db3ca1 100644 --- a/src/cpu/x64/jit_brgemm_conv_utils.cpp +++ b/src/cpu/x64/jit_brgemm_conv_utils.cpp @@ -2189,6 +2189,27 @@ status_t init_conf(jit_brgemm_conv_conf_t &jcp, bool use_inversion, !(jcp.ow_block == 0 || jcp.ic_block == 0 || jcp.oc_block == 0), VERBOSE_BLOCKING_FAIL); + // Dispatch the shape that requires large or small cache to JIT + // for better performance on AVX2 + // The threshold is empirical + const size_t w_cache_sz + = static_cast(jcp.src_dsz) * jcp.ic_block * jcp.iwp + + jcp.dst_dsz * jcp.ow * jcp.oc_block; + const size_t wei_cache_sz = static_cast(jcp.wei_dsz) * jcp.kd_block + * jcp.kh_block * jcp.kw_block * jcp.ic_block * jcp.oc_block; + const size_t nthr_work_amount + = div_up(static_cast(jcp.mb) * jcp.ngroups * jcp.nb_od + * jcp.nb_oh * jcp.nb_ow * jcp.nb_oc, + jcp.nthr); + const bool req_large_cache = jcp.oc >= 256 && jcp.ic >= 256 + && nstl::max(w_cache_sz, wei_cache_sz) * nthr_work_amount + >= brg_blocking_t::L2 * 10; + const bool req_small_cache = jcp.ic <= jcp.acc_simd_w + && nstl::max(w_cache_sz, wei_cache_sz) <= 2048; + VDISPATCH_CONV_IC( + !((req_large_cache || req_small_cache) && jcp.isa == avx2), + "Dispatch the shape that requires large/small cache size to jit"); + // to avoid cache concurrent write access from different threads size_t sc_size = sizeof(brgemm_batch_element_t); jcp.adjusted_batch_size @@ -2403,6 +2424,27 @@ status_t init_1x1_conf(jit_brgemm_conv_conf_t &jcp, cpu_isa_t isa, VDISPATCH_CONV_IC( !(jcp.ic_block == 0 || jcp.oc_block == 0), VERBOSE_BLOCKING_FAIL); + // Dispatch the shape that requires large or small cache to JIT + // for better performance on AVX2 + // The threshold is empirical + const size_t w_cache_sz + = static_cast(jcp.src_dsz) * jcp.ic_block * jcp.iwp + + jcp.dst_dsz * jcp.ow * jcp.oc_block; + const size_t wei_cache_sz = static_cast(jcp.wei_dsz) * jcp.kd_block + * jcp.kh_block * jcp.kw_block * jcp.ic_block * jcp.oc_block; + const size_t nthr_work_amount + = div_up(static_cast(jcp.mb) * jcp.ngroups * jcp.nb_od + * jcp.nb_oh * jcp.nb_ow * jcp.nb_oc, + jcp.nthr); + const bool req_large_cache = jcp.oc >= 256 && jcp.ic >= 256 + && nstl::max(w_cache_sz, wei_cache_sz) * nthr_work_amount + >= brg_blocking_t::L2 * 10; + const bool req_small_cache = jcp.ic <= jcp.acc_simd_w + && nstl::max(w_cache_sz, wei_cache_sz) <= 2048; + VDISPATCH_CONV_IC( + !((req_large_cache || req_small_cache) && jcp.isa == avx2), + "Dispatch the shapes that requie large/small cache size to jit"); + // Configure matrix sizes if (best_brgb.is_os_blocking) {