Moving byte decoding back to token_to_piece ...

... because everyone is using it.
This commit is contained in:
goerch 2023-09-19 13:24:04 +02:00
parent c0990bb739
commit 1b7c3692af
2 changed files with 15 additions and 12 deletions

View file

@ -847,16 +847,6 @@ std::string llama_detokenize_spm(llama_context * ctx, const std::vector<llama_to
return result;
}
std::string llama_decode_text(const std::string& text) {
std::string decoded_text;
auto unicode_sequences = codepoints_from_utf8(text);
for (auto& unicode_sequence : unicode_sequences) {
decoded_text += unicode_to_bytes_bpe(codepoint_to_utf8(unicode_sequence));
}
return decoded_text;
}
std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_token> & tokens) {
std::string piece;
std::string result;
@ -867,7 +857,8 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_to
result += piece;
}
return llama_decode_text(result);
// NOTE: the original tokenizer decodes bytes after collecting the pieces.
return result;
}
//

View file

@ -7207,6 +7207,16 @@ int llama_token_to_piece(const struct llama_context * ctx, llama_token token, ch
return llama_token_to_piece_with_model(&ctx->model, token, buf, length);
}
static std::string llama_decode_text(const std::string& text) {
std::string decoded_text;
auto unicode_sequences = codepoints_from_utf8(text);
for (auto& unicode_sequence : unicode_sequences) {
decoded_text += unicode_to_bytes_bpe(codepoint_to_utf8(unicode_sequence));
}
return decoded_text;
}
// does not write null-terminator to buf
int llama_token_to_piece_with_model(const struct llama_model * model, llama_token token, char * buf, int length) {
if (0 <= token && token < llama_model_n_vocab(model)) {
@ -7214,6 +7224,8 @@ int llama_token_to_piece_with_model(const struct llama_model * model, llama_toke
std::string result = model->vocab.id_to_token[token].text;
if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM) {
llama_unescape_whitespace(result);
} else if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_BPE) {
result = llama_decode_text(result);
}
if (length < (int) result.length()) {
return -result.length();
@ -7239,7 +7251,7 @@ int llama_token_to_piece_with_model(const struct llama_model * model, llama_toke
return 1;
}
else {
std::string result = model->vocab.id_to_token[token].text;
std::string result = llama_decode_text(model->vocab.id_to_token[token].text);
if (length < (int)result.length()) {
return -result.length();
}