Fixing llama_token_to_str for the different sentence_piece token types
This commit is contained in:
parent
b97a505c5d
commit
81fae1dc8f
3 changed files with 87 additions and 48 deletions
|
@ -1,4 +1,5 @@
|
|||
#include "ggml.h"
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
|
120
llama.cpp
120
llama.cpp
|
@ -1765,6 +1765,64 @@ static bool llama_eval_internal(
|
|||
// tokenizer
|
||||
//
|
||||
|
||||
bool llama_is_normal_token(llama_token token) {
|
||||
return token >= 259;
|
||||
}
|
||||
|
||||
bool llama_is_unknown_token(llama_token token) {
|
||||
return token == 0;
|
||||
}
|
||||
|
||||
bool llama_is_control_token(llama_token token) {
|
||||
return token == 1 || token == 2;
|
||||
}
|
||||
|
||||
bool llama_is_bos_token(llama_token token) {
|
||||
return token == 1;
|
||||
}
|
||||
|
||||
bool llama_is_eos_token(llama_token token) {
|
||||
return token == 2;
|
||||
}
|
||||
|
||||
bool llama_is_user_defined_token(llama_token token) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool llama_is_unused_token(llama_token token) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool llama_is_byte_token(llama_token token) {
|
||||
return 3 <= token && token < 259;
|
||||
}
|
||||
|
||||
static std::string llama_escape_whitespace(const std::string& text) {
|
||||
std::string result;
|
||||
bool escaping = false;
|
||||
result += "\xe2\x96\x81";
|
||||
for (size_t offs = 0; offs < text.length(); ++offs) {
|
||||
if (text[offs] == ' ') {
|
||||
if (!escaping) {
|
||||
result += "\xe2\x96\x81";
|
||||
escaping = true;
|
||||
}
|
||||
}
|
||||
else {
|
||||
escaping = false;
|
||||
result += text[offs];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::string llama_unescape_whitespace(const std::string& word) {
|
||||
if (word.length() >= 3 && word.substr(0, 3) == "\xe2\x96\x81") {
|
||||
return std::string(" ") + word.substr(3);
|
||||
}
|
||||
return word;
|
||||
}
|
||||
|
||||
static size_t utf8_len(char src) {
|
||||
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
|
||||
|
@ -1795,39 +1853,6 @@ struct llama_sp_bigram {
|
|||
size_t size;
|
||||
};
|
||||
|
||||
static std::string llama_escape_whitespace(const std::string& text) {
|
||||
std::string result;
|
||||
bool escaping = false;
|
||||
result += char(0xe2);
|
||||
result += char(0x96);
|
||||
result += char(0x81);
|
||||
for (size_t offs = 0; offs < text.length(); ++offs) {
|
||||
if (text[offs] == ' ') {
|
||||
if (!escaping) {
|
||||
result += char(0xe2);
|
||||
result += char(0x96);
|
||||
result += char(0x81);
|
||||
escaping = true;
|
||||
}
|
||||
}
|
||||
else {
|
||||
escaping = false;
|
||||
result += text[offs];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::string llama_unescape_whitespace(const std::string& word) {
|
||||
if (word.length() >= 3 &&
|
||||
word[0] == char(0xe2) &&
|
||||
word[1] == char(0x96) &&
|
||||
word[2] == char(0x81)) {
|
||||
return std::string(" ") + word.substr(3);
|
||||
}
|
||||
return word;
|
||||
}
|
||||
|
||||
// original implementation:
|
||||
// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
|
||||
struct llama_tokenizer {
|
||||
|
@ -3727,12 +3752,29 @@ float * llama_get_embeddings(struct llama_context * ctx) {
|
|||
|
||||
int llama_token_to_str_with_model(const struct llama_model * model, llama_token token, char * str, int length) {
|
||||
if (0 <= token && token < llama_n_vocab_from_model(model)) {
|
||||
std::string result = llama_unescape_whitespace(model->vocab.id_to_token[token].tok);
|
||||
if(result.length() > length) {
|
||||
return - result.length();
|
||||
}
|
||||
strcpy(str, result.c_str());
|
||||
return result.length();
|
||||
if (llama_is_normal_token(token)) {
|
||||
std::string result = llama_unescape_whitespace(model->vocab.id_to_token[token].tok);
|
||||
if(result.length() > length) {
|
||||
return - result.length();
|
||||
}
|
||||
strcpy(str, result.c_str());
|
||||
return result.length();
|
||||
} else if (llama_is_unknown_token(token)) {
|
||||
if(3 > length) {
|
||||
return -3;
|
||||
}
|
||||
strcpy(str, "\xe2\x96\x85");
|
||||
return 3;
|
||||
} else if (llama_is_control_token(token)) {
|
||||
;
|
||||
} else if (llama_is_byte_token(token)) {
|
||||
if(1 > length) {
|
||||
return -1;
|
||||
}
|
||||
str[0] = token - 3;
|
||||
str[1] = 0x00;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -12,15 +12,11 @@
|
|||
static std::string escape_whitespace(const std::string& text) {
|
||||
std::string result;
|
||||
bool escaping = false;
|
||||
result += char(0xe2);
|
||||
result += char(0x96);
|
||||
result += char(0x81);
|
||||
result += "\xe2\x96\x81";
|
||||
for (size_t offs = 0; offs < text.length(); ++offs) {
|
||||
if (text[offs] == ' ') {
|
||||
if (!escaping) {
|
||||
result += char(0xe2);
|
||||
result += char(0x96);
|
||||
result += char(0x81);
|
||||
result += "\xe2\x96\x81";
|
||||
escaping = true;
|
||||
}
|
||||
}
|
||||
|
@ -93,15 +89,15 @@ int main(int argc, char **argv) {
|
|||
if (n == 1) {
|
||||
if (i != tokens[0]) {
|
||||
std::string backward = llama_token_to_str(ctx, tokens[0]);
|
||||
fprintf(stderr, "%s : error: token %d is string %s but tokenize() returns token %d %s\n",
|
||||
fprintf(stderr, "%s : error: token %d is string %s but bpe returns token %d %s\n",
|
||||
__func__, i, llama_token_to_str(ctx, i).c_str(), tokens[0], backward.c_str());
|
||||
}
|
||||
} else {
|
||||
if (i <= 258) {
|
||||
fprintf(stderr, "%s : info: token %d is string %s and tokenize() returns tokens %s\n",
|
||||
fprintf(stderr, "%s : info: token %d is string %s and bpe returns tokens %s\n",
|
||||
__func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens.data(), n).c_str());
|
||||
} else {
|
||||
fprintf(stderr, "%s : error: token %d is string %s but tokenize() returns tokens %s\n",
|
||||
fprintf(stderr, "%s : error: token %d is string %s but bpe returns tokens %s\n",
|
||||
__func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens.data(), n).c_str());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue