diff --git a/llama.cpp b/llama.cpp index c620d9897..3a50090f8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -281,6 +281,15 @@ struct llama_vocab { llama_trie special_token_trie; std::unordered_map special_token_to_id; size_t max_special_token_length = 0; + + void add_special_token(const token & word, id token_id) { + special_token_trie.add(word); + special_token_to_id[word] = token_id; + + if (max_special_token_length < word.size()) { + max_special_token_length = word.size(); + } + } }; struct llama_model { @@ -624,15 +633,8 @@ struct llama_file_loader { for (uint32_t i = 0; i < vocab_sp; i++) { llama_vocab::id token_id = i > 2 ? hparams.n_vocab_base + i : i; const auto & word = vocab.id_to_token[token_id].tok; - if (word.empty()) { - continue; - } - - vocab.special_token_trie.add(word); - vocab.special_token_to_id[word] = token_id; - - if (vocab.max_special_token_length < word.size()) { - vocab.max_special_token_length = word.size(); + if (!word.empty()) { + vocab.add_special_token(word, token_id); } } } @@ -4263,6 +4265,10 @@ llama_token llama_token_nl() { return 13; } +void llama_add_special_token(struct llama_model * model, const char * token, llama_token token_id) { + model->vocab.add_special_token(token, token_id); +} + struct llama_timings llama_get_timings(struct llama_context * ctx) { struct llama_timings result = { /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, diff --git a/llama.h b/llama.h index fa1977f2d..519ee716d 100644 --- a/llama.h +++ b/llama.h @@ -373,6 +373,11 @@ extern "C" { LLAMA_API llama_token llama_token_eos(); // end-of-sentence LLAMA_API llama_token llama_token_nl(); // next-line + LLAMA_API void llama_add_special_token( + struct llama_model * model, + const char * token, + llama_token token_id); + // Grammar // LLAMA_API struct llama_grammar * llama_grammar_init(