From e4c5f7142ec83c8ce596a907d69ea4b617e7481f Mon Sep 17 00:00:00 2001 From: intellinjun <105184542+intellinjun@users.noreply.github.com> Date: Fri, 22 Mar 2024 15:39:19 +0800 Subject: [PATCH] Gemma-7b&&Gemma-2b (#171) --- docs/supported_models.md | 16 +- neural_speed/__init__.py | 2 + neural_speed/application/CMakeLists.txt | 5 + neural_speed/application/main_pybind.cpp | 3 + neural_speed/application/main_run.cpp | 2 +- neural_speed/convert/convert_baichuan.py | 1 + neural_speed/convert/convert_bloom.py | 1 + neural_speed/convert/convert_chatglm.py | 5 +- neural_speed/convert/convert_dolly.py | 1 + neural_speed/convert/convert_falcon.py | 1 + neural_speed/convert/convert_gemma.py | 195 ++++++++ neural_speed/convert/convert_gptj.py | 1 + neural_speed/convert/convert_gptneox.py | 1 + neural_speed/convert/convert_llama.py | 1 + neural_speed/convert/convert_mistral.py | 1 + neural_speed/convert/convert_mixtral.py | 1 + neural_speed/convert/convert_mpt.py | 1 + neural_speed/convert/convert_opt.py | 1 + neural_speed/convert/convert_phi.py | 1 + .../convert/convert_quantized_baichuan.py | 1 + .../convert/convert_quantized_falcon.py | 1 + .../convert/convert_quantized_gptj.py | 1 + .../convert/convert_quantized_llama.py | 1 + .../convert/convert_quantized_mistral.py | 1 + .../convert/convert_quantized_mixtral.py | 1 + neural_speed/convert/convert_quantized_phi.py | 1 + .../convert/convert_quantized_qwen.py | 1 + neural_speed/convert/convert_qwen.py | 1 + neural_speed/convert/convert_stablelm.py | 1 + neural_speed/convert/convert_starcoder.py | 1 + neural_speed/core/layers/Ops.h | 4 +- neural_speed/core/layers/ip_fusion_ffn.cpp | 15 + neural_speed/core/ne_bestla.h | 7 +- neural_speed/core/ne_layers.c | 55 ++- neural_speed/core/ne_layers.h | 3 + neural_speed/models/CMakeLists.txt | 1 + neural_speed/models/gemma/gemma.cpp | 416 ++++++++++++++++++ neural_speed/models/gemma/gemma.h | 60 +++ neural_speed/models/gemma/gemma_utils.cpp | 236 ++++++++++ neural_speed/models/model_utils/gguf.h | 6 + neural_speed/models/model_utils/model_files.h | 6 + neural_speed/models/model_utils/model_types.h | 5 +- .../models/model_utils/model_utils.cpp | 9 +- tests/model-test/cpp_graph_inference.sh | 8 +- 44 files changed, 1069 insertions(+), 13 deletions(-) create mode 100644 neural_speed/convert/convert_gemma.py create mode 100644 neural_speed/models/gemma/gemma.cpp create mode 100644 neural_speed/models/gemma/gemma.h create mode 100644 neural_speed/models/gemma/gemma_utils.cpp diff --git a/docs/supported_models.md b/docs/supported_models.md index 03c63e9ed..df8135677 100644 --- a/docs/supported_models.md +++ b/docs/supported_models.md @@ -276,7 +276,21 @@ Neural Speed supports the following models: Latest 2048 - + + gemma-2b-it , + gemma-7b + ✅ + + + + ✅ + + + + Latest + 8192 + + Whisper-tiny, Whisper-base Whisper-small diff --git a/neural_speed/__init__.py b/neural_speed/__init__.py index 22566b2e1..c6b706d4b 100644 --- a/neural_speed/__init__.py +++ b/neural_speed/__init__.py @@ -70,6 +70,8 @@ def __import_package(self, model_type): import neural_speed.qwen_cpp as cpp_model elif model_type == "phi": import neural_speed.phi_cpp as cpp_model + elif model_type == "gemma": + import neural_speed.gemma_cpp as cpp_model elif model_type == "stablelm": import neural_speed.stablelm_cpp as cpp_model elif model_type == "whisper": diff --git a/neural_speed/application/CMakeLists.txt b/neural_speed/application/CMakeLists.txt index 2a34f5db3..ca659e1b1 100644 --- a/neural_speed/application/CMakeLists.txt +++ b/neural_speed/application/CMakeLists.txt @@ -71,6 +71,7 @@ compile_quant(quant_mistral quant_model.cpp mistral llama) compile_quant(quant_mixtral quant_model.cpp mixtral llama) compile_quant(quant_qwen quant_model.cpp qwen qwen) compile_quant(quant_phi quant_model.cpp phi phi) +compile_quant(quant_gemma quant_model.cpp gemma gemma) compile_quant(quant_stablelm quant_model.cpp stablelm stablelm) compile_quant(quant_whisper quant_whisper.cpp whisper whisper) @@ -99,6 +100,8 @@ set(mymap_stablelm 17) set(mymap_whisper 18) set(mymap_mixtral 19) set(mymap_chatglm3 20) +set(mymap_gemma 21) + function(compile_run TARGET MAIN_CPP MAIN_PY MODEL_NAME MODEL_LIB) @@ -135,8 +138,10 @@ compile_run(run_baichuan main_run.cpp main_pybind.cpp baichuan baichuan) compile_run(run_mistral main_run.cpp main_pybind.cpp mistral llama) compile_run(run_qwen main_run.cpp main_pybind.cpp qwen qwen) compile_run(run_phi main_run.cpp main_pybind.cpp phi phi) +compile_run(run_gemma main_run.cpp main_pybind.cpp gemma gemma) compile_run(run_stablelm main_run.cpp main_pybind.cpp stablelm stablelm) compile_run(run_mixtral main_run.cpp main_pybind.cpp mixtral llama) + # speech recognition compile_run(run_whisper audio_run.cpp whisper_pybind.cpp whisper whisper) diff --git a/neural_speed/application/main_pybind.cpp b/neural_speed/application/main_pybind.cpp index eb3fb8777..d95934401 100644 --- a/neural_speed/application/main_pybind.cpp +++ b/neural_speed/application/main_pybind.cpp @@ -924,6 +924,9 @@ PYBIND11_MODULE(mixtral_cpp, m) #elif MODEL_NAME_ID == 20 PYBIND11_MODULE(chatglm3_cpp, m) +#elif MODEL_NAME_ID == 21 + +PYBIND11_MODULE(gemma_cpp, m) #endif { diff --git a/neural_speed/application/main_run.cpp b/neural_speed/application/main_run.cpp index c784348eb..6872d1133 100644 --- a/neural_speed/application/main_run.cpp +++ b/neural_speed/application/main_run.cpp @@ -229,7 +229,7 @@ int main(int argc, char** argv) { // NOLINT // tokenize the prompt bool add_bos = false; - if (params.model_arch == MODEL_LLAMA) { + if (params.model_arch == MODEL_LLAMA || params.model_arch == MODEL_GEMMA) { add_bos = true; } diff --git a/neural_speed/convert/convert_baichuan.py b/neural_speed/convert/convert_baichuan.py index b6893e310..1fa35805e 100644 --- a/neural_speed/convert/convert_baichuan.py +++ b/neural_speed/convert/convert_baichuan.py @@ -157,6 +157,7 @@ def baichuan13B_convert(model, tokenizer, dir_model, fname_out, ftype, hparams): fout.write(struct.pack("i", hparams["intermediate_size"])) fout.write(struct.pack("i", 0)) # n_experts fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms_norm_eps or layer_norm_eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor diff --git a/neural_speed/convert/convert_bloom.py b/neural_speed/convert/convert_bloom.py index 0923c0e75..7a9609ddf 100644 --- a/neural_speed/convert/convert_bloom.py +++ b/neural_speed/convert/convert_bloom.py @@ -106,6 +106,7 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) # n_experts fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma fout.write(struct.pack("f", hparams.get("layer_norm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor diff --git a/neural_speed/convert/convert_chatglm.py b/neural_speed/convert/convert_chatglm.py index be84facd6..6aa41e5d5 100644 --- a/neural_speed/convert/convert_chatglm.py +++ b/neural_speed/convert/convert_chatglm.py @@ -156,8 +156,6 @@ def chatglm3_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams gguf_file = fname_out gguf_writer = gguf.GGUFWriter(gguf_file, "chatglm3") gguf_writer.add_uint32('magic', 0x67676d66) - import pdb - pdb.set_trace() gguf_writer.add_uint32('version', 1) gguf_writer.add_uint32('n_vocab', hparams["padded_vocab_size"]) gguf_writer.add_embedding_length(hparams["hidden_size"]) @@ -561,6 +559,7 @@ def chatglm3_convert(model, tokenizer, dir_model, fname_out, ftype, hparams): fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) # n_experts fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma fout.write(struct.pack("f", hparams.get("layernorm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor @@ -711,6 +710,7 @@ def chatglm2_convert(model, tokenizer, dir_model, fname_out, ftype, hparams): fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) # n_experts fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma fout.write(struct.pack("f", hparams.get("layernorm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor @@ -862,6 +862,7 @@ def chatglm1_convert(model, tokenizer, dir_model, fname_out, ftype, hparams): fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) # n_experts fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma fout.write(struct.pack("f", hparams.get("layernorm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps fout.write(struct.pack("f", 10000.0)) # freq_base diff --git a/neural_speed/convert/convert_dolly.py b/neural_speed/convert/convert_dolly.py index 87d711b0f..ebab9d118 100644 --- a/neural_speed/convert/convert_dolly.py +++ b/neural_speed/convert/convert_dolly.py @@ -120,6 +120,7 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) # n_experts fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma fout.write(struct.pack("f", hparams.get("layer_norm_eps", 1e-5))) # rms_norm_eps or layer_norm_eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor diff --git a/neural_speed/convert/convert_falcon.py b/neural_speed/convert/convert_falcon.py index 955d02547..f1d82daf3 100644 --- a/neural_speed/convert/convert_falcon.py +++ b/neural_speed/convert/convert_falcon.py @@ -113,6 +113,7 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) # n_experts fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma fout.write(struct.pack("f", hparams.get("layer_norm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor diff --git a/neural_speed/convert/convert_gemma.py b/neural_speed/convert/convert_gemma.py new file mode 100644 index 000000000..f0b61df1a --- /dev/null +++ b/neural_speed/convert/convert_gemma.py @@ -0,0 +1,195 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Convert Hugging Face fine-tuned gpt-neox-like models to ne format +# +# Usage: +# +# python3 models/convert-h5-to-ne.py +# +# This script is similar to "convert-pt-to-ne.py" +# + +import io +import os +import sys +import struct +import json +import code +import torch +import numpy as np +from pathlib import Path +import argparse +from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar, + Union) + + + +# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def main(args_in: Optional[List[str]] = None) -> None: + parser = argparse.ArgumentParser(description="Convert a model to a NE compatible file") + parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)") + parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") + parser.add_argument("--model_hub", choices=["huggingface","modelscope"], + default="huggingface", help="hub to load model") + parser.add_argument("model", type=Path, help="directory containing model file") + args = parser.parse_args(args_in) + + dir_model = args.model.as_posix() + fname_out = args.outfile.as_posix() + + # possible data types + # ftype == 0 -> float32 + # ftype == 1 -> float16 + ftype = 0 + if args.outtype == "f16": + ftype = 1 + if args.model_hub == "modelscope": + from modelscope import AutoModelForCausalLM, AutoTokenizer + else: + from transformers import AutoModelForCausalLM, AutoTokenizer + print("Loading model: ", dir_model) + model = AutoModelForCausalLM.from_pretrained(dir_model) + tokenizer = AutoTokenizer.from_pretrained(dir_model) + model.eval() + for p in model.parameters(): + p.requires_grad = False + hparams = model.config.to_dict() + print("Model loaded: ", dir_model) + + fout = open(fname_out, "wb") + + # 0x67676d6c is unversioned ne + # 0x67676d66 is versioned ggmf (requires token scores) + ne_file_magic = 0x67676d66 + #ne_file_version = 0x00000001 # v1 + + fout.write(struct.pack("i", ne_file_magic)) # magic: ne in hex + fout.write(struct.pack("i", 1)) + fout.write(struct.pack("i", hparams["vocab_size"])) + fout.write(struct.pack("i", hparams["hidden_size"])) + fout.write(struct.pack("i", hparams["intermediate_size"])) # dummy data + fout.write(struct.pack("i", hparams["num_attention_heads"])) + fout.write(struct.pack("i", hparams["num_key_value_heads"])) # multi-query attention + fout.write(struct.pack("i", hparams["num_hidden_layers"])) + fout.write(struct.pack("i", hparams["head_dim"])) + fout.write(struct.pack("i", ftype)) + fout.write( + struct.pack("i", hparams["seq_length"] if "seq_length" in hparams else hparams["max_position_embeddings"])) + fout.write(struct.pack("f", 0.0)) + fout.write(struct.pack("f", 0.0)) + fout.write(struct.pack("i", 0)) + fout.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt) + fout.write(struct.pack("i", 0)) # do_layer_norm_before (for opt) + + fout.write(struct.pack("i", 0)) + fout.write(struct.pack("i", hparams["intermediate_size"])) + fout.write(struct.pack("i", 0)) + fout.write(struct.pack("i", 0)) # n_experts + fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("i", hparams["head_dim"])) # n_embd_head_k + fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps + fout.write(struct.pack("f", 10000.0)) # freq_base + fout.write(struct.pack("f", 1.0)) # rope_factor + + fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled + fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings + fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0)) + fout.write(struct.pack("i", hparams["bos_token_id"])) + fout.write(struct.pack("i", hparams["eos_token_id"])) + fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1)) + fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1)) + + for i in range(hparams["vocab_size"]): + if i < tokenizer.vocab_size: + text = tokenizer.decode([i]).encode('utf-8') + fout.write(struct.pack("i", len(text))) + fout.write(text) + fout.write(struct.pack("f", 0.0 - i)) + else: + text = tokenizer.decode([tokenizer.vocab_size - 1]).encode('utf-8') + fout.write(struct.pack("i", len(text))) + fout.write(text) + fout.write(struct.pack("f", -10000)) + + list_vars = model.state_dict() + + print(hparams) + + for name in list_vars.keys(): + # No gradients for these + list_vars[name].requires_grad = False + src = name + nn = name + + print(src, ' -> ', name) + data = list_vars[src].squeeze().numpy() + data = data.astype(np.float32) + + n_dims = len(data.shape) + print(name, n_dims, data.shape) + + # default type is fp32 + ftype_cur = 0 + if ftype == 1 and n_dims > 1: + print(" Converting to float16", data.shape, data[:3, :3].tolist()) + data = data.astype(np.float16) + ftype_cur = 1 + else: + print(" Converting to float32", data.shape, data[:3, :3].tolist() if n_dims > 1 else data[:3].tolist()) + data = data.astype(np.float32) + # gemma_rms: + # output = self._norm(x.float()).type_as(x) + # return output * (1 + self.weight) + if "norm" in name: + data = data + 1 + str = name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(str), ftype_cur)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + print(str) + fout.write(str) + + # data + data.tofile(fout) + + fout.close() + + print("Done. Output file: " + fname_out) + print("") + + +if __name__ == '__main__': + main() diff --git a/neural_speed/convert/convert_gptj.py b/neural_speed/convert/convert_gptj.py index 3fb12c8f6..2ff15c255 100644 --- a/neural_speed/convert/convert_gptj.py +++ b/neural_speed/convert/convert_gptj.py @@ -105,6 +105,7 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) # n_experts fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma fout.write(struct.pack("f", hparams.get("layer_norm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor diff --git a/neural_speed/convert/convert_gptneox.py b/neural_speed/convert/convert_gptneox.py index 322f47a01..9cbda72fa 100644 --- a/neural_speed/convert/convert_gptneox.py +++ b/neural_speed/convert/convert_gptneox.py @@ -121,6 +121,7 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) # n_experts fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma fout.write(struct.pack("f", hparams.get("layer_norm_eps", 1e-5))) # rms_norm_eps or layer_norm_eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor diff --git a/neural_speed/convert/convert_llama.py b/neural_speed/convert/convert_llama.py index 37af73ab4..328994938 100644 --- a/neural_speed/convert/convert_llama.py +++ b/neural_speed/convert/convert_llama.py @@ -1090,6 +1090,7 @@ def write_file_header(self, params: Params, file_type: NEFileType) -> None: self.fout.write(struct.pack("i", 0)) # n_experts self.fout.write(struct.pack("i", 0)) # n_expert_used + self.fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma self.fout.write(struct.pack("f", params.rms_norm_eps)) self.fout.write(struct.pack("f", params.rope_theta)) self.fout.write(struct.pack("f", params.rope_scale)) diff --git a/neural_speed/convert/convert_mistral.py b/neural_speed/convert/convert_mistral.py index 1889a9860..a1f12b6a8 100644 --- a/neural_speed/convert/convert_mistral.py +++ b/neural_speed/convert/convert_mistral.py @@ -1064,6 +1064,7 @@ def write_file_header(self, params: Params, file_type: NEFileType) -> None: self.fout.write(struct.pack("i", 0)) # n_experts self.fout.write(struct.pack("i", 0)) # n_expert_used + self.fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma self.fout.write(struct.pack("f", params.rms_norm_eps)) self.fout.write(struct.pack("f", params.rope_theta)) self.fout.write(struct.pack("f", params.rope_scale)) diff --git a/neural_speed/convert/convert_mixtral.py b/neural_speed/convert/convert_mixtral.py index 4166d94be..77a1376f4 100644 --- a/neural_speed/convert/convert_mixtral.py +++ b/neural_speed/convert/convert_mixtral.py @@ -1066,6 +1066,7 @@ def write_file_header(self, params: Params, file_type: NEFileType) -> None: self.fout.write(struct.pack("i", 0)) self.fout.write(struct.pack("i", 8)) self.fout.write(struct.pack("i", 2)) + self.fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma self.fout.write(struct.pack("f", params.rms_norm_eps)) self.fout.write(struct.pack("f", params.rope_theta)) self.fout.write(struct.pack("f", params.rope_scale)) diff --git a/neural_speed/convert/convert_mpt.py b/neural_speed/convert/convert_mpt.py index 04ee52421..c31961f23 100644 --- a/neural_speed/convert/convert_mpt.py +++ b/neural_speed/convert/convert_mpt.py @@ -102,6 +102,7 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) # n_experts fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma fout.write(struct.pack("f", hparams.get("layer_norm_eps", 1e-5))) # rms_norm_eps or layer_norm_eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor diff --git a/neural_speed/convert/convert_opt.py b/neural_speed/convert/convert_opt.py index c96885334..ea3f226d6 100644 --- a/neural_speed/convert/convert_opt.py +++ b/neural_speed/convert/convert_opt.py @@ -114,6 +114,7 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) # n_experts fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma fout.write(struct.pack("f", hparams.get("layer_norm_eps", 1e-5))) # rms_norm_eps or layer_norm_eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor diff --git a/neural_speed/convert/convert_phi.py b/neural_speed/convert/convert_phi.py index a916ccf5f..073b54fbf 100644 --- a/neural_speed/convert/convert_phi.py +++ b/neural_speed/convert/convert_phi.py @@ -199,6 +199,7 @@ def phi_convert(model, tokenizer, dir_model, fname_out, ftype, hparams): fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) # n_experts fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma fout.write(struct.pack("f", hparams.get("layer_norm_eps", 1e-5))) # rms_norm_eps or layer_norm_eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor diff --git a/neural_speed/convert/convert_quantized_baichuan.py b/neural_speed/convert/convert_quantized_baichuan.py index 22928a6bc..2220e3b8b 100644 --- a/neural_speed/convert/convert_quantized_baichuan.py +++ b/neural_speed/convert/convert_quantized_baichuan.py @@ -101,6 +101,7 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", hparams["intermediate_size"])) fout.write(struct.pack("i", 0)) # n_experts fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor diff --git a/neural_speed/convert/convert_quantized_falcon.py b/neural_speed/convert/convert_quantized_falcon.py index 956b0a92b..4b90b1d75 100644 --- a/neural_speed/convert/convert_quantized_falcon.py +++ b/neural_speed/convert/convert_quantized_falcon.py @@ -80,6 +80,7 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) # n_experts fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor diff --git a/neural_speed/convert/convert_quantized_gptj.py b/neural_speed/convert/convert_quantized_gptj.py index 707b63aca..85a1e769b 100644 --- a/neural_speed/convert/convert_quantized_gptj.py +++ b/neural_speed/convert/convert_quantized_gptj.py @@ -140,6 +140,7 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) # n_experts fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma fout.write(struct.pack("f", hparams.get("layer_norm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor diff --git a/neural_speed/convert/convert_quantized_llama.py b/neural_speed/convert/convert_quantized_llama.py index 78f497ef4..9762492ee 100644 --- a/neural_speed/convert/convert_quantized_llama.py +++ b/neural_speed/convert/convert_quantized_llama.py @@ -150,6 +150,7 @@ def main(args_in: Optional[List[str]] = None) -> None: f.write(struct.pack("i", 0)) f.write(struct.pack("i", 0)) # n_experts f.write(struct.pack("i", 0)) # n_expert_used + f.write(struct.pack("i", 0)) # n_embd_head_k for gemma f.write(struct.pack("f", config["rms_norm_eps"])) f.write(struct.pack("f", config["rope_theta"] if "rope_theta" in config else 10000)) diff --git a/neural_speed/convert/convert_quantized_mistral.py b/neural_speed/convert/convert_quantized_mistral.py index 3e24295cd..c89bcec41 100644 --- a/neural_speed/convert/convert_quantized_mistral.py +++ b/neural_speed/convert/convert_quantized_mistral.py @@ -159,6 +159,7 @@ def main(args_in: Optional[List[str]] = None) -> None: f.write(struct.pack("i", 0)) f.write(struct.pack("i", 0)) # n_experts f.write(struct.pack("i", 0)) # n_expert_used + f.write(struct.pack("i", 0)) # n_embd_head_k for gemma f.write(struct.pack("f", config["rms_norm_eps"])) f.write(struct.pack("f", config["rope_theta"] if "rope_theta" in config else 10000)) diff --git a/neural_speed/convert/convert_quantized_mixtral.py b/neural_speed/convert/convert_quantized_mixtral.py index c320797d9..df793cda6 100644 --- a/neural_speed/convert/convert_quantized_mixtral.py +++ b/neural_speed/convert/convert_quantized_mixtral.py @@ -174,6 +174,7 @@ def main(args_in: Optional[List[str]] = None) -> None: f.write(struct.pack("i", 0)) f.write(struct.pack("i", 8)) # n_experts f.write(struct.pack("i", 2)) # n_expert_used + f.write(struct.pack("i", 0)) # n_embd_head_k for gemma f.write(struct.pack("f", config["rms_norm_eps"])) f.write(struct.pack("f", config["rope_theta"] if "rope_theta" in config else 10000)) diff --git a/neural_speed/convert/convert_quantized_phi.py b/neural_speed/convert/convert_quantized_phi.py index 8588116b1..9fe22bfaa 100644 --- a/neural_speed/convert/convert_quantized_phi.py +++ b/neural_speed/convert/convert_quantized_phi.py @@ -66,6 +66,7 @@ def convert_phi1_5_gptq_to_bestTLA(model_path, out_path, outtype, model, hparams fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) # n_experts fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor diff --git a/neural_speed/convert/convert_quantized_qwen.py b/neural_speed/convert/convert_quantized_qwen.py index 02ded7622..998695bc4 100644 --- a/neural_speed/convert/convert_quantized_qwen.py +++ b/neural_speed/convert/convert_quantized_qwen.py @@ -83,6 +83,7 @@ def main(args_in: Optional[List[str]] = None) -> None: f.write(struct.pack("i", 0)) f.write(struct.pack("i", 0)) # n_experts f.write(struct.pack("i", 0)) # n_expert_used + f.write(struct.pack("i", 0)) # n_embd_head_k for gemma f.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps f.write(struct.pack("f", 10000.0)) # freq_base f.write(struct.pack("f", 1.0)) # rope_factor diff --git a/neural_speed/convert/convert_qwen.py b/neural_speed/convert/convert_qwen.py index 7cc59b72b..f268b37e6 100644 --- a/neural_speed/convert/convert_qwen.py +++ b/neural_speed/convert/convert_qwen.py @@ -126,6 +126,7 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) # n_experts fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms_norm_eps or layer_norm_eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor diff --git a/neural_speed/convert/convert_stablelm.py b/neural_speed/convert/convert_stablelm.py index af00ba4fc..f5f1d43fd 100644 --- a/neural_speed/convert/convert_stablelm.py +++ b/neural_speed/convert/convert_stablelm.py @@ -197,6 +197,7 @@ def stablelm_convert(model, tokenizer, dir_model, fname_out, ftype, hparams): fout.write(struct.pack("i", 0)) # n_experts fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma fout.write(struct.pack("f", hparams.get("layer_norm_eps", 1e-5))) # rms_norm_eps or layer_norm_eps fout.write(struct.pack("f", hparams["rope_theta"])) # freq_base fout.write(struct.pack("f", 1.0)) # freq_scale, was removed in config.json (by default=1.0) diff --git a/neural_speed/convert/convert_starcoder.py b/neural_speed/convert/convert_starcoder.py index 8f1dba042..b00fc566f 100644 --- a/neural_speed/convert/convert_starcoder.py +++ b/neural_speed/convert/convert_starcoder.py @@ -117,6 +117,7 @@ def main(args_in: Optional[List[str]] = None) -> None: fout.write(struct.pack("i", 0)) fout.write(struct.pack("i", 0)) # n_experts fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("i", 0)) # n_embd_head_k for gemma fout.write(struct.pack("f", hparams.get("layer_norm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor diff --git a/neural_speed/core/layers/Ops.h b/neural_speed/core/layers/Ops.h index b0f441b42..e38dfbafd 100644 --- a/neural_speed/core/layers/Ops.h +++ b/neural_speed/core/layers/Ops.h @@ -45,11 +45,11 @@ enum ne_op { NE_OP_NORM, // normalize NE_OP_RMS_NORM, NE_OP_RMS_NORM_BACK, + NE_OP_RMS_ARGSORT, NE_OP_MUL_MAT, NE_OP_MUL_MAT_BIAS, NE_OP_MUL_MAT_ID, - NE_OP_MUL_ID_FFN_SILU, NE_OP_SCALE, NE_OP_SET, NE_OP_CPY, @@ -76,7 +76,9 @@ enum ne_op { NE_OP_MUL_QKV, NE_OP_MUL_FFN_SILU, NE_OP_MUL_FFN_GELU, + NE_OP_MUL_FFN_GELU_MUL, NE_OP_MUL_FFN_ADD_GELU, + NE_OP_MUL_ID_FFN_SILU, NE_OP_FLASH_ATTN, NE_OP_FLASH_ATTN_KV_UPDATE, NE_OP_FLASH_FF, diff --git a/neural_speed/core/layers/ip_fusion_ffn.cpp b/neural_speed/core/layers/ip_fusion_ffn.cpp index 5875bc5bd..6aed77f91 100644 --- a/neural_speed/core/layers/ip_fusion_ffn.cpp +++ b/neural_speed/core/layers/ip_fusion_ffn.cpp @@ -684,6 +684,11 @@ bool bestla_fusion_FFN_SiLu_f32f32_support(void* w1ptr, void* w2ptr, void* w3ptr return ffn_3w::bestla_fusion_ffn_f32f32_support(w1ptr, w2ptr, w3ptr, seq, fin, fmid, fout); } +bool bestla_fusion_FFN_Gelu_Mul_f32f32_support(void* w1ptr, void* w2ptr, void* w3ptr, int seq, int fin, int fmid, + int fout) { + return ffn_3w::bestla_fusion_ffn_f32f32_support(w1ptr, w2ptr, w3ptr, seq, fin, fmid, fout); +} + void bestla_fusion_FFN_SiLu_f32f32_forward(float* activation, void* w1ptr, void* w2ptr, void* w3ptr, float* tmp1, float* tmp2, float* output, int seq, int fin, int fmid, int fout, void* workspace) { @@ -695,6 +700,16 @@ void bestla_fusion_FFN_SiLu_f32f32_forward(float* activation, void* w1ptr, void* activation, w1ptr, w2ptr, w3ptr, tmp1, tmp2, output, seq, fin, fmid, fout, workspace, epi_args1, epi_args2); } +void bestla_fusion_FFN_Gelu_Mul_f32f32_forward(float* activation, void* w1ptr, void* w2ptr, void* w3ptr, float* tmp1, + float* tmp2, float* output, int seq, int fin, int fmid, int fout, + void* workspace) { + epilogue::gemm::ParamAccumulatorWriteBack epi_args1 = {tmp1, fmid}; + epilogue::gemm::ParamAccumulatorWriteBack epi_args2 = {output, fout}; + ffn_3w::bestla_fusion_ffn_f32f32_forward( + activation, w1ptr, w2ptr, w3ptr, tmp1, tmp2, output, seq, fin, fmid, fout, workspace, epi_args1, epi_args2); +} + bool bestla_fusion_FFN_GeLu_f32f32_support(void* w1ptr, void* w2ptr, int seq, int fin, int fmid, int fout) { return ffn_2w::bestla_fusion_ffn_f32f32_support(w1ptr, w2ptr, seq, fin, fmid, fout); } diff --git a/neural_speed/core/ne_bestla.h b/neural_speed/core/ne_bestla.h index a9ad0a5ea..7d34525fd 100644 --- a/neural_speed/core/ne_bestla.h +++ b/neural_speed/core/ne_bestla.h @@ -49,8 +49,13 @@ void bestla_fusion_QKV_f32f32_forward(float* activation, void* wqptr, void* wkpt unsigned long long bestla_fusion_FFN_f32f32_get_workspace_size(int seq, int fin, int fmid, int fout, void* w1ptr, void* w2ptr); -bool bestla_fusion_FFN_SiLu_f32f32_support(void* w1ptr, void* w2ptr, void* w3ptr, int seq, int fin, int fmid, int fout); +bool bestla_fusion_FFN_Gelu_Mul_f32f32_support(void* w1ptr, void* w2ptr, void* w3ptr, int seq, int fin, int fmid, + int fout); +void bestla_fusion_FFN_Gelu_Mul_f32f32_forward(float* activation, void* w1ptr, void* w2ptr, void* w3ptr, float* tmp1, + float* tmp2, float* output, int seq, int fin, int fmid, int fout, + void* workspace); +bool bestla_fusion_FFN_SiLu_f32f32_support(void* w1ptr, void* w2ptr, void* w3ptr, int seq, int fin, int fmid, int fout); void bestla_fusion_FFN_SiLu_f32f32_forward(float* activation, void* w1ptr, void* w2ptr, void* w3ptr, float* tmp1, float* tmp2, float* output, int seq, int fin, int fmid, int fout, void* workspace); diff --git a/neural_speed/core/ne_layers.c b/neural_speed/core/ne_layers.c index 3f892a371..a23df031e 100644 --- a/neural_speed/core/ne_layers.c +++ b/neural_speed/core/ne_layers.c @@ -427,6 +427,7 @@ static const char* NE_OP_LABEL[NE_OP_COUNT] = { "MUL_QKV", "FFN_SILU", "FFN_GeLU", + "FFN_GeLU_MUL", "FFN_ADD_GeLU", "FFN_ID_SILU", "FLASH_ATTN", @@ -442,7 +443,7 @@ static const char* NE_OP_LABEL[NE_OP_COUNT] = { "DEBUG", }; -static_assert(NE_OP_COUNT == 67, "NE_OP_COUNT != 67"); +static_assert(NE_OP_COUNT == 69, "NE_OP_COUNT != 69"); static const char* NE_OP_SYMBOL[NE_OP_COUNT] = { "none", @@ -502,6 +503,7 @@ static const char* NE_OP_SYMBOL[NE_OP_COUNT] = { "ffn_silu(x)", "ffn_id_silu(x)", "ffn_gelu(x)", + "ffn_gelu_mul(x)", "ffn_gelu_with_bias(x)", "flash_attn(x)", "flash_attn_kv_update(x)", @@ -2396,6 +2398,33 @@ struct ne_tensor* ne_ffn_gelu(struct ne_context* ctx, struct ne_tensor* w1, stru return result; } +struct ne_tensor* ne_ffn_gelu_mul(struct ne_context* ctx, struct ne_tensor* w1, struct ne_tensor* w2, + struct ne_tensor* w3, struct ne_tensor* src) { + NE_ASSERT(ne_are_same_shape(w1, w3)); + NE_ASSERT(w2->ne[0] == w1->ne[1]); + + bool is_node = false; + + if (src->grad || w1->grad || w2->grad || w3->grad) { + is_node = true; + } + + const int64_t ne[4] = {w2->ne[1], src->ne[1], src->ne[2], src->ne[3]}; + struct ne_tensor* result = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, ne, NE_SIZE_CALC); + const int64_t tne[4] = {w1->ne[1], src->ne[1], src->ne[2], src->ne[3]}; + struct ne_tensor* tmp = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC); + struct ne_tensor* tmp1 = ne_new_tensor(ctx, NE_TYPE_F32, src->n_dims, tne, NE_SIZE_CALC); + + result->op = NE_OP_MUL_FFN_GELU_MUL; + result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->src0 = src; + result->src1 = w1; + result->opt[0] = w2; + result->opt[1] = w3; + result->opt[2] = tmp; + result->opt[3] = tmp1; + return result; +} // ne_scale struct ne_tensor* ne_scale_impl(struct ne_context* ctx, struct ne_tensor* a, struct ne_tensor* b, bool inplace) { @@ -7753,6 +7782,25 @@ static void ne_compute_forward_ffn_gelu(const struct ne_compute_params* params, seq, fin, fmid, fout, params->wdata); } +static void ne_compute_forward_ffn_gelu_mul(const struct ne_compute_params* params, const struct ne_tensor* src, + const struct ne_tensor* w1, const struct ne_tensor* w2, + struct ne_tensor* w3, const struct ne_tensor* tmp, struct ne_tensor* tmp1, + struct ne_tensor* dst) { + if (params->type == NE_TASK_INIT) { + return; + } + + if (params->type == NE_TASK_FINALIZE) { + return; + } + const int fin = src->ne[0]; + const int fout = dst->ne[0]; + const int fmid = w1->ne[1]; + const int seq = dst->ne[1]; + bestla_fusion_FFN_Gelu_Mul_f32f32_forward((float*)src->data, w1->data, w2->data, w3->data, (float*)tmp->data, + (float*)tmp1->data, (float*)dst->data, seq, fin, fmid, fout, params->wdata); +} + // ne_compute_forward_scale static void ne_compute_forward_scale_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, @@ -10394,6 +10442,10 @@ static void ne_compute_forward(struct ne_compute_params* params, struct ne_tenso case NE_OP_MUL_FFN_GELU: { ne_compute_forward_ffn_gelu(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor); } break; + case NE_OP_MUL_FFN_GELU_MUL: { + ne_compute_forward_ffn_gelu_mul(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], + tensor->opt[2], tensor->opt[3], tensor); + } break; case NE_OP_SCALE: { ne_compute_forward_scale(params, tensor->src0, tensor->src1, tensor); } break; @@ -11276,6 +11328,7 @@ void ne_graph_compute(struct ne_context* ctx, struct ne_cgraph* cgraph) { } break; case NE_OP_MUL_FFN_SILU: case NE_OP_MUL_FFN_GELU: + case NE_OP_MUL_FFN_GELU_MUL: case NE_OP_MUL_FFN_ADD_GELU: { size_t cur = 0; cur = bestla_fusion_FFN_f32f32_get_workspace_size(node->src0->ne[1], node->src0->ne[0], node->src1->ne[1], diff --git a/neural_speed/core/ne_layers.h b/neural_speed/core/ne_layers.h index c6a7c566d..19734c98d 100644 --- a/neural_speed/core/ne_layers.h +++ b/neural_speed/core/ne_layers.h @@ -272,6 +272,9 @@ NE_API struct ne_tensor* ne_mul_qkv(struct ne_context* ctx, struct ne_tensor* qw NE_API struct ne_tensor* ne_ffn_silu(struct ne_context* ctx, struct ne_tensor* w1, struct ne_tensor* w2, struct ne_tensor* w3, struct ne_tensor* src); +NE_API struct ne_tensor* ne_ffn_gelu_mul(struct ne_context* ctx, struct ne_tensor* w1, struct ne_tensor* w2, + struct ne_tensor* w3, struct ne_tensor* src); + NE_API struct ne_tensor* ne_ffn_add_gelu(struct ne_context* ctx, struct ne_tensor* w1, struct ne_tensor* w2, struct ne_tensor* b1, struct ne_tensor* b2, struct ne_tensor* src); diff --git a/neural_speed/models/CMakeLists.txt b/neural_speed/models/CMakeLists.txt index a1edeca6f..58185c6de 100644 --- a/neural_speed/models/CMakeLists.txt +++ b/neural_speed/models/CMakeLists.txt @@ -34,6 +34,7 @@ add_model(qwen qwen/qwen.cpp qwen/qwen_utils.cpp ${MODEL_UTILS_SOURCE}) add_model(whisper whisper/whisper.cpp whisper/whisper_utils.cpp ${MODEL_UTILS_SOURCE}) add_model(chatglm chatglm/chatglm.cpp chatglm/chatglm_utils.cpp ${MODEL_UTILS_SOURCE}) add_model(chatglm2 chatglm/chatglm2.cpp chatglm/chatglm2_utils.cpp ${MODEL_UTILS_SOURCE}) +add_model(gemma gemma/gemma.cpp gemma/gemma_utils.cpp ${MODEL_UTILS_SOURCE}) add_model(phi phi/phi.cpp phi/phi_utils.cpp ${MODEL_UTILS_SOURCE}) add_model(stablelm stablelm/stablelm.cpp stablelm/stablelm_utils.cpp ${MODEL_UTILS_SOURCE}) add_model(chatglm3 chatglm/chatglm2.cpp chatglm/chatglm2_utils.cpp ${MODEL_UTILS_SOURCE}) diff --git a/neural_speed/models/gemma/gemma.cpp b/neural_speed/models/gemma/gemma.cpp new file mode 100644 index 000000000..1bd069e1d --- /dev/null +++ b/neural_speed/models/gemma/gemma.cpp @@ -0,0 +1,416 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "core/data_types.h" +#include "core/ne.h" +#include "core/ne_layers.h" +#include "core/ne_bestla.h" +#include "core/layers/mha_dense.h" +#include "models/model_utils/model_config.h" +#include "models/model_utils/model_utils.h" +#include "models/model_utils/util.h" + +// feed-forward network +struct ne_tensor* gemma_ff(const model_layer& layer, const int batch_size, const int N, ne_context* ctx0, + ne_tensor* inp) { + struct ne_tensor* cur = inp; + if (bestla_fusion_FFN_Gelu_Mul_f32f32_support(layer.ffn[0]->data, layer.ffn[1]->data, layer.ffn[2]->data, N, + cur->ne[0], layer.ffn[0]->ne[1], layer.ffn[1]->ne[1])) { + cur = ne_ffn_gelu_mul(ctx0, layer.ffn[0], layer.ffn[1], layer.ffn[2], cur); + } else { + struct ne_tensor* cur_1 = ne_mul_mat(ctx0, layer.ffn[0], cur); + + struct ne_tensor* cur_2 = ne_mul_mat(ctx0, layer.ffn[2], cur); + + // GELU activation + cur_1 = ne_gelu(ctx0, cur_1); + + // projection + // cur = proj_w*cur + proj_b + cur = ne_mul(ctx0, cur_1, cur_2); + + cur = ne_mul_mat(ctx0, layer.ffn[1], cur); + } + + return cur; +} + +// evaluate the transformer +// +// - lctx: model context +// - tokens: new batch of tokens to process +// - n_past: the offset to which the kv is cached to +// - n_total: the number of tokens evaluated so far (including evicted tokens if there is any) +// - n_threads: number of threads to use +// +static bool gemma_model_eval_internal(model_context* ctx, const model_input* inputs, const int n_input, + const int n_threads) { + const int64_t t_start_us = ne_time_us(); + model_context& lctx = *ctx; + + // static batching for now + const int N = inputs->n_tokens; + const int n_past = inputs->n_past; + const int n_total = inputs->n_total; + const bool shift_roped_k = lctx.shift_roped_k; + const bool is_ring_full = shift_roped_k && n_total > n_past; + NE_ASSERT(("Shift-RoPE-K to be implemented for the neox-mode RoPE!", !is_ring_full)); + const int batch_size = lctx.batch_size; + MODEL_ASSERT(batch_size == n_input); + const int kv_n_ctx_block = lctx.kv_n_ctx_block; + + const auto& model = lctx.model; + const auto& hparams = model.hparams; + + const auto& kv_self = model.kv_self; + + MODEL_ASSERT(!!kv_self.ctx); + + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_ctx = lctx.n_ctx; + const int n_keep = lctx.n_keep; + const int n_head = hparams.n_head; + const int n_head_kv = hparams.n_head_kv; + const int n_vocab = hparams.n_vocab; + const int n_rot = hparams.n_embd_head_k; + const int head_dim = hparams.n_embd_head_k; + const int n_gqa_embd = head_dim * n_head_kv; + + auto& mem_per_token = lctx.mem_per_token; + auto& buf_compute = lctx.buf_compute; + + struct ne_init_params params = { + /*.mem_size =*/buf_compute.size, + /*.mem_buffer =*/buf_compute.addr, + /*.no_alloc =*/false, + }; + + struct ne_context* ctx0 = ne_init(params); + + // for big progptneoxs, if BLAS is enabled, it is better to use only one thread + // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance + ne_cgraph gf = {}; + gf.n_threads = N >= 32 && ne_cpu_has_blas() ? 1 : n_threads; + const bool run_mha_reordered = kv_self.k->type == NE_TYPE_BTLA; + kv_cache_info_t kv_cache_info = {}; + if (run_mha_reordered) { + NE_ASSERT(("kv cache should be the same dtype", kv_self.v->type == NE_TYPE_BTLA)); + attn_shape_t attn_shape = { + /* .batch_size = */ 1, + /* .head_num = */ n_head, + /* .heads_kv = */ n_head_kv, + /* .head_size = */ head_dim, + /* .sl_q = */ N, // Note: make sure that bestla reordered attn supports next token inference + /* .sl_kv = */ n_past + N, + }; + + NE_ASSERT(("bestla managed kv-cache not supported; use `--memory-f16 / --memory-f32` instead", + bestla_reordered_attn_fp32_support(&attn_shape))); + kv_shape_t kv_shape{ + /* .heads_kv = */ static_cast(n_head_kv), + /* .head_size = */ static_cast(head_dim), + /* .sl_kv_max = */ static_cast(n_ctx), + }; + bestla_reordered_attn_fp32_batch_kv_info(&kv_shape, &kv_cache_info); + } + struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N * batch_size); + ne_set_name(embd, "embd"); + + for (int i = 0; i < batch_size; ++i) { + memcpy(static_cast(embd->data) + i * N, (inputs + i)->tokens, N * ne_element_size(embd)); + } + + struct ne_tensor* inpL = ne_get_rows(ctx0, model.others[0], embd); + inpL = ne_scale(ctx0, inpL, ne_new_f32(ctx0, sqrtf(static_cast(n_embd)))); + + for (int il = 0; il < n_layer; ++il) { + struct ne_tensor* cur; + + lctx.use_buf(ctx0, 0); + + { + // RMS + { + cur = ne_rms_norm(ctx0, inpL, hparams.norm_eps); + cur = ne_mul(ctx0, cur, model.layers[il].norm[0]); + ne_set_name(cur, "input_norm"); + } + + // compute QKV + struct ne_tensor* Kcur = + ne_reshape_3d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[1], cur), head_dim, n_head_kv, N); + ne_set_name(Kcur, "Kcur_matmul"); + struct ne_tensor* Qcur = + ne_reshape_3d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[0], cur), head_dim, n_head, N); + ne_set_name(Qcur, "Qcur_matmul"); + struct ne_tensor* Vcur = + ne_reshape_3d(ctx0, ne_mul_mat(ctx0, model.layers[il].attn[2], cur), head_dim, n_head_kv, N); + ne_set_name(Vcur, "Vcur_matmul"); + + // using mode = 2 for GPT-NeoX mode + Qcur = ne_rope_inplace(ctx0, Qcur, n_past, n_rot, 2, 0, hparams.freq_base, hparams.freq_scale); + ne_set_name(Qcur, "Qcur"); + Qcur = ne_scale_inplace(ctx0, Qcur, ne_new_f32(ctx0, 1.0f / sqrt(static_cast((n_gqa_embd) / n_head_kv)))); + + Kcur = ne_rope_inplace(ctx0, Kcur, n_past, n_rot, 2, 0, hparams.freq_base, hparams.freq_scale); + ne_set_name(Kcur, "kcur"); + const float attn_scale = 1.0f; + // store key and value to memory + if (!run_mha_reordered) { + { + std::vector Kcur_bs(batch_size); + std::vector Vcur_bs(batch_size); + std::vector k_bs(batch_size); + std::vector v_bs(batch_size); + for (int i = 0; i < batch_size; ++i) { + // batch K + Kcur_bs[i] = + ne_permute(ctx0, + ne_view_4d(ctx0, Kcur, head_dim, n_head_kv, N, 1, ne_element_size(Kcur) * head_dim, + ne_element_size(Kcur) * n_gqa_embd, ne_element_size(Kcur) * n_gqa_embd * N, + i * ne_element_size(Kcur) * n_gqa_embd * N), + 0, 2, 1, 3); + k_bs[i] = ne_view_4d( + ctx0, kv_self.k, head_dim, N, n_head_kv, 1, ne_element_size(kv_self.k) * head_dim, + ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_gqa_embd * n_ctx, + ((il * n_ctx) * ne_element_size(kv_self.k) * n_gqa_embd * kv_n_ctx_block + + i * n_ctx * n_gqa_embd * ne_element_size(kv_self.k) + head_dim * n_past * ne_element_size(kv_self.k))); + + // batch V + Vcur_bs[i] = + ne_permute(ctx0, + ne_reshape_4d(ctx0, + ne_view_2d(ctx0, Vcur, n_gqa_embd, N, ne_element_size(Vcur) * n_gqa_embd, + i * ne_element_size(Vcur) * n_gqa_embd * N), + head_dim, n_head_kv, N, 1), + 1, 2, 0, 3); + v_bs[i] = ne_view_4d( + ctx0, kv_self.v, N, head_dim, n_head_kv, 1, n_ctx * ne_element_size(kv_self.v), + n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_gqa_embd, + ((il * n_ctx) * ne_element_size(kv_self.v) * n_gqa_embd * kv_n_ctx_block + + i * n_ctx * n_gqa_embd * ne_element_size(kv_self.v) + n_past * ne_element_size(kv_self.v))); + ne_build_forward_expand(&gf, ne_cpy(ctx0, Kcur_bs[i], k_bs[i])); + ne_build_forward_expand(&gf, ne_cpy(ctx0, Vcur_bs[i], v_bs[i])); + } + } + // Q = Qcur.contiguous().view(n_gqa_embd/n_head, n_head, N).permute(0, 2, 1, 3) + struct ne_tensor* Q = ne_permute(ctx0, ne_reshape_4d(ctx0, Qcur, head_dim, n_head, N, batch_size), 0, 2, 1, 3); + + // K = Kmem.view(n_gqa_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) + struct ne_tensor* K = ne_view_4d( + ctx0, kv_self.k, head_dim, n_past + N, n_head_kv, batch_size, ne_element_size(kv_self.k) * head_dim, + ne_element_size(kv_self.k) * head_dim * n_ctx, ne_element_size(kv_self.k) * n_gqa_embd * n_ctx, + il * n_ctx * ne_element_size(kv_self.k) * n_gqa_embd * kv_n_ctx_block); + + // K * Q + struct ne_tensor* KQ = ne_mul_mat(ctx0, K, Q); + + // KQ_scaled = KQ / sqrt(n_gqa_embd/n_head) + struct ne_tensor* KQ_scaled = ne_scale_inplace(ctx0, KQ, ne_new_f32(ctx0, 1.0f)); + + // KQ_masked = mask_past(KQ_scaled) + struct ne_tensor* KQ_masked = ne_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + + // KQ = soft_max(KQ_masked) + struct ne_tensor* KQ_soft_max = ne_soft_max(ctx0, KQ_masked); + + // V_trans = Vmem.view(n_gqa_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() + struct ne_tensor* V = + ne_view_4d(ctx0, kv_self.v, n_past + N, head_dim, n_head_kv, batch_size, n_ctx * ne_element_size(kv_self.v), + n_ctx * ne_element_size(kv_self.v) * head_dim, n_ctx * ne_element_size(kv_self.v) * n_gqa_embd, + il * n_ctx * ne_element_size(kv_self.v) * n_gqa_embd * kv_n_ctx_block); + + // KQV = transpose(V) * KQ_soft_max + struct ne_tensor* KQV = ne_mul_mat(ctx0, V, KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ne_tensor* KQV_merged = ne_permute(ctx0, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_gqa_embd, N) + cur = ne_cpy(ctx0, KQV_merged, + ne_new_tensor_2d(ctx0, NE_TYPE_F32, head_dim * n_head, N * batch_size, NE_SIZE_CALC)); + } else { + const auto seq_kv = n_past + N; + const auto k_size = kv_cache_info.k_bytes; + const auto v_size = kv_cache_info.v_bytes; + + // store key and value to memory + { + const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor + head_dim, n_ctx, n_head_kv, // ne + 0, 0, // nb (bestla managed) + il * k_size); // offset + ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache, Kcur, n_past, false)); + const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor + head_dim, n_ctx, n_head_kv, // ne + 0, 0, // nb (bestla managed) + il * v_size); // offset + ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, Vcur, n_past, false)); + } + + struct ne_tensor* Q = ne_permute(ctx0, Qcur, 0, 2, 1, 3); + ne_set_name(Q, "Q"); + + struct ne_tensor* K = + ne_view_3d(ctx0, kv_self.k, // tensor + head_dim, seq_kv, n_head_kv, // ne + kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, // nb (bestla managed) + il * k_size); // offset + *reinterpret_cast(&K->nb[0]) = kv_cache_info.k_layout; // us nb0 for layout + ne_set_name(K, "K"); + struct ne_tensor* V = + ne_view_3d(ctx0, kv_self.v, // tensor + seq_kv, head_dim, n_head_kv, // ne + kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, // nb (bestla managed) + il * v_size); // offset + *reinterpret_cast(&V->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout + ne_set_name(V, "V"); + + ne_attn_flags_t attn_flags = 0; + if (n_past == 0) attn_flags |= NE_ATTN_FLAG_IS_CAUSAL; // no causal mask on next-token cases + struct ne_tensor* KQV_Out = ne_flash_attn(ctx0, Q, K, V, attn_scale, attn_flags); + cur = ne_view_2d(ctx0, KQV_Out, head_dim * n_head, N, head_dim * n_head * ne_element_size(KQV_Out), 0); + } + // projection + { cur = ne_mul_mat(ctx0, model.layers[il].attn[3], cur); } + } + lctx.use_buf(ctx0, 1); + + cur = ne_add(ctx0, cur, inpL); + inpL = cur; + + // this is independent of the self-attention result, so it could be done in parallel to the self-attention + // note here we pass inpL instead of cur + { + cur = ne_rms_norm(ctx0, cur, hparams.norm_eps); + cur = ne_mul(ctx0, cur, model.layers[il].norm[1]); + + // struct ne_tensor* cur2 = ne_mul(ctx0, cur, model.layers[il].norm[1]); + // cur = ne_add(ctx0,cur, cur2); + } + cur = gemma_ff(model.layers[il], N, batch_size, ctx0, cur); + + // input for next layer + inpL = ne_add(ctx0, cur, inpL); + } + + lctx.use_buf(ctx0, 0); + // used at the end to optionally extract the embeddings + struct ne_tensor* embeddings = nullptr; + // norm + { + inpL = ne_rms_norm(ctx0, inpL, hparams.norm_eps); + inpL = ne_mul(ctx0, inpL, model.others[1]); + } + lctx.use_buf(ctx0, -1); + // hidden_states = self.ln_f(hidden_states)&lm_head + { inpL = ne_mul_mat(ctx0, model.others[0], inpL); } + + // logits -> probs + // inpL = ne_soft_max_inplace(ctx0, inpL); + + // run the computation + ne_build_forward_expand(&gf, inpL); + ne_graph_compute(ctx0, &gf); + + if (ns_log_level() == 0 || ns_log_level() == 2) { + ne_graph_profiling(&gf); + } + + // update kv token count + lctx.model.kv_self.n = n_past + N; + + // extract logits + { + auto& logits_out = lctx.logits; + + size_t bs_stride = n_vocab * N; + if (lctx.logits_all) { + logits_out.resize(n_vocab * N * batch_size); + for (int i = 0; i < batch_size; ++i) { + memcpy(logits_out.data() + i * bs_stride, reinterpret_cast(ne_get_data(inpL)) + (i * bs_stride), + sizeof(float) * n_vocab * N); + } + } else { + // return result for just the last token + logits_out.resize(n_vocab * batch_size); + for (int i = 0; i < batch_size; ++i) { + memcpy(logits_out.data() + (i * n_vocab), + reinterpret_cast(ne_get_data(inpL)) + (i * bs_stride) + (n_vocab * (N - 1)), + sizeof(float) * n_vocab); + } + } + } + + // extract embeddings + if (!lctx.embedding.empty()) { + auto& embedding_out = lctx.embedding; + + embedding_out.resize(n_embd); + memcpy(embedding_out.data(), reinterpret_cast(ne_get_data(embeddings)) + (n_embd * (N - 1)), + sizeof(float) * n_embd); + } + + if (mem_per_token == 0) { + mem_per_token = ne_used_mem(ctx0) / N; + } + + ne_free(ctx0); + + // measure the performance only for the single-token evals + int64_t time_interval = ne_time_us() - t_start_us; + if (N == 1) { + lctx.t_eval_us += time_interval; + lctx.n_eval++; + } else if (N > 1) { + lctx.t_p_eval_us += time_interval; + lctx.n_p_eval += N; + } + lctx.eval_times.push_back(time_interval); + + return true; +} + +int model_eval(struct model_context* ctx, const model_input* inputs, const int n_input, int n_threads) { + if (!gemma_model_eval_internal(ctx, inputs, n_input, n_threads)) { + fprintf(stderr, "%s: failed to eval\n", __func__); + return 1; + } + + // get a more accurate load time, upon first eval + + if (!ctx->has_evaluated_once) { + ctx->t_load_us = ne_time_us() - ctx->t_start_us; + ctx->has_evaluated_once = true; + } + + return 0; +} diff --git a/neural_speed/models/gemma/gemma.h b/neural_speed/models/gemma/gemma.h new file mode 100644 index 000000000..cc733ae45 --- /dev/null +++ b/neural_speed/models/gemma/gemma.h @@ -0,0 +1,60 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GEMMA_H +#define GEMMA_H + +#include "models/model_utils/model_files.h" +#include "models/model_utils/model_types.h" + +enum gemma_model { + GEMMA_2B, + GEMMA_7B, +}; + +static const model_scratch gemma_mem_req(int n_layers, float enlarge_scale = 1.0f) { + switch (n_layers) { + case 18: + return { + static_cast(enlarge_scale * 1024) * MB, + static_cast(enlarge_scale * 1024) * MB, + static_cast(enlarge_scale * 1608) * MB, + }; + case 28: + return { + static_cast(enlarge_scale * 1024) * MB, + static_cast(enlarge_scale * 1024) * MB, + static_cast(enlarge_scale * 1608) * MB, + }; + default: + MODEL_ASSERT(false); + } +} + +class Gemma : public IModel { + private: + model_archs arch = MODEL_GEMMA; + std::unique_ptr ml; + uint32_t n_layer, n_embd, n_ff, n_vocab, n_head, n_head_kv, n_expert, n_expert_used, n_embd_head_k; + int n_gpu_layer; + bool use_mmap, use_mlock, vocab_only; + model_scratch scratch; + + public: + void init(const char* path_model, model_context* ctx, int n_gpu_layers, bool use_mmap_, bool use_mlock_, + bool vocab_only_) override; + void load(model_context* ctx, model_progress_callback progress_callback, void* progress_callback_user_data) override; +}; + +#endif // GEMMA_H diff --git a/neural_speed/models/gemma/gemma_utils.cpp b/neural_speed/models/gemma/gemma_utils.cpp new file mode 100644 index 000000000..d7c2c4637 --- /dev/null +++ b/neural_speed/models/gemma/gemma_utils.cpp @@ -0,0 +1,236 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "core/data_types.h" +#include "core/ne.h" +#include "core/ne_layers.h" +#include "models/gemma/gemma.h" +#include "models/model_utils/model_config.h" +#include "models/model_utils/model_files.h" +#include "models/model_utils/model_types.h" +#include "models/model_utils/quant_utils.h" +#include "models/model_utils/util.h" +#include "models/models.h" + +void model_load_internal(const std::string& fname, model_archs arch, model_context* ctx, int n_gpu_layers, + bool use_mmap, bool use_mlock, bool vocab_only, model_progress_callback progress_callback, + void* progress_callback_user_data) { + std::unique_ptr ms(new Gemma()); + ms->init(fname.c_str(), ctx, n_gpu_layers, use_mmap, use_mlock, vocab_only); + ms->load(ctx, progress_callback, progress_callback_user_data); + + model_context& lctx = *ctx; + lctx.support_bestla_kv = true; +} + +void Gemma::init(const char* path_model, model_context* ctx, int n_gpu_layer_, bool use_mmap_, bool use_mlock_, + bool vocab_only_) { + model_context& lctx = *ctx; + n_gpu_layer = n_gpu_layer_; + use_mmap = use_mmap_; + use_mlock = use_mlock_; + vocab_only = vocab_only_; + auto& model = lctx.model; + ml.reset(new model_model_loader(path_model, use_mmap, vocab_only)); + lctx.vocab = std::move(ml->file_loaders.at(0)->vocab); + model.hparams = ml->file_loaders.at(0)->hparams; + model_file_version file_version = ml->file_loaders.at(0)->file_version; + auto& hparams = model.hparams; + n_ff = hparams.ffn_hidden_size; + fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab); + fprintf(stderr, "%s: n_ctx = %u\n", __func__, hparams.max_seq_len); + fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd); + fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult); + fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head); + fprintf(stderr, "%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); + fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer); + fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); + fprintf(stderr, "%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); + fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); + fprintf(stderr, "%s: n_parts = %zu\n", __func__, ml->file_loaders.size()); + n_embd = hparams.n_embd; + n_vocab = hparams.n_vocab; + n_layer = hparams.n_layer; + n_head_kv = hparams.n_head_kv; + n_embd_head_k = hparams.n_embd_head_k; + n_head = hparams.n_head; + n_expert = hparams.n_experts; + n_expert_used = hparams.n_experts_used; + scratch = gemma_mem_req(n_layer, lctx.scratch_size_ratio); + model.scratchs = scratch; +} + +#define MODEL_BACKEND_OFFLOAD NE_BACKEND_CPU +void Gemma::load(model_context* ctx, model_progress_callback progress_callback, void* progress_callback_user_data) { + model_context& lctx = *ctx; + auto& model = lctx.model; + auto& ne_ctx = model.ctx; + size_t ctx_size; + size_t mmapped_size; + ml->calc_sizes(&ctx_size, &mmapped_size); + fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); + + // create the ne context + lctx.model.buf.resize(ctx_size); + if (use_mlock) { + lctx.model.mlock_buf.init(lctx.model.buf.addr); + lctx.model.mlock_buf.grow_to(lctx.model.buf.size); + } + + struct ne_init_params params = { + /*.mem_size =*/lctx.model.buf.size, + /*.mem_buffer =*/lctx.model.buf.addr, + /*.no_alloc =*/ml->use_mmap, + }; + + model.ctx = ne_init(params); + if (!model.ctx) { + throw format("ne_init() failed"); + } + + ml->ne_ctx = ne_ctx; + + const int i_gpu_start = n_layer - n_gpu_layer; + model.layers.resize(n_layer); + size_t vram_total = 0; + if (ml->verify_tensor("token_embd.weight")) { + model.others[0] = ml->get_tensor("token_embd.weight", {n_embd, n_vocab}, NE_BACKEND_CPU); + model.others[1] = ml->get_tensor("output_norm.weight", {n_embd}, NE_BACKEND_CPU); + model.others[2] = model.others[0]; + + for (uint32_t i = 0; i < n_layer; ++i) { + const ne_backend backend = static_cast(i) < i_gpu_start ? NE_BACKEND_CPU : MODEL_BACKEND_OFFLOAD; + auto& layer = model.layers[i]; + std::string layers_i = "blk." + std::to_string(i); + + // attention norm + layer.norm[0] = ml->get_tensor(layers_i + ".attn_norm.weight", {n_embd}, backend); + + // qkv GEMM + layer.attn[0] = ml->get_tensor(layers_i + ".attn_q.weight", {n_embd, n_embd_head_k * n_head}, backend); + layer.attn[1] = ml->get_tensor(layers_i + ".attn_k.weight", {n_embd, n_embd_head_k * n_head_kv}, backend); + layer.attn[2] = ml->get_tensor(layers_i + ".attn_v.weight", {n_embd, n_embd_head_k * n_head_kv}, backend); + layer.attn[3] = ml->get_tensor(layers_i + ".attn_output.weight", {n_embd_head_k * n_head, n_embd}, backend); + + // ffn norm + layer.norm[1] = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend); + + // ffn GEMM + if (ml->verify_tensor(layers_i + ".ffn_gate.weight")) { + NE_ASSERT(n_expert == 0); + NE_ASSERT(n_expert_used == 0); + layer.ffn[0] = ml->get_tensor(layers_i + ".ffn_gate.weight", {n_embd, n_ff}, backend); + layer.ffn[1] = ml->get_tensor(layers_i + ".ffn_down.weight", {n_ff, n_embd}, backend); + layer.ffn[2] = ml->get_tensor(layers_i + ".ffn_up.weight", {n_embd, n_ff}, backend); + } + if (backend != NE_BACKEND_CPU) { + vram_total += ne_nbytes(layer.norm[0]) + ne_nbytes(layer.attn[0]) + ne_nbytes(layer.attn[1]) + + ne_nbytes(layer.attn[2]) + ne_nbytes(layer.attn[3]) + ne_nbytes(layer.norm[1]) + + ne_nbytes(layer.ffn[0]) + ne_nbytes(layer.ffn[1]) + ne_nbytes(layer.ffn[2]); + } + } + } else { + model.others[0] = ml->get_tensor("model.embed_tokens.weight", {n_embd, n_vocab}, NE_BACKEND_CPU); + model.others[1] = ml->get_tensor("model.norm.weight", {n_embd}, NE_BACKEND_CPU); + model.others[2] = ml->get_tensor("lm_head.weight", {n_embd, n_vocab}, NE_BACKEND_CPU); + + for (uint32_t i = 0; i < n_layer; ++i) { + const ne_backend backend = static_cast(i) < i_gpu_start ? NE_BACKEND_CPU : MODEL_BACKEND_OFFLOAD; + auto& layer = model.layers[i]; + std::string layers_i = "model.layers." + std::to_string(i); + + // attention norm + layer.norm[0] = ml->get_tensor(layers_i + ".input_layernorm.weight", {n_embd}, backend); + + // qkv GEMM + layer.attn[0] = ml->get_tensor(layers_i + ".self_attn.q_proj.weight", {n_embd, n_embd_head_k * n_head}, backend); + layer.attn[1] = + ml->get_tensor(layers_i + ".self_attn.k_proj.weight", {n_embd, n_embd_head_k * n_head_kv}, backend); + layer.attn[2] = + ml->get_tensor(layers_i + ".self_attn.v_proj.weight", {n_embd, n_embd_head_k * n_head_kv}, backend); + layer.attn[3] = ml->get_tensor(layers_i + ".self_attn.o_proj.weight", {n_embd_head_k * n_head, n_embd}, backend); + + // ffn norm + layer.norm[1] = ml->get_tensor(layers_i + ".post_attention_layernorm.weight", {n_embd}, backend); + + // ffn GEMM + layer.ffn[0] = ml->get_tensor(layers_i + ".mlp.gate_proj.weight", {n_embd, n_ff}, backend); + layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.down_proj.weight", {n_ff, n_embd}, backend); + layer.ffn[2] = ml->get_tensor(layers_i + ".mlp.up_proj.weight", {n_embd, n_ff}, backend); + if (backend != NE_BACKEND_CPU) { + vram_total += ne_nbytes(layer.norm[0]) + ne_nbytes(layer.attn[0]) + ne_nbytes(layer.attn[1]) + + ne_nbytes(layer.attn[2]) + ne_nbytes(layer.attn[3]) + ne_nbytes(layer.norm[1]) + + ne_nbytes(layer.ffn[0]) + ne_nbytes(layer.ffn[1]) + ne_nbytes(layer.ffn[2]); + } + } + } + + // print memory requirements + // this is the total memory required to run the inference + const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory + scratch.scratch0 + scratch.scratch1 + scratch.eval; + fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0); + + (void)n_gpu_layer; + + // populate `tensors_by_name` + for (model_load_tensor& lt : ml->tensors_map.tensors) { + model.tensors_by_name.emplace_back(lt.name, lt.ne_tensor); + } + + ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : nullptr); + + if (progress_callback) { + progress_callback(1.0f, progress_callback_user_data); + } + + model.mapping = std::move(ml->mapping); +} + +#undef MODEL_BACKEND_OFFLOAD + +class gemma_quant_layer : public quant_layer_base { + public: + quant_params_internal get_layer_config(std::string layername, std::vector ne, ne_type type) override { + bool quantize = layername.rfind("weight") == layername.size() - 6; + if ((layername.find("embedding") != std::string::npos) || + (layername == "token_embd.weight" || layername == "model.embed_tokens.weight")) { + // special layer process, can be loaded by config file + return quant_params_internal{quant_bits::q8}; // q80 + } + quantize &= (ne.size() == 2); + if (quantize) { + return mGCfg; // use global quant config + } else { + return quant_params_internal{quant_bits::count}; // non-quant + } + } +}; +REGISTER_QUANT_LAYER_CLASS(gemma); diff --git a/neural_speed/models/model_utils/gguf.h b/neural_speed/models/model_utils/gguf.h index 22d0435ca..791fb97a9 100644 --- a/neural_speed/models/model_utils/gguf.h +++ b/neural_speed/models/model_utils/gguf.h @@ -232,6 +232,7 @@ enum llm_arch { LLM_ARCH_CHATGLM2, LLM_ARCH_CHATGLM3, LLM_ARCH_PHI, + LLM_ARCH_GEMMA, LLM_ARCH_QWEN2, LLM_ARCH_UNKNOWN, }; @@ -253,6 +254,7 @@ static std::map LLM_ARCH_NAMES = {{LLM_ARCH_LLAMA, "llama {LLM_ARCH_CHATGLM2, "chatglm2"}, {LLM_ARCH_CHATGLM3, "chatglm3"}, {LLM_ARCH_PHI, "phi"}, + {LLM_ARCH_GEMMA, "gemma"}, {LLM_ARCH_QWEN2, "qwen2"}}; struct gguf_tensor_info { @@ -432,6 +434,8 @@ enum llm_kv { LLM_KV_ATTENTION_HEAD_COUNT_KV, LLM_KV_ATTENTION_MAX_ALIBI_BIAS, LLM_KV_ATTENTION_CLAMP_KQV, + LLM_KV_ATTENTION_KEY_LENGTH, + LLM_KV_ATTENTION_VALUE_LENGTH, LLM_KV_ATTENTION_LAYERNORM_EPS, LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, LLM_KV_NUM_EXPERTS, @@ -486,6 +490,8 @@ static std::map LLM_KV_NAMES = { {LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv"}, {LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias"}, {LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv"}, + {LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length"}, + {LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length"}, {LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon"}, {LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon"}, diff --git a/neural_speed/models/model_utils/model_files.h b/neural_speed/models/model_utils/model_files.h index b702afc52..b586d0ea5 100644 --- a/neural_speed/models/model_utils/model_files.h +++ b/neural_speed/models/model_utils/model_files.h @@ -911,6 +911,8 @@ struct gguf_loader { GGUF_GET_KEY(ctx_gguf, hparams.n_experts, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_NUM_EXPERTS)); GGUF_GET_KEY(ctx_gguf, hparams.n_experts_used, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_NUM_EXPERTS_USED)); + GGUF_GET_KEY(ctx_gguf, hparams.n_embd_head_k, gguf_get_val_u32, GGUF_TYPE_UINT32, false, + kv(LLM_KV_ATTENTION_KEY_LENGTH)); GGUF_GET_KEY(ctx_gguf, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_BLOCK_COUNT)); GGUF_GET_KEY(ctx_gguf, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT)); @@ -1115,11 +1117,14 @@ struct model_file_loader { hparams.inner_hidden_size = file.read_u32(); hparams.n_experts = file.read_u32(); hparams.n_experts_used = file.read_u32(); + hparams.n_embd_head_k = file.read_u32(); printf("%-16s %d.hparams.inner_hidden_size = %-30d\n", __func__, count++, hparams.inner_hidden_size); printf("%-16s %d.hparams.n_experts = %-30d\n", __func__, count++, hparams.n_experts); printf("%-16s %d.hparams.n_experts_used = %-30d\n", __func__, count++, hparams.n_experts_used); + printf("%-16s %d.hparams.n_embd_head_k = %-30d\n", __func__, count++, hparams.n_embd_head_k); file.read_raw(&hparams.norm_eps, sizeof(float)); + file.read_raw(&hparams.freq_base, sizeof(float)); file.read_raw(&hparams.freq_scale, sizeof(float)); printf("%-16s %d.hparams.norm_eps = %-30f\n", __func__, count++, hparams.norm_eps); @@ -1262,6 +1267,7 @@ struct model_file_saver { file.write_u32(hparams.inner_hidden_size); file.write_u32(hparams.n_experts); file.write_u32(hparams.n_experts_used); + file.write_u32(hparams.n_embd_head_k); file.write_raw(&hparams.norm_eps, sizeof(float)); file.write_raw(&hparams.freq_base, sizeof(float)); diff --git a/neural_speed/models/model_utils/model_types.h b/neural_speed/models/model_utils/model_types.h index 3575a7e2b..454fbbb34 100644 --- a/neural_speed/models/model_utils/model_types.h +++ b/neural_speed/models/model_utils/model_types.h @@ -84,6 +84,7 @@ enum model_archs { MODEL_CHATGLM, MODEL_QWEN, MODEL_PHI, + MODEL_GEMMA, MODEL_STABLELM, MODEL_WHISPER }; @@ -144,6 +145,7 @@ struct model_hparams { int32_t inner_hidden_size = 0; uint32_t n_experts = 0; uint32_t n_experts_used = 0; + uint32_t n_embd_head_k = 0; float rope_scaling_factor = 0.0f; int32_t original_max_position_embeddings = 0; @@ -486,7 +488,8 @@ class model_name_to_arch { {"falcon", MODEL_FALCON}, {"bloom", MODEL_BLOOM}, {"chatglm2", MODEL_CHATGLM2}, {"chatglm", MODEL_CHATGLM}, {"baichuan", MODEL_BAICHUAN}, {"mistral", MODEL_LLAMA}, {"qwen", MODEL_QWEN}, {"phi", MODEL_PHI}, {"stablelm", MODEL_STABLELM}, - {"whisper", MODEL_WHISPER}, {"chatglm3", MODEL_CHATGLM3}, {"mixtral", MODEL_LLAMA}}; + {"whisper", MODEL_WHISPER}, {"chatglm3", MODEL_CHATGLM3}, {"mixtral", MODEL_LLAMA}, + {"gemma", MODEL_GEMMA}}; }; #ifdef __cplusplus diff --git a/neural_speed/models/model_utils/model_utils.cpp b/neural_speed/models/model_utils/model_utils.cpp index bb5b65fa7..266f9486b 100644 --- a/neural_speed/models/model_utils/model_utils.cpp +++ b/neural_speed/models/model_utils/model_utils.cpp @@ -63,7 +63,8 @@ static bool kv_cache_init(const struct model_hparams& hparams, struct model_kv_c const bool shift_roped_k, model_struct* model) { const auto n_layer = hparams.n_layer; auto heads_kv = hparams.n_head_kv > 0 ? hparams.n_head_kv : hparams.n_head; - const auto head_size = hparams.n_embd / hparams.n_head; + const auto head_size = hparams.n_embd_head_k == 0 ? hparams.n_embd / hparams.n_head : hparams.n_embd_head_k; + #ifdef NS_TP_MODEL // when use TP, cached kv will also have smaller size parallel_context* p_ctx = init_parallel_context(); @@ -103,7 +104,7 @@ static bool kv_cache_init(const struct model_hparams& hparams, struct model_kv_c auto& k_cache = model->layers[il].k_cache; auto& v_cache = model->layers[il].v_cache; if (wtype == NE_TYPE_F16) { // chatglm does not support fp32 kv-cache in original impl of chatglm_util.cpp - const int head_size = hparams.n_embd / hparams.n_head; + const auto head_size = hparams.n_embd_head_k == 0 ? hparams.n_embd / hparams.n_head : hparams.n_embd_head_k; const int heads_kv = hparams.multi_query_group_num > 0 ? hparams.multi_query_group_num : hparams.n_head; k_cache = d_ne_new_tensor_4d(model->ctx, NE_TYPE_F16, head_size, n_ctx, heads_kv, batch_size * beam_size); v_cache = d_ne_new_tensor_4d(model->ctx, NE_TYPE_F16, n_ctx, head_size, heads_kv, batch_size * beam_size); @@ -1833,7 +1834,7 @@ static void bestla_model_kv_cache_seq_cpy(struct model_context* ctx, const model const auto& kv_self = ctx->model.kv_self; const auto& hparams = ctx->model.hparams; int heads_kv = hparams.multi_query_group_num > 0 ? hparams.multi_query_group_num : hparams.n_head; - const int head_size = hparams.n_embd / hparams.n_head; + const auto head_size = hparams.n_embd_head_k == 0 ? hparams.n_embd / hparams.n_head : hparams.n_embd_head_k; #ifdef NS_TP_MODEL // when use TP, cached kv will also have smaller size parallel_context* p_ctx = init_parallel_context(); @@ -1857,7 +1858,7 @@ static void bestla_model_kv_cache_seq_cpy(struct model_context* ctx, const model /* .src = */ nullptr, /* .dst = */ nullptr, /* .heads_kv = */ heads_kv, - /* .head_size = */ head_size, + /* .head_size = */ static_cast(head_size), /* .seq_off = */ p0, /* .seq_size = */ p1 - p0, /* .seq_max = */ n_ctx, diff --git a/tests/model-test/cpp_graph_inference.sh b/tests/model-test/cpp_graph_inference.sh index 7702f3db6..97faa47a0 100644 --- a/tests/model-test/cpp_graph_inference.sh +++ b/tests/model-test/cpp_graph_inference.sh @@ -159,6 +159,7 @@ model_name_map["phi2"]="microsoft/phi-2" model_name_map["stablelm"]="stabilityai/stablelm-2-1_6b" model_name_map["qwen-1_5"]="Qwen/Qwen1.5-7B-Chat" model_name_map["mixtral"]="mistralai/Mixtral-8x7B-Instruct-v0.1" +model_name_map["gemma-2b"]="google/gemma-2b-it" model_name_map["mixtral-gptq"]="Mixtral-8x7B-Instruct-v0.1-GPTQ" model_name_map["qwen1.5-gptq"]="Qwen/Qwen1.5-7B-Chat-GPTQ" model_name_map["qwen-gptq"]="TheBloke/Qwen-7B-Chat-GPTQ" @@ -167,6 +168,7 @@ model_name_map["falcon7b-gptq"]="Falcon-7B-Instruct-GPTQ" model_name_map["baichuan13b-gptq"]="Baichuan2-13B-Chat-GPTQ" model_name_map["mistral-gptq"]="TheBloke/Mistral-7B-Instruct-v0.2-GPTQ" + function main() { conda_env="$1" model="$2" @@ -297,6 +299,10 @@ function main() { quant_script="./build/bin/quant_mixtral" convert_script="${convert_script}/convert_mixtral.py" infer_cmd="./build/bin/run_mixtral" + elif [[ "${model}" == "gemma-2b" ]]; then + quant_script="./build/bin/quant_gemma" + convert_script="${convert_script}/convert_gemma.py" + infer_cmd="./build/bin/run_gemma" elif [[ "${model}" == *"-gptq" ]]; then infer_cmd="python $working_dir/scripts/python_api_example_for_gptq.py ${model_path}" precision_list+=("default") @@ -440,7 +446,7 @@ function main() { else real_ctx=$ctx # TODO(Zhenzhong): use same ctx for chatglm & baichuan [[ "${model}" == "chatglm2" || "${model}" == "chatglm-6b" || - "${model}" == "baichuan-13b" || "${model}" == "baichuan2-13b" ]] && real_ctx=1300 + "${model}" == "baichuan-13b" || "${model}" == "baichuan2-13b" ]] && real_ctx=2048 if [[ "${model}" == *"gptq" ]]; then NEURAL_SPEED_VERBOSE=1 OMP_NUM_THREADS=$cores_per_instance numactl -m 0 -C 0-$(($cores_per_instance - 1)) $infer_cmd 2>&1 | tee ${WORKSPACE}/${logs_file} || true & else