From 6c36f544decc551e480e2272a965f85c2747c0f9 Mon Sep 17 00:00:00 2001
From: xiguiw <111278656+xiguiw@users.noreply.github.com>
Date: Mon, 4 Mar 2024 17:27:14 +0800
Subject: [PATCH] resubmit "Implement the YaRN rop scaling feature" (#147)

---
 neural_speed/core/ne_layers.c | 114 ++++++++++++++++++++++++++++------
 neural_speed/core/ne_layers.h |  14 +++++
 2 files changed, 109 insertions(+), 19 deletions(-)

diff --git a/neural_speed/core/ne_layers.c b/neural_speed/core/ne_layers.c
index 791c94076..a493f2b47 100644
--- a/neural_speed/core/ne_layers.c
+++ b/neural_speed/core/ne_layers.c
@@ -3107,7 +3107,8 @@ struct ne_tensor* ne_soft_max_inplace(struct ne_context* ctx, struct ne_tensor*
 
 struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
                                int prompt_size, bool inplace, int n_keep, struct ne_tensor* cossin, int* n_padding,
-                               bool padding_left, float freq_base, float freq_scale) {
+                               bool padding_left, float freq_base, float freq_scale, int yarn_orig_ctx,
+                               float ext_factor, float attn_factor, float beta_fast, float beta_slow) {
   NE_ASSERT(n_past >= 0 || n_keep >= 0);
   NE_ASSERT(padding_left);
   bool is_node = false;
@@ -3147,7 +3148,9 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int
 
   ne_scratch_load(ctx);
 
-  float params[] = {freq_base, freq_scale};
+  /* what the difference of setting parameters in b->data and in op_parameters */
+  /* float and int are in different data ?? */
+  float params[] = {freq_base, freq_scale, (float)yarn_orig_ctx, ext_factor, attn_factor, beta_fast, beta_slow};
   ne_set_op_params(result, &params, sizeof(params));
 
   result->op = NE_OP_ROPE;
@@ -3161,19 +3164,36 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int
 
 struct ne_tensor* ne_rope(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
                           int prompt_size, float freq_base, float freq_scale) {
-  return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, NULL, true, freq_base, freq_scale);
+  return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, NULL, true, freq_base, freq_scale, 0,
+                      0.0f, 1.0f, 0.0f, 0.0f);
 }
 
 struct ne_tensor* ne_rope_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
                                   int prompt_size, float freq_base, float freq_scale) {
-  return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true, freq_base, freq_scale);
+  return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true, freq_base, freq_scale, 0,
+                      0.0f, 1.0f, 0.0f, 0.0f);
 }
 
 struct ne_tensor* ne_rope_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift, int n_dims, int mode,
                                         int prompt_size, int n_keep, struct ne_tensor* cossin, float freq_base,
                                         float freq_scale) {
   return ne_rope_impl(ctx, a, n_shift, n_dims, mode, prompt_size, true, n_keep, cossin, NULL, true, freq_base,
-                      freq_scale);
+                      freq_scale, 0, 0.0f, 1.0f, 0.0f, 0.0f);
+}
+
+struct ne_tensor* ne_rope_custom_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
+                                         int prompt_size, float freq_base, float freq_scale, int yarn_orig_ctx,
+                                         float ext_factor, float attn_factor, float beta_fast, float beta_slow) {
+  return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true, freq_base, freq_scale,
+                      yarn_orig_ctx, ext_factor, attn_factor, beta_fast, beta_slow);
+}
+
+struct ne_tensor* ne_rope_custom_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift, int n_dims,
+                                               int mode, int prompt_size, int n_keep, struct ne_tensor* cossin,
+                                               float freq_base, float freq_scale, int yarn_orig_ctx, float ext_factor,
+                                               float attn_factor, float beta_fast, float beta_slow) {
+  return ne_rope_impl(ctx, a, n_shift, n_dims, mode, prompt_size, true, n_keep, cossin, NULL, true, freq_base,
+                      freq_scale, yarn_orig_ctx, ext_factor, attn_factor, beta_fast, beta_slow);
 }
 
 // ne_rope_back
@@ -3211,14 +3231,14 @@ struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int
 struct ne_tensor* ne_rope_with_padding(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
                                        int prompt_size, int* n_padding, float freq_base, float freq_scale) {
   return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, n_padding, true, freq_base,
-                      freq_scale);
+                      freq_scale, 0, 0.0f, 1.0f, 0.0f, 0.0f);
 }
 
 struct ne_tensor* ne_rope_with_padding_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims,
                                                int mode, int prompt_size, int* n_padding, float freq_base,
                                                float freq_scale) {
-  return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, n_padding, true, freq_base,
-                      freq_scale);
+  return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, n_padding, true, freq_base, freq_scale,
+                      0, 0.0f, 1.0f, 0.0f, 0.0f);
 }
 
 // ne_alibi
@@ -8709,6 +8729,45 @@ static void ne_compute_forward_clamp(const struct ne_compute_params* params, con
   }
 }
 
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+  const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
+  return 1.0 - MIN(1.0, MAX(0.0, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor,
+                      float mscale, float* cos_theta, float* sin_theta) {
+  // Get n-d rotational scaling corrected for extrapolation
+  float theta_interp = freq_scale * theta_extrap;
+  float theta = theta_interp;
+  if (ext_factor != 0.0f) {
+    float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+    theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+    // Get n-d magnitude scaling corrected for interpolation
+    mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+  }
+  *cos_theta = cosf(theta) * mscale;
+  *sin_theta = sinf(theta) * mscale;
+}
+
+#ifndef NE_PI
+#define NE_PI (3.14159265358979323846)
+#endif
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
+  return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)NE_PI)) / (2 * logf(base));
+}
+
+void ggml_rope_yarn_corr_dims(int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow,
+                              float dims[2]) {
+  // start and end correction dims
+  dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base)));
+  dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base)));
+}
+
 // ne_compute_forward_rope
 #define NE_TENSOR_UNARY_OP_LOCALS           \
   NE_TENSOR_LOCALS(int64_t, ne0, src0, ne); \
@@ -8721,12 +8780,18 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
   if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) {
     return;
   }
+
   const int bs = src0->ne[3];
   NE_ASSERT(src1->type == NE_TYPE_I32);
   NE_ASSERT(ne_nelements(src1) == 5 + bs);  // 5 + bs params
 
   const float freq_base = ((float*)(dst->op_params))[0];
   const float freq_scale = 1 / ((float*)(dst->op_params))[1];
+  const int n_orig_ctx = (int)((float*)(dst->op_params))[2];
+  const float ext_factor = ((float*)(dst->op_params))[3];
+  const float attn_factor = ((float*)(dst->op_params))[4];
+  const float beta_fast = ((float*)(dst->op_params))[5];
+  const float beta_slow = ((float*)(dst->op_params))[6];
 
   const int64_t n_past = ((int32_t*)src1->data)[ROPE_NPAST_IDX];
   const int64_t n_dims = ((int32_t*)src1->data)[ROPE_NDIMS_IDX];
@@ -8759,11 +8824,15 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
   int ir = 0;
 
   const float theta_scale = powf(freq_base, -2.0f / n_dims);
+  const float inv_ndims = -1.f / n_dims;
+  float corr_dims[2];
+  ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
 
   const bool skip = mode & 1;
   const bool is_neox = mode & 2;
   const bool is_glm = mode & 4;
   const bool is_shift = n_keep >= 0;
+  const bool use_yarn = ((mode & 0x8) != 0);
   NE_ASSERT(("RoPE shift not supported!", !is_shift));
 
   NE_ASSERT(ne3 == bs);
@@ -8774,21 +8843,21 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
         if (ir++ < ir0) continue;
         if (ir > ir1) break;
 
-        float theta = freq_scale * (float)p;
+        float theta_base = (float)p;
 
         // only for glm when mode == 4
         if (is_glm) {
           const int64_t n_padding = ((int32_t*)src1->data)[ROPE_PARAMS_NUM + i3];
           // position ids
-          theta = MIN(MAX(p - n_padding, 0), prompt_size - 2 - n_padding);
+          theta_base = MIN(MAX(p - n_padding, 0), prompt_size - 2 - n_padding);
           float block_theta = MAX(p - (prompt_size - 2), 0);
           for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
-            const float cos_theta = cosf(theta);
-            const float sin_theta = sinf(theta);
+            const float cos_theta = cosf(theta_base);
+            const float sin_theta = sinf(theta_base);
             const float cos_block_theta = cosf(block_theta);
             const float sin_block_theta = sinf(block_theta);
 
-            theta *= theta_scale;
+            theta_base *= theta_scale;
             block_theta *= theta_scale;
 
             const float* const src = (float*)((char*)src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01 + i0 * nb00);
@@ -8805,11 +8874,12 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
             dst_data[n_dims / 2 * 3] = x2 * sin_block_theta + x3 * cos_block_theta;
           }
         } else if (!is_neox) {
+          // printf("theta_base = %ld, freq_scale %.4f, ne0 %d\n", p, freq_scale, ne0);
           for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-            const float cos_theta = cosf(theta);
-            const float sin_theta = sinf(theta);
+            float cos_theta, sin_theta;
+            rope_yarn(theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
-            theta *= theta_scale;  // theta = i2 * theta_scale^(i0/2)
+            theta_base *= theta_scale;
 
             const float* const src = (float*)((char*)src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01 + i0 * nb00);
             float* dst_data = (float*)((char*)dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
@@ -8824,12 +8894,18 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
           // TODO: this is probably wrong, but I can't figure it out ..
           // ref:
           // https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
+          theta_base = theta_base * freq_scale;
+
           for (int64_t ib = 0; ib < ne0 / n_dims; ++ib) {
             for (int64_t ic = 0; ic < n_dims; ic += 2) {
-              const float cos_theta = cosf(theta);
-              const float sin_theta = sinf(theta);
+              // simplified from `(ib * n_dims + ic) * inv_ndims`
+              float cur_rot = inv_ndims * ic - ib;
 
-              theta *= theta_scale;
+              float cos_theta, sin_theta;
+              rope_yarn(theta_base, freq_scale, corr_dims, (int)cur_rot, ext_factor, attn_factor, &cos_theta,
+                        &sin_theta);
+
+              theta_base *= theta_scale;
 
               const int64_t i0 = ib * n_dims + ic / 2;
 
diff --git a/neural_speed/core/ne_layers.h b/neural_speed/core/ne_layers.h
index 21cd48d44..d95f16dd5 100644
--- a/neural_speed/core/ne_layers.h
+++ b/neural_speed/core/ne_layers.h
@@ -414,6 +414,20 @@ NE_API struct ne_tensor* ne_rope_shift_inplace(struct ne_context* ctx, struct ne
                                                int mode, int prompt_size, int n_keep, struct ne_tensor* cossin,
                                                float freq_base, float freq_scale);
 
+// in-place, returns view(a)
+NE_API struct ne_tensor* ne_rope_custom_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims,
+                                                int mode, int prompt_size, float freq_base, float freq_scale,
+                                                int yarn_orig_ctx, float ext_factor, float attn_factor, float beta_fast,
+                                                float beta_slow);
+
+// shift all tokens by a give p (n_shift)
+// Optionally give a 1d tensor of precomputed interleaved cos/sin value of n_shift*scale^k for k \in [0, n_dims)
+NE_API struct ne_tensor* ne_rope_custom_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift,
+                                                      int n_dims, int mode, int prompt_size, int n_keep,
+                                                      struct ne_tensor* cossin, float freq_base, float freq_scale,
+                                                      int yarn_orig_ctx, float ext_factor, float attn_factor,
+                                                      float beta_fast, float beta_slow);
+
 // rotary position embedding backward, i.e compute dx from dy
 // a - dy
 NE_API struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode);