add new gpt2
This commit is contained in:
parent
21fd874c8d
commit
5600bd8cbc
3 changed files with 130 additions and 126 deletions
|
@ -169,6 +169,30 @@ class Model:
|
||||||
if model_architecture == "PersimmonForCausalLM":
|
if model_architecture == "PersimmonForCausalLM":
|
||||||
return PersimmonModel
|
return PersimmonModel
|
||||||
return Model
|
return Model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_model_name(model_name: str):
|
||||||
|
if model_name == "StableLMEpoch":
|
||||||
|
return StableLMModel
|
||||||
|
if model_name == "GPTNeoX":
|
||||||
|
return GPTNeoXModel
|
||||||
|
if model_name == "Bloom":
|
||||||
|
return BloomModel
|
||||||
|
if model_name == "MPT":
|
||||||
|
return MPTModel
|
||||||
|
if model_name in ("Baichuan", "BaiChuan"):
|
||||||
|
return BaichuanModel
|
||||||
|
if model_name in ("Falcon", "RW"):
|
||||||
|
return FalconModel
|
||||||
|
if model_name == "GPTBigCode":
|
||||||
|
return StarCoderModel
|
||||||
|
if model_name == "GPTRefact":
|
||||||
|
return RefactModel
|
||||||
|
if model_name == "Persimmon":
|
||||||
|
return PersimmonModel
|
||||||
|
if model_name == "DeepseekCoder":
|
||||||
|
return DeepseekCoderModel
|
||||||
|
return Model
|
||||||
|
|
||||||
def _is_model_safetensors(self) -> bool:
|
def _is_model_safetensors(self) -> bool:
|
||||||
return Model.count_model_parts(self.dir_model, ".safetensors") > 0
|
return Model.count_model_parts(self.dir_model, ".safetensors") > 0
|
||||||
|
@ -201,6 +225,8 @@ class Model:
|
||||||
return gguf.MODEL_ARCH.REFACT
|
return gguf.MODEL_ARCH.REFACT
|
||||||
if arch == "PersimmonForCausalLM":
|
if arch == "PersimmonForCausalLM":
|
||||||
return gguf.MODEL_ARCH.PERSIMMON
|
return gguf.MODEL_ARCH.PERSIMMON
|
||||||
|
if arch == "LlamaForCausalLM":
|
||||||
|
return gguf.MODEL_ARCH.LLAMA
|
||||||
|
|
||||||
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
||||||
|
|
||||||
|
@ -823,6 +849,68 @@ class PersimmonModel(Model):
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
class DeepseekCoderModel(Model):
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
block_count = self.hparams["num_hidden_layers"]
|
||||||
|
head_count = self.hparams["num_attention_heads"]
|
||||||
|
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
||||||
|
ctx_length = self.hparams["max_position_embeddings"]
|
||||||
|
|
||||||
|
self.gguf_writer.add_name("deepseek_coder")
|
||||||
|
self.gguf_writer.add_context_length(ctx_length)
|
||||||
|
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||||
|
self.gguf_writer.add_block_count(block_count)
|
||||||
|
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||||
|
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||||
|
self.gguf_writer.add_head_count(head_count)
|
||||||
|
self.gguf_writer.add_head_count_kv(head_count_kv)
|
||||||
|
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
|
||||||
|
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
|
||||||
|
|
||||||
|
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
|
||||||
|
if self.hparams["rope_scaling"].get("type") == "linear":
|
||||||
|
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||||
|
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
|
||||||
|
|
||||||
|
def set_vocab(self):
|
||||||
|
dir_model = self.dir_model
|
||||||
|
hparams = self.hparams
|
||||||
|
tokens: list[bytearray] = []
|
||||||
|
toktypes: list[int] = []
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer # type: ignore[attr-defined]
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(dir_model)
|
||||||
|
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
|
||||||
|
assert max(tokenizer.vocab.values()) < vocab_size
|
||||||
|
|
||||||
|
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
|
||||||
|
added_vocab = tokenizer.get_added_vocab()
|
||||||
|
special_tokens = tokenizer.all_special_tokens
|
||||||
|
for i in range(vocab_size):
|
||||||
|
if i not in reverse_vocab:
|
||||||
|
pad_token = f"[PAD{i}]".encode('utf-8')
|
||||||
|
tokens.append(bytearray(pad_token))
|
||||||
|
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||||
|
elif reverse_vocab[i] in added_vocab:
|
||||||
|
tokens.append(reverse_vocab[i])
|
||||||
|
if reverse_vocab[i] in special_tokens:
|
||||||
|
toktypes.append(gguf.TokenType.CONTROL)
|
||||||
|
else:
|
||||||
|
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||||
|
else:
|
||||||
|
tokens.append(reverse_vocab[i])
|
||||||
|
toktypes.append(gguf.TokenType.NORMAL)
|
||||||
|
|
||||||
|
self.gguf_writer.add_tokenizer_model("deepseek_coder")
|
||||||
|
self.gguf_writer.add_token_list(tokens)
|
||||||
|
self.gguf_writer.add_token_types(toktypes)
|
||||||
|
|
||||||
|
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
|
||||||
|
special_vocab.add_to_gguf(self.gguf_writer)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
###### CONVERSION LOGIC ######
|
###### CONVERSION LOGIC ######
|
||||||
|
|
||||||
|
@ -845,6 +933,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
"model", type=Path,
|
"model", type=Path,
|
||||||
help="directory containing model file",
|
help="directory containing model file",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--model-name", type=str, default=None, help="name of the model")
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
@ -871,7 +960,7 @@ print(f"Loading model: {dir_model.name}")
|
||||||
|
|
||||||
hparams = Model.load_hparams(dir_model)
|
hparams = Model.load_hparams(dir_model)
|
||||||
|
|
||||||
model_class = Model.from_model_architecture(hparams["architectures"][0])
|
model_class = Model.from_model_name(args.model_name) if args.model_name else Model.from_model_architecture(hparams["architectures"][0])
|
||||||
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian)
|
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian)
|
||||||
|
|
||||||
print("Set model parameters")
|
print("Set model parameters")
|
||||||
|
|
146
llama.cpp
146
llama.cpp
|
@ -2265,7 +2265,7 @@ static void llm_load_vocab(
|
||||||
vocab.special_unk_id = 0;
|
vocab.special_unk_id = 0;
|
||||||
vocab.special_sep_id = -1;
|
vocab.special_sep_id = -1;
|
||||||
vocab.special_pad_id = -1;
|
vocab.special_pad_id = -1;
|
||||||
} else if (tokenizer_name == "gpt2") {
|
} else if (tokenizer_name == "gpt2" || tokenizer_name == "deepseek_coder") {
|
||||||
vocab.type = LLAMA_VOCAB_TYPE_BPE;
|
vocab.type = LLAMA_VOCAB_TYPE_BPE;
|
||||||
|
|
||||||
// read bpe merges and populate bpe ranks
|
// read bpe merges and populate bpe ranks
|
||||||
|
@ -5682,136 +5682,32 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> bpe_gpt2_preprocess(const std::string & text) {
|
std::vector<std::string> bpe_gpt2_preprocess(const std::string & text) {
|
||||||
|
|
||||||
std::vector<std::string> bpe_words;
|
std::vector<std::string> bpe_words;
|
||||||
std::vector<std::string> bpe_encoded_words;
|
std::vector<std::string> bpe_encoded_words;
|
||||||
|
// convert input string to wstring
|
||||||
std::string token = "";
|
std::wstring input = from_utf8(text);
|
||||||
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
|
std::wstring regex = from_utf8(gpt2_regex);
|
||||||
bool collecting_numeric = false;
|
std::wregex expr(regex);
|
||||||
bool collecting_letter = false;
|
// std::wsmatch m;
|
||||||
bool collecting_special = false;
|
// // use regex match to get where to split the test string
|
||||||
bool collecting_whitespace_lookahead = false;
|
int array[] = {-1,0};
|
||||||
bool collecting = false;
|
std::wsregex_token_iterator iter(input.begin(), input.end(), expr, array);
|
||||||
|
std::wsregex_token_iterator end;
|
||||||
std::vector<std::string> text_utf;
|
for ( ; iter != end; ++iter){
|
||||||
text_utf.reserve(text.size());
|
if ((*iter).length()>0){
|
||||||
bpe_words.reserve(text.size());
|
bpe_words.push_back(to_utf8(*iter));
|
||||||
bpe_encoded_words.reserve(text.size());
|
|
||||||
|
|
||||||
auto cps = codepoints_from_utf8(text);
|
|
||||||
for (size_t i = 0; i < cps.size(); ++i)
|
|
||||||
text_utf.emplace_back(codepoint_to_utf8(cps[i]));
|
|
||||||
|
|
||||||
for (int i = 0; i < (int)text_utf.size(); i++) {
|
|
||||||
const std::string & utf_char = text_utf[i];
|
|
||||||
bool split_condition = false;
|
|
||||||
int bytes_remain = text_utf.size() - i;
|
|
||||||
// forward backward lookups
|
|
||||||
const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : "";
|
|
||||||
const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : "";
|
|
||||||
|
|
||||||
// handling contractions
|
|
||||||
if (!split_condition && bytes_remain >= 2) {
|
|
||||||
// 's|'t|'m|'d
|
|
||||||
if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) {
|
|
||||||
split_condition = true;
|
|
||||||
}
|
|
||||||
if (split_condition) {
|
|
||||||
if (token.size()) {
|
|
||||||
bpe_words.emplace_back(token); // push previous content as token
|
|
||||||
}
|
|
||||||
token = utf_char + utf_char_next;
|
|
||||||
bpe_words.emplace_back(token);
|
|
||||||
token = "";
|
|
||||||
i++;
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!split_condition && bytes_remain >= 3) {
|
// convert each word to utf8
|
||||||
// 're|'ve|'ll
|
|
||||||
if (utf_char == "\'" && (
|
|
||||||
(utf_char_next == "r" && utf_char_next_next == "e") ||
|
|
||||||
(utf_char_next == "v" && utf_char_next_next == "e") ||
|
|
||||||
(utf_char_next == "l" && utf_char_next_next == "l"))
|
|
||||||
) {
|
|
||||||
split_condition = true;
|
|
||||||
}
|
|
||||||
if (split_condition) {
|
|
||||||
// current token + next token can be defined
|
|
||||||
if (token.size()) {
|
|
||||||
bpe_words.emplace_back(token); // push previous content as token
|
|
||||||
}
|
|
||||||
token = utf_char + utf_char_next + utf_char_next_next;
|
|
||||||
bpe_words.emplace_back(token); // the contraction
|
|
||||||
token = "";
|
|
||||||
i += 2;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!split_condition && !collecting) {
|
|
||||||
if (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) {
|
|
||||||
collecting_letter = true;
|
|
||||||
collecting = true;
|
|
||||||
}
|
|
||||||
else if (codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
|
|
||||||
collecting_numeric = true;
|
|
||||||
collecting = true;
|
|
||||||
}
|
|
||||||
else if (
|
|
||||||
((codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (codepoint_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
|
|
||||||
(!token.size() && utf_char == " " && codepoint_type(utf_char_next) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && codepoint_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
|
|
||||||
) {
|
|
||||||
collecting_special = true;
|
|
||||||
collecting = true;
|
|
||||||
}
|
|
||||||
else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && codepoint_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
|
|
||||||
collecting_whitespace_lookahead = true;
|
|
||||||
collecting = true;
|
|
||||||
}
|
|
||||||
else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
|
|
||||||
split_condition = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if (!split_condition && collecting) {
|
|
||||||
if (collecting_letter && codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER) {
|
|
||||||
split_condition = true;
|
|
||||||
}
|
|
||||||
else if (collecting_numeric && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) {
|
|
||||||
split_condition = true;
|
|
||||||
}
|
|
||||||
else if (collecting_special && (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
|
|
||||||
split_condition = true;
|
|
||||||
}
|
|
||||||
else if (collecting_whitespace_lookahead && (codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
|
|
||||||
split_condition = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (utf_char_next == "") {
|
|
||||||
split_condition = true; // final
|
|
||||||
token += utf_char;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (split_condition) {
|
|
||||||
if (token.size()) {
|
|
||||||
bpe_words.emplace_back(token);
|
|
||||||
}
|
|
||||||
token = utf_char;
|
|
||||||
collecting = false;
|
|
||||||
collecting_letter = false;
|
|
||||||
collecting_numeric = false;
|
|
||||||
collecting_special = false;
|
|
||||||
collecting_whitespace_lookahead = false;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
token += utf_char;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (std::string & word : bpe_words) {
|
for (std::string & word : bpe_words) {
|
||||||
|
std::string text_utf = "";
|
||||||
|
auto utf_word = codepoints_from_utf8(word);
|
||||||
|
for (size_t i = 0; i < utf_word.size(); ++i)
|
||||||
|
text_utf += codepoint_to_utf8(utf_word[i]);
|
||||||
|
|
||||||
std::string encoded_token = "";
|
std::string encoded_token = "";
|
||||||
for (char & c : word) {
|
for (char & c : text_utf) {
|
||||||
encoded_token += bytes_to_unicode_bpe(c);
|
encoded_token += bytes_to_unicode_bpe(c);
|
||||||
}
|
}
|
||||||
bpe_encoded_words.emplace_back(encoded_token);
|
bpe_encoded_words.emplace_back(encoded_token);
|
||||||
|
|
19
unicode.h
19
unicode.h
File diff suppressed because one or more lines are too long
Loading…
Add table
Add a link
Reference in a new issue