llama : normalize code-style

This commit is contained in:
Georgi Gerganov 2023-10-12 14:47:29 +03:00
parent 04ac0558de
commit 5c6b2be11f
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 51 additions and 61 deletions

View file

@ -863,7 +863,7 @@ size_t tokenize_file(
(int) buf.size(), (int) buf.size(),
out_tokens.data(), out_tokens.data(),
(int) out_tokens.size(), (int) out_tokens.size(),
false,false); false, false);
if (n_tokens < 0) { if (n_tokens < 0) {
out_tokens.resize(-n_tokens); out_tokens.resize(-n_tokens);
n_tokens = llama_tokenize( n_tokens = llama_tokenize(
@ -872,7 +872,7 @@ size_t tokenize_file(
(int) buf.size(), (int) buf.size(),
out_tokens.data(), out_tokens.data(),
(int) out_tokens.size(), (int) out_tokens.size(),
false,false); false, false);
} }
if (n_tokens >= 0) { if (n_tokens >= 0) {
out_tokens.resize(n_tokens); out_tokens.resize(n_tokens);
@ -966,7 +966,7 @@ size_t tokenize_file(
(int) buf_sample.size(), (int) buf_sample.size(),
tok_sample.data(), tok_sample.data(),
(int) tok_sample.size(), (int) tok_sample.size(),
false,false); false, false);
if (n_tokens < 0) { if (n_tokens < 0) {
tok_sample.resize(-n_tokens); tok_sample.resize(-n_tokens);
n_tokens = llama_tokenize(llama_get_model(lctx), n_tokens = llama_tokenize(llama_get_model(lctx),
@ -974,7 +974,7 @@ size_t tokenize_file(
(int) buf_sample.size(), (int) buf_sample.size(),
tok_sample.data(), tok_sample.data(),
(int) tok_sample.size(), (int) tok_sample.size(),
false,false); false, false);
GGML_ASSERT(n_tokens >= 0); GGML_ASSERT(n_tokens >= 0);
} }
GGML_ASSERT(n_tokens <= (int) tok_sample.size()); GGML_ASSERT(n_tokens <= (int) tok_sample.size());

View file

@ -2260,89 +2260,79 @@ static void llm_load_vocab(
// //
uint32_t special_tokens_count_by_type = 0; uint32_t special_tokens_count_by_type = 0;
uint32_t special_tokens_count_from_verification = 0; uint32_t special_tokens_count_from_verification = 0;
bool special_tokens_definition_mismatch = false; bool special_tokens_definition_mismatch = false;
for (const auto & t: vocab.token_to_id) for (const auto & t : vocab.token_to_id) {
{
const auto & token = t.first; const auto & token = t.first;
const auto & id = t.second; const auto & id = t.second;
// Count all non-normal tokens in the vocab while iterating // Count all non-normal tokens in the vocab while iterating
if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) {
special_tokens_count_by_type++; special_tokens_count_by_type++;
}
// Skip single character tokens // Skip single character tokens
if (token.length() > 1) if (token.length() > 1) {
{
bool is_tokenizable = false; bool is_tokenizable = false;
// Split token string representation in two, in all possible ways // Split token string representation in two, in all possible ways
// and check if both halves can be matched to a valid token // and check if both halves can be matched to a valid token
for (unsigned i = 1; i < token.length();) for (unsigned i = 1; i < token.length();) {
{
const auto left = token.substr(0, i); const auto left = token.substr(0, i);
const auto right = token.substr(i); const auto right = token.substr(i);
// check if we didnt partition in the middle of a utf sequence // check if we didnt partition in the middle of a utf sequence
auto utf = utf8_len(left.at(left.length() - 1)); auto utf = utf8_len(left.at(left.length() - 1));
if (utf == 1) if (utf == 1) {
{ if (vocab.token_to_id.find(left) != vocab.token_to_id.end() &&
if (vocab.token_to_id.find( left ) != vocab.token_to_id.end() && vocab.token_to_id.find(right) != vocab.token_to_id.end() ) {
vocab.token_to_id.find( right ) != vocab.token_to_id.end() )
{
is_tokenizable = true; is_tokenizable = true;
break; break;
} }
i++; i++;
} } else {
else
{
// skip over the rest of multibyte utf sequence // skip over the rest of multibyte utf sequence
i += utf - 1; i += utf - 1;
} }
} }
if (!is_tokenizable) if (!is_tokenizable) {
{
// Some tokens are multibyte, but they are utf sequences with equivalent text length of 1 // Some tokens are multibyte, but they are utf sequences with equivalent text length of 1
// it's faster to re-filter them here, since there are way less candidates now // it's faster to re-filter them here, since there are way less candidates now
// Calculate a total "utf" length of a token string representation // Calculate a total "utf" length of a token string representation
size_t utf8_str_len = 0; size_t utf8_str_len = 0;
for (unsigned i = 0; i < token.length();) for (unsigned i = 0; i < token.length();) {
{
utf8_str_len++; utf8_str_len++;
i += utf8_len(token.at(i)); i += utf8_len(token.at(i));
} }
// And skip the ones which are one character // And skip the ones which are one character
if (utf8_str_len > 1) if (utf8_str_len > 1) {
{
// At this point what we have left are special tokens only // At this point what we have left are special tokens only
vocab.special_tokens_cache[token] = id; vocab.special_tokens_cache[token] = id;
// Count manually found special tokens // Count manually found special tokens
special_tokens_count_from_verification ++; special_tokens_count_from_verification++;
// If this manually found special token is not marked as such, flag a mismatch // If this manually found special token is not marked as such, flag a mismatch
if (vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL) if (vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL) {
special_tokens_definition_mismatch = true; special_tokens_definition_mismatch = true;
} }
} }
} }
} }
}
if( special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type ) if (special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type) {
{ fprintf(stderr, "%s: warning: Mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n",
fprintf(stderr, "warning: %s: Mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n",
__func__, __func__,
special_tokens_count_from_verification, vocab.id_to_token.size(), special_tokens_count_from_verification, vocab.id_to_token.size(),
special_tokens_count_by_type, vocab.id_to_token.size() special_tokens_count_by_type, vocab.id_to_token.size()
); );
} } else {
else
{
fprintf(stderr, "%s: Special tokens definition check successful ( %u/%zu ).\n", fprintf(stderr, "%s: Special tokens definition check successful ( %u/%zu ).\n",
__func__, __func__,
special_tokens_count_from_verification, vocab.id_to_token.size() special_tokens_count_from_verification, vocab.id_to_token.size()
@ -6611,30 +6601,28 @@ struct fragment_buffer_variant{
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer)
{ {
// for each special token // for each special token
for (const auto & st: vocab.special_tokens_cache) for (const auto & st: vocab.special_tokens_cache) {
{
const auto & special_token = st.first; const auto & special_token = st.first;
const auto & special_id = st.second; const auto & special_id = st.second;
// for each text fragment // for each text fragment
std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin(); std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
while (it != buffer.end()) while (it != buffer.end()) {
{
auto & fragment = (*it); auto & fragment = (*it);
// if a fragment is text ( not yet processed ) // if a fragment is text ( not yet processed )
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
{
auto * raw_text = &(fragment.raw_text); auto * raw_text = &(fragment.raw_text);
auto raw_text_base_offset = fragment.offset; auto raw_text_base_offset = fragment.offset;
auto raw_text_base_length = fragment.length; auto raw_text_base_length = fragment.length;
// loop over the text // loop over the text
while (true) while (true) {
{
// find the first occurence of a given special token in this fragment // find the first occurence of a given special token in this fragment
// passing offset argument only limit the "search area" but match coordinates // passing offset argument only limit the "search area" but match coordinates
// are still relative to the source full raw_text // are still relative to the source full raw_text
auto match = raw_text->find( special_token, raw_text_base_offset ); auto match = raw_text->find(special_token, raw_text_base_offset);
// no occurences found, stop processing this fragment for a given special token // no occurences found, stop processing this fragment for a given special token
if (match == std::string::npos) break; if (match == std::string::npos) break;
@ -6643,21 +6631,20 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break; if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break;
#ifdef PRETOKENIZERDEBUG #ifdef PRETOKENIZERDEBUG
fprintf(stderr,"FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str()); fprintf(stderr, "FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
#endif #endif
auto source = std::distance(buffer.begin(), it); auto source = std::distance(buffer.begin(), it);
// if match is further than base offset // if match is further than base offset
// then we have some text to the left of it // then we have some text to the left of it
if (match > raw_text_base_offset) if (match > raw_text_base_offset) {
{
// left // left
const int64_t left_reminder_offset = raw_text_base_offset + 0; const int64_t left_reminder_offset = raw_text_base_offset + 0;
const int64_t left_reminder_length = match - raw_text_base_offset; const int64_t left_reminder_length = match - raw_text_base_offset;
buffer.emplace_after(it, (*raw_text), left_reminder_offset, left_reminder_length); buffer.emplace_after(it, (*raw_text), left_reminder_offset, left_reminder_length);
#ifdef PRETOKENIZERDEBUG #ifdef PRETOKENIZERDEBUG
fprintf(stderr,"FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str()); fprintf(stderr, "FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
#endif #endif
it++; it++;
} }
@ -6667,33 +6654,36 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
it++; it++;
// right // right
if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
{
const int64_t right_reminder_offset = match + special_token.length(); const int64_t right_reminder_offset = match + special_token.length();
const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length()); const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
buffer.emplace_after(it, (*raw_text), right_reminder_offset, right_reminder_length); buffer.emplace_after(it, (*raw_text), right_reminder_offset, right_reminder_length);
#ifdef PRETOKENIZERDEBUG #ifdef PRETOKENIZERDEBUG
fprintf(stderr,"FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str()); fprintf(stderr, "FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
#endif #endif
it++; it++;
if (source == 0) buffer.erase_after(buffer.before_begin()); if (source == 0) {
else buffer.erase_after(std::next(buffer.begin(), (source-1))); buffer.erase_after(buffer.before_begin());
} else {
buffer.erase_after(std::next(buffer.begin(), (source-1)));
}
// repeat for the right side // repeat for the right side
raw_text_base_offset = right_reminder_offset; raw_text_base_offset = right_reminder_offset;
raw_text_base_length = right_reminder_length; raw_text_base_length = right_reminder_length;
#ifdef PRETOKENIZERDEBUG #ifdef PRETOKENIZERDEBUG
fprintf(stderr,"RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str()); fprintf(stderr, "RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
#endif #endif
} else {
if (source == 0) {
buffer.erase_after(buffer.before_begin());
} else {
buffer.erase_after(std::next(buffer.begin(), (source-1)));
} }
else
{
if (source == 0) buffer.erase_after(buffer.before_begin());
else buffer.erase_after(std::next(buffer.begin(), (source-1)));
break; break;
} }
} }