From e4c5f7142ec83c8ce596a907d69ea4b617e7481f Mon Sep 17 00:00:00 2001
From: intellinjun <105184542+intellinjun@users.noreply.github.com>
Date: Fri, 22 Mar 2024 15:39:19 +0800
Subject: [PATCH] Gemma-7b&&Gemma-2b (#171)
---
docs/supported_models.md | 16 +-
neural_speed/__init__.py | 2 +
neural_speed/application/CMakeLists.txt | 5 +
neural_speed/application/main_pybind.cpp | 3 +
neural_speed/application/main_run.cpp | 2 +-
neural_speed/convert/convert_baichuan.py | 1 +
neural_speed/convert/convert_bloom.py | 1 +
neural_speed/convert/convert_chatglm.py | 5 +-
neural_speed/convert/convert_dolly.py | 1 +
neural_speed/convert/convert_falcon.py | 1 +
neural_speed/convert/convert_gemma.py | 195 ++++++++
neural_speed/convert/convert_gptj.py | 1 +
neural_speed/convert/convert_gptneox.py | 1 +
neural_speed/convert/convert_llama.py | 1 +
neural_speed/convert/convert_mistral.py | 1 +
neural_speed/convert/convert_mixtral.py | 1 +
neural_speed/convert/convert_mpt.py | 1 +
neural_speed/convert/convert_opt.py | 1 +
neural_speed/convert/convert_phi.py | 1 +
.../convert/convert_quantized_baichuan.py | 1 +
.../convert/convert_quantized_falcon.py | 1 +
.../convert/convert_quantized_gptj.py | 1 +
.../convert/convert_quantized_llama.py | 1 +
.../convert/convert_quantized_mistral.py | 1 +
.../convert/convert_quantized_mixtral.py | 1 +
neural_speed/convert/convert_quantized_phi.py | 1 +
.../convert/convert_quantized_qwen.py | 1 +
neural_speed/convert/convert_qwen.py | 1 +
neural_speed/convert/convert_stablelm.py | 1 +
neural_speed/convert/convert_starcoder.py | 1 +
neural_speed/core/layers/Ops.h | 4 +-
neural_speed/core/layers/ip_fusion_ffn.cpp | 15 +
neural_speed/core/ne_bestla.h | 7 +-
neural_speed/core/ne_layers.c | 55 ++-
neural_speed/core/ne_layers.h | 3 +
neural_speed/models/CMakeLists.txt | 1 +
neural_speed/models/gemma/gemma.cpp | 416 ++++++++++++++++++
neural_speed/models/gemma/gemma.h | 60 +++
neural_speed/models/gemma/gemma_utils.cpp | 236 ++++++++++
neural_speed/models/model_utils/gguf.h | 6 +
neural_speed/models/model_utils/model_files.h | 6 +
neural_speed/models/model_utils/model_types.h | 5 +-
.../models/model_utils/model_utils.cpp | 9 +-
tests/model-test/cpp_graph_inference.sh | 8 +-
44 files changed, 1069 insertions(+), 13 deletions(-)
create mode 100644 neural_speed/convert/convert_gemma.py
create mode 100644 neural_speed/models/gemma/gemma.cpp
create mode 100644 neural_speed/models/gemma/gemma.h
create mode 100644 neural_speed/models/gemma/gemma_utils.cpp
diff --git a/docs/supported_models.md b/docs/supported_models.md
index 03c63e9ed..df8135677 100644
--- a/docs/supported_models.md
+++ b/docs/supported_models.md
@@ -276,7 +276,21 @@ Neural Speed supports the following models:
Whisper-tiny,
Whisper-base
Whisper-small
diff --git a/neural_speed/__init__.py b/neural_speed/__init__.py
index 22566b2e1..c6b706d4b 100644
--- a/neural_speed/__init__.py
+++ b/neural_speed/__init__.py
@@ -70,6 +70,8 @@ def __import_package(self, model_type):
import neural_speed.qwen_cpp as cpp_model
elif model_type == "phi":
import neural_speed.phi_cpp as cpp_model
+ elif model_type == "gemma":
+ import neural_speed.gemma_cpp as cpp_model
elif model_type == "stablelm":
import neural_speed.stablelm_cpp as cpp_model
elif model_type == "whisper":
diff --git a/neural_speed/application/CMakeLists.txt b/neural_speed/application/CMakeLists.txt
index 2a34f5db3..ca659e1b1 100644
--- a/neural_speed/application/CMakeLists.txt
+++ b/neural_speed/application/CMakeLists.txt
@@ -71,6 +71,7 @@ compile_quant(quant_mistral quant_model.cpp mistral llama)
compile_quant(quant_mixtral quant_model.cpp mixtral llama)
compile_quant(quant_qwen quant_model.cpp qwen qwen)
compile_quant(quant_phi quant_model.cpp phi phi)
+compile_quant(quant_gemma quant_model.cpp gemma gemma)
compile_quant(quant_stablelm quant_model.cpp stablelm stablelm)
compile_quant(quant_whisper quant_whisper.cpp whisper whisper)
@@ -99,6 +100,8 @@ set(mymap_stablelm 17)
set(mymap_whisper 18)
set(mymap_mixtral 19)
set(mymap_chatglm3 20)
+set(mymap_gemma 21)
+
function(compile_run TARGET MAIN_CPP MAIN_PY MODEL_NAME MODEL_LIB)
@@ -135,8 +138,10 @@ compile_run(run_baichuan main_run.cpp main_pybind.cpp baichuan baichuan)
compile_run(run_mistral main_run.cpp main_pybind.cpp mistral llama)
compile_run(run_qwen main_run.cpp main_pybind.cpp qwen qwen)
compile_run(run_phi main_run.cpp main_pybind.cpp phi phi)
+compile_run(run_gemma main_run.cpp main_pybind.cpp gemma gemma)
compile_run(run_stablelm main_run.cpp main_pybind.cpp stablelm stablelm)
compile_run(run_mixtral main_run.cpp main_pybind.cpp mixtral llama)
+
# speech recognition
compile_run(run_whisper audio_run.cpp whisper_pybind.cpp whisper whisper)
diff --git a/neural_speed/application/main_pybind.cpp b/neural_speed/application/main_pybind.cpp
index eb3fb8777..d95934401 100644
--- a/neural_speed/application/main_pybind.cpp
+++ b/neural_speed/application/main_pybind.cpp
@@ -924,6 +924,9 @@ PYBIND11_MODULE(mixtral_cpp, m)
#elif MODEL_NAME_ID == 20
PYBIND11_MODULE(chatglm3_cpp, m)
+#elif MODEL_NAME_ID == 21
+
+PYBIND11_MODULE(gemma_cpp, m)
#endif
{
diff --git a/neural_speed/application/main_run.cpp b/neural_speed/application/main_run.cpp
index c784348eb..6872d1133 100644
--- a/neural_speed/application/main_run.cpp
+++ b/neural_speed/application/main_run.cpp
@@ -229,7 +229,7 @@ int main(int argc, char** argv) { // NOLINT
// tokenize the prompt
bool add_bos = false;
- if (params.model_arch == MODEL_LLAMA) {
+ if (params.model_arch == MODEL_LLAMA || params.model_arch == MODEL_GEMMA) {
add_bos = true;
}
diff --git a/neural_speed/convert/convert_baichuan.py b/neural_speed/convert/convert_baichuan.py
index b6893e310..1fa35805e 100644
--- a/neural_speed/convert/convert_baichuan.py
+++ b/neural_speed/convert/convert_baichuan.py
@@ -157,6 +157,7 @@ def baichuan13B_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
fout.write(struct.pack("i", hparams["intermediate_size"]))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
diff --git a/neural_speed/convert/convert_bloom.py b/neural_speed/convert/convert_bloom.py
index 0923c0e75..7a9609ddf 100644
--- a/neural_speed/convert/convert_bloom.py
+++ b/neural_speed/convert/convert_bloom.py
@@ -106,6 +106,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("layer_norm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
diff --git a/neural_speed/convert/convert_chatglm.py b/neural_speed/convert/convert_chatglm.py
index be84facd6..6aa41e5d5 100644
--- a/neural_speed/convert/convert_chatglm.py
+++ b/neural_speed/convert/convert_chatglm.py
@@ -156,8 +156,6 @@ def chatglm3_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams
gguf_file = fname_out
gguf_writer = gguf.GGUFWriter(gguf_file, "chatglm3")
gguf_writer.add_uint32('magic', 0x67676d66)
- import pdb
- pdb.set_trace()
gguf_writer.add_uint32('version', 1)
gguf_writer.add_uint32('n_vocab', hparams["padded_vocab_size"])
gguf_writer.add_embedding_length(hparams["hidden_size"])
@@ -561,6 +559,7 @@ def chatglm3_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("layernorm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
@@ -711,6 +710,7 @@ def chatglm2_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("layernorm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
@@ -862,6 +862,7 @@ def chatglm1_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("layernorm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
diff --git a/neural_speed/convert/convert_dolly.py b/neural_speed/convert/convert_dolly.py
index 87d711b0f..ebab9d118 100644
--- a/neural_speed/convert/convert_dolly.py
+++ b/neural_speed/convert/convert_dolly.py
@@ -120,6 +120,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("layer_norm_eps", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
diff --git a/neural_speed/convert/convert_falcon.py b/neural_speed/convert/convert_falcon.py
index 955d02547..f1d82daf3 100644
--- a/neural_speed/convert/convert_falcon.py
+++ b/neural_speed/convert/convert_falcon.py
@@ -113,6 +113,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("layer_norm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
diff --git a/neural_speed/convert/convert_gemma.py b/neural_speed/convert/convert_gemma.py
new file mode 100644
index 000000000..f0b61df1a
--- /dev/null
+++ b/neural_speed/convert/convert_gemma.py
@@ -0,0 +1,195 @@
+# Copyright (c) 2023 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Convert Hugging Face fine-tuned gpt-neox-like models to ne format
+#
+# Usage:
+#
+# python3 models/convert-h5-to-ne.py
+#
+# This script is similar to "convert-pt-to-ne.py"
+#
+
+import io
+import os
+import sys
+import struct
+import json
+import code
+import torch
+import numpy as np
+from pathlib import Path
+import argparse
+from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar,
+ Union)
+
+
+
+# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a significant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def main(args_in: Optional[List[str]] = None) -> None:
+ parser = argparse.ArgumentParser(description="Convert a model to a NE compatible file")
+ parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)")
+ parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
+ parser.add_argument("--model_hub", choices=["huggingface","modelscope"],
+ default="huggingface", help="hub to load model")
+ parser.add_argument("model", type=Path, help="directory containing model file")
+ args = parser.parse_args(args_in)
+
+ dir_model = args.model.as_posix()
+ fname_out = args.outfile.as_posix()
+
+ # possible data types
+ # ftype == 0 -> float32
+ # ftype == 1 -> float16
+ ftype = 0
+ if args.outtype == "f16":
+ ftype = 1
+ if args.model_hub == "modelscope":
+ from modelscope import AutoModelForCausalLM, AutoTokenizer
+ else:
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+ print("Loading model: ", dir_model)
+ model = AutoModelForCausalLM.from_pretrained(dir_model)
+ tokenizer = AutoTokenizer.from_pretrained(dir_model)
+ model.eval()
+ for p in model.parameters():
+ p.requires_grad = False
+ hparams = model.config.to_dict()
+ print("Model loaded: ", dir_model)
+
+ fout = open(fname_out, "wb")
+
+ # 0x67676d6c is unversioned ne
+ # 0x67676d66 is versioned ggmf (requires token scores)
+ ne_file_magic = 0x67676d66
+ #ne_file_version = 0x00000001 # v1
+
+ fout.write(struct.pack("i", ne_file_magic)) # magic: ne in hex
+ fout.write(struct.pack("i", 1))
+ fout.write(struct.pack("i", hparams["vocab_size"]))
+ fout.write(struct.pack("i", hparams["hidden_size"]))
+ fout.write(struct.pack("i", hparams["intermediate_size"])) # dummy data
+ fout.write(struct.pack("i", hparams["num_attention_heads"]))
+ fout.write(struct.pack("i", hparams["num_key_value_heads"])) # multi-query attention
+ fout.write(struct.pack("i", hparams["num_hidden_layers"]))
+ fout.write(struct.pack("i", hparams["head_dim"]))
+ fout.write(struct.pack("i", ftype))
+ fout.write(
+ struct.pack("i", hparams["seq_length"] if "seq_length" in hparams else hparams["max_position_embeddings"]))
+ fout.write(struct.pack("f", 0.0))
+ fout.write(struct.pack("f", 0.0))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt)
+ fout.write(struct.pack("i", 0)) # do_layer_norm_before (for opt)
+
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", hparams["intermediate_size"]))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", 0)) # n_experts
+ fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("i", hparams["head_dim"])) # n_embd_head_k
+ fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
+ fout.write(struct.pack("f", 10000.0)) # freq_base
+ fout.write(struct.pack("f", 1.0)) # rope_factor
+
+ fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
+ fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
+ fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
+ fout.write(struct.pack("i", hparams["bos_token_id"]))
+ fout.write(struct.pack("i", hparams["eos_token_id"]))
+ fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1))
+ fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1))
+
+ for i in range(hparams["vocab_size"]):
+ if i < tokenizer.vocab_size:
+ text = tokenizer.decode([i]).encode('utf-8')
+ fout.write(struct.pack("i", len(text)))
+ fout.write(text)
+ fout.write(struct.pack("f", 0.0 - i))
+ else:
+ text = tokenizer.decode([tokenizer.vocab_size - 1]).encode('utf-8')
+ fout.write(struct.pack("i", len(text)))
+ fout.write(text)
+ fout.write(struct.pack("f", -10000))
+
+ list_vars = model.state_dict()
+
+ print(hparams)
+
+ for name in list_vars.keys():
+ # No gradients for these
+ list_vars[name].requires_grad = False
+ src = name
+ nn = name
+
+ print(src, ' -> ', name)
+ data = list_vars[src].squeeze().numpy()
+ data = data.astype(np.float32)
+
+ n_dims = len(data.shape)
+ print(name, n_dims, data.shape)
+
+ # default type is fp32
+ ftype_cur = 0
+ if ftype == 1 and n_dims > 1:
+ print(" Converting to float16", data.shape, data[:3, :3].tolist())
+ data = data.astype(np.float16)
+ ftype_cur = 1
+ else:
+ print(" Converting to float32", data.shape, data[:3, :3].tolist() if n_dims > 1 else data[:3].tolist())
+ data = data.astype(np.float32)
+ # gemma_rms:
+ # output = self._norm(x.float()).type_as(x)
+ # return output * (1 + self.weight)
+ if "norm" in name:
+ data = data + 1
+ str = name.encode('utf-8')
+ fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
+ for i in range(n_dims):
+ fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
+ print(str)
+ fout.write(str)
+
+ # data
+ data.tofile(fout)
+
+ fout.close()
+
+ print("Done. Output file: " + fname_out)
+ print("")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/neural_speed/convert/convert_gptj.py b/neural_speed/convert/convert_gptj.py
index 3fb12c8f6..2ff15c255 100644
--- a/neural_speed/convert/convert_gptj.py
+++ b/neural_speed/convert/convert_gptj.py
@@ -105,6 +105,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("layer_norm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
diff --git a/neural_speed/convert/convert_gptneox.py b/neural_speed/convert/convert_gptneox.py
index 322f47a01..9cbda72fa 100644
--- a/neural_speed/convert/convert_gptneox.py
+++ b/neural_speed/convert/convert_gptneox.py
@@ -121,6 +121,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("layer_norm_eps", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
diff --git a/neural_speed/convert/convert_llama.py b/neural_speed/convert/convert_llama.py
index 37af73ab4..328994938 100644
--- a/neural_speed/convert/convert_llama.py
+++ b/neural_speed/convert/convert_llama.py
@@ -1090,6 +1090,7 @@ def write_file_header(self, params: Params, file_type: NEFileType) -> None:
self.fout.write(struct.pack("i", 0)) # n_experts
self.fout.write(struct.pack("i", 0)) # n_expert_used
+ self.fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
self.fout.write(struct.pack("f", params.rms_norm_eps))
self.fout.write(struct.pack("f", params.rope_theta))
self.fout.write(struct.pack("f", params.rope_scale))
diff --git a/neural_speed/convert/convert_mistral.py b/neural_speed/convert/convert_mistral.py
index 1889a9860..a1f12b6a8 100644
--- a/neural_speed/convert/convert_mistral.py
+++ b/neural_speed/convert/convert_mistral.py
@@ -1064,6 +1064,7 @@ def write_file_header(self, params: Params, file_type: NEFileType) -> None:
self.fout.write(struct.pack("i", 0)) # n_experts
self.fout.write(struct.pack("i", 0)) # n_expert_used
+ self.fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
self.fout.write(struct.pack("f", params.rms_norm_eps))
self.fout.write(struct.pack("f", params.rope_theta))
self.fout.write(struct.pack("f", params.rope_scale))
diff --git a/neural_speed/convert/convert_mixtral.py b/neural_speed/convert/convert_mixtral.py
index 4166d94be..77a1376f4 100644
--- a/neural_speed/convert/convert_mixtral.py
+++ b/neural_speed/convert/convert_mixtral.py
@@ -1066,6 +1066,7 @@ def write_file_header(self, params: Params, file_type: NEFileType) -> None:
self.fout.write(struct.pack("i", 0))
self.fout.write(struct.pack("i", 8))
self.fout.write(struct.pack("i", 2))
+ self.fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
self.fout.write(struct.pack("f", params.rms_norm_eps))
self.fout.write(struct.pack("f", params.rope_theta))
self.fout.write(struct.pack("f", params.rope_scale))
diff --git a/neural_speed/convert/convert_mpt.py b/neural_speed/convert/convert_mpt.py
index 04ee52421..c31961f23 100644
--- a/neural_speed/convert/convert_mpt.py
+++ b/neural_speed/convert/convert_mpt.py
@@ -102,6 +102,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("layer_norm_eps", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
diff --git a/neural_speed/convert/convert_opt.py b/neural_speed/convert/convert_opt.py
index c96885334..ea3f226d6 100644
--- a/neural_speed/convert/convert_opt.py
+++ b/neural_speed/convert/convert_opt.py
@@ -114,6 +114,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("layer_norm_eps", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
diff --git a/neural_speed/convert/convert_phi.py b/neural_speed/convert/convert_phi.py
index a916ccf5f..073b54fbf 100644
--- a/neural_speed/convert/convert_phi.py
+++ b/neural_speed/convert/convert_phi.py
@@ -199,6 +199,7 @@ def phi_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("layer_norm_eps", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
diff --git a/neural_speed/convert/convert_quantized_baichuan.py b/neural_speed/convert/convert_quantized_baichuan.py
index 22928a6bc..2220e3b8b 100644
--- a/neural_speed/convert/convert_quantized_baichuan.py
+++ b/neural_speed/convert/convert_quantized_baichuan.py
@@ -101,6 +101,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", hparams["intermediate_size"]))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
diff --git a/neural_speed/convert/convert_quantized_falcon.py b/neural_speed/convert/convert_quantized_falcon.py
index 956b0a92b..4b90b1d75 100644
--- a/neural_speed/convert/convert_quantized_falcon.py
+++ b/neural_speed/convert/convert_quantized_falcon.py
@@ -80,6 +80,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
diff --git a/neural_speed/convert/convert_quantized_gptj.py b/neural_speed/convert/convert_quantized_gptj.py
index 707b63aca..85a1e769b 100644
--- a/neural_speed/convert/convert_quantized_gptj.py
+++ b/neural_speed/convert/convert_quantized_gptj.py
@@ -140,6 +140,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("layer_norm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
diff --git a/neural_speed/convert/convert_quantized_llama.py b/neural_speed/convert/convert_quantized_llama.py
index 78f497ef4..9762492ee 100644
--- a/neural_speed/convert/convert_quantized_llama.py
+++ b/neural_speed/convert/convert_quantized_llama.py
@@ -150,6 +150,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
f.write(struct.pack("i", 0))
f.write(struct.pack("i", 0)) # n_experts
f.write(struct.pack("i", 0)) # n_expert_used
+ f.write(struct.pack("i", 0)) # n_embd_head_k for gemma
f.write(struct.pack("f", config["rms_norm_eps"]))
f.write(struct.pack("f", config["rope_theta"] if "rope_theta" in config else 10000))
diff --git a/neural_speed/convert/convert_quantized_mistral.py b/neural_speed/convert/convert_quantized_mistral.py
index 3e24295cd..c89bcec41 100644
--- a/neural_speed/convert/convert_quantized_mistral.py
+++ b/neural_speed/convert/convert_quantized_mistral.py
@@ -159,6 +159,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
f.write(struct.pack("i", 0))
f.write(struct.pack("i", 0)) # n_experts
f.write(struct.pack("i", 0)) # n_expert_used
+ f.write(struct.pack("i", 0)) # n_embd_head_k for gemma
f.write(struct.pack("f", config["rms_norm_eps"]))
f.write(struct.pack("f", config["rope_theta"] if "rope_theta" in config else 10000))
diff --git a/neural_speed/convert/convert_quantized_mixtral.py b/neural_speed/convert/convert_quantized_mixtral.py
index c320797d9..df793cda6 100644
--- a/neural_speed/convert/convert_quantized_mixtral.py
+++ b/neural_speed/convert/convert_quantized_mixtral.py
@@ -174,6 +174,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
f.write(struct.pack("i", 0))
f.write(struct.pack("i", 8)) # n_experts
f.write(struct.pack("i", 2)) # n_expert_used
+ f.write(struct.pack("i", 0)) # n_embd_head_k for gemma
f.write(struct.pack("f", config["rms_norm_eps"]))
f.write(struct.pack("f", config["rope_theta"] if "rope_theta" in config else 10000))
diff --git a/neural_speed/convert/convert_quantized_phi.py b/neural_speed/convert/convert_quantized_phi.py
index 8588116b1..9fe22bfaa 100644
--- a/neural_speed/convert/convert_quantized_phi.py
+++ b/neural_speed/convert/convert_quantized_phi.py
@@ -66,6 +66,7 @@ def convert_phi1_5_gptq_to_bestTLA(model_path, out_path, outtype, model, hparams
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
diff --git a/neural_speed/convert/convert_quantized_qwen.py b/neural_speed/convert/convert_quantized_qwen.py
index 02ded7622..998695bc4 100644
--- a/neural_speed/convert/convert_quantized_qwen.py
+++ b/neural_speed/convert/convert_quantized_qwen.py
@@ -83,6 +83,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
f.write(struct.pack("i", 0))
f.write(struct.pack("i", 0)) # n_experts
f.write(struct.pack("i", 0)) # n_expert_used
+ f.write(struct.pack("i", 0)) # n_embd_head_k for gemma
f.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
f.write(struct.pack("f", 10000.0)) # freq_base
f.write(struct.pack("f", 1.0)) # rope_factor
diff --git a/neural_speed/convert/convert_qwen.py b/neural_speed/convert/convert_qwen.py
index 7cc59b72b..f268b37e6 100644
--- a/neural_speed/convert/convert_qwen.py
+++ b/neural_speed/convert/convert_qwen.py
@@ -126,6 +126,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
diff --git a/neural_speed/convert/convert_stablelm.py b/neural_speed/convert/convert_stablelm.py
index af00ba4fc..f5f1d43fd 100644
--- a/neural_speed/convert/convert_stablelm.py
+++ b/neural_speed/convert/convert_stablelm.py
@@ -197,6 +197,7 @@ def stablelm_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("layer_norm_eps", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", hparams["rope_theta"])) # freq_base
fout.write(struct.pack("f", 1.0)) # freq_scale, was removed in config.json (by default=1.0)
diff --git a/neural_speed/convert/convert_starcoder.py b/neural_speed/convert/convert_starcoder.py
index 8f1dba042..b00fc566f 100644
--- a/neural_speed/convert/convert_starcoder.py
+++ b/neural_speed/convert/convert_starcoder.py
@@ -117,6 +117,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma
fout.write(struct.pack("f", hparams.get("layer_norm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
diff --git a/neural_speed/core/layers/Ops.h b/neural_speed/core/layers/Ops.h
index b0f441b42..e38dfbafd 100644
--- a/neural_speed/core/layers/Ops.h
+++ b/neural_speed/core/layers/Ops.h
@@ -45,11 +45,11 @@ enum ne_op {
NE_OP_NORM, // normalize
NE_OP_RMS_NORM,
NE_OP_RMS_NORM_BACK,
+ NE_OP_RMS_ARGSORT,
NE_OP_MUL_MAT,
NE_OP_MUL_MAT_BIAS,
NE_OP_MUL_MAT_ID,
- NE_OP_MUL_ID_FFN_SILU,
NE_OP_SCALE,
NE_OP_SET,
NE_OP_CPY,
@@ -76,7 +76,9 @@ enum ne_op {
NE_OP_MUL_QKV,
NE_OP_MUL_FFN_SILU,
NE_OP_MUL_FFN_GELU,
+ NE_OP_MUL_FFN_GELU_MUL,
NE_OP_MUL_FFN_ADD_GELU,
+ NE_OP_MUL_ID_FFN_SILU,
NE_OP_FLASH_ATTN,
NE_OP_FLASH_ATTN_KV_UPDATE,
NE_OP_FLASH_FF,
diff --git a/neural_speed/core/layers/ip_fusion_ffn.cpp b/neural_speed/core/layers/ip_fusion_ffn.cpp
index 5875bc5bd..6aed77f91 100644
--- a/neural_speed/core/layers/ip_fusion_ffn.cpp
+++ b/neural_speed/core/layers/ip_fusion_ffn.cpp
@@ -684,6 +684,11 @@ bool bestla_fusion_FFN_SiLu_f32f32_support(void* w1ptr, void* w2ptr, void* w3ptr
return ffn_3w::bestla_fusion_ffn_f32f32_support(w1ptr, w2ptr, w3ptr, seq, fin, fmid, fout);
}
+bool bestla_fusion_FFN_Gelu_Mul_f32f32_support(void* w1ptr, void* w2ptr, void* w3ptr, int seq, int fin, int fmid,
+ int fout) {
+ return ffn_3w::bestla_fusion_ffn_f32f32_support(w1ptr, w2ptr, w3ptr, seq, fin, fmid, fout);
+}
+
void bestla_fusion_FFN_SiLu_f32f32_forward(float* activation, void* w1ptr, void* w2ptr, void* w3ptr, float* tmp1,
float* tmp2, float* output, int seq, int fin, int fmid, int fout,
void* workspace) {
@@ -695,6 +700,16 @@ void bestla_fusion_FFN_SiLu_f32f32_forward(float* activation, void* w1ptr, void*
activation, w1ptr, w2ptr, w3ptr, tmp1, tmp2, output, seq, fin, fmid, fout, workspace, epi_args1, epi_args2);
}
+void bestla_fusion_FFN_Gelu_Mul_f32f32_forward(float* activation, void* w1ptr, void* w2ptr, void* w3ptr, float* tmp1,
+ float* tmp2, float* output, int seq, int fin, int fmid, int fout,
+ void* workspace) {
+ epilogue::gemm::ParamAccumulatorWriteBack epi_args1 = {tmp1, fmid};
+ epilogue::gemm::ParamAccumulatorWriteBack epi_args2 = {output, fout};
+ ffn_3w::bestla_fusion_ffn_f32f32_forward(
+ activation, w1ptr, w2ptr, w3ptr, tmp1, tmp2, output, seq, fin, fmid, fout, workspace, epi_args1, epi_args2);
+}
+
bool bestla_fusion_FFN_GeLu_f32f32_support(void* w1ptr, void* w2ptr, int seq, int fin, int fmid, int fout) {
return ffn_2w::bestla_fusion_ffn_f32f32_support(w1ptr, w2ptr, seq, fin, fmid, fout);
}
diff --git a/neural_speed/core/ne_bestla.h b/neural_speed/core/ne_bestla.h
index a9ad0a5ea..7d34525fd 100644
--- a/neural_speed/core/ne_bestla.h
+++ b/neural_speed/core/ne_bestla.h
@@ -49,8 +49,13 @@ void bestla_fusion_QKV_f32f32_forward(float* activation, void* wqptr, void* wkpt
unsigned long long bestla_fusion_FFN_f32f32_get_workspace_size(int seq, int fin, int fmid, int fout, void* w1ptr,
void* w2ptr);
-bool bestla_fusion_FFN_SiLu_f32f32_support(void* w1ptr, void* w2ptr, void* w3ptr, int seq, int fin, int fmid, int fout);
+bool bestla_fusion_FFN_Gelu_Mul_f32f32_support(void* w1ptr, void* w2ptr, void* w3ptr, int seq, int fin, int fmid,
+ int fout);
+void bestla_fusion_FFN_Gelu_Mul_f32f32_forward(float* activation, void* w1ptr, void* w2ptr, void* w3ptr, float* tmp1,
+ float* tmp2, float* output, int seq, int fin, int fmid, int fout,
+ void* workspace);
+bool bestla_fusion_FFN_SiLu_f32f32_support(void* w1ptr, void* w2ptr, void* w3ptr, int seq, int fin, int fmid, int fout);
void bestla_fusion_FFN_SiLu_f32f32_forward(float* activation, void* w1ptr, void* w2ptr, void* w3ptr, float* tmp1,
float* tmp2, float* output, int seq, int fin, int fmid, int fout,
void* workspace);
diff --git a/neural_speed/core/ne_layers.c b/neural_speed/core/ne_layers.c
index 3f892a371..a23df031e 100644
--- a/neural_speed/core/ne_layers.c
+++ b/neural_speed/core/ne_layers.c
@@ -427,6 +427,7 @@ static const char* NE_OP_LABEL[NE_OP_COUNT] = {
"MUL_QKV",
"FFN_SILU",
"FFN_GeLU",
+ "FFN_GeLU_MUL",
"FFN_ADD_GeLU",
"FFN_ID_SILU",
"FLASH_ATTN",
@@ -442,7 +443,7 @@ static const char* NE_OP_LABEL[NE_OP_COUNT] = {
"DEBUG",
};
-static_assert(NE_OP_COUNT == 67, "NE_OP_COUNT != 67");
+static_assert(NE_OP_COUNT == 69, "NE_OP_COUNT != 69");
static const char* NE_OP_SYMBOL[NE_OP_COUNT] = {
"none",
@@ -502,6 +503,7 @@ static const char* NE_OP_SYMBOL[NE_OP_COUNT] = {
"ffn_silu(x)",
"ffn_id_silu(x)",
"ffn_gelu(x)",
+ "ffn_gelu_mul(x)",
"ffn_gelu_with_bias(x)",
"flash_attn(x)",
"flash_attn_kv_update(x)",
@@ -2396,6 +2398,33 @@ struct ne_tensor* ne_ffn_gelu(struct ne_context* ctx, struct ne_tensor* w1, stru
return result;
}
+struct ne_tensor* ne_ffn_gelu_mul(struct ne_context* ctx, struct ne_tensor* w1, struct ne_tensor* w2,
+ struct ne_tensor* w3, struct ne_tensor* src) {
+ NE_ASSERT(ne_are_same_shape(w1, w3));
+ NE_ASSERT(w2->ne[0] == w1->ne[1]);
+
+ bool is_node = false;
+
+ if (src->grad || w1->grad || w2->grad || w3->grad) {
+ is_node = true;
+ }
+
+ const int64_t ne[4] = {w2->ne[1], src->ne[1], src->ne[2], src->ne[3]};
+ struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, ne, NE_SIZE_CALC);
+ const int64_t tne[4] = {w1->ne[1], src->ne[1], src->ne[2], src->ne[3]};
+ struct ne_tensor* tmp = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC);
+ struct ne_tensor* tmp1 = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC);
+
+ result->op = NE_OP_MUL_FFN_GELU_MUL;
+ result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL;
+ result->src0 = src;
+ result->src1 = w1;
+ result->opt[0] = w2;
+ result->opt[1] = w3;
+ result->opt[2] = tmp;
+ result->opt[3] = tmp1;
+ return result;
+}
// ne_scale
struct ne_tensor* ne_scale_impl(struct ne_context* ctx, struct ne_tensor* a, struct ne_tensor* b, bool inplace) {
@@ -7753,6 +7782,25 @@ static void ne_compute_forward_ffn_gelu(const struct ne_compute_params* params,
seq, fin, fmid, fout, params->wdata);
}
+static void ne_compute_forward_ffn_gelu_mul(const struct ne_compute_params* params, const struct ne_tensor* src,
+ const struct ne_tensor* w1, const struct ne_tensor* w2,
+ struct ne_tensor* w3, const struct ne_tensor* tmp, struct ne_tensor* tmp1,
+ struct ne_tensor* dst) {
+ if (params->type == NE_TASK_INIT) {
+ return;
+ }
+
+ if (params->type == NE_TASK_FINALIZE) {
+ return;
+ }
+ const int fin = src->ne[0];
+ const int fout = dst->ne[0];
+ const int fmid = w1->ne[1];
+ const int seq = dst->ne[1];
+ bestla_fusion_FFN_Gelu_Mul_f32f32_forward((float*)src->data, w1->data, w2->data, w3->data, (float*)tmp->data,
+ (float*)tmp1->data, (float*)dst->data, seq, fin, fmid, fout, params->wdata);
+}
+
// ne_compute_forward_scale
static void ne_compute_forward_scale_f32(const struct ne_compute_params* params, const struct ne_tensor* src0,
@@ -10394,6 +10442,10 @@ static void ne_compute_forward(struct ne_compute_params* params, struct ne_tenso
case NE_OP_MUL_FFN_GELU: {
ne_compute_forward_ffn_gelu(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor);
} break;
+ case NE_OP_MUL_FFN_GELU_MUL: {
+ ne_compute_forward_ffn_gelu_mul(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1],
+ tensor->opt[2], tensor->opt[3], tensor);
+ } break;
case NE_OP_SCALE: {
ne_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
} break;
@@ -11276,6 +11328,7 @@ void ne_graph_compute(struct ne_context* ctx, struct ne_cgraph* cgraph) {
} break;
case NE_OP_MUL_FFN_SILU:
case NE_OP_MUL_FFN_GELU:
+ case NE_OP_MUL_FFN_GELU_MUL:
case NE_OP_MUL_FFN_ADD_GELU: {
size_t cur = 0;
cur = bestla_fusion_FFN_f32f32_get_workspace_size(node->src0->ne[1], node->src0->ne[0], node->src1->ne[1],
diff --git a/neural_speed/core/ne_layers.h b/neural_speed/core/ne_layers.h
index c6a7c566d..19734c98d 100644
--- a/neural_speed/core/ne_layers.h
+++ b/neural_speed/core/ne_layers.h
@@ -272,6 +272,9 @@ NE_API struct ne_tensor* ne_mul_qkv(struct ne_context* ctx, struct ne_tensor* qw
NE_API struct ne_tensor* ne_ffn_silu(struct ne_context* ctx, struct ne_tensor* w1, struct ne_tensor* w2,
struct ne_tensor* w3, struct ne_tensor* src);
+NE_API struct ne_tensor* ne_ffn_gelu_mul(struct ne_context* ctx, struct ne_tensor* w1, struct ne_tensor* w2,
+ struct ne_tensor* w3, struct ne_tensor* src);
+
NE_API struct ne_tensor* ne_ffn_add_gelu(struct ne_context* ctx, struct ne_tensor* w1, struct ne_tensor* w2,
struct ne_tensor* b1, struct ne_tensor* b2, struct ne_tensor* src);
diff --git a/neural_speed/models/CMakeLists.txt b/neural_speed/models/CMakeLists.txt
index a1edeca6f..58185c6de 100644
--- a/neural_speed/models/CMakeLists.txt
+++ b/neural_speed/models/CMakeLists.txt
@@ -34,6 +34,7 @@ add_model(qwen qwen/qwen.cpp qwen/qwen_utils.cpp ${MODEL_UTILS_SOURCE})
add_model(whisper whisper/whisper.cpp whisper/whisper_utils.cpp ${MODEL_UTILS_SOURCE})
add_model(chatglm chatglm/chatglm.cpp chatglm/chatglm_utils.cpp ${MODEL_UTILS_SOURCE})
add_model(chatglm2 chatglm/chatglm2.cpp chatglm/chatglm2_utils.cpp ${MODEL_UTILS_SOURCE})
+add_model(gemma gemma/gemma.cpp gemma/gemma_utils.cpp ${MODEL_UTILS_SOURCE})
add_model(phi phi/phi.cpp phi/phi_utils.cpp ${MODEL_UTILS_SOURCE})
add_model(stablelm stablelm/stablelm.cpp stablelm/stablelm_utils.cpp ${MODEL_UTILS_SOURCE})
add_model(chatglm3 chatglm/chatglm2.cpp chatglm/chatglm2_utils.cpp ${MODEL_UTILS_SOURCE})
diff --git a/neural_speed/models/gemma/gemma.cpp b/neural_speed/models/gemma/gemma.cpp
new file mode 100644
index 000000000..1bd069e1d
--- /dev/null
+++ b/neural_speed/models/gemma/gemma.cpp
@@ -0,0 +1,416 @@
+// Copyright (c) 2023 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "core/data_types.h"
+#include "core/ne.h"
+#include "core/ne_layers.h"
+#include "core/ne_bestla.h"
+#include "core/layers/mha_dense.h"
+#include "models/model_utils/model_config.h"
+#include "models/model_utils/model_utils.h"
+#include "models/model_utils/util.h"
+
+// feed-forward network
+struct ne_tensor* gemma_ff(const model_layer& layer, const int batch_size, const int N, ne_context* ctx0,
+ ne_tensor* inp) {
+ struct ne_tensor* cur = inp;
+ if (bestla_fusion_FFN_Gelu_Mul_f32f32_support(layer.ffn[0]->data, layer.ffn[1]->data, layer.ffn[2]->data, N,
+ cur->ne[0], layer.ffn[0]->ne[1], layer.ffn[1]->ne[1])) {
+ cur = ne_ffn_gelu_mul(ctx0, layer.ffn[0], layer.ffn[1], layer.ffn[2], cur);
+ } else {
+ struct ne_tensor* cur_1 = ne_mul_mat(ctx0, layer.ffn[0], cur);
+
+ struct ne_tensor* cur_2 = ne_mul_mat(ctx0, layer.ffn[2], cur);
+
+ // GELU activation
+ cur_1 = ne_gelu(ctx0, cur_1);
+
+ // projection
+ // cur = proj_w*cur + proj_b
+ cur = ne_mul(ctx0, cur_1, cur_2);
+
+ cur = ne_mul_mat(ctx0, layer.ffn[1], cur);
+ }
+
+ return cur;
+}
+
+// evaluate the transformer
+//
+// - lctx: model context
+// - tokens: new batch of tokens to process
+// - n_past: the offset to which the kv is cached to
+// - n_total: the number of tokens evaluated so far (including evicted tokens if there is any)
+// - n_threads: number of threads to use
+//
+static bool gemma_model_eval_internal(model_context* ctx, const model_input* inputs, const int n_input,
+ const int n_threads) {
+ const int64_t t_start_us = ne_time_us();
+ model_context& lctx = *ctx;
+
+ // static batching for now
+ const int N = inputs->n_tokens;
+ const int n_past = inputs->n_past;
+ const int n_total = inputs->n_total;
+ const bool shift_roped_k = lctx.shift_roped_k;
+ const bool is_ring_full = shift_roped_k && n_total > n_past;
+ NE_ASSERT(("Shift-RoPE-K to be implemented for the neox-mode RoPE!", !is_ring_full));
+ const int batch_size = lctx.batch_size;
+ MODEL_ASSERT(batch_size == n_input);
+ const int kv_n_ctx_block = lctx.kv_n_ctx_block;
+
+ const auto& model = lctx.model;
+ const auto& hparams = model.hparams;
+
+ const auto& kv_self = model.kv_self;
+
+ MODEL_ASSERT(!!kv_self.ctx);
+
+ const int n_embd = hparams.n_embd;
+ const int n_layer = hparams.n_layer;
+ const int n_ctx = lctx.n_ctx;
+ const int n_keep = lctx.n_keep;
+ const int n_head = hparams.n_head;
+ const int n_head_kv = hparams.n_head_kv;
+ const int n_vocab = hparams.n_vocab;
+ const int n_rot = hparams.n_embd_head_k;
+ const int head_dim = hparams.n_embd_head_k;
+ const int n_gqa_embd = head_dim * n_head_kv;
+
+ auto& mem_per_token = lctx.mem_per_token;
+ auto& buf_compute = lctx.buf_compute;
+
+ struct ne_init_params params = {
+ /*.mem_size =*/buf_compute.size,
+ /*.mem_buffer =*/buf_compute.addr,
+ /*.no_alloc =*/false,
+ };
+
+ struct ne_context* ctx0 = ne_init(params);
+
+ // for big progptneoxs, if BLAS is enabled, it is better to use only one thread
+ // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
+ ne_cgraph gf = {};
+ gf.n_threads = N >= 32 && ne_cpu_has_blas() ? 1 : n_threads;
+ const bool run_mha_reordered = kv_self.k->type == NE_TYPE_BTLA;
+ kv_cache_info_t kv_cache_info = {};
+ if (run_mha_reordered) {
+ NE_ASSERT(("kv cache should be the same dtype", kv_self.v->type == NE_TYPE_BTLA));
+ attn_shape_t attn_shape = {
+ /* .batch_size = */ 1,
+ /* .head_num = */ n_head,
+ /* .heads_kv = */ n_head_kv,
+ /* .head_size = */ head_dim,
+ /* .sl_q = */ N, // Note: make sure that bestla reordered attn supports next token inference
+ /* .sl_kv = */ n_past + N,
+ };
+
+ NE_ASSERT(("bestla managed kv-cache not supported; use `--memory-f16 / --memory-f32` instead",
+ bestla_reordered_attn_fp32_support(&attn_shape)));
+ kv_shape_t kv_shape{
+ /* .heads_kv = */ static_cast(n_head_kv),
+ /* .head_size = */ static_cast(head_dim),
+ /* .sl_kv_max = */ static_cast(n_ctx),
+ };
+ bestla_reordered_attn_fp32_batch_kv_info(&kv_shape, &kv_cache_info);
+ }
+ struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N * batch_size);
+ ne_set_name(embd, "embd");
+
+ for (int i = 0; i < batch_size; ++i) {
+ memcpy(static_cast(embd->data) + i * N, (inputs + i)->tokens, N * ne_element_size(embd));
+ }
+
+ struct ne_tensor* inpL = ne_get_rows(ctx0, model.others[0], embd);
+ inpL = ne_scale(ctx0, inpL, ne_new_f32(ctx0, sqrtf(static_cast(n_embd))));
+
+ for (int il = 0; il < n_layer; ++il) {
+ struct ne_tensor* cur;
+
+ lctx.use_buf(ctx0, 0);
+
+ {
+ // RMS
+ {
+ cur = ne_rms_norm(ctx0, inpL, hparams.norm_eps);
+ cur = ne_mul(ctx0, cur, model.layers[il].norm[0]);
+ ne_set_name(cur, "input_norm");
+ }
+
+ // compute QKV
+ struct ne_tensor* Kcur =
+ ne_reshape_3d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[1], cur), head_dim, n_head_kv, N);
+ ne_set_name(Kcur, "Kcur_matmul");
+ struct ne_tensor* Qcur =
+ ne_reshape_3d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[0], cur), head_dim, n_head, N);
+ ne_set_name(Qcur, "Qcur_matmul");
+ struct ne_tensor* Vcur =
+ ne_reshape_3d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[2], cur), head_dim, n_head_kv, N);
+ ne_set_name(Vcur, "Vcur_matmul");
+
+ // using mode = 2 for GPT-NeoX mode
+ Qcur = ne_rope_inplace(ctx0, Qcur, n_past, n_rot, 2, 0, hparams.freq_base, hparams.freq_scale);
+ ne_set_name(Qcur, "Qcur");
+ Qcur = ne_scale_inplace(ctx0, Qcur, ne_new_f32(ctx0, 1.0f / sqrt(static_cast((n_gqa_embd) / n_head_kv))));
+
+ Kcur = ne_rope_inplace(ctx0, Kcur, n_past, n_rot, 2, 0, hparams.freq_base, hparams.freq_scale);
+ ne_set_name(Kcur, "kcur");
+ const float attn_scale = 1.0f;
+ // store key and value to memory
+ if (!run_mha_reordered) {
+ {
+ std::vector Kcur_bs(batch_size);
+ std::vector Vcur_bs(batch_size);
+ std::vector k_bs(batch_size);
+ std::vector v_bs(batch_size);
+ for (int i = 0; i < batch_size; ++i) {
+ // batch K
+ Kcur_bs[i] =
+ ne_permute(ctx0,
+ ne_view_4d(ctx0, Kcur, head_dim, n_head_kv, N, 1, ne_element_size(Kcur) * head_dim,
+ ne_element_size(Kcur) * n_gqa_embd, ne_element_size(Kcur) * n_gqa_embd * N,
+ i * ne_element_size(Kcur) * n_gqa_embd * N),
+ 0, 2, 1, 3);
+ k_bs[i] = ne_view_4d(
+ ctx0, kv_self.k, head_dim, N, n_head_kv, 1, ne_element_size(kv_self.k) * head_dim,
+ ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_gqa_embd * n_ctx,
+ ((il * n_ctx) * ne_element_size(kv_self.k) * n_gqa_embd * kv_n_ctx_block +
+ i * n_ctx * n_gqa_embd * ne_element_size(kv_self.k) + head_dim * n_past * ne_element_size(kv_self.k)));
+
+ // batch V
+ Vcur_bs[i] =
+ ne_permute(ctx0,
+ ne_reshape_4d(ctx0,
+ ne_view_2d(ctx0, Vcur, n_gqa_embd, N, ne_element_size(Vcur) * n_gqa_embd,
+ i * ne_element_size(Vcur) * n_gqa_embd * N),
+ head_dim, n_head_kv, N, 1),
+ 1, 2, 0, 3);
+ v_bs[i] = ne_view_4d(
+ ctx0, kv_self.v, N, head_dim, n_head_kv, 1, n_ctx * ne_element_size(kv_self.v),
+ n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_gqa_embd,
+ ((il * n_ctx) * ne_element_size(kv_self.v) * n_gqa_embd * kv_n_ctx_block +
+ i * n_ctx * n_gqa_embd * ne_element_size(kv_self.v) + n_past * ne_element_size(kv_self.v)));
+ ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_bs[i], k_bs[i]));
+ ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_bs[i], v_bs[i]));
+ }
+ }
+ // Q = Qcur.contiguous().view(n_gqa_embd/n_head, n_head, N).permute(0, 2, 1, 3)
+ struct ne_tensor* Q = ne_permute(ctx0, ne_reshape_4d(ctx0, Qcur, head_dim, n_head, N, batch_size), 0, 2, 1, 3);
+
+ // K = Kmem.view(n_gqa_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
+ struct ne_tensor* K = ne_view_4d(
+ ctx0, kv_self.k, head_dim, n_past + N, n_head_kv, batch_size, ne_element_size(kv_self.k) * head_dim,
+ ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_gqa_embd * n_ctx,
+ il * n_ctx * ne_element_size(kv_self.k) * n_gqa_embd * kv_n_ctx_block);
+
+ // K * Q
+ struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q);
+
+ // KQ_scaled = KQ / sqrt(n_gqa_embd/n_head)
+ struct ne_tensor* KQ_scaled = ne_scale_inplace(ctx0, KQ, ne_new_f32(ctx0, 1.0f));
+
+ // KQ_masked = mask_past(KQ_scaled)
+ struct ne_tensor* KQ_masked = ne_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
+
+ // KQ = soft_max(KQ_masked)
+ struct ne_tensor* KQ_soft_max = ne_soft_max(ctx0, KQ_masked);
+
+ // V_trans = Vmem.view(n_gqa_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
+ struct ne_tensor* V =
+ ne_view_4d(ctx0, kv_self.v, n_past + N, head_dim, n_head_kv, batch_size, n_ctx * ne_element_size(kv_self.v),
+ n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_gqa_embd,
+ il * n_ctx * ne_element_size(kv_self.v) * n_gqa_embd * kv_n_ctx_block);
+
+ // KQV = transpose(V) * KQ_soft_max
+ struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max);
+
+ // KQV_merged = KQV.permute(0, 2, 1, 3)
+ struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3);
+
+ // cur = KQV_merged.contiguous().view(n_gqa_embd, N)
+ cur = ne_cpy(ctx0, KQV_merged,
+ ne_new_tensor_2d(ctx0, NE_TYPE_F32, head_dim * n_head, N * batch_size, NE_SIZE_CALC));
+ } else {
+ const auto seq_kv = n_past + N;
+ const auto k_size = kv_cache_info.k_bytes;
+ const auto v_size = kv_cache_info.v_bytes;
+
+ // store key and value to memory
+ {
+ const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor
+ head_dim, n_ctx, n_head_kv, // ne
+ 0, 0, // nb (bestla managed)
+ il * k_size); // offset
+ ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache, Kcur, n_past, false));
+ const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor
+ head_dim, n_ctx, n_head_kv, // ne
+ 0, 0, // nb (bestla managed)
+ il * v_size); // offset
+ ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, Vcur, n_past, false));
+ }
+
+ struct ne_tensor* Q = ne_permute(ctx0, Qcur, 0, 2, 1, 3);
+ ne_set_name(Q, "Q");
+
+ struct ne_tensor* K =
+ ne_view_3d(ctx0, kv_self.k, // tensor
+ head_dim, seq_kv, n_head_kv, // ne
+ kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, // nb (bestla managed)
+ il * k_size); // offset
+ *reinterpret_cast(&K->nb[0]) = kv_cache_info.k_layout; // us nb0 for layout
+ ne_set_name(K, "K");
+ struct ne_tensor* V =
+ ne_view_3d(ctx0, kv_self.v, // tensor
+ seq_kv, head_dim, n_head_kv, // ne
+ kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, // nb (bestla managed)
+ il * v_size); // offset
+ *reinterpret_cast(&V->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout
+ ne_set_name(V, "V");
+
+ ne_attn_flags_t attn_flags = 0;
+ if (n_past == 0) attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases
+ struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, Q, K, V, attn_scale, attn_flags);
+ cur = ne_view_2d(ctx0, KQV_Out, head_dim * n_head, N, head_dim * n_head * ne_element_size(KQV_Out), 0);
+ }
+ // projection
+ { cur = ne_mul_mat(ctx0, model.layers[il].attn[3], cur); }
+ }
+ lctx.use_buf(ctx0, 1);
+
+ cur = ne_add(ctx0, cur, inpL);
+ inpL = cur;
+
+ // this is independent of the self-attention result, so it could be done in parallel to the self-attention
+ // note here we pass inpL instead of cur
+ {
+ cur = ne_rms_norm(ctx0, cur, hparams.norm_eps);
+ cur = ne_mul(ctx0, cur, model.layers[il].norm[1]);
+
+ // struct ne_tensor* cur2 = ne_mul(ctx0, cur, model.layers[il].norm[1]);
+ // cur = ne_add(ctx0,cur, cur2);
+ }
+ cur = gemma_ff(model.layers[il], N, batch_size, ctx0, cur);
+
+ // input for next layer
+ inpL = ne_add(ctx0, cur, inpL);
+ }
+
+ lctx.use_buf(ctx0, 0);
+ // used at the end to optionally extract the embeddings
+ struct ne_tensor* embeddings = nullptr;
+ // norm
+ {
+ inpL = ne_rms_norm(ctx0, inpL, hparams.norm_eps);
+ inpL = ne_mul(ctx0, inpL, model.others[1]);
+ }
+ lctx.use_buf(ctx0, -1);
+ // hidden_states = self.ln_f(hidden_states)&lm_head
+ { inpL = ne_mul_mat(ctx0, model.others[0], inpL); }
+
+ // logits -> probs
+ // inpL = ne_soft_max_inplace(ctx0, inpL);
+
+ // run the computation
+ ne_build_forward_expand(&gf, inpL);
+ ne_graph_compute(ctx0, &gf);
+
+ if (ns_log_level() == 0 || ns_log_level() == 2) {
+ ne_graph_profiling(&gf);
+ }
+
+ // update kv token count
+ lctx.model.kv_self.n = n_past + N;
+
+ // extract logits
+ {
+ auto& logits_out = lctx.logits;
+
+ size_t bs_stride = n_vocab * N;
+ if (lctx.logits_all) {
+ logits_out.resize(n_vocab * N * batch_size);
+ for (int i = 0; i < batch_size; ++i) {
+ memcpy(logits_out.data() + i * bs_stride, reinterpret_cast(ne_get_data(inpL)) + (i * bs_stride),
+ sizeof(float) * n_vocab * N);
+ }
+ } else {
+ // return result for just the last token
+ logits_out.resize(n_vocab * batch_size);
+ for (int i = 0; i < batch_size; ++i) {
+ memcpy(logits_out.data() + (i * n_vocab),
+ reinterpret_cast(ne_get_data(inpL)) + (i * bs_stride) + (n_vocab * (N - 1)),
+ sizeof(float) * n_vocab);
+ }
+ }
+ }
+
+ // extract embeddings
+ if (!lctx.embedding.empty()) {
+ auto& embedding_out = lctx.embedding;
+
+ embedding_out.resize(n_embd);
+ memcpy(embedding_out.data(), reinterpret_cast(ne_get_data(embeddings)) + (n_embd * (N - 1)),
+ sizeof(float) * n_embd);
+ }
+
+ if (mem_per_token == 0) {
+ mem_per_token = ne_used_mem(ctx0) / N;
+ }
+
+ ne_free(ctx0);
+
+ // measure the performance only for the single-token evals
+ int64_t time_interval = ne_time_us() - t_start_us;
+ if (N == 1) {
+ lctx.t_eval_us += time_interval;
+ lctx.n_eval++;
+ } else if (N > 1) {
+ lctx.t_p_eval_us += time_interval;
+ lctx.n_p_eval += N;
+ }
+ lctx.eval_times.push_back(time_interval);
+
+ return true;
+}
+
+int model_eval(struct model_context* ctx, const model_input* inputs, const int n_input, int n_threads) {
+ if (!gemma_model_eval_internal(ctx, inputs, n_input, n_threads)) {
+ fprintf(stderr, "%s: failed to eval\n", __func__);
+ return 1;
+ }
+
+ // get a more accurate load time, upon first eval
+
+ if (!ctx->has_evaluated_once) {
+ ctx->t_load_us = ne_time_us() - ctx->t_start_us;
+ ctx->has_evaluated_once = true;
+ }
+
+ return 0;
+}
diff --git a/neural_speed/models/gemma/gemma.h b/neural_speed/models/gemma/gemma.h
new file mode 100644
index 000000000..cc733ae45
--- /dev/null
+++ b/neural_speed/models/gemma/gemma.h
@@ -0,0 +1,60 @@
+// Copyright (c) 2023 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GEMMA_H
+#define GEMMA_H
+
+#include "models/model_utils/model_files.h"
+#include "models/model_utils/model_types.h"
+
+enum gemma_model {
+ GEMMA_2B,
+ GEMMA_7B,
+};
+
+static const model_scratch gemma_mem_req(int n_layers, float enlarge_scale = 1.0f) {
+ switch (n_layers) {
+ case 18:
+ return {
+ static_cast(enlarge_scale * 1024) * MB,
+ static_cast(enlarge_scale * 1024) * MB,
+ static_cast(enlarge_scale * 1608) * MB,
+ };
+ case 28:
+ return {
+ static_cast(enlarge_scale * 1024) * MB,
+ static_cast(enlarge_scale * 1024) * MB,
+ static_cast(enlarge_scale * 1608) * MB,
+ };
+ default:
+ MODEL_ASSERT(false);
+ }
+}
+
+class Gemma : public IModel {
+ private:
+ model_archs arch = MODEL_GEMMA;
+ std::unique_ptr ml;
+ uint32_t n_layer, n_embd, n_ff, n_vocab, n_head, n_head_kv, n_expert, n_expert_used, n_embd_head_k;
+ int n_gpu_layer;
+ bool use_mmap, use_mlock, vocab_only;
+ model_scratch scratch;
+
+ public:
+ void init(const char* path_model, model_context* ctx, int n_gpu_layers, bool use_mmap_, bool use_mlock_,
+ bool vocab_only_) override;
+ void load(model_context* ctx, model_progress_callback progress_callback, void* progress_callback_user_data) override;
+};
+
+#endif // GEMMA_H
diff --git a/neural_speed/models/gemma/gemma_utils.cpp b/neural_speed/models/gemma/gemma_utils.cpp
new file mode 100644
index 000000000..d7c2c4637
--- /dev/null
+++ b/neural_speed/models/gemma/gemma_utils.cpp
@@ -0,0 +1,236 @@
+// Copyright (c) 2023 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "core/data_types.h"
+#include "core/ne.h"
+#include "core/ne_layers.h"
+#include "models/gemma/gemma.h"
+#include "models/model_utils/model_config.h"
+#include "models/model_utils/model_files.h"
+#include "models/model_utils/model_types.h"
+#include "models/model_utils/quant_utils.h"
+#include "models/model_utils/util.h"
+#include "models/models.h"
+
+void model_load_internal(const std::string& fname, model_archs arch, model_context* ctx, int n_gpu_layers,
+ bool use_mmap, bool use_mlock, bool vocab_only, model_progress_callback progress_callback,
+ void* progress_callback_user_data) {
+ std::unique_ptr ms(new Gemma());
+ ms->init(fname.c_str(), ctx, n_gpu_layers, use_mmap, use_mlock, vocab_only);
+ ms->load(ctx, progress_callback, progress_callback_user_data);
+
+ model_context& lctx = *ctx;
+ lctx.support_bestla_kv = true;
+}
+
+void Gemma::init(const char* path_model, model_context* ctx, int n_gpu_layer_, bool use_mmap_, bool use_mlock_,
+ bool vocab_only_) {
+ model_context& lctx = *ctx;
+ n_gpu_layer = n_gpu_layer_;
+ use_mmap = use_mmap_;
+ use_mlock = use_mlock_;
+ vocab_only = vocab_only_;
+ auto& model = lctx.model;
+ ml.reset(new model_model_loader(path_model, use_mmap, vocab_only));
+ lctx.vocab = std::move(ml->file_loaders.at(0)->vocab);
+ model.hparams = ml->file_loaders.at(0)->hparams;
+ model_file_version file_version = ml->file_loaders.at(0)->file_version;
+ auto& hparams = model.hparams;
+ n_ff = hparams.ffn_hidden_size;
+ fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);
+ fprintf(stderr, "%s: n_ctx = %u\n", __func__, hparams.max_seq_len);
+ fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd);
+ fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult);
+ fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head);
+ fprintf(stderr, "%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
+ fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
+ fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot);
+ fprintf(stderr, "%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
+ fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff);
+ fprintf(stderr, "%s: n_parts = %zu\n", __func__, ml->file_loaders.size());
+ n_embd = hparams.n_embd;
+ n_vocab = hparams.n_vocab;
+ n_layer = hparams.n_layer;
+ n_head_kv = hparams.n_head_kv;
+ n_embd_head_k = hparams.n_embd_head_k;
+ n_head = hparams.n_head;
+ n_expert = hparams.n_experts;
+ n_expert_used = hparams.n_experts_used;
+ scratch = gemma_mem_req(n_layer, lctx.scratch_size_ratio);
+ model.scratchs = scratch;
+}
+
+#define MODEL_BACKEND_OFFLOAD NE_BACKEND_CPU
+void Gemma::load(model_context* ctx, model_progress_callback progress_callback, void* progress_callback_user_data) {
+ model_context& lctx = *ctx;
+ auto& model = lctx.model;
+ auto& ne_ctx = model.ctx;
+ size_t ctx_size;
+ size_t mmapped_size;
+ ml->calc_sizes(&ctx_size, &mmapped_size);
+ fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
+
+ // create the ne context
+ lctx.model.buf.resize(ctx_size);
+ if (use_mlock) {
+ lctx.model.mlock_buf.init(lctx.model.buf.addr);
+ lctx.model.mlock_buf.grow_to(lctx.model.buf.size);
+ }
+
+ struct ne_init_params params = {
+ /*.mem_size =*/lctx.model.buf.size,
+ /*.mem_buffer =*/lctx.model.buf.addr,
+ /*.no_alloc =*/ml->use_mmap,
+ };
+
+ model.ctx = ne_init(params);
+ if (!model.ctx) {
+ throw format("ne_init() failed");
+ }
+
+ ml->ne_ctx = ne_ctx;
+
+ const int i_gpu_start = n_layer - n_gpu_layer;
+ model.layers.resize(n_layer);
+ size_t vram_total = 0;
+ if (ml->verify_tensor("token_embd.weight")) {
+ model.others[0] = ml->get_tensor("token_embd.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
+ model.others[1] = ml->get_tensor("output_norm.weight", {n_embd}, NE_BACKEND_CPU);
+ model.others[2] = model.others[0];
+
+ for (uint32_t i = 0; i < n_layer; ++i) {
+ const ne_backend backend = static_cast(i) < i_gpu_start ? NE_BACKEND_CPU : MODEL_BACKEND_OFFLOAD;
+ auto& layer = model.layers[i];
+ std::string layers_i = "blk." + std::to_string(i);
+
+ // attention norm
+ layer.norm[0] = ml->get_tensor(layers_i + ".attn_norm.weight", {n_embd}, backend);
+
+ // qkv GEMM
+ layer.attn[0] = ml->get_tensor(layers_i + ".attn_q.weight", {n_embd, n_embd_head_k * n_head}, backend);
+ layer.attn[1] = ml->get_tensor(layers_i + ".attn_k.weight", {n_embd, n_embd_head_k * n_head_kv}, backend);
+ layer.attn[2] = ml->get_tensor(layers_i + ".attn_v.weight", {n_embd, n_embd_head_k * n_head_kv}, backend);
+ layer.attn[3] = ml->get_tensor(layers_i + ".attn_output.weight", {n_embd_head_k * n_head, n_embd}, backend);
+
+ // ffn norm
+ layer.norm[1] = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend);
+
+ // ffn GEMM
+ if (ml->verify_tensor(layers_i + ".ffn_gate.weight")) {
+ NE_ASSERT(n_expert == 0);
+ NE_ASSERT(n_expert_used == 0);
+ layer.ffn[0] = ml->get_tensor(layers_i + ".ffn_gate.weight", {n_embd, n_ff}, backend);
+ layer.ffn[1] = ml->get_tensor(layers_i + ".ffn_down.weight", {n_ff, n_embd}, backend);
+ layer.ffn[2] = ml->get_tensor(layers_i + ".ffn_up.weight", {n_embd, n_ff}, backend);
+ }
+ if (backend != NE_BACKEND_CPU) {
+ vram_total += ne_nbytes(layer.norm[0]) + ne_nbytes(layer.attn[0]) + ne_nbytes(layer.attn[1]) +
+ ne_nbytes(layer.attn[2]) + ne_nbytes(layer.attn[3]) + ne_nbytes(layer.norm[1]) +
+ ne_nbytes(layer.ffn[0]) + ne_nbytes(layer.ffn[1]) + ne_nbytes(layer.ffn[2]);
+ }
+ }
+ } else {
+ model.others[0] = ml->get_tensor("model.embed_tokens.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
+ model.others[1] = ml->get_tensor("model.norm.weight", {n_embd}, NE_BACKEND_CPU);
+ model.others[2] = ml->get_tensor("lm_head.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
+
+ for (uint32_t i = 0; i < n_layer; ++i) {
+ const ne_backend backend = static_cast(i) < i_gpu_start ? NE_BACKEND_CPU : MODEL_BACKEND_OFFLOAD;
+ auto& layer = model.layers[i];
+ std::string layers_i = "model.layers." + std::to_string(i);
+
+ // attention norm
+ layer.norm[0] = ml->get_tensor(layers_i + ".input_layernorm.weight", {n_embd}, backend);
+
+ // qkv GEMM
+ layer.attn[0] = ml->get_tensor(layers_i + ".self_attn.q_proj.weight", {n_embd, n_embd_head_k * n_head}, backend);
+ layer.attn[1] =
+ ml->get_tensor(layers_i + ".self_attn.k_proj.weight", {n_embd, n_embd_head_k * n_head_kv}, backend);
+ layer.attn[2] =
+ ml->get_tensor(layers_i + ".self_attn.v_proj.weight", {n_embd, n_embd_head_k * n_head_kv}, backend);
+ layer.attn[3] = ml->get_tensor(layers_i + ".self_attn.o_proj.weight", {n_embd_head_k * n_head, n_embd}, backend);
+
+ // ffn norm
+ layer.norm[1] = ml->get_tensor(layers_i + ".post_attention_layernorm.weight", {n_embd}, backend);
+
+ // ffn GEMM
+ layer.ffn[0] = ml->get_tensor(layers_i + ".mlp.gate_proj.weight", {n_embd, n_ff}, backend);
+ layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.down_proj.weight", {n_ff, n_embd}, backend);
+ layer.ffn[2] = ml->get_tensor(layers_i + ".mlp.up_proj.weight", {n_embd, n_ff}, backend);
+ if (backend != NE_BACKEND_CPU) {
+ vram_total += ne_nbytes(layer.norm[0]) + ne_nbytes(layer.attn[0]) + ne_nbytes(layer.attn[1]) +
+ ne_nbytes(layer.attn[2]) + ne_nbytes(layer.attn[3]) + ne_nbytes(layer.norm[1]) +
+ ne_nbytes(layer.ffn[0]) + ne_nbytes(layer.ffn[1]) + ne_nbytes(layer.ffn[2]);
+ }
+ }
+ }
+
+ // print memory requirements
+ // this is the total memory required to run the inference
+ const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory
+ scratch.scratch0 + scratch.scratch1 + scratch.eval;
+ fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0);
+
+ (void)n_gpu_layer;
+
+ // populate `tensors_by_name`
+ for (model_load_tensor& lt : ml->tensors_map.tensors) {
+ model.tensors_by_name.emplace_back(lt.name, lt.ne_tensor);
+ }
+
+ ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : nullptr);
+
+ if (progress_callback) {
+ progress_callback(1.0f, progress_callback_user_data);
+ }
+
+ model.mapping = std::move(ml->mapping);
+}
+
+#undef MODEL_BACKEND_OFFLOAD
+
+class gemma_quant_layer : public quant_layer_base {
+ public:
+ quant_params_internal get_layer_config(std::string layername, std::vector ne, ne_type type) override {
+ bool quantize = layername.rfind("weight") == layername.size() - 6;
+ if ((layername.find("embedding") != std::string::npos) ||
+ (layername == "token_embd.weight" || layername == "model.embed_tokens.weight")) {
+ // special layer process, can be loaded by config file
+ return quant_params_internal{quant_bits::q8}; // q80
+ }
+ quantize &= (ne.size() == 2);
+ if (quantize) {
+ return mGCfg; // use global quant config
+ } else {
+ return quant_params_internal{quant_bits::count}; // non-quant
+ }
+ }
+};
+REGISTER_QUANT_LAYER_CLASS(gemma);
diff --git a/neural_speed/models/model_utils/gguf.h b/neural_speed/models/model_utils/gguf.h
index 22d0435ca..791fb97a9 100644
--- a/neural_speed/models/model_utils/gguf.h
+++ b/neural_speed/models/model_utils/gguf.h
@@ -232,6 +232,7 @@ enum llm_arch {
LLM_ARCH_CHATGLM2,
LLM_ARCH_CHATGLM3,
LLM_ARCH_PHI,
+ LLM_ARCH_GEMMA,
LLM_ARCH_QWEN2,
LLM_ARCH_UNKNOWN,
};
@@ -253,6 +254,7 @@ static std::map LLM_ARCH_NAMES = {{LLM_ARCH_LLAMA, "llama
{LLM_ARCH_CHATGLM2, "chatglm2"},
{LLM_ARCH_CHATGLM3, "chatglm3"},
{LLM_ARCH_PHI, "phi"},
+ {LLM_ARCH_GEMMA, "gemma"},
{LLM_ARCH_QWEN2, "qwen2"}};
struct gguf_tensor_info {
@@ -432,6 +434,8 @@ enum llm_kv {
LLM_KV_ATTENTION_HEAD_COUNT_KV,
LLM_KV_ATTENTION_MAX_ALIBI_BIAS,
LLM_KV_ATTENTION_CLAMP_KQV,
+ LLM_KV_ATTENTION_KEY_LENGTH,
+ LLM_KV_ATTENTION_VALUE_LENGTH,
LLM_KV_ATTENTION_LAYERNORM_EPS,
LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
LLM_KV_NUM_EXPERTS,
@@ -486,6 +490,8 @@ static std::map LLM_KV_NAMES = {
{LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv"},
{LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias"},
{LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv"},
+ {LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length"},
+ {LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length"},
{LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon"},
{LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon"},
diff --git a/neural_speed/models/model_utils/model_files.h b/neural_speed/models/model_utils/model_files.h
index b702afc52..b586d0ea5 100644
--- a/neural_speed/models/model_utils/model_files.h
+++ b/neural_speed/models/model_utils/model_files.h
@@ -911,6 +911,8 @@ struct gguf_loader {
GGUF_GET_KEY(ctx_gguf, hparams.n_experts, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_NUM_EXPERTS));
GGUF_GET_KEY(ctx_gguf, hparams.n_experts_used, gguf_get_val_u32, GGUF_TYPE_UINT32, false,
kv(LLM_KV_NUM_EXPERTS_USED));
+ GGUF_GET_KEY(ctx_gguf, hparams.n_embd_head_k, gguf_get_val_u32, GGUF_TYPE_UINT32, false,
+ kv(LLM_KV_ATTENTION_KEY_LENGTH));
GGUF_GET_KEY(ctx_gguf, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_BLOCK_COUNT));
GGUF_GET_KEY(ctx_gguf, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT));
@@ -1115,11 +1117,14 @@ struct model_file_loader {
hparams.inner_hidden_size = file.read_u32();
hparams.n_experts = file.read_u32();
hparams.n_experts_used = file.read_u32();
+ hparams.n_embd_head_k = file.read_u32();
printf("%-16s %d.hparams.inner_hidden_size = %-30d\n", __func__, count++, hparams.inner_hidden_size);
printf("%-16s %d.hparams.n_experts = %-30d\n", __func__, count++, hparams.n_experts);
printf("%-16s %d.hparams.n_experts_used = %-30d\n", __func__, count++, hparams.n_experts_used);
+ printf("%-16s %d.hparams.n_embd_head_k = %-30d\n", __func__, count++, hparams.n_embd_head_k);
file.read_raw(&hparams.norm_eps, sizeof(float));
+
file.read_raw(&hparams.freq_base, sizeof(float));
file.read_raw(&hparams.freq_scale, sizeof(float));
printf("%-16s %d.hparams.norm_eps = %-30f\n", __func__, count++, hparams.norm_eps);
@@ -1262,6 +1267,7 @@ struct model_file_saver {
file.write_u32(hparams.inner_hidden_size);
file.write_u32(hparams.n_experts);
file.write_u32(hparams.n_experts_used);
+ file.write_u32(hparams.n_embd_head_k);
file.write_raw(&hparams.norm_eps, sizeof(float));
file.write_raw(&hparams.freq_base, sizeof(float));
diff --git a/neural_speed/models/model_utils/model_types.h b/neural_speed/models/model_utils/model_types.h
index 3575a7e2b..454fbbb34 100644
--- a/neural_speed/models/model_utils/model_types.h
+++ b/neural_speed/models/model_utils/model_types.h
@@ -84,6 +84,7 @@ enum model_archs {
MODEL_CHATGLM,
MODEL_QWEN,
MODEL_PHI,
+ MODEL_GEMMA,
MODEL_STABLELM,
MODEL_WHISPER
};
@@ -144,6 +145,7 @@ struct model_hparams {
int32_t inner_hidden_size = 0;
uint32_t n_experts = 0;
uint32_t n_experts_used = 0;
+ uint32_t n_embd_head_k = 0;
float rope_scaling_factor = 0.0f;
int32_t original_max_position_embeddings = 0;
@@ -486,7 +488,8 @@ class model_name_to_arch {
{"falcon", MODEL_FALCON}, {"bloom", MODEL_BLOOM}, {"chatglm2", MODEL_CHATGLM2},
{"chatglm", MODEL_CHATGLM}, {"baichuan", MODEL_BAICHUAN}, {"mistral", MODEL_LLAMA},
{"qwen", MODEL_QWEN}, {"phi", MODEL_PHI}, {"stablelm", MODEL_STABLELM},
- {"whisper", MODEL_WHISPER}, {"chatglm3", MODEL_CHATGLM3}, {"mixtral", MODEL_LLAMA}};
+ {"whisper", MODEL_WHISPER}, {"chatglm3", MODEL_CHATGLM3}, {"mixtral", MODEL_LLAMA},
+ {"gemma", MODEL_GEMMA}};
};
#ifdef __cplusplus
diff --git a/neural_speed/models/model_utils/model_utils.cpp b/neural_speed/models/model_utils/model_utils.cpp
index bb5b65fa7..266f9486b 100644
--- a/neural_speed/models/model_utils/model_utils.cpp
+++ b/neural_speed/models/model_utils/model_utils.cpp
@@ -63,7 +63,8 @@ static bool kv_cache_init(const struct model_hparams& hparams, struct model_kv_c
const bool shift_roped_k, model_struct* model) {
const auto n_layer = hparams.n_layer;
auto heads_kv = hparams.n_head_kv > 0 ? hparams.n_head_kv : hparams.n_head;
- const auto head_size = hparams.n_embd / hparams.n_head;
+ const auto head_size = hparams.n_embd_head_k == 0 ? hparams.n_embd / hparams.n_head : hparams.n_embd_head_k;
+
#ifdef NS_TP_MODEL
// when use TP, cached kv will also have smaller size
parallel_context* p_ctx = init_parallel_context();
@@ -103,7 +104,7 @@ static bool kv_cache_init(const struct model_hparams& hparams, struct model_kv_c
auto& k_cache = model->layers[il].k_cache;
auto& v_cache = model->layers[il].v_cache;
if (wtype == NE_TYPE_F16) { // chatglm does not support fp32 kv-cache in original impl of chatglm_util.cpp
- const int head_size = hparams.n_embd / hparams.n_head;
+ const auto head_size = hparams.n_embd_head_k == 0 ? hparams.n_embd / hparams.n_head : hparams.n_embd_head_k;
const int heads_kv = hparams.multi_query_group_num > 0 ? hparams.multi_query_group_num : hparams.n_head;
k_cache = d_ne_new_tensor_4d(model->ctx, NE_TYPE_F16, head_size, n_ctx, heads_kv, batch_size * beam_size);
v_cache = d_ne_new_tensor_4d(model->ctx, NE_TYPE_F16, n_ctx, head_size, heads_kv, batch_size * beam_size);
@@ -1833,7 +1834,7 @@ static void bestla_model_kv_cache_seq_cpy(struct model_context* ctx, const model
const auto& kv_self = ctx->model.kv_self;
const auto& hparams = ctx->model.hparams;
int heads_kv = hparams.multi_query_group_num > 0 ? hparams.multi_query_group_num : hparams.n_head;
- const int head_size = hparams.n_embd / hparams.n_head;
+ const auto head_size = hparams.n_embd_head_k == 0 ? hparams.n_embd / hparams.n_head : hparams.n_embd_head_k;
#ifdef NS_TP_MODEL
// when use TP, cached kv will also have smaller size
parallel_context* p_ctx = init_parallel_context();
@@ -1857,7 +1858,7 @@ static void bestla_model_kv_cache_seq_cpy(struct model_context* ctx, const model
/* .src = */ nullptr,
/* .dst = */ nullptr,
/* .heads_kv = */ heads_kv,
- /* .head_size = */ head_size,
+ /* .head_size = */ static_cast(head_size),
/* .seq_off = */ p0,
/* .seq_size = */ p1 - p0,
/* .seq_max = */ n_ctx,
diff --git a/tests/model-test/cpp_graph_inference.sh b/tests/model-test/cpp_graph_inference.sh
index 7702f3db6..97faa47a0 100644
--- a/tests/model-test/cpp_graph_inference.sh
+++ b/tests/model-test/cpp_graph_inference.sh
@@ -159,6 +159,7 @@ model_name_map["phi2"]="microsoft/phi-2"
model_name_map["stablelm"]="stabilityai/stablelm-2-1_6b"
model_name_map["qwen-1_5"]="Qwen/Qwen1.5-7B-Chat"
model_name_map["mixtral"]="mistralai/Mixtral-8x7B-Instruct-v0.1"
+model_name_map["gemma-2b"]="google/gemma-2b-it"
model_name_map["mixtral-gptq"]="Mixtral-8x7B-Instruct-v0.1-GPTQ"
model_name_map["qwen1.5-gptq"]="Qwen/Qwen1.5-7B-Chat-GPTQ"
model_name_map["qwen-gptq"]="TheBloke/Qwen-7B-Chat-GPTQ"
@@ -167,6 +168,7 @@ model_name_map["falcon7b-gptq"]="Falcon-7B-Instruct-GPTQ"
model_name_map["baichuan13b-gptq"]="Baichuan2-13B-Chat-GPTQ"
model_name_map["mistral-gptq"]="TheBloke/Mistral-7B-Instruct-v0.2-GPTQ"
+
function main() {
conda_env="$1"
model="$2"
@@ -297,6 +299,10 @@ function main() {
quant_script="./build/bin/quant_mixtral"
convert_script="${convert_script}/convert_mixtral.py"
infer_cmd="./build/bin/run_mixtral"
+ elif [[ "${model}" == "gemma-2b" ]]; then
+ quant_script="./build/bin/quant_gemma"
+ convert_script="${convert_script}/convert_gemma.py"
+ infer_cmd="./build/bin/run_gemma"
elif [[ "${model}" == *"-gptq" ]]; then
infer_cmd="python $working_dir/scripts/python_api_example_for_gptq.py ${model_path}"
precision_list+=("default")
@@ -440,7 +446,7 @@ function main() {
else
real_ctx=$ctx # TODO(Zhenzhong): use same ctx for chatglm & baichuan
[[ "${model}" == "chatglm2" || "${model}" == "chatglm-6b" ||
- "${model}" == "baichuan-13b" || "${model}" == "baichuan2-13b" ]] && real_ctx=1300
+ "${model}" == "baichuan-13b" || "${model}" == "baichuan2-13b" ]] && real_ctx=2048
if [[ "${model}" == *"gptq" ]]; then
NEURAL_SPEED_VERBOSE=1 OMP_NUM_THREADS=$cores_per_instance numactl -m 0 -C 0-$(($cores_per_instance - 1)) $infer_cmd 2>&1 | tee ${WORKSPACE}/${logs_file} || true &
else
|