llama : default special tokens based on vocab type

This commit is contained in:
Georgi Gerganov 2023-08-23 21:39:09 +03:00
parent 8c6d3939c7
commit 630d8b408a
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -1654,9 +1654,17 @@ static void llm_load_vocab(
if (tokenizer_name == "llama") { if (tokenizer_name == "llama") {
vocab.type = LLAMA_VOCAB_TYPE_SPM; vocab.type = LLAMA_VOCAB_TYPE_SPM;
// default special tokens
vocab.special_bos_id = 1;
vocab.special_eos_id = 2;
vocab.special_unk_id = 0;
vocab.special_sep_id = -1;
vocab.special_pad_id = -1;
} else if (tokenizer_name == "gpt2") { } else if (tokenizer_name == "gpt2") {
vocab.type = LLAMA_VOCAB_TYPE_BPE; vocab.type = LLAMA_VOCAB_TYPE_BPE;
// read bpe merges and populate bpe ranks
const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
if (merges_keyidx == -1) { if (merges_keyidx == -1) {
throw std::runtime_error("cannot find tokenizer merges in model file\n"); throw std::runtime_error("cannot find tokenizer merges in model file\n");
@ -1677,12 +1685,19 @@ static void llm_load_vocab(
second = word.substr(pos + 1); second = word.substr(pos + 1);
} }
// populate bpe ranks
vocab.bpe_ranks.emplace(std::make_pair(first, second), i); vocab.bpe_ranks.emplace(std::make_pair(first, second), i);
} }
// default special tokens
vocab.special_bos_id = 11;
vocab.special_eos_id = 11;
vocab.special_unk_id = -1;
vocab.special_sep_id = -1;
vocab.special_pad_id = -1;
} else { } else {
LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str()); LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__);
vocab.type = LLAMA_VOCAB_TYPE_SPM; vocab.type = LLAMA_VOCAB_TYPE_SPM;
} }
} }