Build vocab.special_tokens_cache using vocab token types

This commit is contained in:
jaime-m-p 2024-05-27 20:12:43 +02:00
parent af45703f74
commit 938cb4941a

107
llama.cpp
View file

@ -2086,7 +2086,7 @@ struct llama_vocab {
std::unordered_map<token, id> token_to_id; std::unordered_map<token, id> token_to_id;
std::vector<token_data> id_to_token; std::vector<token_data> id_to_token;
std::unordered_map<token, id> special_tokens_cache; std::vector<id> special_tokens_cache;
std::map<std::pair<std::string, std::string>, int> bpe_ranks; std::map<std::pair<std::string, std::string>, int> bpe_ranks;
@ -4724,97 +4724,19 @@ static void llm_load_vocab(
// build special tokens cache // build special tokens cache
{ {
// TODO: It is unclear (to me) at this point, whether special tokes are guaranteed to be of a deterministic type, for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
// and will always be correctly labeled in 'added_tokens.json' etc.
// The assumption is, since special tokens aren't meant to be exposed to end user, they are designed
// to be unmatchable by the tokenizer, therefore tokens from the vocab, which are unmatchable by the tokenizer
// are special tokens.
// From testing, this appears to correlate 1:1 with special tokens.
//
// Counting special tokens and verifying in only one direction
// is sufficient to detect difference in those two sets.
//
uint32_t special_tokens_count_by_type = 0;
uint32_t special_tokens_count_from_verification = 0;
bool special_tokens_definition_mismatch = false;
for (const auto & t : vocab.token_to_id) {
const auto & token = t.first;
const auto & id = t.second;
// Count all non-normal tokens in the vocab while iterating
if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) { if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) {
special_tokens_count_by_type++; vocab.special_tokens_cache.push_back(id);
}
// Skip single character tokens
if (token.length() > 1) {
bool is_tokenizable = false;
// Split token string representation in two, in all possible ways
// and check if both halves can be matched to a valid token
for (unsigned i = 1; i < token.length();) {
const auto left = token.substr(0, i);
const auto right = token.substr(i);
// check if we didnt partition in the middle of a utf sequence
auto utf = utf8_len(left.at(left.length() - 1));
if (utf == 1) {
if (vocab.token_to_id.find(left) != vocab.token_to_id.end() &&
vocab.token_to_id.find(right) != vocab.token_to_id.end() ) {
is_tokenizable = true;
break;
}
i++;
} else {
// skip over the rest of multibyte utf sequence
i += utf - 1;
} }
} }
if (!is_tokenizable) { std::sort( vocab.special_tokens_cache.begin(), vocab.special_tokens_cache.end(),
// Some tokens are multibyte, but they are utf sequences with equivalent text length of 1 [&] (const llama_vocab::id a, const llama_vocab::id b) {
// it's faster to re-filter them here, since there are way less candidates now return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
// Calculate a total "utf" length of a token string representation
size_t utf8_str_len = 0;
for (unsigned i = 0; i < token.length();) {
utf8_str_len++;
i += utf8_len(token.at(i));
} }
// And skip the ones which are one character
if (utf8_str_len > 1) {
// At this point what we have left are special tokens only
vocab.special_tokens_cache[token] = id;
// Count manually found special tokens
special_tokens_count_from_verification++;
// If this manually found special token is not marked as such, flag a mismatch
if (vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL) {
special_tokens_definition_mismatch = true;
}
}
}
}
}
if (special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type) {
LLAMA_LOG_WARN("%s: mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n",
__func__,
special_tokens_count_from_verification, vocab.id_to_token.size(),
special_tokens_count_by_type, vocab.id_to_token.size()
); );
} else {
LLAMA_LOG_INFO("%s: special tokens definition check successful ( %u/%zu ).\n", LLAMA_LOG_INFO("%s: special tokens cache size = %u.\n", __func__, (uint32_t)vocab.special_tokens_cache.size());
__func__,
special_tokens_count_from_verification, vocab.id_to_token.size()
);
}
} }
} }
@ -12898,9 +12820,8 @@ struct fragment_buffer_variant {
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) { static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
// for each special token // for each special token
for (const auto & st: vocab.special_tokens_cache) { for (const llama_vocab::id special_id : vocab.special_tokens_cache) {
const auto & special_token = st.first; const auto & special_token = vocab.id_to_token[special_id].text;
const auto & special_id = st.second;
// for each text fragment // for each text fragment
std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin(); std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
@ -12909,7 +12830,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
// if a fragment is text ( not yet processed ) // if a fragment is text ( not yet processed )
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
auto * raw_text = &(fragment.raw_text); auto & raw_text = fragment.raw_text;
auto raw_text_base_offset = fragment.offset; auto raw_text_base_offset = fragment.offset;
auto raw_text_base_length = fragment.length; auto raw_text_base_length = fragment.length;
@ -12919,7 +12840,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
// find the first occurrence of a given special token in this fragment // find the first occurrence of a given special token in this fragment
// passing offset argument only limit the "search area" but match coordinates // passing offset argument only limit the "search area" but match coordinates
// are still relative to the source full raw_text // are still relative to the source full raw_text
auto match = raw_text->find(special_token, raw_text_base_offset); auto match = raw_text.find(special_token, raw_text_base_offset);
// no occurrences found, stop processing this fragment for a given special token // no occurrences found, stop processing this fragment for a given special token
if (match == std::string::npos) break; if (match == std::string::npos) break;
@ -12938,7 +12859,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
// left // left
const int64_t left_reminder_offset = raw_text_base_offset + 0; const int64_t left_reminder_offset = raw_text_base_offset + 0;
const int64_t left_reminder_length = match - raw_text_base_offset; const int64_t left_reminder_length = match - raw_text_base_offset;
buffer.emplace_after(it, (*raw_text), left_reminder_offset, left_reminder_length); buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
#ifdef PRETOKENIZERDEBUG #ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str()); LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
@ -12954,7 +12875,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) { if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
const int64_t right_reminder_offset = match + special_token.length(); const int64_t right_reminder_offset = match + special_token.length();
const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length()); const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
buffer.emplace_after(it, (*raw_text), right_reminder_offset, right_reminder_length); buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
#ifdef PRETOKENIZERDEBUG #ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str()); LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());