diff --git a/llama.cpp b/llama.cpp index b73281d03..af005c12c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12200,12 +12200,14 @@ struct llm_tokenizer_bpe { void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; + int ignore_merges = false; std::vector word_collection; switch (vocab.type) { case LLAMA_VOCAB_TYPE_BPE: switch (vocab.type_pre) { case LLAMA_VOCAB_PRE_TYPE_LLAMA3: + ignore_merges = true; case LLAMA_VOCAB_PRE_TYPE_DBRX: word_collection = unicode_regex_split(text, { // original regex from tokenizer.json @@ -12292,7 +12294,7 @@ struct llm_tokenizer_bpe { symbols_final.clear(); for (auto & word : word_collection) { - if (vocab.token_to_id.find(word) != vocab.token_to_id.end()) { + if (ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) { llm_symbol sym; sym.text = word.c_str(); sym.n = word.size(); diff --git a/tests/test-tokenizer-1-bpe.cpp b/tests/test-tokenizer-1-bpe.cpp index 015d6a2e4..63e261d8d 100644 --- a/tests/test-tokenizer-1-bpe.cpp +++ b/tests/test-tokenizer-1-bpe.cpp @@ -13,15 +13,27 @@ #include int main(int argc, char **argv) { - if (argc < 2) { - fprintf(stderr, "Usage: %s \n", argv[0]); + if (argc < 2 || argc > 3) { + fprintf(stderr, "Usage: %s [--ignore-merges]\n", argv[0]); return 1; } const std::string fname = argv[1]; + bool ignore_merges = false; + if (argc == 3) { + if (std::strcmp(argv[2], "ignore-merges") != 0) { + fprintf(stderr, "Usage: %s [--ignore-merges]\n", argv[0]); + return 1; + } + ignore_merges = true; + } fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str()); + if (ignore_merges) { + fprintf(stderr, "%s : ignoring merges for tokens inside vocab\n", __func__); + } + llama_model * model; llama_context * ctx; @@ -66,7 +78,7 @@ int main(int argc, char **argv) { try { auto cps = unicode_cpts_from_utf8(str); std::vector tokens = llama_tokenize(ctx, str, false, true); - if (tokens.size() > 1) { + if (ignore_merges && tokens.size() > 1) { fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but " "tokenization of this to multiple tokens: [", @@ -76,6 +88,7 @@ int main(int argc, char **argv) { fprintf(stderr, ", %d", tokens[i]); } fprintf(stderr, "]\n"); + return 2; } std::string check = llama_detokenize_bpe(ctx, tokens); if (check != str) {