From 81fae1dc8fb502577dec6cee128d3c1684ea15f9 Mon Sep 17 00:00:00 2001 From: goerch Date: Mon, 24 Jul 2023 09:05:21 +0200 Subject: [PATCH] Fixing llama_token_to_str for the different sentence_piece token types --- .../train-text-from-scratch.cpp | 1 + llama.cpp | 120 ++++++++++++------ tests/test-tokenizer-1.cpp | 14 +- 3 files changed, 87 insertions(+), 48 deletions(-) diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 6c6806e5e..7375e77df 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1,4 +1,5 @@ #include "ggml.h" +#include "common.h" #include "llama.h" #include #include diff --git a/llama.cpp b/llama.cpp index 7dea8c9c7..fb71aef19 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1765,6 +1765,64 @@ static bool llama_eval_internal( // tokenizer // +bool llama_is_normal_token(llama_token token) { + return token >= 259; +} + +bool llama_is_unknown_token(llama_token token) { + return token == 0; +} + +bool llama_is_control_token(llama_token token) { + return token == 1 || token == 2; +} + +bool llama_is_bos_token(llama_token token) { + return token == 1; +} + +bool llama_is_eos_token(llama_token token) { + return token == 2; +} + +bool llama_is_user_defined_token(llama_token token) { + return false; +} + +bool llama_is_unused_token(llama_token token) { + return false; +} + +bool llama_is_byte_token(llama_token token) { + return 3 <= token && token < 259; +} + +static std::string llama_escape_whitespace(const std::string& text) { + std::string result; + bool escaping = false; + result += "\xe2\x96\x81"; + for (size_t offs = 0; offs < text.length(); ++offs) { + if (text[offs] == ' ') { + if (!escaping) { + result += "\xe2\x96\x81"; + escaping = true; + } + } + else { + escaping = false; + result += text[offs]; + } + } + return result; +} + +static std::string llama_unescape_whitespace(const std::string& word) { + if (word.length() >= 3 && word.substr(0, 3) == "\xe2\x96\x81") { + return std::string(" ") + word.substr(3); + } + return word; +} + static size_t utf8_len(char src) { const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; uint8_t highbits = static_cast(src) >> 4; @@ -1795,39 +1853,6 @@ struct llama_sp_bigram { size_t size; }; -static std::string llama_escape_whitespace(const std::string& text) { - std::string result; - bool escaping = false; - result += char(0xe2); - result += char(0x96); - result += char(0x81); - for (size_t offs = 0; offs < text.length(); ++offs) { - if (text[offs] == ' ') { - if (!escaping) { - result += char(0xe2); - result += char(0x96); - result += char(0x81); - escaping = true; - } - } - else { - escaping = false; - result += text[offs]; - } - } - return result; -} - -static std::string llama_unescape_whitespace(const std::string& word) { - if (word.length() >= 3 && - word[0] == char(0xe2) && - word[1] == char(0x96) && - word[2] == char(0x81)) { - return std::string(" ") + word.substr(3); - } - return word; -} - // original implementation: // https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 struct llama_tokenizer { @@ -3727,12 +3752,29 @@ float * llama_get_embeddings(struct llama_context * ctx) { int llama_token_to_str_with_model(const struct llama_model * model, llama_token token, char * str, int length) { if (0 <= token && token < llama_n_vocab_from_model(model)) { - std::string result = llama_unescape_whitespace(model->vocab.id_to_token[token].tok); - if(result.length() > length) { - return - result.length(); - } - strcpy(str, result.c_str()); - return result.length(); + if (llama_is_normal_token(token)) { + std::string result = llama_unescape_whitespace(model->vocab.id_to_token[token].tok); + if(result.length() > length) { + return - result.length(); + } + strcpy(str, result.c_str()); + return result.length(); + } else if (llama_is_unknown_token(token)) { + if(3 > length) { + return -3; + } + strcpy(str, "\xe2\x96\x85"); + return 3; + } else if (llama_is_control_token(token)) { + ; + } else if (llama_is_byte_token(token)) { + if(1 > length) { + return -1; + } + str[0] = token - 3; + str[1] = 0x00; + return 1; + } } return 0; } diff --git a/tests/test-tokenizer-1.cpp b/tests/test-tokenizer-1.cpp index d9a6293c0..836233f29 100644 --- a/tests/test-tokenizer-1.cpp +++ b/tests/test-tokenizer-1.cpp @@ -12,15 +12,11 @@ static std::string escape_whitespace(const std::string& text) { std::string result; bool escaping = false; - result += char(0xe2); - result += char(0x96); - result += char(0x81); + result += "\xe2\x96\x81"; for (size_t offs = 0; offs < text.length(); ++offs) { if (text[offs] == ' ') { if (!escaping) { - result += char(0xe2); - result += char(0x96); - result += char(0x81); + result += "\xe2\x96\x81"; escaping = true; } } @@ -93,15 +89,15 @@ int main(int argc, char **argv) { if (n == 1) { if (i != tokens[0]) { std::string backward = llama_token_to_str(ctx, tokens[0]); - fprintf(stderr, "%s : error: token %d is string %s but tokenize() returns token %d %s\n", + fprintf(stderr, "%s : error: token %d is string %s but bpe returns token %d %s\n", __func__, i, llama_token_to_str(ctx, i).c_str(), tokens[0], backward.c_str()); } } else { if (i <= 258) { - fprintf(stderr, "%s : info: token %d is string %s and tokenize() returns tokens %s\n", + fprintf(stderr, "%s : info: token %d is string %s and bpe returns tokens %s\n", __func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens.data(), n).c_str()); } else { - fprintf(stderr, "%s : error: token %d is string %s but tokenize() returns tokens %s\n", + fprintf(stderr, "%s : error: token %d is string %s but bpe returns tokens %s\n", __func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens.data(), n).c_str()); } }