Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
resubmit "Implement the YaRN rop scaling feature" (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiguiw authored Mar 4, 2024
1 parent 96dc559 commit 6c36f54
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 19 deletions.
114 changes: 95 additions & 19 deletions neural_speed/core/ne_layers.c
Original file line number Diff line number Diff line change
Expand Up @@ -3107,7 +3107,8 @@ struct ne_tensor* ne_soft_max_inplace(struct ne_context* ctx, struct ne_tensor*

struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
int prompt_size, bool inplace, int n_keep, struct ne_tensor* cossin, int* n_padding,
bool padding_left, float freq_base, float freq_scale) {
bool padding_left, float freq_base, float freq_scale, int yarn_orig_ctx,
float ext_factor, float attn_factor, float beta_fast, float beta_slow) {
NE_ASSERT(n_past >= 0 || n_keep >= 0);
NE_ASSERT(padding_left);
bool is_node = false;
Expand Down Expand Up @@ -3147,7 +3148,9 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int

ne_scratch_load(ctx);

float params[] = {freq_base, freq_scale};
/* what the difference of setting parameters in b->data and in op_parameters */
/* float and int are in different data ?? */
float params[] = {freq_base, freq_scale, (float)yarn_orig_ctx, ext_factor, attn_factor, beta_fast, beta_slow};
ne_set_op_params(result, &params, sizeof(params));

result->op = NE_OP_ROPE;
Expand All @@ -3161,19 +3164,36 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int

struct ne_tensor* ne_rope(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
int prompt_size, float freq_base, float freq_scale) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, NULL, true, freq_base, freq_scale);
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, NULL, true, freq_base, freq_scale, 0,
0.0f, 1.0f, 0.0f, 0.0f);
}

struct ne_tensor* ne_rope_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
int prompt_size, float freq_base, float freq_scale) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true, freq_base, freq_scale);
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true, freq_base, freq_scale, 0,
0.0f, 1.0f, 0.0f, 0.0f);
}

struct ne_tensor* ne_rope_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift, int n_dims, int mode,
int prompt_size, int n_keep, struct ne_tensor* cossin, float freq_base,
float freq_scale) {
return ne_rope_impl(ctx, a, n_shift, n_dims, mode, prompt_size, true, n_keep, cossin, NULL, true, freq_base,
freq_scale);
freq_scale, 0, 0.0f, 1.0f, 0.0f, 0.0f);
}

struct ne_tensor* ne_rope_custom_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
int prompt_size, float freq_base, float freq_scale, int yarn_orig_ctx,
float ext_factor, float attn_factor, float beta_fast, float beta_slow) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, NULL, true, freq_base, freq_scale,
yarn_orig_ctx, ext_factor, attn_factor, beta_fast, beta_slow);
}

struct ne_tensor* ne_rope_custom_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift, int n_dims,
int mode, int prompt_size, int n_keep, struct ne_tensor* cossin,
float freq_base, float freq_scale, int yarn_orig_ctx, float ext_factor,
float attn_factor, float beta_fast, float beta_slow) {
return ne_rope_impl(ctx, a, n_shift, n_dims, mode, prompt_size, true, n_keep, cossin, NULL, true, freq_base,
freq_scale, yarn_orig_ctx, ext_factor, attn_factor, beta_fast, beta_slow);
}

// ne_rope_back
Expand Down Expand Up @@ -3211,14 +3231,14 @@ struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int
struct ne_tensor* ne_rope_with_padding(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode,
int prompt_size, int* n_padding, float freq_base, float freq_scale) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, false, -1, NULL, n_padding, true, freq_base,
freq_scale);
freq_scale, 0, 0.0f, 1.0f, 0.0f, 0.0f);
}

struct ne_tensor* ne_rope_with_padding_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims,
int mode, int prompt_size, int* n_padding, float freq_base,
float freq_scale) {
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, n_padding, true, freq_base,
freq_scale);
return ne_rope_impl(ctx, a, n_past, n_dims, mode, prompt_size, true, -1, NULL, n_padding, true, freq_base, freq_scale,
0, 0.0f, 1.0f, 0.0f, 0.0f);
}

// ne_alibi
Expand Down Expand Up @@ -8709,6 +8729,45 @@ static void ne_compute_forward_clamp(const struct ne_compute_params* params, con
}
}

static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
return 1.0 - MIN(1.0, MAX(0.0, y));
}

// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static void rope_yarn(float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor,
float mscale, float* cos_theta, float* sin_theta) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;

// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
*cos_theta = cosf(theta) * mscale;
*sin_theta = sinf(theta) * mscale;
}

#ifndef NE_PI
#define NE_PI (3.14159265358979323846)
#endif
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)NE_PI)) / (2 * logf(base));
}

void ggml_rope_yarn_corr_dims(int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow,
float dims[2]) {
// start and end correction dims
dims[0] = MAX(0, floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base)));
dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base)));
}

// ne_compute_forward_rope
#define NE_TENSOR_UNARY_OP_LOCALS \
NE_TENSOR_LOCALS(int64_t, ne0, src0, ne); \
Expand All @@ -8721,12 +8780,18 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) {
return;
}

const int bs = src0->ne[3];
NE_ASSERT(src1->type == NE_TYPE_I32);
NE_ASSERT(ne_nelements(src1) == 5 + bs); // 5 + bs params

const float freq_base = ((float*)(dst->op_params))[0];
const float freq_scale = 1 / ((float*)(dst->op_params))[1];
const int n_orig_ctx = (int)((float*)(dst->op_params))[2];
const float ext_factor = ((float*)(dst->op_params))[3];
const float attn_factor = ((float*)(dst->op_params))[4];
const float beta_fast = ((float*)(dst->op_params))[5];
const float beta_slow = ((float*)(dst->op_params))[6];

const int64_t n_past = ((int32_t*)src1->data)[ROPE_NPAST_IDX];
const int64_t n_dims = ((int32_t*)src1->data)[ROPE_NDIMS_IDX];
Expand Down Expand Up @@ -8759,11 +8824,15 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
int ir = 0;

const float theta_scale = powf(freq_base, -2.0f / n_dims);
const float inv_ndims = -1.f / n_dims;
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);

const bool skip = mode & 1;
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
const bool is_shift = n_keep >= 0;
const bool use_yarn = ((mode & 0x8) != 0);
NE_ASSERT(("RoPE shift not supported!", !is_shift));

NE_ASSERT(ne3 == bs);
Expand All @@ -8774,21 +8843,21 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
if (ir++ < ir0) continue;
if (ir > ir1) break;

float theta = freq_scale * (float)p;
float theta_base = (float)p;

// only for glm when mode == 4
if (is_glm) {
const int64_t n_padding = ((int32_t*)src1->data)[ROPE_PARAMS_NUM + i3];
// position ids
theta = MIN(MAX(p - n_padding, 0), prompt_size - 2 - n_padding);
theta_base = MIN(MAX(p - n_padding, 0), prompt_size - 2 - n_padding);
float block_theta = MAX(p - (prompt_size - 2), 0);
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
const float cos_theta = cosf(theta_base);
const float sin_theta = sinf(theta_base);
const float cos_block_theta = cosf(block_theta);
const float sin_block_theta = sinf(block_theta);

theta *= theta_scale;
theta_base *= theta_scale;
block_theta *= theta_scale;

const float* const src = (float*)((char*)src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01 + i0 * nb00);
Expand All @@ -8805,11 +8874,12 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
dst_data[n_dims / 2 * 3] = x2 * sin_block_theta + x3 * cos_block_theta;
}
} else if (!is_neox) {
// printf("theta_base = %ld, freq_scale %.4f, ne0 %d\n", p, freq_scale, ne0);
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);

theta *= theta_scale; // theta = i2 * theta_scale^(i0/2)
theta_base *= theta_scale;

const float* const src = (float*)((char*)src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01 + i0 * nb00);
float* dst_data = (float*)((char*)dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
Expand All @@ -8824,12 +8894,18 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
// TODO: this is probably wrong, but I can't figure it out ..
// ref:
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
theta_base = theta_base * freq_scale;

for (int64_t ib = 0; ib < ne0 / n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;

theta *= theta_scale;
float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, (int)cur_rot, ext_factor, attn_factor, &cos_theta,
&sin_theta);

theta_base *= theta_scale;

const int64_t i0 = ib * n_dims + ic / 2;

Expand Down
14 changes: 14 additions & 0 deletions neural_speed/core/ne_layers.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,20 @@ NE_API struct ne_tensor* ne_rope_shift_inplace(struct ne_context* ctx, struct ne
int mode, int prompt_size, int n_keep, struct ne_tensor* cossin,
float freq_base, float freq_scale);

// in-place, returns view(a)
NE_API struct ne_tensor* ne_rope_custom_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims,
int mode, int prompt_size, float freq_base, float freq_scale,
int yarn_orig_ctx, float ext_factor, float attn_factor, float beta_fast,
float beta_slow);

// shift all tokens by a give p (n_shift)
// Optionally give a 1d tensor of precomputed interleaved cos/sin value of n_shift*scale^k for k \in [0, n_dims)
NE_API struct ne_tensor* ne_rope_custom_shift_inplace(struct ne_context* ctx, struct ne_tensor* a, int n_shift,
int n_dims, int mode, int prompt_size, int n_keep,
struct ne_tensor* cossin, float freq_base, float freq_scale,
int yarn_orig_ctx, float ext_factor, float attn_factor,
float beta_fast, float beta_slow);

// rotary position embedding backward, i.e compute dx from dy
// a - dy
NE_API struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int n_past, int n_dims, int mode);
Expand Down

0 comments on commit 6c36f54

Please sign in to comment.