Add C API for adding special tokens

This commit is contained in:
Igor Pissolati 2023-08-07 17:30:12 -03:00
parent 099119f532
commit d9791bb48b
2 changed files with 20 additions and 9 deletions

View file

@ -281,6 +281,15 @@ struct llama_vocab {
llama_trie special_token_trie;
std::unordered_map<token, id> special_token_to_id;
size_t max_special_token_length = 0;
void add_special_token(const token & word, id token_id) {
special_token_trie.add(word);
special_token_to_id[word] = token_id;
if (max_special_token_length < word.size()) {
max_special_token_length = word.size();
}
}
};
struct llama_model {
@ -624,15 +633,8 @@ struct llama_file_loader {
for (uint32_t i = 0; i < vocab_sp; i++) {
llama_vocab::id token_id = i > 2 ? hparams.n_vocab_base + i : i;
const auto & word = vocab.id_to_token[token_id].tok;
if (word.empty()) {
continue;
}
vocab.special_token_trie.add(word);
vocab.special_token_to_id[word] = token_id;
if (vocab.max_special_token_length < word.size()) {
vocab.max_special_token_length = word.size();
if (!word.empty()) {
vocab.add_special_token(word, token_id);
}
}
}
@ -4263,6 +4265,10 @@ llama_token llama_token_nl() {
return 13;
}
void llama_add_special_token(struct llama_model * model, const char * token, llama_token token_id) {
model->vocab.add_special_token(token, token_id);
}
struct llama_timings llama_get_timings(struct llama_context * ctx) {
struct llama_timings result = {
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us,

View file

@ -373,6 +373,11 @@ extern "C" {
LLAMA_API llama_token llama_token_eos(); // end-of-sentence
LLAMA_API llama_token llama_token_nl(); // next-line
LLAMA_API void llama_add_special_token(
struct llama_model * model,
const char * token,
llama_token token_id);
// Grammar
//
LLAMA_API struct llama_grammar * llama_grammar_init(