add new gpt2
This commit is contained in:
parent
21fd874c8d
commit
5600bd8cbc
3 changed files with 130 additions and 126 deletions
|
@ -169,6 +169,30 @@ class Model:
|
|||
if model_architecture == "PersimmonForCausalLM":
|
||||
return PersimmonModel
|
||||
return Model
|
||||
|
||||
@staticmethod
|
||||
def from_model_name(model_name: str):
|
||||
if model_name == "StableLMEpoch":
|
||||
return StableLMModel
|
||||
if model_name == "GPTNeoX":
|
||||
return GPTNeoXModel
|
||||
if model_name == "Bloom":
|
||||
return BloomModel
|
||||
if model_name == "MPT":
|
||||
return MPTModel
|
||||
if model_name in ("Baichuan", "BaiChuan"):
|
||||
return BaichuanModel
|
||||
if model_name in ("Falcon", "RW"):
|
||||
return FalconModel
|
||||
if model_name == "GPTBigCode":
|
||||
return StarCoderModel
|
||||
if model_name == "GPTRefact":
|
||||
return RefactModel
|
||||
if model_name == "Persimmon":
|
||||
return PersimmonModel
|
||||
if model_name == "DeepseekCoder":
|
||||
return DeepseekCoderModel
|
||||
return Model
|
||||
|
||||
def _is_model_safetensors(self) -> bool:
|
||||
return Model.count_model_parts(self.dir_model, ".safetensors") > 0
|
||||
|
@ -201,6 +225,8 @@ class Model:
|
|||
return gguf.MODEL_ARCH.REFACT
|
||||
if arch == "PersimmonForCausalLM":
|
||||
return gguf.MODEL_ARCH.PERSIMMON
|
||||
if arch == "LlamaForCausalLM":
|
||||
return gguf.MODEL_ARCH.LLAMA
|
||||
|
||||
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
||||
|
||||
|
@ -823,6 +849,68 @@ class PersimmonModel(Model):
|
|||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
class DeepseekCoderModel(Model):
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
head_count = self.hparams["num_attention_heads"]
|
||||
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
||||
ctx_length = self.hparams["max_position_embeddings"]
|
||||
|
||||
self.gguf_writer.add_name("deepseek_coder")
|
||||
self.gguf_writer.add_context_length(ctx_length)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count(head_count)
|
||||
self.gguf_writer.add_head_count_kv(head_count_kv)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
|
||||
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
|
||||
|
||||
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
|
||||
if self.hparams["rope_scaling"].get("type") == "linear":
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
|
||||
|
||||
def set_vocab(self):
|
||||
dir_model = self.dir_model
|
||||
hparams = self.hparams
|
||||
tokens: list[bytearray] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
from transformers import AutoTokenizer # type: ignore[attr-defined]
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model)
|
||||
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
|
||||
assert max(tokenizer.vocab.values()) < vocab_size
|
||||
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
special_tokens = tokenizer.all_special_tokens
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
pad_token = f"[PAD{i}]".encode('utf-8')
|
||||
tokens.append(bytearray(pad_token))
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
elif reverse_vocab[i] in added_vocab:
|
||||
tokens.append(reverse_vocab[i])
|
||||
if reverse_vocab[i] in special_tokens:
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
else:
|
||||
tokens.append(reverse_vocab[i])
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("deepseek_coder")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
@ -845,6 +933,7 @@ def parse_args() -> argparse.Namespace:
|
|||
"model", type=Path,
|
||||
help="directory containing model file",
|
||||
)
|
||||
parser.add_argument("--model-name", type=str, default=None, help="name of the model")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
@ -871,7 +960,7 @@ print(f"Loading model: {dir_model.name}")
|
|||
|
||||
hparams = Model.load_hparams(dir_model)
|
||||
|
||||
model_class = Model.from_model_architecture(hparams["architectures"][0])
|
||||
model_class = Model.from_model_name(args.model_name) if args.model_name else Model.from_model_architecture(hparams["architectures"][0])
|
||||
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian)
|
||||
|
||||
print("Set model parameters")
|
||||
|
|
146
llama.cpp
146
llama.cpp
|
@ -2265,7 +2265,7 @@ static void llm_load_vocab(
|
|||
vocab.special_unk_id = 0;
|
||||
vocab.special_sep_id = -1;
|
||||
vocab.special_pad_id = -1;
|
||||
} else if (tokenizer_name == "gpt2") {
|
||||
} else if (tokenizer_name == "gpt2" || tokenizer_name == "deepseek_coder") {
|
||||
vocab.type = LLAMA_VOCAB_TYPE_BPE;
|
||||
|
||||
// read bpe merges and populate bpe ranks
|
||||
|
@ -5682,136 +5682,32 @@ private:
|
|||
}
|
||||
|
||||
std::vector<std::string> bpe_gpt2_preprocess(const std::string & text) {
|
||||
|
||||
std::vector<std::string> bpe_words;
|
||||
std::vector<std::string> bpe_encoded_words;
|
||||
|
||||
std::string token = "";
|
||||
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
|
||||
bool collecting_numeric = false;
|
||||
bool collecting_letter = false;
|
||||
bool collecting_special = false;
|
||||
bool collecting_whitespace_lookahead = false;
|
||||
bool collecting = false;
|
||||
|
||||
std::vector<std::string> text_utf;
|
||||
text_utf.reserve(text.size());
|
||||
bpe_words.reserve(text.size());
|
||||
bpe_encoded_words.reserve(text.size());
|
||||
|
||||
auto cps = codepoints_from_utf8(text);
|
||||
for (size_t i = 0; i < cps.size(); ++i)
|
||||
text_utf.emplace_back(codepoint_to_utf8(cps[i]));
|
||||
|
||||
for (int i = 0; i < (int)text_utf.size(); i++) {
|
||||
const std::string & utf_char = text_utf[i];
|
||||
bool split_condition = false;
|
||||
int bytes_remain = text_utf.size() - i;
|
||||
// forward backward lookups
|
||||
const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : "";
|
||||
const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : "";
|
||||
|
||||
// handling contractions
|
||||
if (!split_condition && bytes_remain >= 2) {
|
||||
// 's|'t|'m|'d
|
||||
if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) {
|
||||
split_condition = true;
|
||||
}
|
||||
if (split_condition) {
|
||||
if (token.size()) {
|
||||
bpe_words.emplace_back(token); // push previous content as token
|
||||
}
|
||||
token = utf_char + utf_char_next;
|
||||
bpe_words.emplace_back(token);
|
||||
token = "";
|
||||
i++;
|
||||
continue;
|
||||
// convert input string to wstring
|
||||
std::wstring input = from_utf8(text);
|
||||
std::wstring regex = from_utf8(gpt2_regex);
|
||||
std::wregex expr(regex);
|
||||
// std::wsmatch m;
|
||||
// // use regex match to get where to split the test string
|
||||
int array[] = {-1,0};
|
||||
std::wsregex_token_iterator iter(input.begin(), input.end(), expr, array);
|
||||
std::wsregex_token_iterator end;
|
||||
for ( ; iter != end; ++iter){
|
||||
if ((*iter).length()>0){
|
||||
bpe_words.push_back(to_utf8(*iter));
|
||||
}
|
||||
}
|
||||
if (!split_condition && bytes_remain >= 3) {
|
||||
// 're|'ve|'ll
|
||||
if (utf_char == "\'" && (
|
||||
(utf_char_next == "r" && utf_char_next_next == "e") ||
|
||||
(utf_char_next == "v" && utf_char_next_next == "e") ||
|
||||
(utf_char_next == "l" && utf_char_next_next == "l"))
|
||||
) {
|
||||
split_condition = true;
|
||||
}
|
||||
if (split_condition) {
|
||||
// current token + next token can be defined
|
||||
if (token.size()) {
|
||||
bpe_words.emplace_back(token); // push previous content as token
|
||||
}
|
||||
token = utf_char + utf_char_next + utf_char_next_next;
|
||||
bpe_words.emplace_back(token); // the contraction
|
||||
token = "";
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (!split_condition && !collecting) {
|
||||
if (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) {
|
||||
collecting_letter = true;
|
||||
collecting = true;
|
||||
}
|
||||
else if (codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
|
||||
collecting_numeric = true;
|
||||
collecting = true;
|
||||
}
|
||||
else if (
|
||||
((codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (codepoint_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
|
||||
(!token.size() && utf_char == " " && codepoint_type(utf_char_next) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && codepoint_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
|
||||
) {
|
||||
collecting_special = true;
|
||||
collecting = true;
|
||||
}
|
||||
else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && codepoint_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
|
||||
collecting_whitespace_lookahead = true;
|
||||
collecting = true;
|
||||
}
|
||||
else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
|
||||
split_condition = true;
|
||||
}
|
||||
}
|
||||
else if (!split_condition && collecting) {
|
||||
if (collecting_letter && codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER) {
|
||||
split_condition = true;
|
||||
}
|
||||
else if (collecting_numeric && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) {
|
||||
split_condition = true;
|
||||
}
|
||||
else if (collecting_special && (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
|
||||
split_condition = true;
|
||||
}
|
||||
else if (collecting_whitespace_lookahead && (codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
|
||||
split_condition = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (utf_char_next == "") {
|
||||
split_condition = true; // final
|
||||
token += utf_char;
|
||||
}
|
||||
|
||||
if (split_condition) {
|
||||
if (token.size()) {
|
||||
bpe_words.emplace_back(token);
|
||||
}
|
||||
token = utf_char;
|
||||
collecting = false;
|
||||
collecting_letter = false;
|
||||
collecting_numeric = false;
|
||||
collecting_special = false;
|
||||
collecting_whitespace_lookahead = false;
|
||||
}
|
||||
else {
|
||||
token += utf_char;
|
||||
}
|
||||
}
|
||||
|
||||
// convert each word to utf8
|
||||
for (std::string & word : bpe_words) {
|
||||
std::string text_utf = "";
|
||||
auto utf_word = codepoints_from_utf8(word);
|
||||
for (size_t i = 0; i < utf_word.size(); ++i)
|
||||
text_utf += codepoint_to_utf8(utf_word[i]);
|
||||
|
||||
std::string encoded_token = "";
|
||||
for (char & c : word) {
|
||||
for (char & c : text_utf) {
|
||||
encoded_token += bytes_to_unicode_bpe(c);
|
||||
}
|
||||
bpe_encoded_words.emplace_back(encoded_token);
|
||||
|
|
19
unicode.h
19
unicode.h
File diff suppressed because one or more lines are too long
Loading…
Add table
Add a link
Reference in a new issue