shorten param name, add st verification by type

This commit is contained in:
staviq 2023-10-10 16:34:24 +02:00
parent b592c70deb
commit fc634d87a8
4 changed files with 67 additions and 25 deletions

View file

@ -863,22 +863,22 @@ std::vector<llama_token> llama_tokenize(
const struct llama_context * ctx, const struct llama_context * ctx,
const std::string & text, const std::string & text,
bool add_bos, bool add_bos,
bool allow_special_tokens) { bool special) {
return llama_tokenize(llama_get_model(ctx), text, add_bos, allow_special_tokens); return llama_tokenize(llama_get_model(ctx), text, add_bos, special);
} }
std::vector<llama_token> llama_tokenize( std::vector<llama_token> llama_tokenize(
const struct llama_model * model, const struct llama_model * model,
const std::string & text, const std::string & text,
bool add_bos, bool add_bos,
bool allow_special_tokens) { bool special) {
// upper limit for the number of tokens // upper limit for the number of tokens
int n_tokens = text.length() + add_bos; int n_tokens = text.length() + add_bos;
std::vector<llama_token> result(n_tokens); std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, allow_special_tokens); n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
if (n_tokens < 0) { if (n_tokens < 0) {
result.resize(-n_tokens); result.resize(-n_tokens);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, allow_special_tokens); int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
GGML_ASSERT(check == -n_tokens); GGML_ASSERT(check == -n_tokens);
} else { } else {
result.resize(n_tokens); result.resize(n_tokens);

View file

@ -152,13 +152,13 @@ std::vector<llama_token> llama_tokenize(
const struct llama_context * ctx, const struct llama_context * ctx,
const std::string & text, const std::string & text,
bool add_bos, bool add_bos,
bool allow_special_tokens = false); bool special = false);
std::vector<llama_token> llama_tokenize( std::vector<llama_token> llama_tokenize(
const struct llama_model * model, const struct llama_model * model,
const std::string & text, const std::string & text,
bool add_bos, bool add_bos,
bool allow_special_tokens = false); bool special = false);
// tokenizes a token into a piece // tokenizes a token into a piece
// should work similar to Python's `tokenizer.id_to_piece` // should work similar to Python's `tokenizer.id_to_piece`

View file

@ -2066,7 +2066,7 @@ static void llm_load_hparams(
} }
// TODO: This should probably be in llama.h // TODO: This should probably be in llama.h
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool allow_special_tokens = false); static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special = false);
static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch); static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch);
static void llm_load_vocab( static void llm_load_vocab(
@ -2192,15 +2192,30 @@ static void llm_load_vocab(
// are special tokens. // are special tokens.
// From testing, this appears to corelate 1:1 with special tokens. // From testing, this appears to corelate 1:1 with special tokens.
// //
// Counting special tokens and verifying in only one direction
// is sufficient to detect difference in those two sets.
//
uint32_t special_tokens_count_by_type = 0;
uint32_t special_tokens_count_from_verification = 0;
bool special_tokens_definition_mismatch = false;
for (const auto & t: vocab.token_to_id) for (const auto & t: vocab.token_to_id)
{ {
const auto & token = t.first; const auto & token = t.first;
const auto & id = t.second; const auto & id = t.second;
// Count all non-normal tokens in the vocab while iterating
if( vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL )
special_tokens_count_by_type++;
// Skip single character tokens
if( token.length() > 1 ) if( token.length() > 1 )
{ {
bool is_tokenizable = false; bool is_tokenizable = false;
// Split token string representation in two, in all possible ways
// and check if both halves can be matched to a valid token
for (unsigned i = 1; i < token.length();) for (unsigned i = 1; i < token.length();)
{ {
const auto left = token.substr(0, i); const auto left = token.substr(0, i);
@ -2211,8 +2226,6 @@ static void llm_load_vocab(
if( utf == 1 ) if( utf == 1 )
{ {
//fprintf(stderr, "BSTC . '%s' '%s' '%s'\n", token.c_str(), left.c_str(), right.c_str());
if (vocab.token_to_id.find( left ) != vocab.token_to_id.end() && if (vocab.token_to_id.find( left ) != vocab.token_to_id.end() &&
vocab.token_to_id.find( right ) != vocab.token_to_id.end() ) vocab.token_to_id.find( right ) != vocab.token_to_id.end() )
{ {
@ -2224,7 +2237,6 @@ static void llm_load_vocab(
} }
else else
{ {
// fprintf(stderr, "BSTC SKIP '%s' '%s' '%s'\n", token.c_str(), left.c_str(), right.c_str());
// skip over the rest of multibyte utf sequence // skip over the rest of multibyte utf sequence
i += utf - 1; i += utf - 1;
} }
@ -2232,7 +2244,10 @@ static void llm_load_vocab(
if (!is_tokenizable) if (!is_tokenizable)
{ {
// it's faster to re-filter them here, since there is way less candidates now // Some tokens are multibyte, but they are utf sequences with equivalent text length of 1
// it's faster to re-filter them here, since there are way less candidates now
// Calculate a total "utf" length of a token string representation
size_t utf8_str_len = 0; size_t utf8_str_len = 0;
for (unsigned i = 0; i < token.length();) for (unsigned i = 0; i < token.length();)
{ {
@ -2240,15 +2255,39 @@ static void llm_load_vocab(
i += utf8_len( token.at(i) ); i += utf8_len( token.at(i) );
} }
// And skip the ones which are one character
if (utf8_str_len > 1) if (utf8_str_len > 1)
{ {
//fprintf(stderr, "BSTC SPECIAL '%s' '%d' ('%ld')\n", token.c_str(), id, utf8_str_len); // At this point what we have left are special tokens only
vocab.special_tokens_cache[token] = id; vocab.special_tokens_cache[token] = id;
// Count manually found special tokens
special_tokens_count_from_verification ++;
// If this manually found special token is not marked as such, flag a mismatch
if( vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL )
special_tokens_definition_mismatch = true;
} }
} }
} }
} }
if( special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type )
{
fprintf(stderr, "%s: WARNING: Mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n",
__func__,
special_tokens_count_from_verification, vocab.id_to_token.size(),
special_tokens_count_by_type, vocab.id_to_token.size()
);
}
else
{
fprintf(stderr, "%s: Special tokens definition check successful ( %u/%zu ).\n",
__func__,
special_tokens_count_from_verification, vocab.id_to_token.size()
);
}
} }
} }
@ -5777,7 +5816,7 @@ struct fragment_buffer_variant{
std::string raw_text; std::string raw_text;
}; };
void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer)
{ {
// for each special token // for each special token
for( const auto & st: vocab.special_tokens_cache ) for( const auto & st: vocab.special_tokens_cache )
@ -5834,7 +5873,8 @@ void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragmen
} }
else else
{ {
auto prev = std::prev( buffer.begin(), -(source-1) ); //auto prev = std::prev( buffer.begin(), -(source-1) );
auto prev = std::next( buffer.begin(), (source-1) );
buffer.erase_after(prev); buffer.erase_after(prev);
} }
//it = std::prev( it, 1 ); //it = std::prev( it, 1 );
@ -5850,7 +5890,8 @@ void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragmen
} }
else else
{ {
auto prev = std::prev( buffer.begin(), -(source) ); //auto prev = std::prev( buffer.begin(), -(source) );
auto prev = std::next( buffer.begin(), (source) );
buffer.erase_after(prev); buffer.erase_after(prev);
} }
//it = std::prev( it, 1 ); //it = std::prev( it, 1 );
@ -5865,7 +5906,7 @@ void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragmen
} }
} }
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool allow_special_tokens) { static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special) {
std::vector<llama_vocab::id> output; std::vector<llama_vocab::id> output;
// OG tokenizer behavior: // OG tokenizer behavior:
@ -5885,7 +5926,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
fragment_buffer.emplace_front( raw_text ); fragment_buffer.emplace_front( raw_text );
if (allow_special_tokens) { if (special) {
tokenizer_st_partition( vocab, fragment_buffer ); tokenizer_st_partition( vocab, fragment_buffer );
} }
@ -8843,8 +8884,8 @@ int llama_tokenize(
llama_token * tokens, llama_token * tokens,
int n_max_tokens, int n_max_tokens,
bool add_bos, bool add_bos,
bool allow_special_tokens) { bool special) {
auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, allow_special_tokens); auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, special);
if (n_max_tokens < (int) res.size()) { if (n_max_tokens < (int) res.size()) {
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);

11
llama.h
View file

@ -511,10 +511,11 @@ extern "C" {
// Tokenization // Tokenization
// //
// Convert the provided text into tokens. /// @details Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens. /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens /// @return Returns the number of tokens on success, no more than n_max_tokens
// Returns a negative number on failure - the number of tokens that would have been returned /// @return Returns a negative number on failure - the number of tokens that would have been returned
/// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
LLAMA_API int llama_tokenize( LLAMA_API int llama_tokenize(
const struct llama_model * model, const struct llama_model * model,
const char * text, const char * text,
@ -522,7 +523,7 @@ extern "C" {
llama_token * tokens, llama_token * tokens,
int n_max_tokens, int n_max_tokens,
bool add_bos, bool add_bos,
bool allow_special_tokens); bool special);
// Token Id -> Piece. // Token Id -> Piece.
// Uses the vocabulary in the provided context. // Uses the vocabulary in the provided context.