diff --git a/src/cpu/aarch64/acl_inner_product.hpp b/src/cpu/aarch64/acl_inner_product.hpp index a2be164f09c..af4bc39a867 100644 --- a/src/cpu/aarch64/acl_inner_product.hpp +++ b/src/cpu/aarch64/acl_inner_product.hpp @@ -93,20 +93,25 @@ struct acl_inner_product_fwd_t : public primitive_t { status_t init(engine_t *engine) { using namespace data_type; + const format_kind_t weights_format_kind_received + = weights_md_.format_kind; const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef) && attr()->has_default_values( primitive_attr_t::skip_mask_t::post_ops, f16); const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef) && attr()->has_default_values( primitive_attr_t::skip_mask_t::post_ops, f32); + const bool is_weights_md_format_ok + = utils::one_of(weights_format_kind_received, + format_kind::any, format_kind::blocked); const bool ok = is_fwd() && !has_zero_dim_memory() && utils::one_of(true, is_fp16_ok, is_fp32_ok) - && weights_md_.format_kind == format_kind::any - && set_default_params() == status::success; + && is_weights_md_format_ok + && set_default_params(true) == status::success; if (!ok) return status::unimplemented; - CHECK(init_conf_ip(engine)); + CHECK(init_conf_ip(engine, weights_format_kind_received)); return status::success; } @@ -115,7 +120,8 @@ struct acl_inner_product_fwd_t : public primitive_t { acl_post_ops_t post_ops; - status_t init_conf_ip(engine_t *engine) { + status_t init_conf_ip( + engine_t *engine, format_kind_t weights_format_kind_received) { ACL_CHECK_SUPPORT(src_md()->ndims != weights_md()->ndims, "source and weights dimensions must match"); @@ -257,10 +263,19 @@ struct acl_inner_product_fwd_t : public primitive_t { return status::unimplemented; } + const memory_desc_t weights_md_received = weights_md_; acl_utils::reorder_to_weight_format(aip.wei_tensor_info, weights_md_, expected_weight_format, inner_dim, o_dim, remaining_dims, {}); + ACL_CHECK_SUPPORT( + (weights_format_kind_received == format_kind::blocked) + && !(dnnl_memory_desc_equal( + &weights_md_received, &weights_md_)), + "specific blocked format not supported by ACL, use " + "format_kind_t::any to find a supported blocked format for " + "your platform"); + // clang-format off // Validate fully connected layer manually to check for return status