Skip to content

Commit

Permalink
cpu: x64: matmul: update bf16 small shape dispatch condition
Browse files Browse the repository at this point in the history
  • Loading branch information
karasjoh000 authored and ankalinin committed Dec 27, 2023
1 parent fa43640 commit c0ae38c
Showing 1 changed file with 23 additions and 35 deletions.
58 changes: 23 additions & 35 deletions src/cpu/x64/matmul/brgemm_matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,6 @@ using namespace dnnl::impl::utils;
using namespace data_type;
using namespace format_tag;

// Condition generated by decision tree classifier.
bool amx_xf16_is_small_shape(dim_t batch, dim_t M, dim_t K, dim_t N) {
const float b = static_cast<float>(batch);
const float m = static_cast<float>(M);
const float k = static_cast<float>(K);
const float n = static_cast<float>(N);
return (k <= 28 && n / k > 39.7) || (k <= 28 && n / k <= 8)
|| (k > 28 && m * k <= 248.0 && n > 52.0)
|| ((m * k * n) / b > 4862 && b * k >= 23 && n / k > 8
&& m * k * n <= 60817408);
}

int get_default_n_block(format_tag_t matrix_b_tag) {
// Note: consider using weights mem_descriptor 'inner_blks' to
// return B's inner block for non-default cases.
Expand Down Expand Up @@ -1045,29 +1033,6 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
bgmmc.bcast_B_desc.set_params(
weights_d.dims(), dst_d.dims(), bgmmc.batch_ndims, bgmmc.batch);

// Dispatch small shapes to VNNI for better performance
const bool runtime_dims
= bgmmc.is_runtime_M || bgmmc.is_runtime_N || bgmmc.is_runtime_K;

bool is_small_shapes = bgmmc.is_amx && !runtime_dims;

// Disable 'small_shape' heuristic for amx_fp16 until it is validated with
// performance measurements.
is_small_shapes = is_small_shapes && (bgmmc.isa != avx512_core_amx_fp16);

if (bm_conf_utils.is_bf16() || bm_conf_utils.is_f16()) {
is_small_shapes = is_small_shapes
&& amx_xf16_is_small_shape(
bgmmc.batch, bgmmc.M, bgmmc.K, bgmmc.N);
} else {
is_small_shapes = is_small_shapes && bgmmc.ndims < 3
&& ((bgmmc.M == 1 && bgmmc.K == 256)
|| (bgmmc.M <= 32 && bgmmc.M * bgmmc.N <= 256)
|| bgmmc.K <= 16);
}

VCONDCHECK_BG(!is_small_shapes, VERBOSE_SMALL_SHAPES);

// required granularity for k dimension
bgmmc.required_k_granularity
= bgmmc.is_amx ? data_type_vnni_granularity(bgmmc.wei_dt) : 1;
Expand Down Expand Up @@ -1214,6 +1179,29 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,

init_aux_values(bgmmc, src_d, weights_d, dst_d);

// Dispatch small shapes to VNNI for better performance
const bool runtime_dims
= bgmmc.is_runtime_M || bgmmc.is_runtime_N || bgmmc.is_runtime_K;

bool is_small_shapes = bgmmc.is_amx && !runtime_dims;

// Disable 'small_shape' heuristic for amx_fp16 until it is validated with
// performance measurements.
is_small_shapes = is_small_shapes && (bgmmc.isa != avx512_core_amx_fp16);

if (bm_conf_utils.is_bf16() || bm_conf_utils.is_f16()) {
// empirical observation for performance breakpoint between amx and vnni bf16/f16
const dim_t buffer_a_chunk_sz_limit = 126;
is_small_shapes = is_small_shapes
&& bgmmc.buffer_a_chunk_sz <= buffer_a_chunk_sz_limit;
} else {
is_small_shapes = is_small_shapes && bgmmc.ndims < 3
&& ((bgmmc.M == 1 && bgmmc.K == 256)
|| (bgmmc.M <= 32 && bgmmc.M * bgmmc.N <= 256)
|| bgmmc.K <= 16);
}
VCONDCHECK_BG(!is_small_shapes, VERBOSE_SMALL_SHAPES);

return status::success;
}

Expand Down

0 comments on commit c0ae38c

Please sign in to comment.