Add C API for adding special tokens
This commit is contained in:
parent
099119f532
commit
d9791bb48b
2 changed files with 20 additions and 9 deletions
24
llama.cpp
24
llama.cpp
|
@ -281,6 +281,15 @@ struct llama_vocab {
|
||||||
llama_trie special_token_trie;
|
llama_trie special_token_trie;
|
||||||
std::unordered_map<token, id> special_token_to_id;
|
std::unordered_map<token, id> special_token_to_id;
|
||||||
size_t max_special_token_length = 0;
|
size_t max_special_token_length = 0;
|
||||||
|
|
||||||
|
void add_special_token(const token & word, id token_id) {
|
||||||
|
special_token_trie.add(word);
|
||||||
|
special_token_to_id[word] = token_id;
|
||||||
|
|
||||||
|
if (max_special_token_length < word.size()) {
|
||||||
|
max_special_token_length = word.size();
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_model {
|
struct llama_model {
|
||||||
|
@ -624,15 +633,8 @@ struct llama_file_loader {
|
||||||
for (uint32_t i = 0; i < vocab_sp; i++) {
|
for (uint32_t i = 0; i < vocab_sp; i++) {
|
||||||
llama_vocab::id token_id = i > 2 ? hparams.n_vocab_base + i : i;
|
llama_vocab::id token_id = i > 2 ? hparams.n_vocab_base + i : i;
|
||||||
const auto & word = vocab.id_to_token[token_id].tok;
|
const auto & word = vocab.id_to_token[token_id].tok;
|
||||||
if (word.empty()) {
|
if (!word.empty()) {
|
||||||
continue;
|
vocab.add_special_token(word, token_id);
|
||||||
}
|
|
||||||
|
|
||||||
vocab.special_token_trie.add(word);
|
|
||||||
vocab.special_token_to_id[word] = token_id;
|
|
||||||
|
|
||||||
if (vocab.max_special_token_length < word.size()) {
|
|
||||||
vocab.max_special_token_length = word.size();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4263,6 +4265,10 @@ llama_token llama_token_nl() {
|
||||||
return 13;
|
return 13;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_add_special_token(struct llama_model * model, const char * token, llama_token token_id) {
|
||||||
|
model->vocab.add_special_token(token, token_id);
|
||||||
|
}
|
||||||
|
|
||||||
struct llama_timings llama_get_timings(struct llama_context * ctx) {
|
struct llama_timings llama_get_timings(struct llama_context * ctx) {
|
||||||
struct llama_timings result = {
|
struct llama_timings result = {
|
||||||
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us,
|
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us,
|
||||||
|
|
5
llama.h
5
llama.h
|
@ -373,6 +373,11 @@ extern "C" {
|
||||||
LLAMA_API llama_token llama_token_eos(); // end-of-sentence
|
LLAMA_API llama_token llama_token_eos(); // end-of-sentence
|
||||||
LLAMA_API llama_token llama_token_nl(); // next-line
|
LLAMA_API llama_token llama_token_nl(); // next-line
|
||||||
|
|
||||||
|
LLAMA_API void llama_add_special_token(
|
||||||
|
struct llama_model * model,
|
||||||
|
const char * token,
|
||||||
|
llama_token token_id);
|
||||||
|
|
||||||
// Grammar
|
// Grammar
|
||||||
//
|
//
|
||||||
LLAMA_API struct llama_grammar * llama_grammar_init(
|
LLAMA_API struct llama_grammar * llama_grammar_init(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue