diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index f079fcd42..c7d939a40 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -215,50 +215,50 @@ class Model(ABC): except KeyError: raise NotImplementedError(f'Architecture {arch!r} not supported!') from None - @staticmethod - def from_model_architecture(model_architecture): - if model_architecture == "GPTNeoXForCausalLM": - return GPTNeoXModel - if model_architecture == "BloomForCausalLM": - return BloomModel - if model_architecture == "MPTForCausalLM": - return MPTModel - if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"): - return BaichuanModel - if model_architecture in ("FalconForCausalLM", "RWForCausalLM"): - return FalconModel - if model_architecture == "GPTBigCodeForCausalLM": - return StarCoderModel - if model_architecture == "GPTRefactForCausalLM": - return RefactModel - if model_architecture == "PersimmonForCausalLM": - return PersimmonModel - if model_architecture == "LlamaForCausalLM": - return DeepseekCoderModel - if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"): - return StableLMModel - if model_architecture == "QWenLMHeadModel": - return QwenModel - if model_architecture == "Qwen2ForCausalLM": - return Model - if model_architecture == "MixtralForCausalLM": - return MixtralModel - if model_architecture == "GPT2LMHeadModel": - return GPT2Model - if model_architecture == "PhiForCausalLM": - return Phi2Model - if model_architecture == "PlamoForCausalLM": - return PlamoModel - if model_architecture == "CodeShellForCausalLM": - return CodeShellModel - if model_architecture == "OrionForCausalLM": - return OrionModel - if model_architecture == "InternLM2ForCausalLM": - return InternLM2Model - if model_architecture == "MiniCPMForCausalLM": - return MiniCPMModel - if model_architecture == "BertModel": - return BertModel + # @staticmethod + # def from_model_architecture(model_architecture): + # if model_architecture == "GPTNeoXForCausalLM": + # return GPTNeoXModel + # if model_architecture == "BloomForCausalLM": + # return BloomModel + # if model_architecture == "MPTForCausalLM": + # return MPTModel + # if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"): + # return BaichuanModel + # if model_architecture in ("FalconForCausalLM", "RWForCausalLM"): + # return FalconModel + # if model_architecture == "GPTBigCodeForCausalLM": + # return StarCoderModel + # if model_architecture == "GPTRefactForCausalLM": + # return RefactModel + # if model_architecture == "PersimmonForCausalLM": + # return PersimmonModel + # if model_architecture == "LlamaForCausalLM": + # return LlamaModel + # if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"): + # return StableLMModel + # if model_architecture == "QWenLMHeadModel": + # return QwenModel + # if model_architecture == "Qwen2ForCausalLM": + # return Model + # if model_architecture == "MixtralForCausalLM": + # return MixtralModel + # if model_architecture == "GPT2LMHeadModel": + # return GPT2Model + # if model_architecture == "PhiForCausalLM": + # return Phi2Model + # if model_architecture == "PlamoForCausalLM": + # return PlamoModel + # if model_architecture == "CodeShellForCausalLM": + # return CodeShellModel + # if model_architecture == "OrionForCausalLM": + # return OrionModel + # if model_architecture == "InternLM2ForCausalLM": + # return InternLM2Model + # if model_architecture == "MiniCPMForCausalLM": + # return MiniCPMModel + # if model_architecture == "BertModel": + # return BertModel @staticmethod def from_model_name(model_name: str): @@ -281,10 +281,8 @@ class Model(ABC): return RefactModel if model_name_lower == "persimmon": return PersimmonModel - if model_name_lower == "deepseekcoder": - return DeepseekCoderModel - if model_name_lower == "deepseekllm": - return DeepseekLLMModel + if model_name_lower in ("llama", "deepseekcoder", "deepseekllm"): + return LlamaModel return Model def _is_model_safetensors(self) -> bool: @@ -376,7 +374,6 @@ class Model(ABC): return tokens, toktypes - def _set_vocab_gpt2(self, tokenizer_model:str = "gpt2") -> None: tokens, toktypes = self.get_basic_vocab() self.gguf_writer.add_tokenizer_model(tokenizer_model) @@ -1312,31 +1309,7 @@ class PersimmonModel(Model): n_dims = len(data.shape) print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") self.gguf_writer.add_tensor(new_name, data) -@Model.register("LlamaForCausalLM") -class DeepseekCoderModel(Model): - model_arch = gguf.MODEL_ARCH.LLAMA - def set_gguf_parameters(self): - super().set_gguf_parameters() - head_count = self.hparams["num_attention_heads"] - head_count_kv = self.hparams.get("num_key_value_heads", head_count) - self.gguf_writer.add_head_count(head_count) - self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) - self.gguf_writer.add_head_count_kv(head_count_kv) - self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) - self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"]) - - if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: - if self.hparams["rope_scaling"].get("type") == "linear": - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) - self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) - - def set_vocab(self): - self._set_vocab_gpt2("deepseek_coder") - -class DeepseekLLMModel(DeepseekCoderModel): - def set_vocab(self): - self._set_vocab_gpt2("deepseek_llm") @Model.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM") class StableLMModel(Model): @@ -1479,6 +1452,11 @@ class LlamaModel(Model): self.gguf_writer.add_vocab_size(hparams["vocab_size"]) self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"]) + if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: + if self.hparams["rope_scaling"].get("type") == "linear": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + # Same as super class, but permuting q_proj, k_proj def write_tensors(self): block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) diff --git a/llama.cpp b/llama.cpp index e3a32cf7e..09d8a0dd8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2114,6 +2114,7 @@ struct llama_vocab { ttype type; }; + enum llm_arch arch = LLM_ARCH_UNKNOWN; enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; std::unordered_map token_to_id; @@ -4243,10 +4244,6 @@ static void llm_load_vocab( } else { if (tokenizer_name == "gpt2") { vocab.type = LLAMA_VOCAB_TYPE_BPE; - } else if (tokenizer_name == "deepseek_coder") { - vocab.type = LLAMA_VOCAB_TYPE_DEEPSEEKCODER; - } else if (tokenizer_name == "deepseek_llm") { - vocab.type = LLAMA_VOCAB_TYPE_DEEPSEEKLLM; } else { LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str()); LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); @@ -4287,6 +4284,8 @@ static void llm_load_vocab( vocab.special_cls_id = -1; vocab.special_mask_id = -1; } + + vocab.arch = model.arch; } const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); @@ -11784,10 +11783,9 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { auto buf = token_data.text.substr(3, 2); return strtol(buf.c_str(), NULL, 16); } - case LLAMA_VOCAB_TYPE_DEEPSEEKCODER: case LLAMA_VOCAB_TYPE_BPE: { GGML_ASSERT(false); - return unicode_utf8_to_byte(token_data.text); + return unicode_utf8_to_byte(token_data.text); // TODO: why is this here after GGML_ASSERT? } case LLAMA_VOCAB_TYPE_WPM: { GGML_ASSERT(false); @@ -11812,7 +11810,6 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { return vocab.token_to_id.at(buf2); } case LLAMA_VOCAB_TYPE_WPM: - case LLAMA_VOCAB_TYPE_DEEPSEEKCODER: case LLAMA_VOCAB_TYPE_BPE: { return vocab.token_to_id.at(unicode_byte_to_utf8(ch)); } @@ -12014,33 +12011,43 @@ struct llm_tokenizer_bpe { std::vector word_collection; switch (vocab.type) { case LLAMA_VOCAB_TYPE_BPE: - word_collection = unicode_regex_split(text, { - "\\p{P}+", - "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", - "\\p{N}+", - "[0-9][0-9][0-9]" - }); - break; - case LLAMA_VOCAB_TYPE_DEEPSEEKCODER: - word_collection = unicode_regex_split(text, { - "[\r\n]", - "\\s?\\p{L}+", - "\\s?\\p{P}+", - "[一-龥ࠀ-一가-퟿]+", - "\\p{N}+" - }); - break; - case LLAMA_VOCAB_TYPE_DEEPSEEKLLM: - word_collection = unicode_regex_split(text, { - "[\r\n]", - "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+", - "\\s?[!-/:-~!-/:-~‘-‟ -。]+", - "\\s+$", - "[一-龥ࠀ-一가-퟿]+", - "\\p{N}+" - }); + switch (vocab.arch) { + // TODO: how to detect deepseek and llama v3 models? + //case LLM_ARCH_LLAMA: + //case LLM_ARCH_DEEPSEEK_CODER: + // word_collection = unicode_regex_split(text, { + // "[\r\n]", + // "\\s?\\p{L}+", + // "\\s?\\p{P}+", + // "[一-龥ࠀ-一가-퟿]+", + // "\\p{N}+" + // }); + // break; + //case LLM_ARCH_DEEPSEEK_LLM: + // word_collection = unicode_regex_split(text, { + // "[\r\n]", + // "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+", + // "\\s?[!-/:-~!-/:-~‘-‟ -。]+", + // "\\s+$", + // "[一-龥ࠀ-一가-퟿]+", + // "\\p{N}+" + // }); + // break; + default: + // default regex for BPE tokenization pre-processing + { + word_collection = unicode_regex_split(text, { + "\\p{P}+", + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + "\\p{N}+", + "[0-9][0-9][0-9]" + }); + } + break; + } break; default: + GGML_ASSERT(false); break; } @@ -12486,8 +12493,6 @@ static std::vector llama_tokenize_internal(const llama_vocab & output.push_back(vocab.special_eos_id); } } break; - case LLAMA_VOCAB_TYPE_DEEPSEEKCODER: - case LLAMA_VOCAB_TYPE_DEEPSEEKLLM: case LLAMA_VOCAB_TYPE_BPE: { if (add_special && vocab.special_add_bos == 1) { @@ -17188,8 +17193,6 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token } break; } - case LLAMA_VOCAB_TYPE_DEEPSEEKCODER: - case LLAMA_VOCAB_TYPE_DEEPSEEKLLM: case LLAMA_VOCAB_TYPE_BPE: { // NOTE: we accept all unsupported token types, // suppressing them like CONTROL tokens. diff --git a/llama.h b/llama.h index 3062dbd2b..8aa763672 100644 --- a/llama.h +++ b/llama.h @@ -67,8 +67,6 @@ extern "C" { LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece - LLAMA_VOCAB_TYPE_DEEPSEEKCODER = 4, // Deepseek Coder - LLAMA_VOCAB_TYPE_DEEPSEEKLLM = 5, // Deepseek LLM }; // note: these values should be synchronized with ggml_rope diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 5f1bb729f..4f0889007 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -41,12 +41,13 @@ llama_test(test-quantize-perf.cpp) llama_test(test-sampling.cpp) llama_test(test-chat-template.cpp) +# TODO: tmp disabled LLaMA v3 and Deepseek tests +llama_test(test-tokenizer-0-llama.cpp NAME test-tokenizer-0-llama ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf) +#llama_test(test-tokenizer-0-llama-v3.cpp NAME test-tokenizer-0-llama-v3 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-v3.gguf) +llama_test(test-tokenizer-0-falcon.cpp NAME test-tokenizer-0-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf) -llama_test(test-tokenizer-0-llama.cpp NAME test-tokenizer-0-llama ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf) -llama_test(test-tokenizer-0-falcon.cpp NAME test-tokenizer-0-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf) - -llama_test(test-tokenizer-0-deepseek-coder.cpp NAME test-tokenizer-0-deepseek-coder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-deepseek-coder.gguf) -llama_test(test-tokenizer-0-deepseek-llm.cpp NAME test-tokenizer-0-deepseek-llm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-deepseek-llm.gguf) +#llama_test(test-tokenizer-0-deepseek-coder.cpp NAME test-tokenizer-0-deepseek-coder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-deepseek-coder.gguf) +#llama_test(test-tokenizer-0-deepseek-llm.cpp NAME test-tokenizer-0-deepseek-llm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-deepseek-llm.gguf) llama_test(test-tokenizer-1-llama.cpp NAME test-tokenizer-1-llama ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf) llama_test(test-tokenizer-1-llama.cpp NAME test-tokenizer-1-baichuan ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf) diff --git a/tests/test-tokenizer-0-deepseek-coder.cpp b/tests/test-tokenizer-0-deepseek-coder.cpp index 1be6b7ab7..a3dc0047d 100644 --- a/tests/test-tokenizer-0-deepseek-coder.cpp +++ b/tests/test-tokenizer-0-deepseek-coder.cpp @@ -89,8 +89,8 @@ int main(int argc, char **argv) { } } - if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_DEEPSEEKCODER) { - fprintf(stderr, "%s : error: vocab type is not DEEPSEEKCODER\n", __func__); + if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_BPE) { + fprintf(stderr, "%s : error: vocab type is not BPE\n", __func__); llama_free_model(model); llama_free(ctx); return 2; diff --git a/tests/test-tokenizer-0-deepseek-llm.cpp b/tests/test-tokenizer-0-deepseek-llm.cpp index 8afc0a81f..c621e02d9 100644 --- a/tests/test-tokenizer-0-deepseek-llm.cpp +++ b/tests/test-tokenizer-0-deepseek-llm.cpp @@ -89,8 +89,8 @@ int main(int argc, char **argv) { } } - if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_DEEPSEEKLLM) { - fprintf(stderr, "%s : error: vocab type is not DEEPSEEKLLM\n", __func__); + if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_BPE) { + fprintf(stderr, "%s : error: vocab type is not BPE\n", __func__); llama_free_model(model); llama_free(ctx); return 2; diff --git a/tests/test-tokenizer-0-llama-v3.cpp b/tests/test-tokenizer-0-llama-v3.cpp new file mode 100644 index 000000000..a0ecf6283 --- /dev/null +++ b/tests/test-tokenizer-0-llama-v3.cpp @@ -0,0 +1,197 @@ +#include "llama.h" +#include "common.h" +#include "console.h" + +#include +#include +#include +#include +#include + +// generate using test-tokenizer-0-llama.py +static const std::map> & k_tests() { + static std::map> _k_tests = { + { "" , { }, }, + { " " , { 220, }, }, + { " " , { 256, }, }, + { " " , { 262, }, }, + { "\t" , { 197, }, }, + { "\n" , { 198, }, }, + { "\t\n" , { 1602, }, }, + { "Hello world" , { 9906, 1917, }, }, + { " Hello world" , { 22691, 1917, }, }, + { "Hello World" , { 9906, 4435, }, }, + { " Hello World" , { 22691, 4435, }, }, + { " Hello World!" , { 22691, 4435, 0, }, }, + { "Hello, world!" , { 9906, 11, 1917, 0, }, }, + { " Hello, world!" , { 22691, 11, 1917, 0, }, }, + { " this is 🦙.cpp" , { 420, 374, 11410, 99, 247, 13, 11055, }, }, + { "w048 7tuijk dsdfhu" , { 86, 23904, 220, 22, 83, 2005, 42908, 11729, 3013, 17156, }, }, + { "нещо на Български" , { 79862, 102118, 13373, 64571, 34694, 3114, 112203, 80112, }, }, + { "កាន់តែពិសេសអាចខលចេញ" , { 21549, 222, 98629, 241, 45358, 233, 21549, 237, 45358, 224, 21549, 244, 21549, 115, 21549, 253, 45358, 223, 21549, 253, 21549, 95, 98629, 227, 21549, 223, 21549, 249, 21549, 227, 45358, 223, 21549, 231, }, }, + { "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)", { 9468, 248, 222, 320, 8416, 8, 27623, 114, 102470, 9468, 234, 104, 31643, 320, 36773, 100166, 98634, 8, 26602, 227, 320, 3323, 43465, 430, 706, 1202, 1866, 4037, 8, }, }, + { "Hello" , { 9906, }, }, + { " Hello" , { 22691, }, }, + { " Hello" , { 220, 22691, }, }, + { " Hello" , { 256, 22691, }, }, + { " Hello" , { 262, 22691, }, }, + { " Hello\n Hello" , { 262, 22691, 198, 262, 22691, }, }, + { " (" , { 320, }, }, + { "\n =" , { 198, 284, }, }, + { "' era" , { 6, 11639, }, }, + { "Hello, y'all! How are you 😁 ?我想在apple工作1314151天~", { 9906, 11, 379, 65948, 0, 2650, 527, 499, 27623, 223, 949, 37046, 101067, 19000, 23182, 102301, 9263, 18136, 16, 36827, 21909, }, }, + { "3" , { 18, }, }, + { "33" , { 1644, }, }, + { "333" , { 8765, }, }, + { "3333" , { 8765, 18, }, }, + { "33333" , { 8765, 1644, }, }, + { "333333" , { 8765, 8765, }, }, + { "3333333" , { 8765, 8765, 18, }, }, + { "33333333" , { 8765, 8765, 1644, }, }, + { "333333333" , { 8765, 8765, 8765, }, }, + }; + + return _k_tests; +} + +int main(int argc, char **argv) { + if (argc < 2) { + fprintf(stderr, "Usage: %s vocab-file [text-file]\n", argv[0]); + return 1; + } + + const std::string fname = argv[1]; + + std::string fname_text; + if (argc > 2) { + fname_text = argv[2]; + } + + fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str()); + + llama_model * model; + llama_context * ctx; + + llama_backend_init(); + + // load the vocab + { + auto mparams = llama_model_default_params(); + + mparams.vocab_only = true; + + model = llama_load_model_from_file(fname.c_str(), mparams); + + if (model == NULL) { + fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); + return 1; + } + + auto cparams = llama_context_default_params(); + + ctx = llama_new_context_with_model(model, cparams); + + if (ctx == NULL) { + fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); + llama_free_model(model); + return 1; + } + } + + if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_BPE) { + fprintf(stderr, "%s : error: vocab type is not BPE\n", __func__); + llama_free_model(model); + llama_free(ctx); + return 2; + } + +#ifdef _WIN32 + // We need this for unicode console support + console::init(false, false); + atexit([]() { console::cleanup(); }); +#endif + + bool success = true; + + for (const auto & test_kv : k_tests()) { + const std::vector res = llama_tokenize(ctx, test_kv.first, false); + + printf("\n"); + printf("src: '%s'\n", test_kv.first.c_str()); + printf("res: '%s'\n", llama_detokenize_bpe(ctx, res).c_str()); + printf("tok: "); + for (const auto & tok : res) { + printf("%d ", tok); + } + printf("\n"); + + bool correct = res.size() == test_kv.second.size(); + for (int i = 0; i < (int) res.size() && correct; ++i) { + if (test_kv.second[i] != res[i]) { + correct = false; + } + } + + if (!correct) { + fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str()); + fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__, + llama_detokenize_bpe(ctx, res).c_str(), + llama_detokenize_bpe(ctx, test_kv.second).c_str()); + fprintf(stderr, "%s : expected tokens: ", __func__); + for (const auto & t : test_kv.second) { + fprintf(stderr, "%6d, ", t); + } + fprintf(stderr, "\n"); + fprintf(stderr, "%s : got tokens: ", __func__); + for (const auto & t : res) { + fprintf(stderr, "%6d, ", t); + } + fprintf(stderr, "\n"); + + success = false; + } + } + + if (!fname_text.empty()) { + fprintf(stderr, "%s : tokenizing: '%s'\n", __func__, fname_text.c_str()); + + std::string text; + { + std::ifstream ifs(fname_text); + if (!ifs) { + fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_text.c_str()); + return 1; + } + text = std::string(std::istreambuf_iterator(ifs), std::istreambuf_iterator()); + } + + fprintf(stderr, "%s : text size: %zu\n", __func__, text.size()); + + const std::vector res = llama_tokenize(ctx, text, false); + + fprintf(stderr, "%s : tokens: %zu\n", __func__, res.size()); + + { + const std::string fname_out = fname_text + ".tokcpp"; + + std::ofstream ofs(fname_out); + if (!ofs) { + fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_out.c_str()); + return 1; + } + + for (const auto & tok : res) { + ofs << tok << " '" << llama_detokenize_bpe(ctx, std::vector{tok}) << "'" << std::endl; + } + } + + fprintf(stderr, "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str()); + } + + llama_free_model(model); + llama_free(ctx); + + llama_backend_free(); + + return success ? 0 : 3; +} diff --git a/tests/test-tokenizer-0-llama.cpp b/tests/test-tokenizer-0-llama.cpp index e6b13ab91..fd407041b 100644 --- a/tests/test-tokenizer-0-llama.cpp +++ b/tests/test-tokenizer-0-llama.cpp @@ -126,7 +126,7 @@ int main(int argc, char **argv) { } printf("\n"); - bool correct = res_nobos.size() == test_kv.second.size() && res_bos.size() == res_nobos.size() + 1 && res_bos[0] == 1; + bool correct = res_nobos.size() == test_kv.second.size() && res_bos.size() == res_nobos.size() + 1 && res_bos[0] == llama_token_bos(model); for (int i = 0; i < (int) res_nobos.size() && correct; ++i) { if (test_kv.second[i] != res_bos[i + 1]) {