diff --git a/common/common.cpp b/common/common.cpp index 4e40a5250..0d91a6a35 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -747,7 +747,7 @@ std::string llama_token_to_piece(const struct llama_context * ctx, llama_token t return std::string(result.data(), result.size()); } -std::string llama_detokenize(llama_context * ctx, const std::vector & tokens) { +std::string llama_detokenize_spm(llama_context * ctx, const std::vector & tokens) { const llama_token bos_id = llama_token_bos(ctx); std::string piece; @@ -767,3 +767,15 @@ std::string llama_detokenize(llama_context * ctx, const std::vector return result; } +std::string llama_detokenize_bpe(llama_context * ctx, const std::vector & tokens) { + std::string piece; + std::string result; + + for (size_t i = 0; i < tokens.size(); ++i) { + piece = llama_token_to_piece(ctx, tokens[i]); + + result += piece; + } + + return result; +} diff --git a/common/common.h b/common/common.h index 1c1acf989..97fda2be7 100644 --- a/common/common.h +++ b/common/common.h @@ -129,9 +129,18 @@ std::string llama_token_to_piece( const struct llama_context * ctx, llama_token token); +// TODO: these should be moved in llama.h C-style API under single `llama_detokenize` function +// that takes into account the tokenizer type and decides how to handle the leading space +// // detokenizes a vector of tokens into a string // should work similar to Python's `tokenizer.decode` // removes the leading space from the first non-BOS token -std::string llama_detokenize( +std::string llama_detokenize_spm( + llama_context * ctx, + const std::vector & tokens); + +// detokenizes a vector of tokens into a string +// should work similar to Python's `tokenizer.decode` +std::string llama_detokenize_bpe( llama_context * ctx, const std::vector & tokens); diff --git a/tests/test-tokenizer-0-falcon.cpp b/tests/test-tokenizer-0-falcon.cpp index 3063e2c64..836fb8ad2 100644 --- a/tests/test-tokenizer-0-falcon.cpp +++ b/tests/test-tokenizer-0-falcon.cpp @@ -82,10 +82,8 @@ int main(int argc, char **argv) { } } - const int n_vocab = llama_n_vocab(ctx); - - if (n_vocab != 65024) { - fprintf(stderr, "%s : expected 65024 tokens, got %d\n", __func__, n_vocab); + if (llama_vocab_type(ctx) != LLAMA_VOCAB_TYPE_BPE) { + fprintf(stderr, "%s : error: vocab type is not SPM\n", __func__); llama_free_model(model); llama_free(ctx); return 2; @@ -98,7 +96,7 @@ int main(int argc, char **argv) { printf("\n"); printf("src: '%s'\n", test_kv.first.c_str()); - printf("res: '%s'\n", llama_detokenize(ctx, res).c_str()); + printf("res: '%s'\n", llama_detokenize_bpe(ctx, res).c_str()); printf("tok: "); for (const auto & tok : res) { printf("%d ", tok); @@ -116,8 +114,8 @@ int main(int argc, char **argv) { if (!correct) { fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str()); fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__, - llama_detokenize(ctx, res).c_str(), - llama_detokenize(ctx, test_kv.second).c_str()); + llama_detokenize_bpe(ctx, res).c_str(), + llama_detokenize_bpe(ctx, test_kv.second).c_str()); fprintf(stderr, "%s : expected tokens: ", __func__); for (const auto & t : test_kv.second) { fprintf(stderr, "%6d, ", t); diff --git a/tests/test-tokenizer-0-llama.cpp b/tests/test-tokenizer-0-llama.cpp index c28cd2753..8630742c6 100644 --- a/tests/test-tokenizer-0-llama.cpp +++ b/tests/test-tokenizer-0-llama.cpp @@ -82,10 +82,8 @@ int main(int argc, char **argv) { } } - const int n_vocab = llama_n_vocab(ctx); - - if (n_vocab != 32000) { - fprintf(stderr, "%s : expected 32000 tokens, got %d\n", __func__, n_vocab); + if (llama_vocab_type(ctx) != LLAMA_VOCAB_TYPE_SPM) { + fprintf(stderr, "%s : error: vocab type is not SPM\n", __func__); llama_free_model(model); llama_free(ctx); return 2; @@ -99,7 +97,7 @@ int main(int argc, char **argv) { printf("\n"); printf("src: '%s'\n", test_kv.first.c_str()); - printf("res: '%s'\n", llama_detokenize(ctx, res_bos).c_str()); + printf("res: '%s'\n", llama_detokenize_spm(ctx, res_bos).c_str()); printf("tok: "); for (const auto & tok : res_bos) { printf("%d ", tok); @@ -120,8 +118,8 @@ int main(int argc, char **argv) { if (!correct) { fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str()); fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__, - llama_detokenize(ctx, res_nobos).c_str(), - llama_detokenize(ctx, test_kv.second).c_str()); + llama_detokenize_spm(ctx, res_nobos).c_str(), + llama_detokenize_spm(ctx, test_kv.second).c_str()); fprintf(stderr, "%s : expected tokens: ", __func__); for (const auto & t : test_kv.second) { fprintf(stderr, "%6d, ", t);