Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
Add eps as a parameter to ne_norm function (#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
aahouzi authored Mar 13, 2024
1 parent 2309fbb commit 0ec1a6e
Show file tree
Hide file tree
Showing 31 changed files with 83 additions and 80 deletions.
2 changes: 1 addition & 1 deletion neural_speed/convert/convert_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def baichuan13B_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
fout.write(struct.pack("i", hparams["intermediate_size"]))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor

Expand Down
2 changes: 1 addition & 1 deletion neural_speed/convert/convert_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", hparams.get("layer_norm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor

Expand Down
6 changes: 4 additions & 2 deletions neural_speed/convert/convert_chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def chatglm2_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
fout.write(struct.pack("f", hparams.get("layernorm_epsilon", 1e-6))) # rms norm eps
fout.write(struct.pack("f", hparams.get("layernorm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor

Expand Down Expand Up @@ -541,10 +541,12 @@ def chatglm1_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):

fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", hparams["inner_hidden_size"]))

fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", hparams.get("layernorm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps

fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
Expand Down
2 changes: 1 addition & 1 deletion neural_speed/convert/convert_dolly.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", hparams.get("layer_norm_eps", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
Expand Down
2 changes: 1 addition & 1 deletion neural_speed/convert/convert_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", hparams.get("layer_norm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor

Expand Down
2 changes: 1 addition & 1 deletion neural_speed/convert/convert_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", hparams.get("layer_norm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor

Expand Down
2 changes: 1 addition & 1 deletion neural_speed/convert/convert_gptneox.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", hparams.get("layer_norm_eps", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor

Expand Down
2 changes: 1 addition & 1 deletion neural_speed/convert/convert_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", hparams.get("layer_norm_eps", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor

Expand Down
2 changes: 1 addition & 1 deletion neural_speed/convert/convert_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", hparams.get("layer_norm_eps", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor

Expand Down
2 changes: 1 addition & 1 deletion neural_speed/convert/convert_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def phi_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", hparams.get("layer_norm_eps", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
Expand Down
3 changes: 1 addition & 2 deletions neural_speed/convert/convert_quantized_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
fout.write(struct.pack("f", hparams.get(
"rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", hparams.get("layer_norm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
Expand Down
2 changes: 1 addition & 1 deletion neural_speed/convert/convert_qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor

Expand Down
2 changes: 1 addition & 1 deletion neural_speed/convert/convert_starcoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", hparams.get("layer_norm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor

Expand Down
16 changes: 9 additions & 7 deletions neural_speed/core/ne_layers.c
Original file line number Diff line number Diff line change
Expand Up @@ -2065,7 +2065,7 @@ struct ne_tensor* ne_silu_back(struct ne_context* ctx, struct ne_tensor* a, stru

// ne_norm

struct ne_tensor* ne_norm_impl(struct ne_context* ctx, struct ne_tensor* a, bool inplace) {
struct ne_tensor* ne_norm_impl(struct ne_context* ctx, struct ne_tensor* a, bool inplace, float eps) {
bool is_node = false;

if (!inplace && (a->grad)) {
Expand All @@ -2075,20 +2075,21 @@ struct ne_tensor* ne_norm_impl(struct ne_context* ctx, struct ne_tensor* a, bool

struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a);

ne_set_op_params(result, &eps, sizeof(eps));

result->op = NE_OP_NORM;
result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = NULL; // TODO: maybe store epsilon here?

return result;
}

struct ne_tensor* ne_norm(struct ne_context* ctx, struct ne_tensor* a) {
return ne_norm_impl(ctx, a, false);
struct ne_tensor* ne_norm(struct ne_context* ctx, struct ne_tensor* a, float eps) {
return ne_norm_impl(ctx, a, false, eps);
}

struct ne_tensor* ne_norm_inplace(struct ne_context* ctx, struct ne_tensor* a) {
return ne_norm_impl(ctx, a, true);
struct ne_tensor* ne_norm_inplace(struct ne_context* ctx, struct ne_tensor* a, float eps) {
return ne_norm_impl(ctx, a, true, eps);
}

struct ne_tensor* ne_rms_norm_impl(struct ne_context* ctx, struct ne_tensor* a, bool inplace, float eps) {
Expand Down Expand Up @@ -6184,7 +6185,8 @@ static void ne_compute_forward_norm_f32(const struct ne_compute_params* params,
const size_t nb2 = dst->nb[2];
const size_t nb3 = dst->nb[3];

const float eps = 1e-5f; // TODO: make this a parameter
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

if (ne_is_contiguous(src0) && ne_is_contiguous(dst)) {
bestla_layernormalization(ne03 * ne02 * ne01, ne00, false, eps, (const float*)src0->data, (float*)dst->data);
Expand Down
3 changes: 1 addition & 2 deletions neural_speed/core/ne_layers.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,7 @@ NE_API struct ne_tensor* ne_silu(struct ne_context* ctx, struct ne_tensor* a);
NE_API struct ne_tensor* ne_silu_back(struct ne_context* ctx, struct ne_tensor* a, struct ne_tensor* b);

// normalize along rows
// TODO: eps is hardcoded to 1e-5 for now
NE_API struct ne_tensor* ne_norm(struct ne_context* ctx, struct ne_tensor* a);
NE_API struct ne_tensor* ne_norm(struct ne_context* ctx, struct ne_tensor* a, float eps);

NE_API struct ne_tensor* ne_rms_norm(struct ne_context* ctx, struct ne_tensor* a, float eps);

Expand Down
6 changes: 3 additions & 3 deletions neural_speed/models/baichuan/baichuan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input*
struct ne_tensor* residual = inpL;

// LayerNorm
cur = ne_rms_norm(ctx0, inpL, hparams.rms_norm_eps);
cur = ne_rms_norm(ctx0, inpL, hparams.norm_eps);
cur = ne_mul(ctx0, cur, model.layers[il].norm[0]);
// SelfAttention
{
Expand Down Expand Up @@ -260,7 +260,7 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input*
residual = cur;

// post_attention_layernorm
struct ne_tensor* hidden_states = ne_rms_norm(ctx0, cur, hparams.rms_norm_eps);
struct ne_tensor* hidden_states = ne_rms_norm(ctx0, cur, hparams.norm_eps);
hidden_states = ne_mul(ctx0, hidden_states, model.layers[il].norm[1]);

// mlp.forward
Expand Down Expand Up @@ -291,7 +291,7 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input*
struct ne_tensor* embeddings = nullptr;
// norm
{
inpL = ne_rms_norm(ctx0, inpL, hparams.rms_norm_eps);
inpL = ne_rms_norm(ctx0, inpL, hparams.norm_eps);
inpL = ne_mul(ctx0, inpL, model.others[1]);
}

Expand Down
8 changes: 4 additions & 4 deletions neural_speed/models/bloom/bloom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ static bool bloom_model_eval_internal(model_context* ctx, const model_input* inp

// word embeddings norm
{
inpL = ne_norm(ctx0, inpL);
inpL = ne_norm(ctx0, inpL, hparams.norm_eps);
inpL = ne_mul(ctx0, ne_repeat(ctx0, model.others[1], inpL), inpL);
inpL = ne_add(ctx0, ne_repeat(ctx0, model.others[2], inpL), inpL);
}
Expand All @@ -112,7 +112,7 @@ static bool bloom_model_eval_internal(model_context* ctx, const model_input* inp
lctx.use_buf(ctx0, 0);
// norm
{
cur = ne_norm(ctx0, inpL);
cur = ne_norm(ctx0, inpL, hparams.norm_eps);

// cur = attention_norm*cur
cur = ne_mul(ctx0, ne_repeat(ctx0, model.layers[il].norm[0], cur), cur);
Expand Down Expand Up @@ -205,7 +205,7 @@ static bool bloom_model_eval_internal(model_context* ctx, const model_input* inp
{
// norm
{
cur = ne_norm(ctx0, inpFF);
cur = ne_norm(ctx0, inpFF, hparams.norm_eps);

// cur = ffn_norm*cur + ffn_norm_b
cur = ne_mul(ctx0, ne_repeat(ctx0, model.layers[il].ffn[0], cur), cur);
Expand Down Expand Up @@ -242,7 +242,7 @@ static bool bloom_model_eval_internal(model_context* ctx, const model_input* inp
lctx.use_buf(ctx0, -1);
// norm
{
inpL = ne_norm(ctx0, inpL);
inpL = ne_norm(ctx0, inpL, hparams.norm_eps);

// inpL = norm*inpL
inpL = ne_mul(ctx0, ne_repeat(ctx0, model.others[3], inpL), inpL);
Expand Down
6 changes: 3 additions & 3 deletions neural_speed/models/chatglm/chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i

lctx.use_buf(ctx0, 0);

cur = ne_norm(ctx0, inpL);
cur = ne_norm(ctx0, inpL, hparams.norm_eps);

ne_set_name(cur, "cur");
cur = ne_mul(ctx0, cur, model.layers[il].norm[0]);
Expand Down Expand Up @@ -238,7 +238,7 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
attn_input = ne_scale_inplace(ctx0, attn_input, alpha);
inpL = ne_add_inplace(ctx0, attn_input, cur);

struct ne_tensor* mlp_input = ne_norm(ctx0, inpL);
struct ne_tensor* mlp_input = ne_norm(ctx0, inpL, hparams.norm_eps);

ne_set_name(mlp_input, "mlp_input");
mlp_input = ne_mul(ctx0, mlp_input, model.layers[il].norm[2]);
Expand Down Expand Up @@ -270,7 +270,7 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
struct ne_tensor* embeddings = nullptr;
// norm
{
inpL = ne_norm(ctx0, inpL);
inpL = ne_norm(ctx0, inpL, hparams.norm_eps);

ne_set_name(inpL, "inpL");
inpL = ne_mul(ctx0, inpL, model.others[1]);
Expand Down
6 changes: 3 additions & 3 deletions neural_speed/models/chatglm/chatglm2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
lctx.use_buf(ctx0, 0);

// self-attention
cur = ne_rms_norm(ctx0, inpL, hparams.rms_norm_eps);
cur = ne_rms_norm(ctx0, inpL, hparams.norm_eps);
cur = ne_mul(ctx0, cur, model.layers[il].norm[0]);
{
// compute QKV
Expand Down Expand Up @@ -297,7 +297,7 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
struct ne_tensor* hidden_states = ne_add(ctx0, inpL, cur);

// mlp.forward
struct ne_tensor* mlp_output = ne_rms_norm(ctx0, hidden_states, hparams.rms_norm_eps);
struct ne_tensor* mlp_output = ne_rms_norm(ctx0, hidden_states, hparams.norm_eps);
mlp_output = ne_mul(ctx0, mlp_output, model.layers[il].norm[1]);

if (model.layers[il].ffn_fusion &&
Expand Down Expand Up @@ -332,7 +332,7 @@ static bool chatglm_model_eval_internal(model_context* ctx, const model_input* i
struct ne_tensor* embeddings = nullptr;
// norm
{
inpL = ne_rms_norm(ctx0, inpL, hparams.rms_norm_eps);
inpL = ne_rms_norm(ctx0, inpL, hparams.norm_eps);
inpL = ne_mul(ctx0, inpL, model.others[1]);
}

Expand Down
6 changes: 3 additions & 3 deletions neural_speed/models/falcon/falcon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,12 @@ static bool falcon_model_eval_internal(model_context* ctx, const model_input* in
// self-attention
{
{
layernorm_output = ne_norm(ctx0, inpL);
layernorm_output = ne_norm(ctx0, inpL, hparams.norm_eps);
layernorm_output =
ne_add(ctx0, ne_mul(ctx0, ne_repeat(ctx0, model.layers[il].norm[0], layernorm_output), layernorm_output),
ne_repeat(ctx0, model.layers[il].norm[1], layernorm_output));
if (n_head_kv == 8) { // 40B (FFN does not receive ATTN output)
cur = ne_norm(ctx0, inpL);
cur = ne_norm(ctx0, inpL, hparams.norm_eps);
cur = ne_add(ctx0, ne_mul(ctx0, ne_repeat(ctx0, model.layers[il].norm[2], cur), cur),
ne_repeat(ctx0, model.layers[il].norm[3], cur));
} else { // 7B
Expand Down Expand Up @@ -299,7 +299,7 @@ static bool falcon_model_eval_internal(model_context* ctx, const model_input* in
struct ne_tensor* embeddings = nullptr;
// norm
{
inpL = ne_norm(ctx0, inpL);
inpL = ne_norm(ctx0, inpL, hparams.norm_eps);

// inpL = ln_f_g*inpL + ln_f_b
inpL = ne_add(ctx0, ne_mul(ctx0, ne_repeat(ctx0, model.others[1], inpL), inpL),
Expand Down
4 changes: 2 additions & 2 deletions neural_speed/models/gptj/gptj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu
lctx.use_buf(ctx0, 0);

// norm
cur = ne_norm(ctx0, inpL);
cur = ne_norm(ctx0, inpL, hparams.norm_eps);

// cur = ln_1_g*cur + ln_1_b
cur = ne_add(ctx0, ne_mul(ctx0, cur, model.layers[il].norm[0]), model.layers[il].norm[1]);
Expand Down Expand Up @@ -561,7 +561,7 @@ static bool gptj_model_eval_internal(model_context* ctx, const model_input* inpu

// norm
{
inpL = ne_norm(ctx0, inpL);
inpL = ne_norm(ctx0, inpL, hparams.norm_eps);

// inpL = inpL*norm(broadcasted)
inpL = ne_add(ctx0, ne_mul(ctx0, inpL, model.others[1]), model.others[2]);
Expand Down
14 changes: 7 additions & 7 deletions neural_speed/models/gptneox/gptneox.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
#include "models/model_utils/util.h"

// feed-forward network
struct ne_tensor* gpt_neox_ff(const model_layer& layer, const int batch_size, const int N, ne_context* ctx0,
ne_tensor* inp) {
struct ne_tensor* cur = ne_norm(ctx0, inp);
struct ne_tensor* gpt_neox_ff(const model_layer& layer, const int batch_size, const int N, const float eps,
ne_context* ctx0, ne_tensor* inp) {
struct ne_tensor* cur = ne_norm(ctx0, inp, eps);

cur = ne_add(ctx0, ne_mul(ctx0, ne_repeat(ctx0, layer.norm[2], cur), cur), ne_repeat(ctx0, layer.norm[3], cur));
if (bestla_fusion_FFN_Add_GeLu_f32f32_support(layer.ffn[0]->data, layer.ffn[2]->data, N * batch_size, cur->ne[0],
Expand Down Expand Up @@ -167,7 +167,7 @@ static bool gptneox_model_eval_internal(model_context* ctx, const model_input* i
// self-attention
{
{
cur = ne_norm(ctx0, inpL);
cur = ne_norm(ctx0, inpL, hparams.norm_eps);

cur = ne_add(ctx0, ne_mul(ctx0, ne_repeat(ctx0, model.layers[il].norm[0], cur), cur),
ne_repeat(ctx0, model.layers[il].norm[1], cur));
Expand Down Expand Up @@ -315,7 +315,7 @@ static bool gptneox_model_eval_internal(model_context* ctx, const model_input* i
if (hparams.par_res == 0) {
struct ne_tensor* inpFF = ne_add(ctx0, cur, inpL);

cur = gpt_neox_ff(model.layers[il], N, batch_size, ctx0, inpFF);
cur = gpt_neox_ff(model.layers[il], N, batch_size, hparams.norm_eps, ctx0, inpFF);

// input for next layer
inpL = ne_add(ctx0, cur, inpFF);
Expand All @@ -324,7 +324,7 @@ static bool gptneox_model_eval_internal(model_context* ctx, const model_input* i

// this is independent of the self-attention result, so it could be done in parallel to the self-attention
// note here we pass inpL instead of cur
cur = gpt_neox_ff(model.layers[il], N, batch_size, ctx0, inpL);
cur = gpt_neox_ff(model.layers[il], N, batch_size, hparams.norm_eps, ctx0, inpL);

// layer input + FF
cur = ne_add(ctx0, cur, inpFF);
Expand All @@ -339,7 +339,7 @@ static bool gptneox_model_eval_internal(model_context* ctx, const model_input* i
struct ne_tensor* embeddings = nullptr;
// norm
{
inpL = ne_norm(ctx0, inpL);
inpL = ne_norm(ctx0, inpL, hparams.norm_eps);

// inpL = ln_f_g*inpL + ln_f_b
inpL = ne_add(ctx0, ne_mul(ctx0, ne_repeat(ctx0, model.others[1], inpL), inpL),
Expand Down
Loading

0 comments on commit 0ec1a6e

Please sign in to comment.