diff --git a/src/cpu/aarch64/acl_deconvolution.hpp b/src/cpu/aarch64/acl_deconvolution.hpp index 97413c7ba65..ff1af13f7c9 100644 --- a/src/cpu/aarch64/acl_deconvolution.hpp +++ b/src/cpu/aarch64/acl_deconvolution.hpp @@ -193,8 +193,9 @@ struct acl_deconvolution_fwd_t : public primitive_t { } // Data layout - const auto acl_layout = is_nspc ? arm_compute::DataLayout::NHWC - : arm_compute::DataLayout::NCHW; + const arm_compute::DataLayout acl_layout = is_nspc + ? arm_compute::DataLayout::NHWC + : arm_compute::DataLayout::NCHW; acl_pd_conf.src_info = arm_compute::TensorInfo(is_nspc ? arm_compute::TensorShape(ic, iw, ih, mb) @@ -243,18 +244,15 @@ struct acl_deconvolution_fwd_t : public primitive_t { // padding is set for convolution. Otherwise, describe deconvolution as convolution of // upsampling input with stride = 1 and pad = 0. arm_compute::ConvolutionMethod conv_method; - arm_compute::TensorInfo *conv_src_info; + arm_compute::TensorInfo conv_src_info( + acl_pd_conf.src_info.clone()->set_is_resizable(true)); unsigned int pad_left = 0; unsigned int pad_right = 0; unsigned int pad_top = 0; unsigned int pad_bottom = 0; if (sh != 1 || sw != 1) { - arm_compute::TensorInfo scale_out_info( - acl_pd_conf.src_info.clone() - ->set_is_resizable(true) - .reset_padding() - .set_tensor_shape(scale_out_shape)); - conv_src_info = &scale_out_info; + conv_src_info.reset_padding(); + conv_src_info.set_tensor_shape(scale_out_shape); } else { // compute correct padding here pad_left = pr > pl ? pr - pl : 0; @@ -269,15 +267,13 @@ struct acl_deconvolution_fwd_t : public primitive_t { pad_right += deconv_pad_x / 2; pad_top += deconv_pad_y / 2; pad_bottom += deconv_pad_y / 2; - - conv_src_info = &acl_pd_conf.src_info; } const arm_compute::PadStrideInfo conv_info(1, 1, pad_left, pad_right, pad_top, pad_bottom, arm_compute::DimensionRoundingType::CEIL); conv_method = arm_compute::NEConvolutionLayer::get_convolution_method( - conv_src_info, &acl_pd_conf.wei_info, + &conv_src_info, &acl_pd_conf.wei_info, &acl_pd_conf.dst_info, conv_info, arm_compute::WeightsInfo(), arm_compute::Size2D(1U, 1U), diff --git a/src/cpu/aarch64/brgemm/brgemm.cpp b/src/cpu/aarch64/brgemm/brgemm.cpp index b7d962b9487..64f73814e30 100644 --- a/src/cpu/aarch64/brgemm/brgemm.cpp +++ b/src/cpu/aarch64/brgemm/brgemm.cpp @@ -1,6 +1,7 @@ /******************************************************************************* * Copyright 2020-2023 Intel Corporation * Copyright 2023-2024 FUJITSU LIMITED +* Copyright 2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -170,8 +171,8 @@ status_t brgemm_desc_init(brgemm_t *brg, cpu_isa_t isa, if (brg == nullptr) return status::invalid_arguments; if (transA || transB) return status::unimplemented; - brgemm_utils::init_brgemm_conf(brg, isa, type, dt_a, dt_b, layout, alpha, - beta, LDA, LDB, LDC, M, N, K, strides); + CHECK(brgemm_utils::init_brgemm_conf(brg, isa, type, dt_a, dt_b, layout, + alpha, beta, LDA, LDB, LDC, M, N, K, strides)); if (M <= 0 || N <= 0 || K <= 0) return status::invalid_arguments; bool ldx_check = (brg->is_row_major()) ? (LDA < K) @@ -197,8 +198,8 @@ status_t brdgmm_desc_init(brgemm_t *brg, cpu_isa_t isa, if (transA || layout != brgemm_row_major || alpha != 1.0f || beta != 0.f) return status::unimplemented; - brgemm_utils::init_brdgmm_conf(brg, isa, type, dt_a, dt_b, layout, alpha, - beta, LDA, LDC, M, N, strides); + CHECK(brgemm_utils::init_brdgmm_conf(brg, isa, type, dt_a, dt_b, layout, + alpha, beta, LDA, LDC, M, N, strides)); const bool ldx_check = (LDA < N || LDC < N); if (ldx_check) return status::invalid_arguments; diff --git a/src/cpu/aarch64/brgemm/brgemm_utils.cpp b/src/cpu/aarch64/brgemm/brgemm_utils.cpp index 109436db6bf..24e109584a8 100644 --- a/src/cpu/aarch64/brgemm/brgemm_utils.cpp +++ b/src/cpu/aarch64/brgemm/brgemm_utils.cpp @@ -1,6 +1,7 @@ /******************************************************************************* * Copyright 2022-2023 Intel Corporation * Copyright 2023-2024 FUJITSU LIMITED +* Copyright 2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,15 +48,18 @@ impl::data_type_t get_accum_datatype(brgemm_t *brg) { return brg->is_int8 ? data_type::s32 : data_type::f32; } -void init_kernel_datatype( +status_t init_kernel_datatype( brgemm_t *brg, impl::data_type_t dt_a, impl::data_type_t dt_b) { - assert(dt_a != data_type::undef && dt_b != data_type::undef); + if (dt_a != data_type::undef && dt_b != data_type::undef) + return status::unimplemented; brg->is_int8 = utils::one_of(dt_a, data_type::u8, data_type::s8) && utils::one_of(dt_b, data_type::u8, data_type::s8); brg->is_bf16 = (dt_a == data_type::bf16) && (dt_b == data_type::bf16); brg->is_f32 = (dt_a == data_type::f32) && (dt_b == data_type::f32); brg->is_f16 = utils::one_of(data_type::f16, dt_a, dt_b); - assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16); + if (brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16) + return status::unimplemented; + return status::success; } void init_common_conf(brgemm_t *brg, brgemm_batch_kind_t type, float alpha, @@ -88,7 +92,7 @@ void maybe_try_bf32(brgemm_t *brg) { // } -void set_isa_impl(brgemm_t *brg) { +status_t set_isa_impl(brgemm_t *brg) { auto is_isa_ok = [&](cpu_isa_t isa) { return mayiuse(isa) && // maybe IMPLICATION(brg->isa_user != isa_undef, @@ -96,19 +100,14 @@ void set_isa_impl(brgemm_t *brg) { one_of(brg->isa_user, isa_undef, isa); }; - if (brg->is_bf32) { - assert(!"unsupported case"); - } else if (brg->is_f32) { - brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(sve_512), sve_512, - is_isa_ok(sve_256), sve_256); - } else if (brg->is_bf16) { - assert(!"unsupported case"); - } else if (brg->is_f16) { - assert(!"unsupported case"); - } else if (brg->is_int8) { + if (brg->is_bf32 || brg->is_bf16 || brg->is_f16) { + return status::unimplemented; + } else if (brg->is_f32 || brg->is_int8) { brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(sve_512), sve_512, is_isa_ok(sve_256), sve_256); + return status::success; } + return status::success; } void set_brg_vmm(brgemm_t *brg) { @@ -187,7 +186,7 @@ inline size_t data_type_vnni_granularity(data_type_t data_type) { } status_t brgemm_blocking(brgemm_t *brg) { - set_isa_impl(brg); + CHECK(set_isa_impl(brg)); if (brg->isa_impl == isa_undef) return status::unimplemented; assert(!brg->is_dgmm); // should not be called from brdgmm set_brg_vmm(brg); @@ -296,10 +295,11 @@ status_t brdgmm_blocking(brgemm_t *brg) { return status::success; } -void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, - impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout, - float alpha, float beta, dim_t LDA, dim_t LDB, dim_t LDC, dim_t M, - dim_t N, dim_t K, const brgemm_strides_t *strides, bool is_bf32) { +status_t init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, + brgemm_batch_kind_t type, impl::data_type_t dt_a, + impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta, + dim_t LDA, dim_t LDB, dim_t LDC, dim_t M, dim_t N, dim_t K, + const brgemm_strides_t *strides, bool is_bf32) { init_common_conf(brg, type, alpha, beta, strides); @@ -307,7 +307,7 @@ void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, brg->dt_a = brg->is_row_major() ? dt_a : dt_b; brg->dt_b = brg->is_row_major() ? dt_b : dt_a; - init_kernel_datatype(brg, brg->dt_a, brg->dt_b); + CHECK(init_kernel_datatype(brg, brg->dt_a, brg->dt_b)); brg->dt_c = get_accum_datatype(brg); brg->dt_d = brg->dt_c; @@ -319,7 +319,7 @@ void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, brg->typesize_D = types::data_type_size(brg->dt_d); brg->isa_user = isa; - set_isa_impl(brg); + CHECK(set_isa_impl(brg)); brg->is_bf32 = false; brg->has_int8_vnni = true; @@ -352,11 +352,13 @@ void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, brg->rd_step = has_no_vnni_compute_instruction ? 1 : data_type_vnni_granularity(brg->dt_b); + return status::success; } -void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, - impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout, - float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N, +status_t init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, + brgemm_batch_kind_t type, impl::data_type_t dt_a, + impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta, + dim_t LDA, dim_t LDC, dim_t M, dim_t N, const brgemm_strides_t *strides) { init_common_conf(brg, type, alpha, beta, strides); @@ -365,7 +367,7 @@ void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, brg->dt_a = dt_a; brg->dt_b = dt_b; - init_kernel_datatype(brg, brg->dt_a, brg->dt_b); + CHECK(init_kernel_datatype(brg, brg->dt_a, brg->dt_b)); brg->dt_c = get_accum_datatype(brg); brg->dt_d = brg->dt_c; @@ -394,6 +396,7 @@ void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, brg->bcast_dim = M; brg->load_dim = N; + return status::success; } } // namespace brgemm_utils @@ -402,4 +405,4 @@ void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, } // namespace impl } // namespace dnnl -//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s \ No newline at end of file +//vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s diff --git a/src/cpu/aarch64/brgemm/brgemm_utils.hpp b/src/cpu/aarch64/brgemm/brgemm_utils.hpp index 485b5fde961..563a5d734ac 100644 --- a/src/cpu/aarch64/brgemm/brgemm_utils.hpp +++ b/src/cpu/aarch64/brgemm/brgemm_utils.hpp @@ -1,6 +1,7 @@ /******************************************************************************* * Copyright 2022 Intel Corporation * Copyright 2024 FUJITSU LIMITED +* Copyright 2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,20 +45,21 @@ status_t brdgmm_blocking(brgemm_t *brg); * having to depend on BRGeMM's API. An additional feature is that this * function can be modified depending on needs without requiring changes * at the API level. */ -void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, - impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout, - float alpha, float beta, dim_t LDA, dim_t LDB, dim_t LDC, dim_t M, - dim_t N, dim_t K, const brgemm_strides_t *strides = nullptr, - bool is_bf32 = false); +status_t init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, + brgemm_batch_kind_t type, impl::data_type_t dt_a, + impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta, + dim_t LDA, dim_t LDB, dim_t LDC, dim_t M, dim_t N, dim_t K, + const brgemm_strides_t *strides = nullptr, bool is_bf32 = false); /* The purpose of this function is to enable initialization of brgemm values * and then call additional functions like blocking heuristics without * having to depend on BRDGeMM's API. An additional feature is that this * function can be modified depending on needs without requiring changes * at the API level. */ -void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, - impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout, - float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N, +status_t init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, + brgemm_batch_kind_t type, impl::data_type_t dt_a, + impl::data_type_t dt_b, brgemm_layout_t layout, float alpha, float beta, + dim_t LDA, dim_t LDC, dim_t M, dim_t N, const brgemm_strides_t *strides = nullptr); } // namespace brgemm_utils diff --git a/src/cpu/aarch64/jit_brgemm_conv_utils.cpp b/src/cpu/aarch64/jit_brgemm_conv_utils.cpp index b93db5c423d..3b9d3422594 100644 --- a/src/cpu/aarch64/jit_brgemm_conv_utils.cpp +++ b/src/cpu/aarch64/jit_brgemm_conv_utils.cpp @@ -1,6 +1,7 @@ /******************************************************************************* * Copyright 2021-2023 Intel Corporation * Copyright 2024 FUJITSU LIMITED +* Copyright 2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -725,9 +726,9 @@ status_t brg_blocking_t::estimate_brgemm_ur() { const float alpha = 1.0; const float beta = 0.0; brgemm_t brg; - brgemm_utils::init_brgemm_conf(&brg, isa, brgemm_addr, src_dt, wei_dt, + CHECK(brgemm_utils::init_brgemm_conf(&brg, isa, brgemm_addr, src_dt, wei_dt, brgemm_row_major, alpha, beta, LDA, LDB, LDC, vM, vN, vK, nullptr, - is_bf32); + is_bf32)); CHECK(brgemm_utils::brgemm_blocking(&brg)); ur = brg.bd_block; ur_block = brg.bd_block; @@ -771,9 +772,9 @@ status_t brg_blocking_t::get_brgemm_ur( * rnd_up(oc, oc_block) * wei_dsz; const auto strides_ptr = (brg_type == brgemm_strd) ? &brg_strides : nullptr; - brgemm_utils::init_brgemm_conf(&brg, isa, brg_type, src_dt, - wei_dt, brgemm_row_major, alpha, vbeta, LDA, LDB, LDC, - vM, vN, vK, strides_ptr, is_bf32); + CHECK(brgemm_utils::init_brgemm_conf(&brg, isa, brg_type, + src_dt, wei_dt, brgemm_row_major, alpha, vbeta, LDA, + LDB, LDC, vM, vN, vK, strides_ptr, is_bf32)); CHECK(brgemm_utils::brgemm_blocking(&brg)); brgemm_attr_t brgattr; diff --git a/src/cpu/aarch64/matmul/brgemm_matmul.cpp b/src/cpu/aarch64/matmul/brgemm_matmul.cpp index bebdae12041..7ede5613803 100644 --- a/src/cpu/aarch64/matmul/brgemm_matmul.cpp +++ b/src/cpu/aarch64/matmul/brgemm_matmul.cpp @@ -1,6 +1,7 @@ /******************************************************************************* * Copyright 2021-2023 Intel Corporation * Copyright 2024 FUJITSU LIMITED +* Copyright 2024 Arm Ltd. and affiliates * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -642,7 +643,6 @@ void brgemm_matmul_t::copy_b_chunk_in_buffer( = (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx); ctx.current_K_start = k; ctx.current_K_iters = nstl::min(bgmmc.K_blk, bgmmc.K); - assert(isa == sve_512); (*copy_B_kernel_)(&ctx); } @@ -654,7 +654,6 @@ void brgemm_matmul_t::copy_b_chunk_in_buffer( = (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx); ctx.current_K_start = k; ctx.current_K_iters = bgmmc.K % bgmmc.K_blk; - assert(isa == sve_512); (*copy_B_kernel_)(&ctx); } } diff --git a/src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp b/src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp index bd9bc023eaf..2adec6e360f 100644 --- a/src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp +++ b/src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright 2021-2023 Intel Corporation +* Copyright 2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -129,7 +130,7 @@ bool post_ops_ok(brgemm_matmul_conf_t &bgmmc, const primitive_attr_t &attr, } status_t check_isa_with_datatype( - const cpu_isa_t isa, const brgemm_matmul_conf_utils_t &bm_conf_utils) { + const brgemm_matmul_conf_utils_t &bm_conf_utils) { if (bm_conf_utils.is_f32() && !bm_conf_utils.is_int8() && !bm_conf_utils.is_bf16() && !bm_conf_utils.is_f16() && !bm_conf_utils.is_int8()) @@ -732,8 +733,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, dst_d.format_kind() == format_kind::any, bias_md.format_kind == format_kind::any); - VCHECK_BG(check_isa_with_datatype(isa, bm_conf_utils), - VERBOSE_ISA_DT_MISMATCH); + VCHECK_BG(check_isa_with_datatype(bm_conf_utils), VERBOSE_ISA_DT_MISMATCH); bgmmc.a_dt_sz = bgmmc.tr_a_dt_sz = types::data_type_size(bgmmc.src_dt); bgmmc.b_dt_sz = bgmmc.tr_b_dt_sz = types::data_type_size(bgmmc.wei_dt); @@ -1107,4 +1107,4 @@ void init_scratchpad(memory_tracking::registrar_t &scratchpad, } // namespace aarch64 } // namespace cpu } // namespace impl -} // namespace dnnl \ No newline at end of file +} // namespace dnnl