Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
[LLM Runtime] Enable phi-2&phi-1.5&phi-1 (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
intellinjun authored Jan 25, 2024
1 parent ea4b713 commit c212d89
Show file tree
Hide file tree
Showing 15 changed files with 1,005 additions and 13 deletions.
12 changes: 12 additions & 0 deletions docs/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,18 @@ Neural Speed supports the following models:
<td> </td>
<td> </td>
<td>Latest</td>
</tr>
<tr>
<td><a href="https://huggingface.co/microsoft/phi-2" target="_blank" rel="noopener noreferrer">phi-2</a>,
<a href="https://huggingface.co/microsoft/phi-1_5" target="_blank" rel="noopener noreferrer">phi-1_5</a>
<a href="https://huggingface.co/microsoft/phi-1" target="_blank" rel="noopener noreferrer">phi-1</a></td>
<td>✅</td>
<td> </td>
<td> </td>
<td>✅</td>
<td> </td>
<td> </td>
<td>Latest</td>
</tr>
<tr>
<td><a href="https://huggingface.co/openai/whisper-tiny" target="_blank" rel="noopener noreferrer">Whisper-tiny</a>,
Expand Down
4 changes: 4 additions & 0 deletions neural_speed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def __import_package(self, model_type):
import neural_speed.qwen_cpp as cpp_model
elif model_type == "mistral":
import neural_speed.mistral_cpp as cpp_model
elif model_type == "qwen":
import neural_speed.qwen_cpp as cpp_model
elif model_type == "phi":
import neural_speed.phi_cpp as cpp_model
elif model_type == "whisper":
import neural_speed.whisper_cpp as cpp_model
else:
Expand Down
8 changes: 6 additions & 2 deletions neural_speed/application/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ compile_quant(quant_chatglm quant_model.cpp chatglm chatglm)
compile_quant(quant_chatglm2 quant_model.cpp chatglm2 chatglm2)
compile_quant(quant_baichuan quant_model.cpp baichuan baichuan)
compile_quant(quant_mistral quant_model.cpp mistral llama)
compile_quant(quant_qwen quant_model.cpp qwen qwen)
compile_quant(quant_qwen quant_model.cpp qwen qwen)
compile_quant(quant_phi quant_model.cpp phi phi)
compile_quant(quant_whisper quant_whisper.cpp whisper whisper)

# all models running
Expand All @@ -90,7 +91,9 @@ set(mymap_baichuan 12)
set(mymap_polyglot 13)
set(mymap_mistral 14)
set(mymap_qwen 15)
set(mymap_whisper 16)
set(mymap_phi 16)
set(mymap_whisper 17)



function(compile_run TARGET MAIN_CPP MAIN_PY MODEL_NAME MODEL_LIB)
Expand Down Expand Up @@ -125,6 +128,7 @@ compile_run(run_chatglm main_run.cpp main_pybind.cpp chatglm chatglm)
compile_run(run_baichuan main_run.cpp main_pybind.cpp baichuan baichuan)
compile_run(run_mistral main_run.cpp main_pybind.cpp mistral llama)
compile_run(run_qwen main_run.cpp main_pybind.cpp qwen qwen)
compile_run(run_phi main_run.cpp main_pybind.cpp phi phi)

# speech recognition
compile_run(run_whisper audio_run.cpp whisper_pybind.cpp whisper whisper)
8 changes: 8 additions & 0 deletions neural_speed/application/main_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,14 @@ PYBIND11_MODULE(mistral_cpp, m)

PYBIND11_MODULE(qwen_cpp, m)

#elif MODEL_NAME_ID == 16

PYBIND11_MODULE(phi_cpp, m)

#elif MODEL_NAME_ID == 17

PYBIND11_MODULE(whisper_cpp, m)

#endif
{
m.doc() = "cpp model python binding";
Expand Down
2 changes: 1 addition & 1 deletion neural_speed/application/whisper_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ void Model::inference(const std::string& fname_inp) {
return;
}

#if MODEL_NAME_ID == 16
#if MODEL_NAME_ID == 17

PYBIND11_MODULE(whisper_cpp, m)
#endif
Expand Down
2 changes: 0 additions & 2 deletions neural_speed/convert/convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ def set_gguf_parameters(self):
self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))

def write_tensors(self):
import pdb
pdb.set_trace()
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
for name, data_torch in self.get_tensors():
Expand Down
296 changes: 296 additions & 0 deletions neural_speed/convert/convert_phi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Convert Hugging Face fine-tuned gpt-neox-like models to ne format
#
# Usage:
#
# python3 models/convert-h5-to-ne.py
#
# This script is similar to "convert-pt-to-ne.py"
#

import struct
import numpy as np
from pathlib import Path
import argparse
from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar,
Union)
from transformers import AutoModelForCausalLM, AutoTokenizer
import gguf

# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))

def phi_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams):
print("phi.gguf converting: ")
list_vars = model.state_dict()
n_rot = int(hparams["partial_rotary_factor"]*hparams["hidden_size"]/hparams["num_attention_heads"])
for name in list_vars.keys():
print(name, list_vars[name].shape, list_vars[name].dtype)

print(hparams)

gguf_file = fname_out + '.gguf'
gguf_writer = gguf.GGUFWriter(gguf_file, "phi")

gguf_writer.add_uint32('magic', 0x67676d66)
gguf_writer.add_uint32('version', 1)
gguf_writer.add_uint32('n_vocab', hparams["vocab_size"])
gguf_writer.add_embedding_length(hparams["hidden_size"])
gguf_writer.add_head_count(hparams["num_attention_heads"])
gguf_writer.add_head_count_kv(hparams["num_key_value_heads"])

gguf_writer.add_block_count(hparams["num_hidden_layers"])
gguf_writer.add_rope_dimension_count(n_rot)
gguf_writer.add_uint32('ftype', ftype)
gguf_writer.add_context_length(hparams["max_position_embeddings"])
gguf_writer.add_feed_forward_length(hparams["intermediate_size"])

gguf_writer.add_bos_token_id(tokenizer.bos_token_id)
gguf_writer.add_eos_token_id(tokenizer.eos_token_id)
gguf_writer.add_pad_token_id(tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0)
gguf_writer.add_sep_token_id(tokenizer.sep_token_id if tokenizer.sep_token_id is not None else 0)

def write_vocab_gguf(dir_model, hparams, gguf_writer):
tokens: list[bytearray] = []
toktypes: list[int] = []

from transformers import AutoTokenizer # type: ignore[attr-defined]
tokenizer = AutoTokenizer.from_pretrained(dir_model)
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
assert max(tokenizer.vocab.values()) < vocab_size

reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
added_vocab = tokenizer.get_added_vocab()

for i in range(vocab_size):
if i not in reverse_vocab:
pad_token = f"[PAD{i}]".encode('utf-8')
tokens.append(bytearray(pad_token))
toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
if tokenizer.added_tokens_decoder[i].special:
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)

gguf_writer.add_tokenizer_model("gpt2")
gguf_writer.add_token_list(tokens)
gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
special_vocab.add_to_gguf(gguf_writer)

write_vocab_gguf(dir_model, hparams, gguf_writer)

# tensor info
print("gguf: get tensor metadata")
for name in list_vars.keys():
data = list_vars[name].squeeze().numpy()

print("Processing variable: " + name + " with shape: ", data.shape)
if 'inv_freq' in name:
continue

n_dims = len(data.shape)

# ftype == 0 -> float32, ftype == 1 -> float16
ftype_cur = 0
if ftype != 0:
if name[-7:] == ".weight" and n_dims == 2:
print(" Converting to float16")
data = data.astype(np.float16)
ftype_cur = 1
else:
print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0
else:
if data.dtype != np.float32:
print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0

# print(f"[{i+1:{padi}d}/{len(model)}]
# Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4}")

gguf_writer.add_tensor(name, data)

print("gguf: write header")
gguf_writer.write_header_to_file()
print("gguf: write metadata")
gguf_writer.write_kv_data_to_file()
print("gguf: write tensors")
gguf_writer.write_tensors_to_file()

gguf_writer.close()

print("Done. Output file: " + fname_out)
print("")

def phi_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
n_rot = int(hparams["partial_rotary_factor"]*hparams["hidden_size"]/hparams["num_attention_heads"])
model.eval()
for p in model.parameters():
p.requires_grad = False
hparams = model.config.to_dict()
print("Model loaded: ", dir_model)

fout = open(fname_out, "wb")

# 0x67676d6c is unversioned ne
# 0x67676d66 is versioned ggmf (requires token scores)
ne_file_magic = 0x67676d66
#ne_file_version = 0x00000001 # v1

fout.write(struct.pack("i", ne_file_magic)) # magic: ne in hex
fout.write(struct.pack("i", 1))
fout.write(struct.pack("i", hparams["vocab_size"]))
fout.write(struct.pack("i", hparams["hidden_size"]))
fout.write(struct.pack("i", hparams["intermediate_size"])) # dummy data
fout.write(struct.pack("i", hparams["num_attention_heads"]))
fout.write(struct.pack("i", hparams["num_key_value_heads"])) # multi-query attention
fout.write(struct.pack("i", hparams["num_hidden_layers"]))
fout.write(struct.pack("i", n_rot))
fout.write(struct.pack("i", ftype))
fout.write(struct.pack("i", hparams["max_position_embeddings"]))
fout.write(struct.pack("f", 0.0))
fout.write(struct.pack("f", 0.0))
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt)
fout.write(struct.pack("i", 0)) # do_layer_norm_before (for opt)

fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0))
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor
fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else -1))
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1))
fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1))
fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1))

for i in range(hparams["vocab_size"]):
if i < tokenizer.vocab_size:
text = tokenizer.decode([i]).encode('utf-8')
fout.write(struct.pack("i", len(text)))
fout.write(text)
fout.write(struct.pack("f", 0.0 - i))
else:
text = tokenizer.decode([tokenizer.vocab_size - 1]).encode('utf-8')
fout.write(struct.pack("i", len(text)))
fout.write(text)
fout.write(struct.pack("f", -10000))

list_vars = model.state_dict()

print(hparams)

for name in list_vars.keys():
# No gradients for these
list_vars[name].requires_grad = False
src = name
print(src, ' -> ', name)
data = list_vars[src].squeeze().numpy()
data = data.astype(np.float32)

n_dims = len(data.shape)
print(name, n_dims, data.shape)

# default type is fp32
ftype_cur = 0
if ftype == 1 and n_dims > 1:
print(" Converting to float16", data.shape, data[:3, :3].tolist())
data = data.astype(np.float16)
ftype_cur = 1
else:
print(" Converting to float32", data.shape, data[:3, :3].tolist() if n_dims > 1 else data[:3].tolist())
data = data.astype(np.float32)

# header
str = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
print(str)
fout.write(str)

# data
data.tofile(fout)

fout.close()

print("Done. Output file: " + fname_out)
print("")

def main(args_in: Optional[List[str]] = None) -> None:
parser = argparse.ArgumentParser(description="Convert a model to a NE compatible file")
parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file")
parser.add_argument("--format",
type=str,
default="NE",
choices=["NE", "GGUF"],
help="convert to the GGUF or NE format")
args = parser.parse_args(args_in)

dir_model = args.model.as_posix()
fname_out = args.outfile.as_posix()

# possible data types
# ftype == 0 -> float32
# ftype == 1 -> float16
ftype = 0
if args.outtype == "f16":
ftype = 1

tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
print("Loading model: ", dir_model)
model = AutoModelForCausalLM.from_pretrained(dir_model, trust_remote_code=True)
hparams = model.config.to_dict()
if args.format == "GGUF":
phi_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams)
else:
phi_convert(model, tokenizer, dir_model, fname_out, ftype, hparams)



if __name__ == '__main__':
main()

1 change: 1 addition & 0 deletions neural_speed/models/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ add_model(qwen qwen/qwen.cpp qwen/qwen_utils.cpp ${MODEL_UTILS_SOURCE})
add_model(whisper whisper/whisper.cpp whisper/whisper_utils.cpp ${MODEL_UTILS_SOURCE})
add_model(chatglm chatglm/chatglm.cpp chatglm/chatglm_utils.cpp ${MODEL_UTILS_SOURCE})
add_model(chatglm2 chatglm/chatglm2.cpp chatglm/chatglm2_utils.cpp ${MODEL_UTILS_SOURCE})
add_model(phi phi/phi.cpp phi/phi_utils.cpp ${MODEL_UTILS_SOURCE})
Loading

0 comments on commit c212d89

Please sign in to comment.