update and refactor

This commit is contained in:
Bingxuan Wang 2023-11-16 13:58:59 +08:00
parent 7d971ee3d9
commit 3f4185b654
3 changed files with 34 additions and 70 deletions

View file

@ -166,32 +166,37 @@ class Model:
return RefactModel return RefactModel
if model_architecture == "PersimmonForCausalLM": if model_architecture == "PersimmonForCausalLM":
return PersimmonModel return PersimmonModel
if model_architecture == "LlamaForCausalLM":
return DeepseekCoderModel
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"): if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
return StableLMModel return StableLMModel
return Model return Model
@staticmethod @staticmethod
def from_model_name(model_name: str): def from_model_name(model_name: str):
if model_name == "StableLMEpoch": model_name_lower = model_name.lower()
if model_name_lower == "stablelmepoch":
return StableLMModel return StableLMModel
if model_name == "GPTNeoX": if model_name_lower == "gptneox":
return GPTNeoXModel return GPTNeoXModel
if model_name == "Bloom": if model_name_lower == "bloom":
return BloomModel return BloomModel
if model_name == "MPT": if model_name_lower == "mpt":
return MPTModel return MPTModel
if model_name in ("Baichuan", "BaiChuan"): if model_name_lower in ("baichuan", "baichuan"):
return BaichuanModel return BaichuanModel
if model_name in ("Falcon", "RW"): if model_name_lower in ("falcon", "rw"):
return FalconModel return FalconModel
if model_name == "GPTBigCode": if model_name_lower == "gptbigcode":
return StarCoderModel return StarCoderModel
if model_name == "GPTRefact": if model_name_lower == "gptrefact":
return RefactModel return RefactModel
if model_name == "Persimmon": if model_name_lower == "persimmon":
return PersimmonModel return PersimmonModel
if model_name == "DeepseekCoder": if model_name_lower == "deepseekcoder":
return DeepseekCoderModel return DeepseekCoderModel
if model_name_lower == "stablelm":
return StableLMModel
return Model return Model
def _is_model_safetensors(self) -> bool: def _is_model_safetensors(self) -> bool:
@ -232,7 +237,7 @@ class Model:
raise NotImplementedError(f'Architecture "{arch}" not supported!') raise NotImplementedError(f'Architecture "{arch}" not supported!')
def _set_vocab_gpt2(self): def _set_vocab_gpt2(self, tokenizer_model:str = "gpt2"):
dir_model = self.dir_model dir_model = self.dir_model
hparams = self.hparams hparams = self.hparams
tokens: list[bytearray] = [] tokens: list[bytearray] = []
@ -261,7 +266,7 @@ class Model:
tokens.append(reverse_vocab[i]) tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL) toktypes.append(gguf.TokenType.NORMAL)
self.gguf_writer.add_tokenizer_model("gpt2") self.gguf_writer.add_tokenizer_model(tokenizer_model)
self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes) self.gguf_writer.add_token_types(toktypes)
@ -842,20 +847,15 @@ class PersimmonModel(Model):
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data) self.gguf_writer.add_tensor(new_name, data)
class DeepseekCoderModel(Model): class DeepseekCoderModel(Model):
def set_gguf_parameters(self): def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"] super().set_gguf_parameters()
print(self.dir_model.name)
head_count = self.hparams["num_attention_heads"] head_count = self.hparams["num_attention_heads"]
head_count_kv = self.hparams.get("num_key_value_heads", head_count) head_count_kv = self.hparams.get("num_key_value_heads", head_count)
ctx_length = self.hparams["max_position_embeddings"]
self.gguf_writer.add_name("deepseek_coder")
self.gguf_writer.add_context_length(ctx_length)
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count(head_count) self.gguf_writer.add_head_count(head_count)
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(head_count_kv) self.gguf_writer.add_head_count_kv(head_count_kv)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"]) self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
@ -866,43 +866,7 @@ class DeepseekCoderModel(Model):
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
def set_vocab(self): def set_vocab(self):
dir_model = self.dir_model self._set_vocab_gpt2("deepseek_coder")
hparams = self.hparams
tokens: list[bytearray] = []
toktypes: list[int] = []
from transformers import AutoTokenizer # type: ignore[attr-defined]
tokenizer = AutoTokenizer.from_pretrained(dir_model)
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
assert max(tokenizer.vocab.values()) < vocab_size
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
added_vocab = tokenizer.get_added_vocab()
special_tokens = tokenizer.all_special_tokens
for i in range(vocab_size):
if i not in reverse_vocab:
pad_token = f"[PAD{i}]".encode('utf-8')
tokens.append(bytearray(pad_token))
toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
if reverse_vocab[i] in special_tokens:
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)
self.gguf_writer.add_tokenizer_model("deepseek_coder")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
special_vocab.add_to_gguf(self.gguf_writer)
class StableLMModel(Model): class StableLMModel(Model):

View file

@ -2300,10 +2300,10 @@ static void llm_load_vocab(
vocab.special_sep_id = -1; vocab.special_sep_id = -1;
vocab.special_pad_id = -1; vocab.special_pad_id = -1;
} else if (tokenizer_name == "gpt2" || tokenizer_name == "deepseek_coder") { } else if (tokenizer_name == "gpt2" || tokenizer_name == "deepseek_coder") {
if(tokenizer_name == "gpt2"){ if(tokenizer_name == "gpt2") {
vocab.type = LLAMA_VOCAB_TYPE_BPE; vocab.type = LLAMA_VOCAB_TYPE_BPE;
} }
else if (tokenizer_name == "deepseek_coder"){ else if (tokenizer_name == "deepseek_coder") {
vocab.type = LLAMA_VOCAB_TYPE_DEEPSEEKCODER; vocab.type = LLAMA_VOCAB_TYPE_DEEPSEEKCODER;
} }
@ -2502,7 +2502,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
// hparams // hparams
LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver)); LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver));
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str()); LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : (vocab.type ==LLAMA_VOCAB_TYPE_BPE ? "BPE" : "DEEPSEEKCODER")); // TODO: fix LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : (vocab.type == LLAMA_VOCAB_TYPE_BPE ? "BPE" : "DEEPSEEKCODER")); // TODO: fix
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size()); LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size());
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
@ -5984,7 +5984,7 @@ private:
work_queue.push(bigram); work_queue.push(bigram);
} }
std::vector<std::string> byte_encoding_process(const std::vector<std::string> &bpe_words){ std::vector<std::string> byte_encoding_process(const std::vector<std::string> &bpe_words) {
std::vector<std::string>bpe_encoded_words; std::vector<std::string>bpe_encoded_words;
for (auto word : bpe_words) { for (auto word : bpe_words) {
std::string text_utf = ""; std::string text_utf = "";
@ -6001,12 +6001,12 @@ private:
return bpe_encoded_words; return bpe_encoded_words;
} }
std::vector<std::string> regex_preprocess(const std::vector<std::string> &input, const std::string & regex_expr){ std::vector<std::string> regex_preprocess(const std::vector<std::string> &input, const std::string & regex_expr) {
std::regex expr(regex_expr); std::regex expr(regex_expr);
std::vector<std::string> bpe_words; std::vector<std::string> bpe_words;
// std::wsmatch m; // std::wsmatch m;
// // use regex match to get where to split the test string // // use regex match to get where to split the test string
for(auto& text:input){ for(auto& text:input) {
std::cregex_iterator it(text.data(), text.data() + text.size(), expr); std::cregex_iterator it(text.data(), text.data() + text.size(), expr);
std::cregex_iterator end; std::cregex_iterator end;
@ -6015,14 +6015,14 @@ private:
while (it != end) { while (it != end) {
std::cmatch match = *it; std::cmatch match = *it;
std::string match_str = match.str(); std::string match_str = match.str();
if(match.position()>start_idx){ if(match.position()>start_idx) {
bpe_words.emplace_back(text.substr(start_idx, match.position()-start_idx)); bpe_words.emplace_back(text.substr(start_idx, match.position()-start_idx));
} }
bpe_words.emplace_back(match_str); bpe_words.emplace_back(match_str);
start_idx = match.position() + match.length(); start_idx = match.position() + match.length();
++it; ++it;
} }
if(start_idx < text.size()){ if(start_idx < text.size()) {
bpe_words.emplace_back(text.substr(start_idx, text.size()-start_idx)); bpe_words.emplace_back(text.substr(start_idx, text.size()-start_idx));
} }
} }
@ -6033,7 +6033,7 @@ private:
std::vector<std::string> bpe_words = {text}; std::vector<std::string> bpe_words = {text};
for(auto & regex_expr : gpt2_regex){ for(auto & regex_expr : gpt2_regex) {
bpe_words = regex_preprocess(bpe_words, regex_expr); bpe_words = regex_preprocess(bpe_words, regex_expr);
} }
@ -6056,18 +6056,18 @@ private:
while (it != end) { while (it != end) {
std::wcmatch match = *it; std::wcmatch match = *it;
std::wstring match_str = match.str(); std::wstring match_str = match.str();
if(match.position()>start_idx){ if(match.position()>start_idx) {
bpe_words.emplace_back(to_utf8(wtext.substr(start_idx, match.position()-start_idx))); bpe_words.emplace_back(to_utf8(wtext.substr(start_idx, match.position()-start_idx)));
} }
bpe_words.emplace_back(to_utf8(match_str)); bpe_words.emplace_back(to_utf8(match_str));
start_idx = match.position() + match.length(); start_idx = match.position() + match.length();
++it; ++it;
} }
if(start_idx < wtext.size()){ if(start_idx < wtext.size()) {
bpe_words.emplace_back(to_utf8(wtext.substr(start_idx, wtext.size()-start_idx))); bpe_words.emplace_back(to_utf8(wtext.substr(start_idx, wtext.size()-start_idx)));
} }
for(auto & regex_expr : deepseek_coder_regex){ for(auto & regex_expr : deepseek_coder_regex) {
bpe_words = regex_preprocess(bpe_words, regex_expr); bpe_words = regex_preprocess(bpe_words, regex_expr);
} }