shorten param name, add st verification by type

This commit is contained in:
staviq 2023-10-10 16:34:24 +02:00
parent b592c70deb
commit fc634d87a8
4 changed files with 67 additions and 25 deletions

View file

@ -863,22 +863,22 @@ std::vector<llama_token> llama_tokenize(
const struct llama_context * ctx,
const std::string & text,
bool add_bos,
bool allow_special_tokens) {
return llama_tokenize(llama_get_model(ctx), text, add_bos, allow_special_tokens);
bool special) {
return llama_tokenize(llama_get_model(ctx), text, add_bos, special);
}
std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_bos,
bool allow_special_tokens) {
bool special) {
// upper limit for the number of tokens
int n_tokens = text.length() + add_bos;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, allow_special_tokens);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, allow_special_tokens);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);

View file

@ -152,13 +152,13 @@ std::vector<llama_token> llama_tokenize(
const struct llama_context * ctx,
const std::string & text,
bool add_bos,
bool allow_special_tokens = false);
bool special = false);
std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_bos,
bool allow_special_tokens = false);
bool special = false);
// tokenizes a token into a piece
// should work similar to Python's `tokenizer.id_to_piece`

View file

@ -2066,7 +2066,7 @@ static void llm_load_hparams(
}
// TODO: This should probably be in llama.h
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool allow_special_tokens = false);
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special = false);
static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch);
static void llm_load_vocab(
@ -2192,15 +2192,30 @@ static void llm_load_vocab(
// are special tokens.
// From testing, this appears to corelate 1:1 with special tokens.
//
// Counting special tokens and verifying in only one direction
// is sufficient to detect difference in those two sets.
//
uint32_t special_tokens_count_by_type = 0;
uint32_t special_tokens_count_from_verification = 0;
bool special_tokens_definition_mismatch = false;
for (const auto & t: vocab.token_to_id)
{
const auto & token = t.first;
const auto & id = t.second;
// Count all non-normal tokens in the vocab while iterating
if( vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL )
special_tokens_count_by_type++;
// Skip single character tokens
if( token.length() > 1 )
{
bool is_tokenizable = false;
// Split token string representation in two, in all possible ways
// and check if both halves can be matched to a valid token
for (unsigned i = 1; i < token.length();)
{
const auto left = token.substr(0, i);
@ -2211,8 +2226,6 @@ static void llm_load_vocab(
if( utf == 1 )
{
//fprintf(stderr, "BSTC . '%s' '%s' '%s'\n", token.c_str(), left.c_str(), right.c_str());
if (vocab.token_to_id.find( left ) != vocab.token_to_id.end() &&
vocab.token_to_id.find( right ) != vocab.token_to_id.end() )
{
@ -2224,7 +2237,6 @@ static void llm_load_vocab(
}
else
{
// fprintf(stderr, "BSTC SKIP '%s' '%s' '%s'\n", token.c_str(), left.c_str(), right.c_str());
// skip over the rest of multibyte utf sequence
i += utf - 1;
}
@ -2232,7 +2244,10 @@ static void llm_load_vocab(
if (!is_tokenizable)
{
// it's faster to re-filter them here, since there is way less candidates now
// Some tokens are multibyte, but they are utf sequences with equivalent text length of 1
// it's faster to re-filter them here, since there are way less candidates now
// Calculate a total "utf" length of a token string representation
size_t utf8_str_len = 0;
for (unsigned i = 0; i < token.length();)
{
@ -2240,15 +2255,39 @@ static void llm_load_vocab(
i += utf8_len( token.at(i) );
}
// And skip the ones which are one character
if (utf8_str_len > 1)
{
//fprintf(stderr, "BSTC SPECIAL '%s' '%d' ('%ld')\n", token.c_str(), id, utf8_str_len);
// At this point what we have left are special tokens only
vocab.special_tokens_cache[token] = id;
// Count manually found special tokens
special_tokens_count_from_verification ++;
// If this manually found special token is not marked as such, flag a mismatch
if( vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL )
special_tokens_definition_mismatch = true;
}
}
}
}
if( special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type )
{
fprintf(stderr, "%s: WARNING: Mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n",
__func__,
special_tokens_count_from_verification, vocab.id_to_token.size(),
special_tokens_count_by_type, vocab.id_to_token.size()
);
}
else
{
fprintf(stderr, "%s: Special tokens definition check successful ( %u/%zu ).\n",
__func__,
special_tokens_count_from_verification, vocab.id_to_token.size()
);
}
}
}
@ -5777,7 +5816,7 @@ struct fragment_buffer_variant{
std::string raw_text;
};
void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer)
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer)
{
// for each special token
for( const auto & st: vocab.special_tokens_cache )
@ -5834,7 +5873,8 @@ void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragmen
}
else
{
auto prev = std::prev( buffer.begin(), -(source-1) );
//auto prev = std::prev( buffer.begin(), -(source-1) );
auto prev = std::next( buffer.begin(), (source-1) );
buffer.erase_after(prev);
}
//it = std::prev( it, 1 );
@ -5850,7 +5890,8 @@ void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragmen
}
else
{
auto prev = std::prev( buffer.begin(), -(source) );
//auto prev = std::prev( buffer.begin(), -(source) );
auto prev = std::next( buffer.begin(), (source) );
buffer.erase_after(prev);
}
//it = std::prev( it, 1 );
@ -5865,7 +5906,7 @@ void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragmen
}
}
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool allow_special_tokens) {
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special) {
std::vector<llama_vocab::id> output;
// OG tokenizer behavior:
@ -5885,7 +5926,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
fragment_buffer.emplace_front( raw_text );
if (allow_special_tokens) {
if (special) {
tokenizer_st_partition( vocab, fragment_buffer );
}
@ -8843,8 +8884,8 @@ int llama_tokenize(
llama_token * tokens,
int n_max_tokens,
bool add_bos,
bool allow_special_tokens) {
auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, allow_special_tokens);
bool special) {
auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, special);
if (n_max_tokens < (int) res.size()) {
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);

11
llama.h
View file

@ -511,10 +511,11 @@ extern "C" {
// Tokenization
//
// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
// Returns a negative number on failure - the number of tokens that would have been returned
/// @details Convert the provided text into tokens.
/// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
/// @return Returns the number of tokens on success, no more than n_max_tokens
/// @return Returns a negative number on failure - the number of tokens that would have been returned
/// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
LLAMA_API int llama_tokenize(
const struct llama_model * model,
const char * text,
@ -522,7 +523,7 @@ extern "C" {
llama_token * tokens,
int n_max_tokens,
bool add_bos,
bool allow_special_tokens);
bool special);
// Token Id -> Piece.
// Uses the vocabulary in the provided context.