diff --git a/src/cpu/aarch64/shuffle/jit_uni_shuffle.cpp b/src/cpu/aarch64/shuffle/jit_uni_shuffle.cpp index 4d8d15219eb..4d1cf145692 100644 --- a/src/cpu/aarch64/shuffle/jit_uni_shuffle.cpp +++ b/src/cpu/aarch64/shuffle/jit_uni_shuffle.cpp @@ -1,6 +1,6 @@ /******************************************************************************* -* Copyright 2020-2022 Intel Corporation -* Copyright 2022 FUJITSU LIMITED +* Copyright 2020-2024 Intel Corporation +* Copyright 2022-2024 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,6 +34,7 @@ template status_t jit_uni_shuffle_t::pd_t::init(engine_t *engine) { using namespace format_tag; using namespace data_type; + using namespace types; const memory_desc_wrapper src_d(is_fwd() ? src_md() : diff_src_md()); const memory_desc_wrapper dst_d(is_fwd() ? dst_md() : diff_dst_md()); @@ -58,7 +59,10 @@ status_t jit_uni_shuffle_t::pd_t::init(engine_t *engine) { if (blocked_format == format_tag::undef) return status::unimplemented; conf_.blk_size = src_d.blocking_desc().strides[ndims() - 1]; - conf_.simd_w = cpu_isa_traits::vlen / sizeof(float); + /* Because "ST1H { .S }, , [, .S, UXTW #1]" is used + to gather data for bf16, simd_w must be calculated + with sizeof(uint32_t). */ + conf_.simd_w = cpu_isa_traits::vlen / sizeof(uint32_t); const bool has_spatial = utils::one_of(ndims(), 3, 4, 5); const dim_t HW = H() * W(); diff --git a/src/cpu/aarch64/shuffle/jit_uni_shuffle_kernel.cpp b/src/cpu/aarch64/shuffle/jit_uni_shuffle_kernel.cpp index 72272616f5e..3ae88f5ad18 100644 --- a/src/cpu/aarch64/shuffle/jit_uni_shuffle_kernel.cpp +++ b/src/cpu/aarch64/shuffle/jit_uni_shuffle_kernel.cpp @@ -1,6 +1,6 @@ /******************************************************************************* -* Copyright 2021-2022 Intel Corporation -* Copyright 2022 FUJITSU LIMITED +* Copyright 2021-2024 Intel Corporation +* Copyright 2022-2024 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,9 +47,12 @@ jit_uni_shuffle_kernel_t::jit_uni_shuffle_kernel_t( template void jit_uni_shuffle_kernel_t::prepare_mask() { using namespace data_type; + using namespace types; if (conf_.simd_tail > 0) { - assert(utils::one_of(conf_.data_type, f32, s32)); - assert(conf_.simd_tail < isa_sveLen / sizeof(float)); + /* Because "ST1H { .S }, , [, .S, UXTW #1]" is used + to gather data for bf16, simd_tail must be evaluated + with sizeof(unsigned). */ + assert(conf_.simd_tail < isa_sveLen / sizeof(uint32_t)); index(vmm_tmp_.s, 0, 1); cmplt(k_tail_mask_.s, P_ALL_ONE / T_z, vmm_tmp_.s, conf_.simd_tail); } @@ -68,13 +71,17 @@ void jit_uni_shuffle_kernel_t::prepare_mask() {} template void jit_uni_shuffle_kernel_t::gather_data(const XReg ®_src_addr, const int indices_idx, const int data_idx, const bool is_tail) { - if (conf_.dt_size == sizeof(float)) { - const PReg &mask = is_tail ? k_tail_mask_ : k_full_mask_; + using namespace data_type; + const PReg &mask = is_tail ? k_tail_mask_ : k_full_mask_; + + if (utils::one_of(conf_.data_type, f32, s32)) { lsr(TRegS(indices_idx), TRegS(indices_idx), 2); ld1w(TRegS(data_idx), mask / T_z, ptr(reg_src_addr, TRegS(indices_idx), UXTW, 2)); - } else { - assert(!"unsupported emu_gather_data"); + } else if (conf_.data_type == bf16) { + lsr(TRegS(indices_idx), TRegS(indices_idx), 1); + ld1h(TRegS(data_idx), mask / T_z, + ptr(reg_src_addr, TRegS(indices_idx), UXTW, 1)); } } @@ -97,21 +104,26 @@ void jit_uni_shuffle_kernel_t::gather_data(const XReg &addr, template void jit_uni_shuffle_kernel_t::store_data(const int data_idx, const XReg ®_dst_addr, const int offset, const bool is_tail) { + using namespace data_type; const auto extend_for_padding = is_tail && padding_size_ + conf_.simd_tail >= conf_.simd_w; + const PReg &mask = is_tail ? k_tail_mask_ : P_ALL_ONE; + + add_imm(X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0); + if (extend_for_padding) { sel(vmm_tmp_.s, k_tail_mask_, TRegS(data_idx), vmm_zero_.s); - add_imm(X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0); - st1w(vmm_tmp_.s, P_ALL_ONE, ptr(X_DEFAULT_ADDR)); + if (utils::one_of(conf_.data_type, f32, s32)) + st1w(vmm_tmp_.s, P_ALL_ONE, ptr(X_DEFAULT_ADDR)); + else // bf16 + st1h(vmm_tmp_.s, P_ALL_ONE, ptr(X_DEFAULT_ADDR)); } else { - if (is_tail) { - add_imm(X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0); - st1w(TRegS(data_idx), k_tail_mask_, ptr(X_DEFAULT_ADDR)); - } else { - add_imm(X_DEFAULT_ADDR, reg_dst_addr, offset, X_TMP_0); - st1w(TRegS(data_idx), P_ALL_ONE, ptr(X_DEFAULT_ADDR)); - } + if (utils::one_of(conf_.data_type, f32, s32)) + st1w(TRegS(data_idx), mask, ptr(X_DEFAULT_ADDR)); + else // bf16 + st1h(TRegS(data_idx), mask, ptr(X_DEFAULT_ADDR)); } + append_zero_padding( reg_dst_, isa_sveLen > 128 ? extend_for_padding : false); }