update and refactor
This commit is contained in:
parent
7d971ee3d9
commit
3f4185b654
3 changed files with 34 additions and 70 deletions
|
@ -166,32 +166,37 @@ class Model:
|
|||
return RefactModel
|
||||
if model_architecture == "PersimmonForCausalLM":
|
||||
return PersimmonModel
|
||||
if model_architecture == "LlamaForCausalLM":
|
||||
return DeepseekCoderModel
|
||||
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
|
||||
return StableLMModel
|
||||
return Model
|
||||
|
||||
@staticmethod
|
||||
def from_model_name(model_name: str):
|
||||
if model_name == "StableLMEpoch":
|
||||
model_name_lower = model_name.lower()
|
||||
if model_name_lower == "stablelmepoch":
|
||||
return StableLMModel
|
||||
if model_name == "GPTNeoX":
|
||||
if model_name_lower == "gptneox":
|
||||
return GPTNeoXModel
|
||||
if model_name == "Bloom":
|
||||
if model_name_lower == "bloom":
|
||||
return BloomModel
|
||||
if model_name == "MPT":
|
||||
if model_name_lower == "mpt":
|
||||
return MPTModel
|
||||
if model_name in ("Baichuan", "BaiChuan"):
|
||||
if model_name_lower in ("baichuan", "baichuan"):
|
||||
return BaichuanModel
|
||||
if model_name in ("Falcon", "RW"):
|
||||
if model_name_lower in ("falcon", "rw"):
|
||||
return FalconModel
|
||||
if model_name == "GPTBigCode":
|
||||
if model_name_lower == "gptbigcode":
|
||||
return StarCoderModel
|
||||
if model_name == "GPTRefact":
|
||||
if model_name_lower == "gptrefact":
|
||||
return RefactModel
|
||||
if model_name == "Persimmon":
|
||||
if model_name_lower == "persimmon":
|
||||
return PersimmonModel
|
||||
if model_name == "DeepseekCoder":
|
||||
if model_name_lower == "deepseekcoder":
|
||||
return DeepseekCoderModel
|
||||
if model_name_lower == "stablelm":
|
||||
return StableLMModel
|
||||
return Model
|
||||
|
||||
def _is_model_safetensors(self) -> bool:
|
||||
|
@ -232,7 +237,7 @@ class Model:
|
|||
|
||||
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
||||
|
||||
def _set_vocab_gpt2(self):
|
||||
def _set_vocab_gpt2(self, tokenizer_model:str = "gpt2"):
|
||||
dir_model = self.dir_model
|
||||
hparams = self.hparams
|
||||
tokens: list[bytearray] = []
|
||||
|
@ -261,7 +266,7 @@ class Model:
|
|||
tokens.append(reverse_vocab[i])
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_tokenizer_model(tokenizer_model)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
|
@ -842,20 +847,15 @@ class PersimmonModel(Model):
|
|||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
|
||||
class DeepseekCoderModel(Model):
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
super().set_gguf_parameters()
|
||||
print(self.dir_model.name)
|
||||
head_count = self.hparams["num_attention_heads"]
|
||||
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
||||
ctx_length = self.hparams["max_position_embeddings"]
|
||||
|
||||
self.gguf_writer.add_name("deepseek_coder")
|
||||
self.gguf_writer.add_context_length(ctx_length)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count(head_count)
|
||||
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(head_count_kv)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
|
||||
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
|
||||
|
@ -866,43 +866,7 @@ class DeepseekCoderModel(Model):
|
|||
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
|
||||
|
||||
def set_vocab(self):
|
||||
dir_model = self.dir_model
|
||||
hparams = self.hparams
|
||||
tokens: list[bytearray] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
from transformers import AutoTokenizer # type: ignore[attr-defined]
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model)
|
||||
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
|
||||
assert max(tokenizer.vocab.values()) < vocab_size
|
||||
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
special_tokens = tokenizer.all_special_tokens
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
pad_token = f"[PAD{i}]".encode('utf-8')
|
||||
tokens.append(bytearray(pad_token))
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
elif reverse_vocab[i] in added_vocab:
|
||||
tokens.append(reverse_vocab[i])
|
||||
if reverse_vocab[i] in special_tokens:
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
else:
|
||||
tokens.append(reverse_vocab[i])
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("deepseek_coder")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
|
||||
|
||||
self._set_vocab_gpt2("deepseek_coder")
|
||||
|
||||
|
||||
class StableLMModel(Model):
|
||||
|
|
24
llama.cpp
24
llama.cpp
|
@ -2300,10 +2300,10 @@ static void llm_load_vocab(
|
|||
vocab.special_sep_id = -1;
|
||||
vocab.special_pad_id = -1;
|
||||
} else if (tokenizer_name == "gpt2" || tokenizer_name == "deepseek_coder") {
|
||||
if(tokenizer_name == "gpt2"){
|
||||
if(tokenizer_name == "gpt2") {
|
||||
vocab.type = LLAMA_VOCAB_TYPE_BPE;
|
||||
}
|
||||
else if (tokenizer_name == "deepseek_coder"){
|
||||
else if (tokenizer_name == "deepseek_coder") {
|
||||
vocab.type = LLAMA_VOCAB_TYPE_DEEPSEEKCODER;
|
||||
}
|
||||
|
||||
|
@ -2502,7 +2502,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
|||
// hparams
|
||||
LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver));
|
||||
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
|
||||
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : (vocab.type ==LLAMA_VOCAB_TYPE_BPE ? "BPE" : "DEEPSEEKCODER")); // TODO: fix
|
||||
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : (vocab.type == LLAMA_VOCAB_TYPE_BPE ? "BPE" : "DEEPSEEKCODER")); // TODO: fix
|
||||
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
|
||||
LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size());
|
||||
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
|
||||
|
@ -5984,7 +5984,7 @@ private:
|
|||
work_queue.push(bigram);
|
||||
}
|
||||
|
||||
std::vector<std::string> byte_encoding_process(const std::vector<std::string> &bpe_words){
|
||||
std::vector<std::string> byte_encoding_process(const std::vector<std::string> &bpe_words) {
|
||||
std::vector<std::string>bpe_encoded_words;
|
||||
for (auto word : bpe_words) {
|
||||
std::string text_utf = "";
|
||||
|
@ -6001,12 +6001,12 @@ private:
|
|||
return bpe_encoded_words;
|
||||
}
|
||||
|
||||
std::vector<std::string> regex_preprocess(const std::vector<std::string> &input, const std::string & regex_expr){
|
||||
std::vector<std::string> regex_preprocess(const std::vector<std::string> &input, const std::string & regex_expr) {
|
||||
std::regex expr(regex_expr);
|
||||
std::vector<std::string> bpe_words;
|
||||
// std::wsmatch m;
|
||||
// // use regex match to get where to split the test string
|
||||
for(auto& text:input){
|
||||
for(auto& text:input) {
|
||||
std::cregex_iterator it(text.data(), text.data() + text.size(), expr);
|
||||
std::cregex_iterator end;
|
||||
|
||||
|
@ -6015,14 +6015,14 @@ private:
|
|||
while (it != end) {
|
||||
std::cmatch match = *it;
|
||||
std::string match_str = match.str();
|
||||
if(match.position()>start_idx){
|
||||
if(match.position()>start_idx) {
|
||||
bpe_words.emplace_back(text.substr(start_idx, match.position()-start_idx));
|
||||
}
|
||||
bpe_words.emplace_back(match_str);
|
||||
start_idx = match.position() + match.length();
|
||||
++it;
|
||||
}
|
||||
if(start_idx < text.size()){
|
||||
if(start_idx < text.size()) {
|
||||
bpe_words.emplace_back(text.substr(start_idx, text.size()-start_idx));
|
||||
}
|
||||
}
|
||||
|
@ -6033,7 +6033,7 @@ private:
|
|||
|
||||
std::vector<std::string> bpe_words = {text};
|
||||
|
||||
for(auto & regex_expr : gpt2_regex){
|
||||
for(auto & regex_expr : gpt2_regex) {
|
||||
bpe_words = regex_preprocess(bpe_words, regex_expr);
|
||||
}
|
||||
|
||||
|
@ -6056,18 +6056,18 @@ private:
|
|||
while (it != end) {
|
||||
std::wcmatch match = *it;
|
||||
std::wstring match_str = match.str();
|
||||
if(match.position()>start_idx){
|
||||
if(match.position()>start_idx) {
|
||||
bpe_words.emplace_back(to_utf8(wtext.substr(start_idx, match.position()-start_idx)));
|
||||
}
|
||||
bpe_words.emplace_back(to_utf8(match_str));
|
||||
start_idx = match.position() + match.length();
|
||||
++it;
|
||||
}
|
||||
if(start_idx < wtext.size()){
|
||||
if(start_idx < wtext.size()) {
|
||||
bpe_words.emplace_back(to_utf8(wtext.substr(start_idx, wtext.size()-start_idx)));
|
||||
}
|
||||
|
||||
for(auto & regex_expr : deepseek_coder_regex){
|
||||
for(auto & regex_expr : deepseek_coder_regex) {
|
||||
bpe_words = regex_preprocess(bpe_words, regex_expr);
|
||||
}
|
||||
|
||||
|
|
Binary file not shown.
Loading…
Add table
Add a link
Reference in a new issue