Skip to content

Commit

Permalink
[CPU] Add RMSNorm jit implementation (openvinotoolkit#26147)
Browse files Browse the repository at this point in the history
### Details:
- *Jit implementation for RMSNorm, performance data on HBM with batch=2,
input=1024(unit is ms):*
 <html xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:dt="uuid:C2F41010-65B3-11d1-A29F-00AA00C14882"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=OneNote.File>
<meta name=Generator content="Microsoft OneNote 15">
</head>

<body lang=en-US style='font-family:Calibri;font-size:11.0pt'>
<!--StartFragment-->

<div style='direction:ltr'>


infer | w/o rms |   |   | w/ rms |   |  
-- | -- | -- | -- | -- | -- | --
  | all | power+reduce | updateNodes | all | rms | updateNodes
first   token | 1994.582 | 15.591 | 15.563 | 1991.487 | 5.49 | 13.768
second   token | 77.507 | 1.477 | 2.806 | 76.326 | 0.577 | 2.48



</div>

<!--EndFragment-->
</body>

</html>

 - *...*

### Tickets:
 - *[150097](https://jira.devtools.intel.com/browse/CVS-150097)*
 - *[136265](https://jira.devtools.intel.com/browse/CVS-136265)*
  • Loading branch information
luo-cheng2021 authored Aug 27, 2024
1 parent dc21bad commit 528f85c
Show file tree
Hide file tree
Showing 23 changed files with 1,014 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@

_OPENVINO_OP_REG(AUGRUCell, ov::op::internal)
_OPENVINO_OP_REG(AUGRUSequence, ov::op::internal)
_OPENVINO_OP_REG(RMS, ov::op::internal)
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace pass {
class RMSFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("RMSFusion", "0");
RMSFusion();
RMSFusion(bool force_tail_convert = true);
};

} // namespace pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ static std::function<bool(ov::Output<ov::Node>)> constant_value(const float targ
};
}

RMSFusion::RMSFusion() {
RMSFusion::RMSFusion(bool force_tail_convert) {
using namespace ov::pass::pattern;

// Detect RMS decomposition pattern
Expand Down Expand Up @@ -67,8 +67,11 @@ RMSFusion::RMSFusion() {
auto gamma = wrap_type<ov::op::v0::Constant>(type_matches(element::f32));
auto mul2 = wrap_type<ov::op::v1::Multiply>({gamma, mul1});

// compress RMS result
auto comp = wrap_type<ov::op::v0::Convert>({mul2});
std::shared_ptr<ov::Node> comp = mul2;
if (force_tail_convert) {
// compress RMS result
comp = wrap_type<ov::op::v0::Convert>({mul2});
}

ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,36 @@ TEST_F(TransformationTestsF, RMSNormFusionTest5) {
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rms}, ov::ParameterVector{input});
}
}

// no convert at the end of the subgraph
TEST_F(TransformationTestsF, RMSNormFusionTest6) {
{
auto input = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::PartialShape{-1, -1, 6});
auto power_const = ov::opset10::Constant::create(ov::element::f32, {}, {2.f});
auto power = std::make_shared<ov::opset10::Power>(input, power_const);
auto mean_axes = ov::opset10::Constant::create(ov::element::i64, ov::Shape{1}, {-1});
auto mean = std::make_shared<ov::opset10::ReduceMean>(power, mean_axes, true);
auto eps = ov::opset10::Constant::create(ov::element::f32, {}, {1e-5f});
auto add_eps = std::make_shared<ov::opset10::Add>(mean, eps);
auto sqrt = std::make_shared<ov::opset10::Sqrt>(add_eps);
auto div_const = ov::opset10::Constant::create(ov::element::f32, {}, {-1});
auto div = std::make_shared<ov::opset10::Power>(sqrt, div_const);
auto mul1 = std::make_shared<ov::opset10::Multiply>(input, div);
auto gamma = ov::opset10::Constant::create(ov::element::f32,
ov::Shape{6},
{0.029f, 0.014f, 0.003f, 0.013f, 0.015f, 0.009f});
auto mul2 = std::make_shared<ov::opset10::Multiply>(gamma, mul1);

model = std::make_shared<ov::Model>(ov::NodeVector{mul2}, ov::ParameterVector{input});
manager.register_pass<RMSFusion>(false);
}
{
auto input = std::make_shared<ov::opset10::Parameter>(ov::element::f32, ov::PartialShape{-1, -1, 6});
auto rms_const = ov::opset10::Constant::create(ov::element::f32,
ov::Shape{6},
{0.029f, 0.014f, 0.003f, 0.013f, 0.015f, 0.009f});
auto rms = std::make_shared<ov::op::internal::RMS>(input, rms_const, 1e-5f);

model_ref = std::make_shared<ov::Model>(ov::NodeVector{rms}, ov::ParameterVector{input});
}
}
2 changes: 2 additions & 0 deletions src/plugins/intel_cpu/src/cpu_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ static const TypeToNameMap& get_type_to_name_tbl() {
{"EmbeddingBagOffsets", Type::EmbeddingBagOffsets},
{"LLMMLP", Type::LLMMLP},
{"QKVProjection", Type::QKVProjection},
{"RMS", Type::RMS}
};
return type_to_name_tbl;
}
Expand Down Expand Up @@ -379,6 +380,7 @@ std::string NameFromType(const Type type) {
CASE(CausalMaskPreprocess);
CASE(LLMMLP);
CASE(QKVProjection);
CASE(RMS);
CASE(Unknown);
}
#undef CASE
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/cpu_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ enum class Type {
CausalMaskPreprocess,
LLMMLP,
QKVProjection,
RMS
};

enum class Algorithm {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ static inline void exp_ps_avx512(__m512& src) {
__m512 half = _mm512_loadu_ps(reinterpret_cast<const float*>(c_half)); // 0.5f
__m512 ln2f = _mm512_loadu_ps(reinterpret_cast<const float*>(c_ln2)); // ln(2)
__m512 one = _mm512_loadu_ps(reinterpret_cast<const float*>(c_1)); // 1.0f
__m512i exponent_bias = _mm512_load_epi32(c_bias); // 127
__m512i exponent_bias = _mm512_loadu_si512(c_bias); // 127
__m512 exp_pol1 = _mm512_loadu_ps(reinterpret_cast<const float*>(c_p1)); // p1 = 0.999999701f
__m512 exp_pol2 = _mm512_loadu_ps(reinterpret_cast<const float*>(c_p2)); // p2 = 0.499991506f
__m512 exp_pol3 = _mm512_loadu_ps(reinterpret_cast<const float*>(c_p3)); // p3 = 0.166676521f
Expand Down
241 changes: 241 additions & 0 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/rms_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "rms_kernel.hpp"

using namespace dnnl::impl::cpu::x64;
using namespace Xbyak;

namespace ov {
namespace intel_cpu {
namespace kernel {

#define GET_OFF(field) offsetof(jit_rms_call_args, field)

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::reduce_zmm_to_ymm(
const Xmm &acc, const Xmm &tmp) {
const Zmm zmm_acc(acc.getIdx());
const Ymm ymm_acc(acc.getIdx());
const Ymm ymm_to_acc(tmp.getIdx());
vextractf64x4(ymm_to_acc, zmm_acc, 1);
vaddps(ymm_acc, ymm_acc, ymm_to_acc);
}

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::reduce_ymm_to_xmm(
const Xmm &acc, const Xmm &tmp) {
const Ymm ymm_acc(acc.getIdx());
const Xmm xmm_acc(acc.getIdx());
const Xmm xmm_to_acc(tmp.getIdx());
vextractf128(xmm_to_acc, ymm_acc, 1);
vaddps(xmm_acc, xmm_acc, xmm_to_acc);
}

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::reduce_xmm_to_scalar(const Xmm &acc,
const Xmm &tmp, const std::size_t number_of_values_to_reduce) {
assert(number_of_values_to_reduce <= number_of_f32_in_xmm_);

const Xmm xmm_acc(acc.getIdx());
const Xmm ymm_to_acc(tmp.getIdx());

static constexpr int number_of_f32_to_move = number_of_f32_in_xmm_ - 1;
static constexpr uint8_t insertps_configuration[number_of_f32_to_move]
= {0b01001110, 0b10001110, 0b11001110};

for (std::size_t i = 0; i < number_of_values_to_reduce - 1; i++) {
vinsertps(ymm_to_acc, ymm_to_acc, xmm_acc, insertps_configuration[i]);
vaddss(xmm_acc, xmm_acc, ymm_to_acc);
}
}

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::reduce_ymm_to_scalar(
const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp1, const Xbyak::Xmm &tmp2,
const std::size_t number_of_values_to_reduce) {
assert(number_of_values_to_reduce <= number_of_f32_in_ymm_);

const Ymm ymm_acc(acc.getIdx());
const Xmm xmm_acc(acc.getIdx());
const Xmm xmm_tmp(tmp1.getIdx());
const Xmm xmm_acc_upper_half(tmp2.getIdx());

if (number_of_values_to_reduce == number_of_f32_in_ymm_) {
reduce_ymm_to_xmm(ymm_acc, xmm_tmp);
reduce_xmm_to_scalar(xmm_acc, xmm_tmp);
} else if (number_of_values_to_reduce > number_of_f32_in_xmm_) {
vextractf128(xmm_acc_upper_half, ymm_acc, 1);
reduce_xmm_to_scalar(xmm_acc, xmm_tmp);
reduce_xmm_to_scalar(xmm_acc_upper_half, xmm_tmp,
number_of_values_to_reduce - number_of_f32_in_xmm_);
vaddss(xmm_acc, xmm_acc, xmm_acc_upper_half);
} else if (number_of_values_to_reduce <= number_of_f32_in_xmm_) {
reduce_xmm_to_scalar(xmm_acc, xmm_tmp, number_of_values_to_reduce);
}
}

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::reduce_vmm_to_scalar(
const Xbyak::Xmm &acc, const Xbyak::Xmm &tmp1, const Xbyak::Xmm &tmp2,
const Xbyak::Xmm &tmp3, const std::size_t number_of_values_to_reduce) {
assert(number_of_values_to_reduce <= number_of_f32_in_zmm_);

const Zmm zmm_acc(acc.getIdx());
const Ymm ymm_acc(acc.getIdx());
const Xmm xmm_acc(acc.getIdx());
const Ymm ymm_acc_upper_half(tmp1.getIdx());
const Xmm xmm_acc_upper_half(tmp1.getIdx());
const Ymm ymm_tmp(tmp2.getIdx());
const Xmm xmm_tmp1(tmp2.getIdx());
const Xmm xmm_tmp2(tmp3.getIdx());

if (number_of_values_to_reduce == number_of_f32_in_zmm_) {
reduce_zmm_to_ymm(zmm_acc, ymm_tmp);
reduce_ymm_to_xmm(ymm_acc, xmm_tmp1);
reduce_xmm_to_scalar(xmm_acc, xmm_tmp1);
} else if (number_of_values_to_reduce > number_of_f32_in_ymm_) {
vextractf64x4(ymm_acc_upper_half, zmm_acc, 1);
reduce_ymm_to_scalar(ymm_acc, xmm_tmp1, xmm_tmp2);
reduce_ymm_to_scalar(ymm_acc_upper_half, xmm_tmp1, xmm_tmp2,
number_of_values_to_reduce - number_of_f32_in_ymm_);
vaddps(xmm_acc, xmm_acc, xmm_acc_upper_half);
} else if (number_of_values_to_reduce <= number_of_f32_in_ymm_) {
reduce_ymm_to_scalar(
ymm_acc, xmm_tmp1, xmm_tmp2, number_of_values_to_reduce);
}
}

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::generate() {
this->preamble();
mov(reg_src, ptr[abi_param1 + GET_OFF(src)]);
mov(reg_scale, ptr[abi_param1 + GET_OFF(scale)]);
mov(reg_dst, ptr[abi_param1 + GET_OFF(dst)]);
uni_vpxor(vmm_sum0, vmm_sum0, vmm_sum0);
uni_vpxor(vmm_sum1, vmm_sum1, vmm_sum1);
uni_vpxor(vmm_sum2, vmm_sum2, vmm_sum2);
uni_vpxor(vmm_sum3, vmm_sum3, vmm_sum3);
mov(reg_src_org, reg_src);

mov(reg_size, m_jcp.data_size / (vec_size * 4));
// x * 1/Sqrt(ReduceMean(x^2,axes)+eps) * gamma
// sum(x^2)
align(16);
Xbyak::Label loop_4reg;
L(loop_4reg);
{
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false);
vfmadd231ps(vmm_sum0, vmm_src, vmm_src);
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false, vec_size * m_jcp.src_prc.size() * 1);
vfmadd231ps(vmm_sum1, vmm_src, vmm_src);
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false, vec_size * m_jcp.src_prc.size() * 2);
vfmadd231ps(vmm_sum2, vmm_src, vmm_src);
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false, vec_size * m_jcp.src_prc.size() * 3);
vfmadd231ps(vmm_sum3, vmm_src, vmm_src);

add(reg_src, vec_size * m_jcp.src_prc.size() * 4);
dec(reg_size);
jnz(loop_4reg);
}
// 1 ~ 3 vmm
for (size_t i = m_jcp.data_size / (vec_size * 4) * 4; i < m_jcp.data_size / vec_size; i++) {
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false);
vfmadd231ps(vmm_sum0, vmm_src, vmm_src);
add(reg_src, vec_size * m_jcp.src_prc.size());
}
// tail
if (m_jcp.data_size % vec_size) {
load(vmm_src, reg_src, m_jcp.src_prc, m_jcp.data_size % vec_size, false);
vfmadd231ps(vmm_sum0, vmm_src, vmm_src);
}
vaddps(vmm_sum0, vmm_sum0, vmm_sum1);
vaddps(vmm_sum2, vmm_sum2, vmm_sum3);
vaddps(vmm_rsqrt, vmm_sum0, vmm_sum2);
reduce_vmm_to_scalar(vmm_rsqrt, vmm_sum0, vmm_sum1, vmm_sum3, vec_size);

// mean(x^2)
mov(reg_tmp.cvt32(), float2int(1.0f / m_jcp.data_size));
vmovd(xmm_tmp, reg_tmp.cvt32());
vmulss(xmm_rsqrt, xmm_rsqrt, xmm_tmp);
// mean(x^2)+eps
mov(reg_tmp.cvt32(), float2int(m_jcp.eps));
vmovd(xmm_tmp, reg_tmp.cvt32());
vaddss(xmm_rsqrt, xmm_rsqrt, xmm_tmp);
// rsqrt(mean(x^2)+eps)
vrsqrtss(xmm_rsqrt, xmm_rsqrt, xmm_rsqrt);

// x * rsqrt(mean(x^2)+eps)
if (m_jcp.scale_size == 1) {
// rsqrt(mean(x^2)+eps)
vmovd(xmm_tmp, ptr[reg_scale]);
vmulss(xmm_rsqrt, xmm_rsqrt, xmm_tmp);
}
vbroadcastss(vmm_rsqrt, xmm_rsqrt);
mov(reg_size, m_jcp.data_size / vec_size);
mov(reg_src, reg_src_org);
align(16);
Xbyak::Label loop_mul;
L(loop_mul);
{
load(vmm_src, reg_src, m_jcp.src_prc, vec_size, false);
vmulps(vmm_src, vmm_src, vmm_rsqrt);
if (m_jcp.scale_size != 1) {
load(vmm_tmp, reg_scale, ov::element::f32, vec_size, false);
vmulps(vmm_src, vmm_src, vmm_tmp);
}
store(reg_dst, vmm_src, m_jcp.dst_prc, vec_size);

add(reg_src, vec_size * m_jcp.src_prc.size());
if (m_jcp.scale_size != 1) {
add(reg_scale, vec_size * sizeof(float));
}
add(reg_dst, vec_size * m_jcp.dst_prc.size());
dec(reg_size);
jnz(loop_mul);
}
// tail
if (m_jcp.data_size % vec_size) {
load(vmm_src, reg_src, m_jcp.src_prc, m_jcp.data_size % vec_size, false);
vmulps(vmm_src, vmm_src, vmm_rsqrt);
if (m_jcp.scale_size != 1) {
load(vmm_tmp, reg_scale, ov::element::f32, m_jcp.data_size % vec_size, false);
vmulps(vmm_src, vmm_src, vmm_tmp);
}
store(reg_dst, vmm_src, m_jcp.dst_prc, m_jcp.data_size % vec_size);
}

this->postamble();
for (const auto& emitter : emitters) {
if (emitter.second)
emitter.second->emit_data();
}
}

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::load(const Vmm& vmm_dst, const Xbyak::Reg64& reg_src, ov::element::Type src_prc, const int& elt_num, bool fill, size_t offset) {
const auto seed = load_emitter_params(src_prc, ov::element::f32, elt_num, fill, "float_min").hash();
if (!emitters[seed]) {
emitters[seed].reset(new jit_load_emitter(this, isa, src_prc, ov::element::f32, elt_num, ov::element::f32, fill, "float_min"));
}
emitters[seed]->emit_code({static_cast<size_t>(reg_src.getIdx()), offset}, {static_cast<size_t>(vmm_dst.getIdx())},
pool_aux_vmm_idxs, pool_aux_gpr_idxs);
}

template <cpu_isa_t isa>
void jit_rms_kernel<isa>::store(const Xbyak::Reg64& reg_dst, const Vmm& vmm_src, ov::element::Type dst_prc, const int& elt_num, size_t offset) {
const auto seed = store_emitter_params(ov::element::f32, dst_prc, elt_num).hash();
if (!emitters[seed]) {
emitters[seed].reset(new jit_store_emitter(this, isa, ov::element::f32, dst_prc, elt_num));
}
emitters[seed]->emit_code({static_cast<size_t>(vmm_src.getIdx())}, {static_cast<size_t>(reg_dst.getIdx()), offset},
pool_aux_vmm_idxs, pool_aux_gpr_idxs);
}

template struct jit_rms_kernel<cpu_isa_t::avx512_core>;
template struct jit_rms_kernel<cpu_isa_t::avx2>;

} // namespace kernel
} // namespace intel_cpu
} // namespace ov
Loading

0 comments on commit 528f85c

Please sign in to comment.