Refactor llm_tokenizer_bpe: move code to constructor

This commit is contained in:
jaime-m-p 2024-05-25 02:20:55 +02:00
parent fe3c531915
commit 6f4c300bff

View file

@ -12272,107 +12272,101 @@ struct llm_bigram_bpe {
}; };
struct llm_tokenizer_bpe { struct llm_tokenizer_bpe {
llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {} llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {
GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
int final_prev_index = -1;
bool ignore_merges = false;
std::vector<std::string> word_collection;
switch (vocab.type) {
case LLAMA_VOCAB_TYPE_BPE:
switch (vocab.type_pre) { switch (vocab.type_pre) {
case LLAMA_VOCAB_PRE_TYPE_LLAMA3: case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
ignore_merges = true; ignore_merges = true;
word_collection = unicode_regex_split(text, { regex_exprs = {
// original regex from tokenizer.json // original regex from tokenizer.json
//"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
// adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989 // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}); };
break; break;
case LLAMA_VOCAB_PRE_TYPE_DBRX: case LLAMA_VOCAB_PRE_TYPE_DBRX:
word_collection = unicode_regex_split(text, { regex_exprs = {
// same as llama3 // same as llama3
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}); };
break; break;
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM: case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
word_collection = unicode_regex_split(text, { regex_exprs = {
"[\r\n]", "[\r\n]",
"\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ--ℝℤΩℨK--ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA--z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+", "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ--ℝℤΩℨK--ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA--z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
"\\s?[!-/:-~---‟ -。]+", "\\s?[!-/:-~---‟ -。]+",
"\\s+$", "\\s+$",
"[一-龥ࠀ-一가-퟿]+", "[一-龥ࠀ-一가-퟿]+",
"\\p{N}+", "\\p{N}+",
}); };
break; break;
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER: case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
word_collection = unicode_regex_split(text, { regex_exprs = {
"[\r\n]", "[\r\n]",
"\\s?\\p{L}+", "\\s?\\p{L}+",
"\\s?\\p{P}+", "\\s?\\p{P}+",
"[一-龥ࠀ-一가-퟿]+", "[一-龥ࠀ-一가-퟿]+",
"\\p{N}", "\\p{N}",
}); };
break; break;
case LLAMA_VOCAB_PRE_TYPE_FALCON: case LLAMA_VOCAB_PRE_TYPE_FALCON:
word_collection = unicode_regex_split(text, { regex_exprs = {
"[\\p{P}\\$\\+<=>\\^~\\|]+", "[\\p{P}\\$\\+<=>\\^~\\|]+",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
"[0-9][0-9][0-9]", "[0-9][0-9][0-9]",
}); };
break; break;
case LLAMA_VOCAB_PRE_TYPE_MPT: case LLAMA_VOCAB_PRE_TYPE_MPT:
// TODO: MPT pre-tokenization regexes are unknown // TODO: MPT pre-tokenization regexes are unknown
// the following are close, but not exact. run the following: // the following are close, but not exact. run the following:
// ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf // ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf
GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed"); GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed");
word_collection = unicode_regex_split(text, { regex_exprs = {
"\\s?\\p{L}+", "\\s?\\p{L}+",
"\\s?\\p{P}+", "\\s?\\p{P}+",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
}); };
break; break;
case LLAMA_VOCAB_PRE_TYPE_STARCODER: case LLAMA_VOCAB_PRE_TYPE_STARCODER:
case LLAMA_VOCAB_PRE_TYPE_REFACT: case LLAMA_VOCAB_PRE_TYPE_REFACT:
case LLAMA_VOCAB_PRE_TYPE_COMMAND_R: case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
word_collection = unicode_regex_split(text, { regex_exprs = {
"\\p{N}", "\\p{N}",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
}); };
break; break;
case LLAMA_VOCAB_PRE_TYPE_GPT2: case LLAMA_VOCAB_PRE_TYPE_GPT2:
case LLAMA_VOCAB_PRE_TYPE_OLMO: case LLAMA_VOCAB_PRE_TYPE_OLMO:
word_collection = unicode_regex_split(text, { regex_exprs = {
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
}); };
break; break;
case LLAMA_VOCAB_PRE_TYPE_STABLELM2: case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
case LLAMA_VOCAB_PRE_TYPE_QWEN2: case LLAMA_VOCAB_PRE_TYPE_QWEN2:
word_collection = unicode_regex_split(text, { regex_exprs = {
// original regex from tokenizer.json // original regex from tokenizer.json
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}); };
break; break;
default: default:
// default regex for BPE tokenization pre-processing // default regex for BPE tokenization pre-processing
word_collection = unicode_regex_split(text, { regex_exprs = {
"[\\p{P}\\$\\+<=>\\^~\\|]+", "[\\p{P}\\$\\+<=>\\^~\\|]+",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
"\\p{N}+", "\\p{N}+",
"[0-9][0-9][0-9]", "[0-9][0-9][0-9]",
}); };
break; break;
} }
break;
default:
GGML_ASSERT(false);
break;
} }
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
int final_prev_index = -1;
const auto word_collection = unicode_regex_split(text, regex_exprs);
symbols_final.clear(); symbols_final.clear();
for (auto & word : word_collection) { for (auto & word : word_collection) {
@ -12505,6 +12499,9 @@ private:
const llama_vocab & vocab; const llama_vocab & vocab;
bool ignore_merges = false;
std::vector<std::string> regex_exprs;
std::vector<llm_symbol> symbols; std::vector<llm_symbol> symbols;
std::vector<llm_symbol> symbols_final; std::vector<llm_symbol> symbols_final;
@ -12852,6 +12849,8 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
} break; } break;
case LLAMA_VOCAB_TYPE_BPE: case LLAMA_VOCAB_TYPE_BPE:
{ {
llm_tokenizer_bpe tokenizer(vocab);
if (add_special && vocab.special_add_bos != 0) { if (add_special && vocab.special_add_bos != 0) {
GGML_ASSERT(vocab.special_bos_id != -1); GGML_ASSERT(vocab.special_bos_id != -1);
output.push_back(vocab.special_bos_id); output.push_back(vocab.special_bos_id);
@ -12864,7 +12863,6 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
#ifdef PRETOKENIZERDEBUG #ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif #endif
llm_tokenizer_bpe tokenizer(vocab);
tokenizer.tokenize(raw_text, output); tokenizer.tokenize(raw_text, output);
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token); output.push_back(fragment.token);