Rewrite special token handling from #1931
This commit is contained in:
parent
c47066d833
commit
b592c70deb
6 changed files with 243 additions and 31 deletions
|
@ -862,21 +862,23 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
|||
std::vector<llama_token> llama_tokenize(
|
||||
const struct llama_context * ctx,
|
||||
const std::string & text,
|
||||
bool add_bos) {
|
||||
return llama_tokenize(llama_get_model(ctx), text, add_bos);
|
||||
bool add_bos,
|
||||
bool allow_special_tokens) {
|
||||
return llama_tokenize(llama_get_model(ctx), text, add_bos, allow_special_tokens);
|
||||
}
|
||||
|
||||
std::vector<llama_token> llama_tokenize(
|
||||
const struct llama_model * model,
|
||||
const std::string & text,
|
||||
bool add_bos) {
|
||||
bool add_bos,
|
||||
bool allow_special_tokens) {
|
||||
// upper limit for the number of tokens
|
||||
int n_tokens = text.length() + add_bos;
|
||||
std::vector<llama_token> result(n_tokens);
|
||||
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
|
||||
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, allow_special_tokens);
|
||||
if (n_tokens < 0) {
|
||||
result.resize(-n_tokens);
|
||||
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
|
||||
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, allow_special_tokens);
|
||||
GGML_ASSERT(check == -n_tokens);
|
||||
} else {
|
||||
result.resize(n_tokens);
|
||||
|
|
|
@ -151,12 +151,14 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
|||
std::vector<llama_token> llama_tokenize(
|
||||
const struct llama_context * ctx,
|
||||
const std::string & text,
|
||||
bool add_bos);
|
||||
bool add_bos,
|
||||
bool allow_special_tokens = false);
|
||||
|
||||
std::vector<llama_token> llama_tokenize(
|
||||
const struct llama_model * model,
|
||||
const std::string & text,
|
||||
bool add_bos);
|
||||
bool add_bos,
|
||||
bool allow_special_tokens = false);
|
||||
|
||||
// tokenizes a token into a piece
|
||||
// should work similar to Python's `tokenizer.id_to_piece`
|
||||
|
|
|
@ -863,7 +863,7 @@ size_t tokenize_file(
|
|||
(int) buf.size(),
|
||||
out_tokens.data(),
|
||||
(int) out_tokens.size(),
|
||||
false);
|
||||
false,false);
|
||||
if (n_tokens < 0) {
|
||||
out_tokens.resize(-n_tokens);
|
||||
n_tokens = llama_tokenize(
|
||||
|
@ -872,7 +872,7 @@ size_t tokenize_file(
|
|||
(int) buf.size(),
|
||||
out_tokens.data(),
|
||||
(int) out_tokens.size(),
|
||||
false);
|
||||
false,false);
|
||||
}
|
||||
if (n_tokens >= 0) {
|
||||
out_tokens.resize(n_tokens);
|
||||
|
@ -966,7 +966,7 @@ size_t tokenize_file(
|
|||
(int) buf_sample.size(),
|
||||
tok_sample.data(),
|
||||
(int) tok_sample.size(),
|
||||
false);
|
||||
false,false);
|
||||
if (n_tokens < 0) {
|
||||
tok_sample.resize(-n_tokens);
|
||||
n_tokens = llama_tokenize(llama_get_model(lctx),
|
||||
|
@ -974,7 +974,7 @@ size_t tokenize_file(
|
|||
(int) buf_sample.size(),
|
||||
tok_sample.data(),
|
||||
(int) tok_sample.size(),
|
||||
false);
|
||||
false,false);
|
||||
GGML_ASSERT(n_tokens >= 0);
|
||||
}
|
||||
GGML_ASSERT(n_tokens <= (int) tok_sample.size());
|
||||
|
|
|
@ -237,7 +237,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
|
||||
LOG("tokenize the prompt\n");
|
||||
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
|
||||
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
|
||||
} else {
|
||||
LOG("use session tokens\n");
|
||||
embd_inp = session_tokens;
|
||||
|
@ -259,10 +259,10 @@ int main(int argc, char ** argv) {
|
|||
if (ctx_guidance) {
|
||||
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(params.cfg_negative_prompt));
|
||||
|
||||
guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos);
|
||||
guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos, true);
|
||||
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));
|
||||
|
||||
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
|
||||
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
|
||||
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp));
|
||||
|
||||
original_prompt_len = original_inp.size();
|
||||
|
@ -316,8 +316,8 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// prefix & suffix for instruct mode
|
||||
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos);
|
||||
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);
|
||||
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos, true);
|
||||
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true);
|
||||
|
||||
LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx));
|
||||
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx));
|
||||
|
@ -715,7 +715,7 @@ int main(int argc, char ** argv) {
|
|||
if (params.interactive) {
|
||||
if (!params.antiprompt.empty()) {
|
||||
// tokenize and inject first reverse prompt
|
||||
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
|
||||
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false, true);
|
||||
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
|
||||
is_antiprompt = true;
|
||||
}
|
||||
|
@ -780,7 +780,7 @@ int main(int argc, char ** argv) {
|
|||
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
|
||||
}
|
||||
|
||||
const auto line_inp = ::llama_tokenize(ctx, buffer, false);
|
||||
const auto line_inp = ::llama_tokenize(ctx, buffer, false, true);
|
||||
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp));
|
||||
|
||||
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
|
||||
|
|
219
llama.cpp
219
llama.cpp
|
@ -75,6 +75,7 @@
|
|||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <set>
|
||||
#include <forward_list>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
|
@ -1154,6 +1155,8 @@ struct llama_vocab {
|
|||
std::unordered_map<token, id> token_to_id;
|
||||
std::vector<token_data> id_to_token;
|
||||
|
||||
std::unordered_map<token, id> special_tokens_cache;
|
||||
|
||||
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
|
||||
|
||||
// default LLaMA special tokens
|
||||
|
@ -2063,7 +2066,7 @@ static void llm_load_hparams(
|
|||
}
|
||||
|
||||
// TODO: This should probably be in llama.h
|
||||
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos);
|
||||
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool allow_special_tokens = false);
|
||||
static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch);
|
||||
|
||||
static void llm_load_vocab(
|
||||
|
@ -2179,6 +2182,74 @@ static void llm_load_vocab(
|
|||
GGUF_GET_KEY(ctx, vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID));
|
||||
GGUF_GET_KEY(ctx, vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID));
|
||||
GGUF_GET_KEY(ctx, vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID));
|
||||
|
||||
// build special tokens cache
|
||||
{
|
||||
// TODO: It is unclear (to me) at this point, whether special tokes are guaranteed to be of a deterministic type,
|
||||
// and will always be correctly labeled in 'added_tokens.json' etc.
|
||||
// The assumption is, since special tokens aren't meant to be exposed to end user, they are designed
|
||||
// to be unmatchable by the tokenizer, therefore tokens from the vocab, which are unmatchable by the tokenizer
|
||||
// are special tokens.
|
||||
// From testing, this appears to corelate 1:1 with special tokens.
|
||||
//
|
||||
for (const auto & t: vocab.token_to_id)
|
||||
{
|
||||
const auto & token = t.first;
|
||||
const auto & id = t.second;
|
||||
|
||||
if( token.length() > 1 )
|
||||
{
|
||||
bool is_tokenizable = false;
|
||||
|
||||
for (unsigned i = 1; i < token.length();)
|
||||
{
|
||||
const auto left = token.substr(0, i);
|
||||
const auto right = token.substr(i);
|
||||
|
||||
// check if we didnt partition in the middle of a utf sequence
|
||||
auto utf = utf8_len( left.at( left.length() -1 ) );
|
||||
|
||||
if( utf == 1 )
|
||||
{
|
||||
//fprintf(stderr, "BSTC . '%s' '%s' '%s'\n", token.c_str(), left.c_str(), right.c_str());
|
||||
|
||||
if (vocab.token_to_id.find( left ) != vocab.token_to_id.end() &&
|
||||
vocab.token_to_id.find( right ) != vocab.token_to_id.end() )
|
||||
{
|
||||
is_tokenizable = true;
|
||||
break;
|
||||
}
|
||||
|
||||
i++;
|
||||
}
|
||||
else
|
||||
{
|
||||
// fprintf(stderr, "BSTC SKIP '%s' '%s' '%s'\n", token.c_str(), left.c_str(), right.c_str());
|
||||
// skip over the rest of multibyte utf sequence
|
||||
i += utf - 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_tokenizable)
|
||||
{
|
||||
// it's faster to re-filter them here, since there is way less candidates now
|
||||
size_t utf8_str_len = 0;
|
||||
for (unsigned i = 0; i < token.length();)
|
||||
{
|
||||
utf8_str_len++;
|
||||
i += utf8_len( token.at(i) );
|
||||
}
|
||||
|
||||
if (utf8_str_len > 1)
|
||||
{
|
||||
//fprintf(stderr, "BSTC SPECIAL '%s' '%d' ('%ld')\n", token.c_str(), id, utf8_str_len);
|
||||
|
||||
vocab.special_tokens_cache[token] = id;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
||||
|
@ -5686,7 +5757,115 @@ private:
|
|||
llm_bigram_bpe::queue work_queue;
|
||||
};
|
||||
|
||||
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos) {
|
||||
typedef enum FRAGMENT_BUFFER_VARIANT_TYPE{
|
||||
FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN,
|
||||
FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT
|
||||
} FRAGMENT_BUFFER_VARIANT_TYPE;
|
||||
|
||||
struct fragment_buffer_variant{
|
||||
fragment_buffer_variant(llama_vocab::id token)
|
||||
:
|
||||
type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN),
|
||||
token(token){}
|
||||
fragment_buffer_variant(std::string raw_text)
|
||||
:
|
||||
type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT),
|
||||
raw_text(raw_text){}
|
||||
|
||||
FRAGMENT_BUFFER_VARIANT_TYPE type;
|
||||
llama_vocab::id token;
|
||||
std::string raw_text;
|
||||
};
|
||||
|
||||
void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer)
|
||||
{
|
||||
// for each special token
|
||||
for( const auto & st: vocab.special_tokens_cache )
|
||||
{
|
||||
const auto & special_token = st.first;
|
||||
const auto & special_id = st.second;
|
||||
|
||||
// for each text fragment
|
||||
//for (auto & fragment: buffer)
|
||||
std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
|
||||
while (it != buffer.end())
|
||||
{
|
||||
auto & fragment = (*it);
|
||||
// if a fragment is text ( not yet processed )
|
||||
if( fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT )
|
||||
{
|
||||
auto * raw_text = &(fragment.raw_text);
|
||||
|
||||
// loop over the text
|
||||
while(true)
|
||||
{
|
||||
// find the first occurence of a given special token in this fragment
|
||||
auto match = raw_text->find( special_token );
|
||||
|
||||
// no occurences found, stop processing this fragment for a given special token
|
||||
if (match == std::string::npos)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
auto source = std::distance( buffer.begin(), it );
|
||||
|
||||
if( match > 0 )
|
||||
{
|
||||
// left
|
||||
buffer.emplace_after(it, raw_text->substr(0, match));
|
||||
it++;
|
||||
}
|
||||
|
||||
// special token
|
||||
buffer.emplace_after(it, special_id);
|
||||
it++;
|
||||
|
||||
|
||||
// right
|
||||
if (match + special_token.length() < raw_text->length())
|
||||
{
|
||||
buffer.emplace_after(it, raw_text->substr(match + special_token.length()));
|
||||
it++;
|
||||
|
||||
if (source == 0)
|
||||
{
|
||||
buffer.erase_after(buffer.before_begin());
|
||||
}
|
||||
else
|
||||
{
|
||||
auto prev = std::prev( buffer.begin(), -(source-1) );
|
||||
buffer.erase_after(prev);
|
||||
}
|
||||
//it = std::prev( it, 1 );
|
||||
|
||||
// repeat for the right side
|
||||
raw_text = &((*it).raw_text);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (source == 0)
|
||||
{
|
||||
buffer.erase_after(buffer.before_begin());
|
||||
}
|
||||
else
|
||||
{
|
||||
auto prev = std::prev( buffer.begin(), -(source) );
|
||||
buffer.erase_after(prev);
|
||||
}
|
||||
//it = std::prev( it, 1 );
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
it++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool allow_special_tokens) {
|
||||
std::vector<llama_vocab::id> output;
|
||||
|
||||
// OG tokenizer behavior:
|
||||
|
@ -5702,20 +5881,48 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
|||
return output;
|
||||
}
|
||||
|
||||
std::forward_list<fragment_buffer_variant> fragment_buffer;
|
||||
|
||||
fragment_buffer.emplace_front( raw_text );
|
||||
|
||||
if (allow_special_tokens) {
|
||||
tokenizer_st_partition( vocab, fragment_buffer );
|
||||
}
|
||||
|
||||
switch (vocab.type) {
|
||||
case LLAMA_VOCAB_TYPE_SPM:
|
||||
{
|
||||
for (const auto & fragment: fragment_buffer)
|
||||
{
|
||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT)
|
||||
{
|
||||
// without adding this leading whitespace, we do not get the same results as the original tokenizer
|
||||
raw_text = " " + raw_text;
|
||||
auto raw_text = " " + fragment.raw_text;
|
||||
|
||||
llm_tokenizer_spm tokenizer(vocab);
|
||||
llama_escape_whitespace(raw_text);
|
||||
tokenizer.tokenize(raw_text, output);
|
||||
}
|
||||
else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||
{
|
||||
output.push_back(fragment.token);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLAMA_VOCAB_TYPE_BPE:
|
||||
{
|
||||
for (const auto & fragment: fragment_buffer)
|
||||
{
|
||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT)
|
||||
{
|
||||
llm_tokenizer_bpe tokenizer(vocab);
|
||||
tokenizer.tokenize(raw_text, output);
|
||||
}
|
||||
else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||
{
|
||||
output.push_back(fragment.token);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
}
|
||||
|
||||
|
@ -8629,15 +8836,15 @@ llama_token llama_token_eot(const struct llama_context * ctx) {
|
|||
return ctx->model.vocab.special_eot_id;
|
||||
}
|
||||
|
||||
|
||||
int llama_tokenize(
|
||||
const struct llama_model * model,
|
||||
const char * text,
|
||||
int text_len,
|
||||
llama_token * tokens,
|
||||
int n_max_tokens,
|
||||
bool add_bos) {
|
||||
auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos);
|
||||
bool add_bos,
|
||||
bool allow_special_tokens) {
|
||||
auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, allow_special_tokens);
|
||||
|
||||
if (n_max_tokens < (int) res.size()) {
|
||||
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
|
||||
|
|
3
llama.h
3
llama.h
|
@ -521,7 +521,8 @@ extern "C" {
|
|||
int text_len,
|
||||
llama_token * tokens,
|
||||
int n_max_tokens,
|
||||
bool add_bos);
|
||||
bool add_bos,
|
||||
bool allow_special_tokens);
|
||||
|
||||
// Token Id -> Piece.
|
||||
// Uses the vocabulary in the provided context.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue