Rewrite special token handling from #1931

This commit is contained in:
staviq 2023-10-08 02:43:23 +02:00
parent c47066d833
commit b592c70deb
6 changed files with 243 additions and 31 deletions

View file

@ -862,21 +862,23 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
std::vector<llama_token> llama_tokenize(
const struct llama_context * ctx,
const std::string & text,
bool add_bos) {
return llama_tokenize(llama_get_model(ctx), text, add_bos);
bool add_bos,
bool allow_special_tokens) {
return llama_tokenize(llama_get_model(ctx), text, add_bos, allow_special_tokens);
}
std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_bos) {
bool add_bos,
bool allow_special_tokens) {
// upper limit for the number of tokens
int n_tokens = text.length() + add_bos;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, allow_special_tokens);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, allow_special_tokens);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);

View file

@ -151,12 +151,14 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
std::vector<llama_token> llama_tokenize(
const struct llama_context * ctx,
const std::string & text,
bool add_bos);
bool add_bos,
bool allow_special_tokens = false);
std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_bos);
bool add_bos,
bool allow_special_tokens = false);
// tokenizes a token into a piece
// should work similar to Python's `tokenizer.id_to_piece`

View file

@ -863,7 +863,7 @@ size_t tokenize_file(
(int) buf.size(),
out_tokens.data(),
(int) out_tokens.size(),
false);
false,false);
if (n_tokens < 0) {
out_tokens.resize(-n_tokens);
n_tokens = llama_tokenize(
@ -872,7 +872,7 @@ size_t tokenize_file(
(int) buf.size(),
out_tokens.data(),
(int) out_tokens.size(),
false);
false,false);
}
if (n_tokens >= 0) {
out_tokens.resize(n_tokens);
@ -966,7 +966,7 @@ size_t tokenize_file(
(int) buf_sample.size(),
tok_sample.data(),
(int) tok_sample.size(),
false);
false,false);
if (n_tokens < 0) {
tok_sample.resize(-n_tokens);
n_tokens = llama_tokenize(llama_get_model(lctx),
@ -974,7 +974,7 @@ size_t tokenize_file(
(int) buf_sample.size(),
tok_sample.data(),
(int) tok_sample.size(),
false);
false,false);
GGML_ASSERT(n_tokens >= 0);
}
GGML_ASSERT(n_tokens <= (int) tok_sample.size());

View file

@ -237,7 +237,7 @@ int main(int argc, char ** argv) {
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
LOG("tokenize the prompt\n");
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
} else {
LOG("use session tokens\n");
embd_inp = session_tokens;
@ -259,10 +259,10 @@ int main(int argc, char ** argv) {
if (ctx_guidance) {
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(params.cfg_negative_prompt));
guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos);
guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos, true);
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp));
original_prompt_len = original_inp.size();
@ -316,8 +316,8 @@ int main(int argc, char ** argv) {
}
// prefix & suffix for instruct mode
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos);
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos, true);
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true);
LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx));
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx));
@ -715,7 +715,7 @@ int main(int argc, char ** argv) {
if (params.interactive) {
if (!params.antiprompt.empty()) {
// tokenize and inject first reverse prompt
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false, true);
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
is_antiprompt = true;
}
@ -780,7 +780,7 @@ int main(int argc, char ** argv) {
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
}
const auto line_inp = ::llama_tokenize(ctx, buffer, false);
const auto line_inp = ::llama_tokenize(ctx, buffer, false, true);
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp));
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());

231
llama.cpp
View file

@ -75,6 +75,7 @@
#include <thread>
#include <unordered_map>
#include <set>
#include <forward_list>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
@ -1154,6 +1155,8 @@ struct llama_vocab {
std::unordered_map<token, id> token_to_id;
std::vector<token_data> id_to_token;
std::unordered_map<token, id> special_tokens_cache;
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
// default LLaMA special tokens
@ -2063,7 +2066,7 @@ static void llm_load_hparams(
}
// TODO: This should probably be in llama.h
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos);
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool allow_special_tokens = false);
static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch);
static void llm_load_vocab(
@ -2179,6 +2182,74 @@ static void llm_load_vocab(
GGUF_GET_KEY(ctx, vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID));
GGUF_GET_KEY(ctx, vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID));
GGUF_GET_KEY(ctx, vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID));
// build special tokens cache
{
// TODO: It is unclear (to me) at this point, whether special tokes are guaranteed to be of a deterministic type,
// and will always be correctly labeled in 'added_tokens.json' etc.
// The assumption is, since special tokens aren't meant to be exposed to end user, they are designed
// to be unmatchable by the tokenizer, therefore tokens from the vocab, which are unmatchable by the tokenizer
// are special tokens.
// From testing, this appears to corelate 1:1 with special tokens.
//
for (const auto & t: vocab.token_to_id)
{
const auto & token = t.first;
const auto & id = t.second;
if( token.length() > 1 )
{
bool is_tokenizable = false;
for (unsigned i = 1; i < token.length();)
{
const auto left = token.substr(0, i);
const auto right = token.substr(i);
// check if we didnt partition in the middle of a utf sequence
auto utf = utf8_len( left.at( left.length() -1 ) );
if( utf == 1 )
{
//fprintf(stderr, "BSTC . '%s' '%s' '%s'\n", token.c_str(), left.c_str(), right.c_str());
if (vocab.token_to_id.find( left ) != vocab.token_to_id.end() &&
vocab.token_to_id.find( right ) != vocab.token_to_id.end() )
{
is_tokenizable = true;
break;
}
i++;
}
else
{
// fprintf(stderr, "BSTC SKIP '%s' '%s' '%s'\n", token.c_str(), left.c_str(), right.c_str());
// skip over the rest of multibyte utf sequence
i += utf - 1;
}
}
if (!is_tokenizable)
{
// it's faster to re-filter them here, since there is way less candidates now
size_t utf8_str_len = 0;
for (unsigned i = 0; i < token.length();)
{
utf8_str_len++;
i += utf8_len( token.at(i) );
}
if (utf8_str_len > 1)
{
//fprintf(stderr, "BSTC SPECIAL '%s' '%d' ('%ld')\n", token.c_str(), id, utf8_str_len);
vocab.special_tokens_cache[token] = id;
}
}
}
}
}
}
static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
@ -5686,7 +5757,115 @@ private:
llm_bigram_bpe::queue work_queue;
};
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos) {
typedef enum FRAGMENT_BUFFER_VARIANT_TYPE{
FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN,
FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT
} FRAGMENT_BUFFER_VARIANT_TYPE;
struct fragment_buffer_variant{
fragment_buffer_variant(llama_vocab::id token)
:
type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN),
token(token){}
fragment_buffer_variant(std::string raw_text)
:
type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT),
raw_text(raw_text){}
FRAGMENT_BUFFER_VARIANT_TYPE type;
llama_vocab::id token;
std::string raw_text;
};
void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer)
{
// for each special token
for( const auto & st: vocab.special_tokens_cache )
{
const auto & special_token = st.first;
const auto & special_id = st.second;
// for each text fragment
//for (auto & fragment: buffer)
std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
while (it != buffer.end())
{
auto & fragment = (*it);
// if a fragment is text ( not yet processed )
if( fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT )
{
auto * raw_text = &(fragment.raw_text);
// loop over the text
while(true)
{
// find the first occurence of a given special token in this fragment
auto match = raw_text->find( special_token );
// no occurences found, stop processing this fragment for a given special token
if (match == std::string::npos)
{
break;
}
auto source = std::distance( buffer.begin(), it );
if( match > 0 )
{
// left
buffer.emplace_after(it, raw_text->substr(0, match));
it++;
}
// special token
buffer.emplace_after(it, special_id);
it++;
// right
if (match + special_token.length() < raw_text->length())
{
buffer.emplace_after(it, raw_text->substr(match + special_token.length()));
it++;
if (source == 0)
{
buffer.erase_after(buffer.before_begin());
}
else
{
auto prev = std::prev( buffer.begin(), -(source-1) );
buffer.erase_after(prev);
}
//it = std::prev( it, 1 );
// repeat for the right side
raw_text = &((*it).raw_text);
}
else
{
if (source == 0)
{
buffer.erase_after(buffer.before_begin());
}
else
{
auto prev = std::prev( buffer.begin(), -(source) );
buffer.erase_after(prev);
}
//it = std::prev( it, 1 );
break;
}
}
}
it++;
}
}
}
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool allow_special_tokens) {
std::vector<llama_vocab::id> output;
// OG tokenizer behavior:
@ -5702,20 +5881,48 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
return output;
}
std::forward_list<fragment_buffer_variant> fragment_buffer;
fragment_buffer.emplace_front( raw_text );
if (allow_special_tokens) {
tokenizer_st_partition( vocab, fragment_buffer );
}
switch (vocab.type) {
case LLAMA_VOCAB_TYPE_SPM:
{
// without adding this leading whitespace, we do not get the same results as the original tokenizer
raw_text = " " + raw_text;
for (const auto & fragment: fragment_buffer)
{
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT)
{
// without adding this leading whitespace, we do not get the same results as the original tokenizer
auto raw_text = " " + fragment.raw_text;
llm_tokenizer_spm tokenizer(vocab);
llama_escape_whitespace(raw_text);
tokenizer.tokenize(raw_text, output);
llm_tokenizer_spm tokenizer(vocab);
llama_escape_whitespace(raw_text);
tokenizer.tokenize(raw_text, output);
}
else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
{
output.push_back(fragment.token);
}
}
} break;
case LLAMA_VOCAB_TYPE_BPE:
{
llm_tokenizer_bpe tokenizer(vocab);
tokenizer.tokenize(raw_text, output);
for (const auto & fragment: fragment_buffer)
{
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT)
{
llm_tokenizer_bpe tokenizer(vocab);
tokenizer.tokenize(raw_text, output);
}
else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
{
output.push_back(fragment.token);
}
}
} break;
}
@ -8629,15 +8836,15 @@ llama_token llama_token_eot(const struct llama_context * ctx) {
return ctx->model.vocab.special_eot_id;
}
int llama_tokenize(
const struct llama_model * model,
const char * text,
int text_len,
llama_token * tokens,
int n_max_tokens,
bool add_bos) {
auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos);
bool add_bos,
bool allow_special_tokens) {
auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, allow_special_tokens);
if (n_max_tokens < (int) res.size()) {
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);

View file

@ -521,7 +521,8 @@ extern "C" {
int text_len,
llama_token * tokens,
int n_max_tokens,
bool add_bos);
bool add_bos,
bool allow_special_tokens);
// Token Id -> Piece.
// Uses the vocabulary in the provided context.