Fixing llama_token_to_str for the different sentence_piece token types
This commit is contained in:
parent
b97a505c5d
commit
81fae1dc8f
3 changed files with 87 additions and 48 deletions
|
@ -1,4 +1,5 @@
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
#include "common.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
108
llama.cpp
108
llama.cpp
|
@ -1765,6 +1765,64 @@ static bool llama_eval_internal(
|
||||||
// tokenizer
|
// tokenizer
|
||||||
//
|
//
|
||||||
|
|
||||||
|
bool llama_is_normal_token(llama_token token) {
|
||||||
|
return token >= 259;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_is_unknown_token(llama_token token) {
|
||||||
|
return token == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_is_control_token(llama_token token) {
|
||||||
|
return token == 1 || token == 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_is_bos_token(llama_token token) {
|
||||||
|
return token == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_is_eos_token(llama_token token) {
|
||||||
|
return token == 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_is_user_defined_token(llama_token token) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_is_unused_token(llama_token token) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_is_byte_token(llama_token token) {
|
||||||
|
return 3 <= token && token < 259;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string llama_escape_whitespace(const std::string& text) {
|
||||||
|
std::string result;
|
||||||
|
bool escaping = false;
|
||||||
|
result += "\xe2\x96\x81";
|
||||||
|
for (size_t offs = 0; offs < text.length(); ++offs) {
|
||||||
|
if (text[offs] == ' ') {
|
||||||
|
if (!escaping) {
|
||||||
|
result += "\xe2\x96\x81";
|
||||||
|
escaping = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
escaping = false;
|
||||||
|
result += text[offs];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string llama_unescape_whitespace(const std::string& word) {
|
||||||
|
if (word.length() >= 3 && word.substr(0, 3) == "\xe2\x96\x81") {
|
||||||
|
return std::string(" ") + word.substr(3);
|
||||||
|
}
|
||||||
|
return word;
|
||||||
|
}
|
||||||
|
|
||||||
static size_t utf8_len(char src) {
|
static size_t utf8_len(char src) {
|
||||||
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||||
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
|
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
|
||||||
|
@ -1795,39 +1853,6 @@ struct llama_sp_bigram {
|
||||||
size_t size;
|
size_t size;
|
||||||
};
|
};
|
||||||
|
|
||||||
static std::string llama_escape_whitespace(const std::string& text) {
|
|
||||||
std::string result;
|
|
||||||
bool escaping = false;
|
|
||||||
result += char(0xe2);
|
|
||||||
result += char(0x96);
|
|
||||||
result += char(0x81);
|
|
||||||
for (size_t offs = 0; offs < text.length(); ++offs) {
|
|
||||||
if (text[offs] == ' ') {
|
|
||||||
if (!escaping) {
|
|
||||||
result += char(0xe2);
|
|
||||||
result += char(0x96);
|
|
||||||
result += char(0x81);
|
|
||||||
escaping = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
escaping = false;
|
|
||||||
result += text[offs];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string llama_unescape_whitespace(const std::string& word) {
|
|
||||||
if (word.length() >= 3 &&
|
|
||||||
word[0] == char(0xe2) &&
|
|
||||||
word[1] == char(0x96) &&
|
|
||||||
word[2] == char(0x81)) {
|
|
||||||
return std::string(" ") + word.substr(3);
|
|
||||||
}
|
|
||||||
return word;
|
|
||||||
}
|
|
||||||
|
|
||||||
// original implementation:
|
// original implementation:
|
||||||
// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
|
// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
|
||||||
struct llama_tokenizer {
|
struct llama_tokenizer {
|
||||||
|
@ -3727,12 +3752,29 @@ float * llama_get_embeddings(struct llama_context * ctx) {
|
||||||
|
|
||||||
int llama_token_to_str_with_model(const struct llama_model * model, llama_token token, char * str, int length) {
|
int llama_token_to_str_with_model(const struct llama_model * model, llama_token token, char * str, int length) {
|
||||||
if (0 <= token && token < llama_n_vocab_from_model(model)) {
|
if (0 <= token && token < llama_n_vocab_from_model(model)) {
|
||||||
|
if (llama_is_normal_token(token)) {
|
||||||
std::string result = llama_unescape_whitespace(model->vocab.id_to_token[token].tok);
|
std::string result = llama_unescape_whitespace(model->vocab.id_to_token[token].tok);
|
||||||
if(result.length() > length) {
|
if(result.length() > length) {
|
||||||
return - result.length();
|
return - result.length();
|
||||||
}
|
}
|
||||||
strcpy(str, result.c_str());
|
strcpy(str, result.c_str());
|
||||||
return result.length();
|
return result.length();
|
||||||
|
} else if (llama_is_unknown_token(token)) {
|
||||||
|
if(3 > length) {
|
||||||
|
return -3;
|
||||||
|
}
|
||||||
|
strcpy(str, "\xe2\x96\x85");
|
||||||
|
return 3;
|
||||||
|
} else if (llama_is_control_token(token)) {
|
||||||
|
;
|
||||||
|
} else if (llama_is_byte_token(token)) {
|
||||||
|
if(1 > length) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
str[0] = token - 3;
|
||||||
|
str[1] = 0x00;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,15 +12,11 @@
|
||||||
static std::string escape_whitespace(const std::string& text) {
|
static std::string escape_whitespace(const std::string& text) {
|
||||||
std::string result;
|
std::string result;
|
||||||
bool escaping = false;
|
bool escaping = false;
|
||||||
result += char(0xe2);
|
result += "\xe2\x96\x81";
|
||||||
result += char(0x96);
|
|
||||||
result += char(0x81);
|
|
||||||
for (size_t offs = 0; offs < text.length(); ++offs) {
|
for (size_t offs = 0; offs < text.length(); ++offs) {
|
||||||
if (text[offs] == ' ') {
|
if (text[offs] == ' ') {
|
||||||
if (!escaping) {
|
if (!escaping) {
|
||||||
result += char(0xe2);
|
result += "\xe2\x96\x81";
|
||||||
result += char(0x96);
|
|
||||||
result += char(0x81);
|
|
||||||
escaping = true;
|
escaping = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -93,15 +89,15 @@ int main(int argc, char **argv) {
|
||||||
if (n == 1) {
|
if (n == 1) {
|
||||||
if (i != tokens[0]) {
|
if (i != tokens[0]) {
|
||||||
std::string backward = llama_token_to_str(ctx, tokens[0]);
|
std::string backward = llama_token_to_str(ctx, tokens[0]);
|
||||||
fprintf(stderr, "%s : error: token %d is string %s but tokenize() returns token %d %s\n",
|
fprintf(stderr, "%s : error: token %d is string %s but bpe returns token %d %s\n",
|
||||||
__func__, i, llama_token_to_str(ctx, i).c_str(), tokens[0], backward.c_str());
|
__func__, i, llama_token_to_str(ctx, i).c_str(), tokens[0], backward.c_str());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (i <= 258) {
|
if (i <= 258) {
|
||||||
fprintf(stderr, "%s : info: token %d is string %s and tokenize() returns tokens %s\n",
|
fprintf(stderr, "%s : info: token %d is string %s and bpe returns tokens %s\n",
|
||||||
__func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens.data(), n).c_str());
|
__func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens.data(), n).c_str());
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "%s : error: token %d is string %s but tokenize() returns tokens %s\n",
|
fprintf(stderr, "%s : error: token %d is string %s but bpe returns tokens %s\n",
|
||||||
__func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens.data(), n).c_str());
|
__func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens.data(), n).c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue