Work on the BPE tokenizer (#3252)
* Work on the BPE tokenizer Tokenizer tests work for Falcon-7B * Try to fix build problem * Fix debug assertion failure * Fix MSVC Unicode BOM problem * Cleanup and an improvement * Fix compiler warning * Cleanup * Test doesn't work over the full range of Unicodes * Update .gitignore and Makefile * Another Makefile rule * Testing Aquila * Moving byte decoding back to `token_to_piece` ... ... because everyone is using it. * Guarding some unusable code pathes * Streamlining code and adding some more assertions Important change: I'm classifying added tokens as control tokens now for BPE. * Adding a comment * Adding another assertion * Fixed vocabulary guarding assertions * Fix PR for recent change * Fix PR for recent change * Fix for compiler warning * Fix PR for recent change * Fix PR for recent change * Fix PR for recent change * Fix for compiler warning * Fixes for more compiler warnings * Remove unused code * Fix initialization of static maps * Add scores and token types back, adapt gptneox * Update llama.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update unicode.h Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update unicode.h Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Ported Starcoder and added some assertions * Fix coding style * Apply @jploski 's fix for missing tokens --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
1c84003c08
commit
ff5a3f0c09
15 changed files with 852 additions and 227 deletions
278
llama.cpp
278
llama.cpp
|
@ -1,6 +1,8 @@
|
|||
#define LLAMA_API_INTERNAL
|
||||
#include "llama.h"
|
||||
|
||||
#include "unicode.h"
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#include "ggml-alloc.h"
|
||||
|
@ -1980,6 +1982,7 @@ static void llm_load_vocab(
|
|||
|
||||
for (int i = 0; i < n_merges; i++) {
|
||||
const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
|
||||
GGML_ASSERT(codepoints_from_utf8(word).size() > 0);
|
||||
|
||||
std::string first;
|
||||
std::string second;
|
||||
|
@ -2014,6 +2017,7 @@ static void llm_load_vocab(
|
|||
|
||||
for (uint32_t i = 0; i < n_vocab; i++) {
|
||||
std::string word = gguf_get_arr_str(ctx, token_idx, i);
|
||||
GGML_ASSERT(codepoints_from_utf8(word).size() > 0);
|
||||
|
||||
vocab.token_to_id[word] = i;
|
||||
|
||||
|
@ -2022,12 +2026,13 @@ static void llm_load_vocab(
|
|||
token_data.score = scores ? scores[i] : 0.0f;
|
||||
token_data.type = toktypes ? (llama_token_type) toktypes[i] : LLAMA_TOKEN_TYPE_NORMAL;
|
||||
}
|
||||
GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size());
|
||||
|
||||
// determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
|
||||
if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
|
||||
vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
|
||||
} else {
|
||||
vocab.linefeed_id = llama_tokenize_internal(vocab, "\n", false)[0];
|
||||
vocab.linefeed_id = llama_tokenize_internal(vocab, "\u010A", false)[0];
|
||||
}
|
||||
|
||||
// special tokens
|
||||
|
@ -4236,18 +4241,41 @@ static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
|
|||
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE;
|
||||
}
|
||||
|
||||
static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) {
|
||||
static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) {
|
||||
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED;
|
||||
}
|
||||
|
||||
static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
|
||||
GGML_ASSERT(llama_is_byte_token(vocab, id));
|
||||
const auto& token_data = vocab.id_to_token.at(id);
|
||||
auto buf = token_data.text.substr(3, 2);
|
||||
return strtol(buf.c_str(), NULL, 16);
|
||||
switch (llama_vocab_get_type(vocab)) {
|
||||
case LLAMA_VOCAB_TYPE_SPM: {
|
||||
auto buf = token_data.text.substr(3, 2);
|
||||
return strtol(buf.c_str(), NULL, 16);
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_BPE: {
|
||||
GGML_ASSERT(false);
|
||||
return unicode_to_bytes_bpe(token_data.text);
|
||||
}
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
|
||||
char buf[7];
|
||||
int result = snprintf(buf, sizeof(buf), "<0x%02X>", ch);
|
||||
GGML_ASSERT(0 <= result && result < 7);
|
||||
return vocab.token_to_id.at(buf);
|
||||
switch (llama_vocab_get_type(vocab)) {
|
||||
case LLAMA_VOCAB_TYPE_SPM: {
|
||||
char buf[7];
|
||||
int result = snprintf(buf, sizeof(buf), "<0x%02X>", ch);
|
||||
GGML_ASSERT(0 <= result && result < 7);
|
||||
return vocab.token_to_id.at(buf);
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_BPE: {
|
||||
return vocab.token_to_id.at(bytes_to_unicode_bpe(ch));
|
||||
}
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_escape_whitespace(std::string & text) {
|
||||
|
@ -4527,15 +4555,9 @@ struct llm_tokenizer_bpe {
|
|||
std::string byte_str(1, *j);
|
||||
auto token_multibyte = vocab.token_to_id.find(byte_str);
|
||||
if (token_multibyte == vocab.token_to_id.end()) {
|
||||
try {
|
||||
llama_token token_byte = llama_byte_to_token(vocab, *j);
|
||||
output.push_back(token_byte);
|
||||
} catch (const std::out_of_range & err) {
|
||||
fprintf(stderr,"ERROR: byte not found in vocab: '%s'\n", byte_str.c_str());
|
||||
}
|
||||
} else {
|
||||
output.push_back((*token_multibyte).second);
|
||||
throw std::runtime_error("ERROR: byte not found in vocab");
|
||||
}
|
||||
output.push_back((*token_multibyte).second);
|
||||
}
|
||||
} else {
|
||||
output.push_back((*token).second);
|
||||
|
@ -4572,23 +4594,144 @@ private:
|
|||
work_queue.push(bigram);
|
||||
}
|
||||
|
||||
// probably not 100% correct
|
||||
static std::vector<std::string> bpe_gpt2_preprocess(const std::string & text) {
|
||||
std::vector<std::string> words;
|
||||
std::vector<std::string> bpe_gpt2_preprocess(const std::string & text) {
|
||||
std::vector<std::string> bpe_words;
|
||||
std::vector<std::string> bpe_encoded_words;
|
||||
|
||||
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
|
||||
const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
||||
const std::regex re(pattern);
|
||||
std::string token = "";
|
||||
// GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
|
||||
bool collecting_numeric = false;
|
||||
bool collecting_letter = false;
|
||||
bool collecting_special = false;
|
||||
bool collecting_whitespace_lookahead = false;
|
||||
bool collecting = false;
|
||||
|
||||
auto words_begin = std::sregex_iterator(text.begin(), text.end(), re);
|
||||
auto words_end = std::sregex_iterator();
|
||||
auto n_words = std::distance(words_begin, words_end);
|
||||
words.reserve(n_words);
|
||||
for (auto it = words_begin; it != words_end; ++it) {
|
||||
words.push_back(it->str());
|
||||
std::vector<std::string> text_utf;
|
||||
text_utf.reserve(text.size());
|
||||
bpe_words.reserve(text.size());
|
||||
bpe_encoded_words.reserve(text.size());
|
||||
|
||||
auto cps = codepoints_from_utf8(text);
|
||||
for (size_t i = 0; i < cps.size(); ++i)
|
||||
text_utf.emplace_back(codepoint_to_utf8(cps[i]));
|
||||
|
||||
for (int i = 0; i < (int)text_utf.size(); i++) {
|
||||
const std::string & utf_char = text_utf[i];
|
||||
bool split_condition = false;
|
||||
// const char* text_pos = raw_text_p + utf_char.seq_offset_bytes;
|
||||
int bytes_remain = text_utf.size() - i;
|
||||
// forward backward lookups
|
||||
const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : "";
|
||||
const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : "";
|
||||
|
||||
// handling contractions
|
||||
if (!split_condition && bytes_remain >= 2) {
|
||||
// 's|'t|'m|'d
|
||||
if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) {
|
||||
split_condition = true;
|
||||
}
|
||||
if (split_condition) {
|
||||
if (token.size()) {
|
||||
bpe_words.emplace_back(token); // push previous content as token
|
||||
}
|
||||
token = utf_char + utf_char_next;
|
||||
bpe_words.emplace_back(token);
|
||||
token = "";
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (!split_condition && bytes_remain >= 3) {
|
||||
// 're|'ve|'ll
|
||||
if (utf_char == "\'" && (
|
||||
(utf_char_next == "r" || utf_char_next_next == "e") ||
|
||||
(utf_char_next == "v" || utf_char_next_next == "e") ||
|
||||
(utf_char_next == "l" || utf_char_next_next == "l"))
|
||||
) {
|
||||
split_condition = true;
|
||||
}
|
||||
if (split_condition) {
|
||||
// current token + next token can be defined
|
||||
if (token.size()) {
|
||||
bpe_words.emplace_back(token); // push previous content as token
|
||||
}
|
||||
token = utf_char + utf_char_next + utf_char_next_next;
|
||||
bpe_words.emplace_back(token); // the contraction
|
||||
token = "";
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (!split_condition && !collecting) {
|
||||
if (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) {
|
||||
collecting_letter = true;
|
||||
collecting = true;
|
||||
}
|
||||
else if (codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
|
||||
collecting_numeric = true;
|
||||
collecting = true;
|
||||
}
|
||||
else if (
|
||||
((codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (codepoint_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
|
||||
(!token.size() && utf_char == " " && codepoint_type(utf_char_next) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && codepoint_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
|
||||
) {
|
||||
collecting_special = true;
|
||||
collecting = true;
|
||||
}
|
||||
else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && codepoint_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
|
||||
collecting_whitespace_lookahead = true;
|
||||
collecting = true;
|
||||
}
|
||||
else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
|
||||
split_condition = true;
|
||||
}
|
||||
}
|
||||
else if (!split_condition && collecting) {
|
||||
if (collecting_letter && codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER) {
|
||||
split_condition = true;
|
||||
}
|
||||
else if (collecting_numeric && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) {
|
||||
split_condition = true;
|
||||
}
|
||||
else if (collecting_special && (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
|
||||
split_condition = true;
|
||||
}
|
||||
else if (collecting_whitespace_lookahead && codepoint_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE) {
|
||||
split_condition = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (utf_char_next == "") {
|
||||
split_condition = true; // final
|
||||
token += utf_char;
|
||||
}
|
||||
|
||||
if (split_condition) {
|
||||
if (token.size()) {
|
||||
bpe_words.emplace_back(token);
|
||||
}
|
||||
token = utf_char;
|
||||
collecting = false;
|
||||
collecting_letter = false;
|
||||
collecting_numeric = false;
|
||||
collecting_special = false;
|
||||
collecting_whitespace_lookahead = false;
|
||||
}
|
||||
else {
|
||||
token += utf_char;
|
||||
}
|
||||
}
|
||||
return words;
|
||||
|
||||
for (std::string & word : bpe_words) {
|
||||
std::string encoded_token = "";
|
||||
for (char & c : word) {
|
||||
encoded_token += bytes_to_unicode_bpe(c);
|
||||
}
|
||||
bpe_encoded_words.emplace_back(encoded_token);
|
||||
}
|
||||
|
||||
return bpe_encoded_words;
|
||||
}
|
||||
|
||||
const llama_vocab & vocab;
|
||||
|
@ -7532,35 +7675,66 @@ int llama_tokenize(
|
|||
return res.size();
|
||||
}
|
||||
|
||||
static std::string llama_decode_text(const std::string & text) {
|
||||
std::string decoded_text;
|
||||
auto unicode_sequences = codepoints_from_utf8(text);
|
||||
for (auto& unicode_sequence : unicode_sequences) {
|
||||
decoded_text += unicode_to_bytes_bpe(codepoint_to_utf8(unicode_sequence));
|
||||
}
|
||||
|
||||
return decoded_text;
|
||||
}
|
||||
|
||||
// does not write null-terminator to buf
|
||||
int llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int length) {
|
||||
if (0 <= token && token < llama_n_vocab(model)) {
|
||||
if (llama_is_normal_token(model->vocab, token)) {
|
||||
std::string result = model->vocab.id_to_token[token].text;
|
||||
if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM) {
|
||||
switch (llama_vocab_get_type(model->vocab)) {
|
||||
case LLAMA_VOCAB_TYPE_SPM: {
|
||||
if (llama_is_normal_token(model->vocab, token)) {
|
||||
std::string result = model->vocab.id_to_token[token].text;
|
||||
llama_unescape_whitespace(result);
|
||||
if (length < (int) result.length()) {
|
||||
return -result.length();
|
||||
}
|
||||
memcpy(buf, result.c_str(), result.length());
|
||||
return result.length();
|
||||
} else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
|
||||
if (length < 3) {
|
||||
return -3;
|
||||
}
|
||||
memcpy(buf, "\xe2\x96\x85", 3);
|
||||
return 3;
|
||||
} else if (llama_is_control_token(model->vocab, token)) {
|
||||
;
|
||||
} else if (llama_is_byte_token(model->vocab, token)) {
|
||||
if (length < 1) {
|
||||
return -1;
|
||||
}
|
||||
buf[0] = llama_token_to_byte(model->vocab, token);
|
||||
return 1;
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
if (length < (int) result.length()) {
|
||||
return -result.length();
|
||||
break;
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_BPE: {
|
||||
if (llama_is_normal_token(model->vocab, token)) {
|
||||
std::string result = model->vocab.id_to_token[token].text;
|
||||
result = llama_decode_text(result);
|
||||
if (length < (int) result.length()) {
|
||||
return -result.length();
|
||||
}
|
||||
memcpy(buf, result.c_str(), result.length());
|
||||
return result.length();
|
||||
} else if (llama_is_control_token(model->vocab, token)) {
|
||||
;
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
memcpy(buf, result.c_str(), result.length());
|
||||
return result.length();
|
||||
} else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
|
||||
if (length < 3) {
|
||||
return -3;
|
||||
}
|
||||
buf[0] = '\xe2';
|
||||
buf[1] = '\x96';
|
||||
buf[2] = '\x85';
|
||||
return 3;
|
||||
} else if (llama_is_control_token(model->vocab, token)) {
|
||||
// do nothing
|
||||
} else if (llama_is_byte_token(model->vocab, token)) {
|
||||
if (length < 1) {
|
||||
return -1;
|
||||
}
|
||||
buf[0] = llama_token_to_byte(model->vocab, token);
|
||||
return 1;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue