Moving byte decoding back to token_to_piece ...

... because everyone is using it.
This commit is contained in:
goerch 2023-09-19 13:24:04 +02:00
parent c0990bb739
commit 1b7c3692af
2 changed files with 15 additions and 12 deletions

View file

@ -847,16 +847,6 @@ std::string llama_detokenize_spm(llama_context * ctx, const std::vector<llama_to
return result; return result;
} }
std::string llama_decode_text(const std::string& text) {
std::string decoded_text;
auto unicode_sequences = codepoints_from_utf8(text);
for (auto& unicode_sequence : unicode_sequences) {
decoded_text += unicode_to_bytes_bpe(codepoint_to_utf8(unicode_sequence));
}
return decoded_text;
}
std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_token> & tokens) { std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_token> & tokens) {
std::string piece; std::string piece;
std::string result; std::string result;
@ -867,7 +857,8 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_to
result += piece; result += piece;
} }
return llama_decode_text(result); // NOTE: the original tokenizer decodes bytes after collecting the pieces.
return result;
} }
// //

View file

@ -7207,6 +7207,16 @@ int llama_token_to_piece(const struct llama_context * ctx, llama_token token, ch
return llama_token_to_piece_with_model(&ctx->model, token, buf, length); return llama_token_to_piece_with_model(&ctx->model, token, buf, length);
} }
static std::string llama_decode_text(const std::string& text) {
std::string decoded_text;
auto unicode_sequences = codepoints_from_utf8(text);
for (auto& unicode_sequence : unicode_sequences) {
decoded_text += unicode_to_bytes_bpe(codepoint_to_utf8(unicode_sequence));
}
return decoded_text;
}
// does not write null-terminator to buf // does not write null-terminator to buf
int llama_token_to_piece_with_model(const struct llama_model * model, llama_token token, char * buf, int length) { int llama_token_to_piece_with_model(const struct llama_model * model, llama_token token, char * buf, int length) {
if (0 <= token && token < llama_model_n_vocab(model)) { if (0 <= token && token < llama_model_n_vocab(model)) {
@ -7214,6 +7224,8 @@ int llama_token_to_piece_with_model(const struct llama_model * model, llama_toke
std::string result = model->vocab.id_to_token[token].text; std::string result = model->vocab.id_to_token[token].text;
if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM) { if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM) {
llama_unescape_whitespace(result); llama_unescape_whitespace(result);
} else if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_BPE) {
result = llama_decode_text(result);
} }
if (length < (int) result.length()) { if (length < (int) result.length()) {
return -result.length(); return -result.length();
@ -7239,7 +7251,7 @@ int llama_token_to_piece_with_model(const struct llama_model * model, llama_toke
return 1; return 1;
} }
else { else {
std::string result = model->vocab.id_to_token[token].text; std::string result = llama_decode_text(model->vocab.id_to_token[token].text);
if (length < (int)result.length()) { if (length < (int)result.length()) {
return -result.length(); return -result.length();
} }