diff --git a/llama.cpp b/llama.cpp index 82b7638ae..03c73ee7b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2235,15 +2235,32 @@ static void llm_load_vocab( if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { vocab.linefeed_id = llama_byte_to_token(vocab, '\n'); } else { - vocab.linefeed_id = llama_tokenize_internal(vocab, "\u010A", false)[0]; + const std::vector ids = llama_tokenize_internal(vocab, "\u010A", false); + GGML_ASSERT(ids.size() == 1 && "model vocab missing newline token"); + vocab.linefeed_id = ids[0]; } // special tokens - GGUF_GET_KEY(ctx, vocab.special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_BOS_ID)); - GGUF_GET_KEY(ctx, vocab.special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_EOS_ID)); - GGUF_GET_KEY(ctx, vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID)); - GGUF_GET_KEY(ctx, vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID)); - GGUF_GET_KEY(ctx, vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID)); + { + const std::vector> special_token_types = { + { LLM_KV_TOKENIZER_BOS_ID, &vocab.special_bos_id }, + { LLM_KV_TOKENIZER_EOS_ID, &vocab.special_eos_id }, + { LLM_KV_TOKENIZER_UNK_ID, &vocab.special_unk_id }, + { LLM_KV_TOKENIZER_SEP_ID, &vocab.special_sep_id }, + { LLM_KV_TOKENIZER_PAD_ID, &vocab.special_pad_id }, + }; + for (auto & it : special_token_types ) { + int32_t id = -1; + const std::string kstr = kv(std::get<0>(it)); + + GGUF_GET_KEY(ctx, id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kstr); + if (id != -1 && (id < 0 || size_t(id) >= vocab.id_to_token.size())) { + LLAMA_LOG_WARN("%s: bad special token value %d for key '%s' -- ignoring\n", __func__, id, kstr.c_str()); + continue; + } + *(std::get<1>(it)) = id; + } + } // build special tokens cache { @@ -6084,11 +6101,10 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { } static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { + const char * hex = "0123456789ABCDEF"; switch (llama_vocab_get_type(vocab)) { case LLAMA_VOCAB_TYPE_SPM: { - char buf[7]; - int result = snprintf(buf, sizeof(buf), "<0x%02X>", ch); - GGML_ASSERT(0 <= result && result < 7); + const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 }; return vocab.token_to_id.at(buf); } case LLAMA_VOCAB_TYPE_BPE: {