Add validation for special token ids to llama.cpp
Small optimization for llama_byte_to_token SPM mode
This commit is contained in:
parent
8402566a7c
commit
d1075f6e08
1 changed files with 25 additions and 9 deletions
34
llama.cpp
34
llama.cpp
|
@ -2235,15 +2235,32 @@ static void llm_load_vocab(
|
||||||
if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
|
if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
|
||||||
vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
|
vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
|
||||||
} else {
|
} else {
|
||||||
vocab.linefeed_id = llama_tokenize_internal(vocab, "\u010A", false)[0];
|
const std::vector<int> ids = llama_tokenize_internal(vocab, "\u010A", false);
|
||||||
|
GGML_ASSERT(ids.size() == 1 && "model vocab missing newline token");
|
||||||
|
vocab.linefeed_id = ids[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
// special tokens
|
// special tokens
|
||||||
GGUF_GET_KEY(ctx, vocab.special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_BOS_ID));
|
{
|
||||||
GGUF_GET_KEY(ctx, vocab.special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_EOS_ID));
|
const std::vector<std::tuple<enum llm_kv, int32_t *>> special_token_types = {
|
||||||
GGUF_GET_KEY(ctx, vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID));
|
{ LLM_KV_TOKENIZER_BOS_ID, &vocab.special_bos_id },
|
||||||
GGUF_GET_KEY(ctx, vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID));
|
{ LLM_KV_TOKENIZER_EOS_ID, &vocab.special_eos_id },
|
||||||
GGUF_GET_KEY(ctx, vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID));
|
{ LLM_KV_TOKENIZER_UNK_ID, &vocab.special_unk_id },
|
||||||
|
{ LLM_KV_TOKENIZER_SEP_ID, &vocab.special_sep_id },
|
||||||
|
{ LLM_KV_TOKENIZER_PAD_ID, &vocab.special_pad_id },
|
||||||
|
};
|
||||||
|
for (auto & it : special_token_types ) {
|
||||||
|
int32_t id = -1;
|
||||||
|
const std::string kstr = kv(std::get<0>(it));
|
||||||
|
|
||||||
|
GGUF_GET_KEY(ctx, id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kstr);
|
||||||
|
if (id != -1 && (id < 0 || size_t(id) >= vocab.id_to_token.size())) {
|
||||||
|
LLAMA_LOG_WARN("%s: bad special token value %d for key '%s' -- ignoring\n", __func__, id, kstr.c_str());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
*(std::get<1>(it)) = id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// build special tokens cache
|
// build special tokens cache
|
||||||
{
|
{
|
||||||
|
@ -6084,11 +6101,10 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
|
static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
|
||||||
|
const char * hex = "0123456789ABCDEF";
|
||||||
switch (llama_vocab_get_type(vocab)) {
|
switch (llama_vocab_get_type(vocab)) {
|
||||||
case LLAMA_VOCAB_TYPE_SPM: {
|
case LLAMA_VOCAB_TYPE_SPM: {
|
||||||
char buf[7];
|
const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
|
||||||
int result = snprintf(buf, sizeof(buf), "<0x%02X>", ch);
|
|
||||||
GGML_ASSERT(0 <= result && result < 7);
|
|
||||||
return vocab.token_to_id.at(buf);
|
return vocab.token_to_id.at(buf);
|
||||||
}
|
}
|
||||||
case LLAMA_VOCAB_TYPE_BPE: {
|
case LLAMA_VOCAB_TYPE_BPE: {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue