Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update gpt2 preprocess and add deepseek coder preprocess #4070

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 24 additions & 58 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,32 +166,37 @@ def from_model_architecture(model_architecture):
return RefactModel
if model_architecture == "PersimmonForCausalLM":
return PersimmonModel
if model_architecture == "LlamaForCausalLM":
return DeepseekCoderModel
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
return StableLMModel
return Model

@staticmethod
def from_model_name(model_name: str):
DOGEwbx marked this conversation as resolved.
Show resolved Hide resolved
if model_name == "StableLMEpoch":
model_name_lower = model_name.lower()
if model_name_lower == "stablelmepoch":
return StableLMModel
DOGEwbx marked this conversation as resolved.
Show resolved Hide resolved
if model_name == "GPTNeoX":
if model_name_lower == "gptneox":
return GPTNeoXModel
if model_name == "Bloom":
if model_name_lower == "bloom":
return BloomModel
if model_name == "MPT":
if model_name_lower == "mpt":
return MPTModel
if model_name in ("Baichuan", "BaiChuan"):
if model_name_lower in ("baichuan", "baichuan"):
return BaichuanModel
if model_name in ("Falcon", "RW"):
if model_name_lower in ("falcon", "rw"):
return FalconModel
if model_name == "GPTBigCode":
if model_name_lower == "gptbigcode":
return StarCoderModel
if model_name == "GPTRefact":
if model_name_lower == "gptrefact":
return RefactModel
if model_name == "Persimmon":
if model_name_lower == "persimmon":
return PersimmonModel
if model_name == "DeepseekCoder":
if model_name_lower == "deepseekcoder":
return DeepseekCoderModel
if model_name_lower == "stablelm":
return StableLMModel
return Model

def _is_model_safetensors(self) -> bool:
Expand Down Expand Up @@ -227,10 +232,12 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
return gguf.MODEL_ARCH.PERSIMMON
if arch == "LlamaForCausalLM":
return gguf.MODEL_ARCH.LLAMA
if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
return gguf.MODEL_ARCH.STABLELM

raise NotImplementedError(f'Architecture "{arch}" not supported!')

def _set_vocab_gpt2(self):
def _set_vocab_gpt2(self, tokenizer_model:str = "gpt2"):
dir_model = self.dir_model
hparams = self.hparams
tokens: list[bytearray] = []
Expand Down Expand Up @@ -259,7 +266,7 @@ def _set_vocab_gpt2(self):
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)

self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_model(tokenizer_model)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)

Expand Down Expand Up @@ -840,20 +847,15 @@ def write_tensors(self):
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)


class DeepseekCoderModel(Model):
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
super().set_gguf_parameters()
print(self.dir_model.name)
DOGEwbx marked this conversation as resolved.
Show resolved Hide resolved
head_count = self.hparams["num_attention_heads"]
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
ctx_length = self.hparams["max_position_embeddings"]

self.gguf_writer.add_name("deepseek_coder")
self.gguf_writer.add_context_length(ctx_length)
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count(head_count)
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(head_count_kv)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
Expand All @@ -864,43 +866,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])

def set_vocab(self):
DOGEwbx marked this conversation as resolved.
Show resolved Hide resolved
dir_model = self.dir_model
hparams = self.hparams
tokens: list[bytearray] = []
toktypes: list[int] = []

from transformers import AutoTokenizer # type: ignore[attr-defined]
tokenizer = AutoTokenizer.from_pretrained(dir_model)
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
assert max(tokenizer.vocab.values()) < vocab_size

reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
added_vocab = tokenizer.get_added_vocab()
special_tokens = tokenizer.all_special_tokens
for i in range(vocab_size):
if i not in reverse_vocab:
pad_token = f"[PAD{i}]".encode('utf-8')
tokens.append(bytearray(pad_token))
toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
if reverse_vocab[i] in special_tokens:
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)

self.gguf_writer.add_tokenizer_model("deepseek_coder")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
special_vocab.add_to_gguf(self.gguf_writer)



self._set_vocab_gpt2("deepseek_coder")


class StableLMModel(Model):
Expand Down
24 changes: 12 additions & 12 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2300,10 +2300,10 @@ static void llm_load_vocab(
vocab.special_sep_id = -1;
vocab.special_pad_id = -1;
} else if (tokenizer_name == "gpt2" || tokenizer_name == "deepseek_coder") {
if(tokenizer_name == "gpt2"){
if(tokenizer_name == "gpt2") {
vocab.type = LLAMA_VOCAB_TYPE_BPE;
}
else if (tokenizer_name == "deepseek_coder"){
else if (tokenizer_name == "deepseek_coder") {
vocab.type = LLAMA_VOCAB_TYPE_DEEPSEEKCODER;
}

Expand Down Expand Up @@ -2502,7 +2502,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
// hparams
LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver));
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : (vocab.type ==LLAMA_VOCAB_TYPE_BPE ? "BPE" : "DEEPSEEKCODER")); // TODO: fix
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : (vocab.type == LLAMA_VOCAB_TYPE_BPE ? "BPE" : "DEEPSEEKCODER")); // TODO: fix
Copy link
Contributor

@teleprint-me teleprint-me Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is SPM?

>>> from transformers import AutoTokenizer
>>> model_directory = "models/deepseek-ai/deepseek-coder-6.7b-instruct"
>>> tokenizer = AutoTokenizer.from_pretrained(model_directory)
>>> tokenizer.vocab_files_names
{'vocab_file': 'tokenizer.model', 'tokenizer_file': 'tokenizer.json'}

LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size());
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
Expand Down Expand Up @@ -5984,7 +5984,7 @@ struct llm_tokenizer_bpe {
work_queue.push(bigram);
}

std::vector<std::string> byte_encoding_process(const std::vector<std::string> &bpe_words){
std::vector<std::string> byte_encoding_process(const std::vector<std::string> &bpe_words) {
std::vector<std::string>bpe_encoded_words;
for (auto word : bpe_words) {
std::string text_utf = "";
Expand All @@ -6001,12 +6001,12 @@ struct llm_tokenizer_bpe {
return bpe_encoded_words;
}

std::vector<std::string> regex_preprocess(const std::vector<std::string> &input, const std::string & regex_expr){
std::vector<std::string> regex_preprocess(const std::vector<std::string> &input, const std::string & regex_expr) {
std::regex expr(regex_expr);
std::vector<std::string> bpe_words;
// std::wsmatch m;
// // use regex match to get where to split the test string
for(auto& text:input){
for(auto& text:input) {
std::cregex_iterator it(text.data(), text.data() + text.size(), expr);
std::cregex_iterator end;

Expand All @@ -6015,14 +6015,14 @@ struct llm_tokenizer_bpe {
while (it != end) {
std::cmatch match = *it;
std::string match_str = match.str();
if(match.position()>start_idx){
if(match.position()>start_idx) {
bpe_words.emplace_back(text.substr(start_idx, match.position()-start_idx));
}
bpe_words.emplace_back(match_str);
start_idx = match.position() + match.length();
++it;
}
if(start_idx < text.size()){
if(start_idx < text.size()) {
bpe_words.emplace_back(text.substr(start_idx, text.size()-start_idx));
}
}
Expand All @@ -6033,7 +6033,7 @@ struct llm_tokenizer_bpe {

std::vector<std::string> bpe_words = {text};

for(auto & regex_expr : gpt2_regex){
for(auto & regex_expr : gpt2_regex) {
bpe_words = regex_preprocess(bpe_words, regex_expr);
}

Expand All @@ -6056,18 +6056,18 @@ struct llm_tokenizer_bpe {
while (it != end) {
std::wcmatch match = *it;
std::wstring match_str = match.str();
if(match.position()>start_idx){
if(match.position()>start_idx) {
bpe_words.emplace_back(to_utf8(wtext.substr(start_idx, match.position()-start_idx)));
}
bpe_words.emplace_back(to_utf8(match_str));
start_idx = match.position() + match.length();
++it;
}
if(start_idx < wtext.size()){
if(start_idx < wtext.size()) {
bpe_words.emplace_back(to_utf8(wtext.substr(start_idx, wtext.size()-start_idx)));
}

for(auto & regex_expr : deepseek_coder_regex){
for(auto & regex_expr : deepseek_coder_regex) {
bpe_words = regex_preprocess(bpe_words, regex_expr);
}

Expand Down
Binary file modified models/ggml-vocab-deepseek-coder.gguf
Binary file not shown.