Fixing llama_token_to_str for the different sentence_piece token types

This commit is contained in:
goerch 2023-07-24 09:05:21 +02:00
parent b97a505c5d
commit 81fae1dc8f
3 changed files with 87 additions and 48 deletions

View file

@ -1,4 +1,5 @@
#include "ggml.h" #include "ggml.h"
#include "common.h"
#include "llama.h" #include "llama.h"
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>

108
llama.cpp
View file

@ -1765,6 +1765,64 @@ static bool llama_eval_internal(
// tokenizer // tokenizer
// //
bool llama_is_normal_token(llama_token token) {
return token >= 259;
}
bool llama_is_unknown_token(llama_token token) {
return token == 0;
}
bool llama_is_control_token(llama_token token) {
return token == 1 || token == 2;
}
bool llama_is_bos_token(llama_token token) {
return token == 1;
}
bool llama_is_eos_token(llama_token token) {
return token == 2;
}
bool llama_is_user_defined_token(llama_token token) {
return false;
}
bool llama_is_unused_token(llama_token token) {
return false;
}
bool llama_is_byte_token(llama_token token) {
return 3 <= token && token < 259;
}
static std::string llama_escape_whitespace(const std::string& text) {
std::string result;
bool escaping = false;
result += "\xe2\x96\x81";
for (size_t offs = 0; offs < text.length(); ++offs) {
if (text[offs] == ' ') {
if (!escaping) {
result += "\xe2\x96\x81";
escaping = true;
}
}
else {
escaping = false;
result += text[offs];
}
}
return result;
}
static std::string llama_unescape_whitespace(const std::string& word) {
if (word.length() >= 3 && word.substr(0, 3) == "\xe2\x96\x81") {
return std::string(" ") + word.substr(3);
}
return word;
}
static size_t utf8_len(char src) { static size_t utf8_len(char src) {
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
uint8_t highbits = static_cast<uint8_t>(src) >> 4; uint8_t highbits = static_cast<uint8_t>(src) >> 4;
@ -1795,39 +1853,6 @@ struct llama_sp_bigram {
size_t size; size_t size;
}; };
static std::string llama_escape_whitespace(const std::string& text) {
std::string result;
bool escaping = false;
result += char(0xe2);
result += char(0x96);
result += char(0x81);
for (size_t offs = 0; offs < text.length(); ++offs) {
if (text[offs] == ' ') {
if (!escaping) {
result += char(0xe2);
result += char(0x96);
result += char(0x81);
escaping = true;
}
}
else {
escaping = false;
result += text[offs];
}
}
return result;
}
static std::string llama_unescape_whitespace(const std::string& word) {
if (word.length() >= 3 &&
word[0] == char(0xe2) &&
word[1] == char(0x96) &&
word[2] == char(0x81)) {
return std::string(" ") + word.substr(3);
}
return word;
}
// original implementation: // original implementation:
// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 // https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
struct llama_tokenizer { struct llama_tokenizer {
@ -3727,12 +3752,29 @@ float * llama_get_embeddings(struct llama_context * ctx) {
int llama_token_to_str_with_model(const struct llama_model * model, llama_token token, char * str, int length) { int llama_token_to_str_with_model(const struct llama_model * model, llama_token token, char * str, int length) {
if (0 <= token && token < llama_n_vocab_from_model(model)) { if (0 <= token && token < llama_n_vocab_from_model(model)) {
if (llama_is_normal_token(token)) {
std::string result = llama_unescape_whitespace(model->vocab.id_to_token[token].tok); std::string result = llama_unescape_whitespace(model->vocab.id_to_token[token].tok);
if(result.length() > length) { if(result.length() > length) {
return - result.length(); return - result.length();
} }
strcpy(str, result.c_str()); strcpy(str, result.c_str());
return result.length(); return result.length();
} else if (llama_is_unknown_token(token)) {
if(3 > length) {
return -3;
}
strcpy(str, "\xe2\x96\x85");
return 3;
} else if (llama_is_control_token(token)) {
;
} else if (llama_is_byte_token(token)) {
if(1 > length) {
return -1;
}
str[0] = token - 3;
str[1] = 0x00;
return 1;
}
} }
return 0; return 0;
} }

View file

@ -12,15 +12,11 @@
static std::string escape_whitespace(const std::string& text) { static std::string escape_whitespace(const std::string& text) {
std::string result; std::string result;
bool escaping = false; bool escaping = false;
result += char(0xe2); result += "\xe2\x96\x81";
result += char(0x96);
result += char(0x81);
for (size_t offs = 0; offs < text.length(); ++offs) { for (size_t offs = 0; offs < text.length(); ++offs) {
if (text[offs] == ' ') { if (text[offs] == ' ') {
if (!escaping) { if (!escaping) {
result += char(0xe2); result += "\xe2\x96\x81";
result += char(0x96);
result += char(0x81);
escaping = true; escaping = true;
} }
} }
@ -93,15 +89,15 @@ int main(int argc, char **argv) {
if (n == 1) { if (n == 1) {
if (i != tokens[0]) { if (i != tokens[0]) {
std::string backward = llama_token_to_str(ctx, tokens[0]); std::string backward = llama_token_to_str(ctx, tokens[0]);
fprintf(stderr, "%s : error: token %d is string %s but tokenize() returns token %d %s\n", fprintf(stderr, "%s : error: token %d is string %s but bpe returns token %d %s\n",
__func__, i, llama_token_to_str(ctx, i).c_str(), tokens[0], backward.c_str()); __func__, i, llama_token_to_str(ctx, i).c_str(), tokens[0], backward.c_str());
} }
} else { } else {
if (i <= 258) { if (i <= 258) {
fprintf(stderr, "%s : info: token %d is string %s and tokenize() returns tokens %s\n", fprintf(stderr, "%s : info: token %d is string %s and bpe returns tokens %s\n",
__func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens.data(), n).c_str()); __func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens.data(), n).c_str());
} else { } else {
fprintf(stderr, "%s : error: token %d is string %s but tokenize() returns tokens %s\n", fprintf(stderr, "%s : error: token %d is string %s but bpe returns tokens %s\n",
__func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens.data(), n).c_str()); __func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens.data(), n).c_str());
} }
} }