Skip to content

Commit

Permalink
cpu: aarch64: Expand brgemm aarch64 unsupported cases handling mechan…
Browse files Browse the repository at this point in the history
…ism (#2099)
  • Loading branch information
Radu2k authored and mgouicem committed Dec 4, 2024
1 parent 4793296 commit 9a1dc92
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 61 deletions.
20 changes: 8 additions & 12 deletions src/cpu/aarch64/acl_deconvolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,9 @@ struct acl_deconvolution_fwd_t : public primitive_t {
}

// Data layout
const auto acl_layout = is_nspc ? arm_compute::DataLayout::NHWC
: arm_compute::DataLayout::NCHW;
const arm_compute::DataLayout acl_layout = is_nspc
? arm_compute::DataLayout::NHWC
: arm_compute::DataLayout::NCHW;

acl_pd_conf.src_info = arm_compute::TensorInfo(is_nspc
? arm_compute::TensorShape(ic, iw, ih, mb)
Expand Down Expand Up @@ -243,18 +244,15 @@ struct acl_deconvolution_fwd_t : public primitive_t {
// padding is set for convolution. Otherwise, describe deconvolution as convolution of
// upsampling input with stride = 1 and pad = 0.
arm_compute::ConvolutionMethod conv_method;
arm_compute::TensorInfo *conv_src_info;
arm_compute::TensorInfo conv_src_info(
acl_pd_conf.src_info.clone()->set_is_resizable(true));
unsigned int pad_left = 0;
unsigned int pad_right = 0;
unsigned int pad_top = 0;
unsigned int pad_bottom = 0;
if (sh != 1 || sw != 1) {
arm_compute::TensorInfo scale_out_info(
acl_pd_conf.src_info.clone()
->set_is_resizable(true)
.reset_padding()
.set_tensor_shape(scale_out_shape));
conv_src_info = &scale_out_info;
conv_src_info.reset_padding();
conv_src_info.set_tensor_shape(scale_out_shape);
} else {
// compute correct padding here
pad_left = pr > pl ? pr - pl : 0;
Expand All @@ -269,15 +267,13 @@ struct acl_deconvolution_fwd_t : public primitive_t {
pad_right += deconv_pad_x / 2;
pad_top += deconv_pad_y / 2;
pad_bottom += deconv_pad_y / 2;

conv_src_info = &acl_pd_conf.src_info;
}
const arm_compute::PadStrideInfo conv_info(1, 1, pad_left,
pad_right, pad_top, pad_bottom,
arm_compute::DimensionRoundingType::CEIL);
conv_method
= arm_compute::NEConvolutionLayer::get_convolution_method(
conv_src_info, &acl_pd_conf.wei_info,
&conv_src_info, &acl_pd_conf.wei_info,
&acl_pd_conf.dst_info, conv_info,
arm_compute::WeightsInfo(),
arm_compute::Size2D(1U, 1U),
Expand Down
9 changes: 5 additions & 4 deletions src/cpu/aarch64/brgemm/brgemm.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/*******************************************************************************
* Copyright 2020-2023 Intel Corporation
* Copyright 2023-2024 FUJITSU LIMITED
* Copyright 2024 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -170,8 +171,8 @@ status_t brgemm_desc_init(brgemm_t *brg, cpu_isa_t isa,
if (brg == nullptr) return status::invalid_arguments;
if (transA || transB) return status::unimplemented;

brgemm_utils::init_brgemm_conf(brg, isa, type, dt_a, dt_b, layout, alpha,
beta, LDA, LDB, LDC, M, N, K, strides);
CHECK(brgemm_utils::init_brgemm_conf(brg, isa, type, dt_a, dt_b, layout,
alpha, beta, LDA, LDB, LDC, M, N, K, strides));

if (M <= 0 || N <= 0 || K <= 0) return status::invalid_arguments;
bool ldx_check = (brg->is_row_major()) ? (LDA < K)
Expand All @@ -197,8 +198,8 @@ status_t brdgmm_desc_init(brgemm_t *brg, cpu_isa_t isa,
if (transA || layout != brgemm_row_major || alpha != 1.0f || beta != 0.f)
return status::unimplemented;

brgemm_utils::init_brdgmm_conf(brg, isa, type, dt_a, dt_b, layout, alpha,
beta, LDA, LDC, M, N, strides);
CHECK(brgemm_utils::init_brdgmm_conf(brg, isa, type, dt_a, dt_b, layout,
alpha, beta, LDA, LDC, M, N, strides));

const bool ldx_check = (LDA < N || LDC < N);
if (ldx_check) return status::invalid_arguments;
Expand Down
55 changes: 29 additions & 26 deletions src/cpu/aarch64/brgemm/brgemm_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/*******************************************************************************
* Copyright 2022-2023 Intel Corporation
* Copyright 2023-2024 FUJITSU LIMITED
* Copyright 2024 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -47,15 +48,18 @@ impl::data_type_t get_accum_datatype(brgemm_t *brg) {
return brg->is_int8 ? data_type::s32 : data_type::f32;
}

void init_kernel_datatype(
status_t init_kernel_datatype(
brgemm_t *brg, impl::data_type_t dt_a, impl::data_type_t dt_b) {
assert(dt_a != data_type::undef && dt_b != data_type::undef);
if (dt_a != data_type::undef && dt_b != data_type::undef)
return status::unimplemented;
brg->is_int8 = utils::one_of(dt_a, data_type::u8, data_type::s8)
&& utils::one_of(dt_b, data_type::u8, data_type::s8);
brg->is_bf16 = (dt_a == data_type::bf16) && (dt_b == data_type::bf16);
brg->is_f32 = (dt_a == data_type::f32) && (dt_b == data_type::f32);
brg->is_f16 = utils::one_of(data_type::f16, dt_a, dt_b);
assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16);
if (brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16)
return status::unimplemented;
return status::success;
}

void init_common_conf(brgemm_t *brg, brgemm_batch_kind_t type, float alpha,
Expand Down Expand Up @@ -88,27 +92,22 @@ void maybe_try_bf32(brgemm_t *brg) {
//
}

void set_isa_impl(brgemm_t *brg) {
status_t set_isa_impl(brgemm_t *brg) {
auto is_isa_ok = [&](cpu_isa_t isa) {
return mayiuse(isa) &&
// maybe IMPLICATION(brg->isa_user != isa_undef,
// is_superset(brg->isa_user, isa)), but the API is not clear.
one_of(brg->isa_user, isa_undef, isa);
};

if (brg->is_bf32) {
assert(!"unsupported case");
} else if (brg->is_f32) {
brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(sve_512), sve_512,
is_isa_ok(sve_256), sve_256);
} else if (brg->is_bf16) {
assert(!"unsupported case");
} else if (brg->is_f16) {
assert(!"unsupported case");
} else if (brg->is_int8) {
if (brg->is_bf32 || brg->is_bf16 || brg->is_f16) {
return status::unimplemented;
} else if (brg->is_f32 || brg->is_int8) {
brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(sve_512), sve_512,
is_isa_ok(sve_256), sve_256);
return status::success;
}
return status::success;
}

void set_brg_vmm(brgemm_t *brg) {
Expand Down Expand Up @@ -187,7 +186,7 @@ inline size_t data_type_vnni_granularity(data_type_t data_type) {
}
status_t brgemm_blocking(brgemm_t *brg) {

set_isa_impl(brg);
CHECK(set_isa_impl(brg));
if (brg->isa_impl == isa_undef) return status::unimplemented;
assert(!brg->is_dgmm); // should not be called from brdgmm
set_brg_vmm(brg);
Expand Down Expand Up @@ -296,18 +295,19 @@ status_t brdgmm_blocking(brgemm_t *brg) {
return status::success;
}

void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout,
float alpha, float beta, dim_t LDA, dim_t LDB, dim_t LDC, dim_t M,
dim_t N, dim_t K, const brgemm_strides_t *strides, bool is_bf32) {
status_t init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa,
brgemm_batch_kind_t type, impl::data_type_t dt_a,
impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta,
dim_t LDA, dim_t LDB, dim_t LDC, dim_t M, dim_t N, dim_t K,
const brgemm_strides_t *strides, bool is_bf32) {

init_common_conf(brg, type, alpha, beta, strides);

brg->layout = layout;

brg->dt_a = brg->is_row_major() ? dt_a : dt_b;
brg->dt_b = brg->is_row_major() ? dt_b : dt_a;
init_kernel_datatype(brg, brg->dt_a, brg->dt_b);
CHECK(init_kernel_datatype(brg, brg->dt_a, brg->dt_b));

brg->dt_c = get_accum_datatype(brg);
brg->dt_d = brg->dt_c;
Expand All @@ -319,7 +319,7 @@ void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
brg->typesize_D = types::data_type_size(brg->dt_d);

brg->isa_user = isa;
set_isa_impl(brg);
CHECK(set_isa_impl(brg));
brg->is_bf32 = false;

brg->has_int8_vnni = true;
Expand Down Expand Up @@ -352,11 +352,13 @@ void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
brg->rd_step = has_no_vnni_compute_instruction
? 1
: data_type_vnni_granularity(brg->dt_b);
return status::success;
}

void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout,
float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N,
status_t init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa,
brgemm_batch_kind_t type, impl::data_type_t dt_a,
impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta,
dim_t LDA, dim_t LDC, dim_t M, dim_t N,
const brgemm_strides_t *strides) {

init_common_conf(brg, type, alpha, beta, strides);
Expand All @@ -365,7 +367,7 @@ void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,

brg->dt_a = dt_a;
brg->dt_b = dt_b;
init_kernel_datatype(brg, brg->dt_a, brg->dt_b);
CHECK(init_kernel_datatype(brg, brg->dt_a, brg->dt_b));

brg->dt_c = get_accum_datatype(brg);
brg->dt_d = brg->dt_c;
Expand Down Expand Up @@ -394,6 +396,7 @@ void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,

brg->bcast_dim = M;
brg->load_dim = N;
return status::success;
}

} // namespace brgemm_utils
Expand All @@ -402,4 +405,4 @@ void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
} // namespace impl
} // namespace dnnl

//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
18 changes: 10 additions & 8 deletions src/cpu/aarch64/brgemm/brgemm_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/*******************************************************************************
* Copyright 2022 Intel Corporation
* Copyright 2024 FUJITSU LIMITED
* Copyright 2024 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -44,20 +45,21 @@ status_t brdgmm_blocking(brgemm_t *brg);
* having to depend on BRGeMM's API. An additional feature is that this
* function can be modified depending on needs without requiring changes
* at the API level. */
void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout,
float alpha, float beta, dim_t LDA, dim_t LDB, dim_t LDC, dim_t M,
dim_t N, dim_t K, const brgemm_strides_t *strides = nullptr,
bool is_bf32 = false);
status_t init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa,
brgemm_batch_kind_t type, impl::data_type_t dt_a,
impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta,
dim_t LDA, dim_t LDB, dim_t LDC, dim_t M, dim_t N, dim_t K,
const brgemm_strides_t *strides = nullptr, bool is_bf32 = false);

/* The purpose of this function is to enable initialization of brgemm values
* and then call additional functions like blocking heuristics without
* having to depend on BRDGeMM's API. An additional feature is that this
* function can be modified depending on needs without requiring changes
* at the API level. */
void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout,
float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N,
status_t init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa,
brgemm_batch_kind_t type, impl::data_type_t dt_a,
impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta,
dim_t LDA, dim_t LDC, dim_t M, dim_t N,
const brgemm_strides_t *strides = nullptr);

} // namespace brgemm_utils
Expand Down
11 changes: 6 additions & 5 deletions src/cpu/aarch64/jit_brgemm_conv_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/*******************************************************************************
* Copyright 2021-2023 Intel Corporation
* Copyright 2024 FUJITSU LIMITED
* Copyright 2024 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -725,9 +726,9 @@ status_t brg_blocking_t::estimate_brgemm_ur() {
const float alpha = 1.0;
const float beta = 0.0;
brgemm_t brg;
brgemm_utils::init_brgemm_conf(&brg, isa, brgemm_addr, src_dt, wei_dt,
CHECK(brgemm_utils::init_brgemm_conf(&brg, isa, brgemm_addr, src_dt, wei_dt,
brgemm_row_major, alpha, beta, LDA, LDB, LDC, vM, vN, vK, nullptr,
is_bf32);
is_bf32));
CHECK(brgemm_utils::brgemm_blocking(&brg));
ur = brg.bd_block;
ur_block = brg.bd_block;
Expand Down Expand Up @@ -771,9 +772,9 @@ status_t brg_blocking_t::get_brgemm_ur(
* rnd_up(oc, oc_block) * wei_dsz;
const auto strides_ptr
= (brg_type == brgemm_strd) ? &brg_strides : nullptr;
brgemm_utils::init_brgemm_conf(&brg, isa, brg_type, src_dt,
wei_dt, brgemm_row_major, alpha, vbeta, LDA, LDB, LDC,
vM, vN, vK, strides_ptr, is_bf32);
CHECK(brgemm_utils::init_brgemm_conf(&brg, isa, brg_type,
src_dt, wei_dt, brgemm_row_major, alpha, vbeta, LDA,
LDB, LDC, vM, vN, vK, strides_ptr, is_bf32));
CHECK(brgemm_utils::brgemm_blocking(&brg));

brgemm_attr_t brgattr;
Expand Down
3 changes: 1 addition & 2 deletions src/cpu/aarch64/matmul/brgemm_matmul.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/*******************************************************************************
* Copyright 2021-2023 Intel Corporation
* Copyright 2024 FUJITSU LIMITED
* Copyright 2024 Arm Ltd. and affiliates
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
Expand Down Expand Up @@ -642,7 +643,6 @@ void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
= (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx);
ctx.current_K_start = k;
ctx.current_K_iters = nstl::min(bgmmc.K_blk, bgmmc.K);
assert(isa == sve_512);
(*copy_B_kernel_)(&ctx);
}

Expand All @@ -654,7 +654,6 @@ void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
= (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx);
ctx.current_K_start = k;
ctx.current_K_iters = bgmmc.K % bgmmc.K_blk;
assert(isa == sve_512);
(*copy_B_kernel_)(&ctx);
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright 2021-2023 Intel Corporation
* Copyright 2024 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -129,7 +130,7 @@ bool post_ops_ok(brgemm_matmul_conf_t &bgmmc, const primitive_attr_t &attr,
}

status_t check_isa_with_datatype(
const cpu_isa_t isa, const brgemm_matmul_conf_utils_t &bm_conf_utils) {
const brgemm_matmul_conf_utils_t &bm_conf_utils) {
if (bm_conf_utils.is_f32() && !bm_conf_utils.is_int8()
&& !bm_conf_utils.is_bf16() && !bm_conf_utils.is_f16()
&& !bm_conf_utils.is_int8())
Expand Down Expand Up @@ -732,8 +733,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
dst_d.format_kind() == format_kind::any,
bias_md.format_kind == format_kind::any);

VCHECK_BG(check_isa_with_datatype(isa, bm_conf_utils),
VERBOSE_ISA_DT_MISMATCH);
VCHECK_BG(check_isa_with_datatype(bm_conf_utils), VERBOSE_ISA_DT_MISMATCH);

bgmmc.a_dt_sz = bgmmc.tr_a_dt_sz = types::data_type_size(bgmmc.src_dt);
bgmmc.b_dt_sz = bgmmc.tr_b_dt_sz = types::data_type_size(bgmmc.wei_dt);
Expand Down Expand Up @@ -1107,4 +1107,4 @@ void init_scratchpad(memory_tracking::registrar_t &scratchpad,
} // namespace aarch64
} // namespace cpu
} // namespace impl
} // namespace dnnl
} // namespace dnnl

0 comments on commit 9a1dc92

Please sign in to comment.