refactor: respect special token from metadata
when vocab.type is SPM, we will confirm the linefeed_id by searching the char, and use special_pad_id instead if not found. the special_*_id are usually record in metadata, to ensure the special_pad_id can be used correctly, we need to obtain it from metadata first and then perform the above confirmation logic. Signed-off-by: thxCode <thxcode0824@gmail.com>
This commit is contained in:
parent
bb55b19c04
commit
0b90345749
1 changed files with 80 additions and 80 deletions
160
src/llama.cpp
160
src/llama.cpp
|
@ -5517,6 +5517,86 @@ static void llm_load_vocab(
|
||||||
}
|
}
|
||||||
GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size());
|
GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size());
|
||||||
|
|
||||||
|
// special tokens
|
||||||
|
{
|
||||||
|
const std::vector<std::pair<enum llm_kv, int32_t &>> special_token_types = {
|
||||||
|
{ LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id },
|
||||||
|
{ LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id },
|
||||||
|
{ LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id },
|
||||||
|
{ LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id },
|
||||||
|
{ LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id },
|
||||||
|
{ LLM_KV_TOKENIZER_CLS_ID, vocab.special_cls_id },
|
||||||
|
{ LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id },
|
||||||
|
{ LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_prefix_id },
|
||||||
|
{ LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_suffix_id },
|
||||||
|
{ LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id },
|
||||||
|
{ LLM_KV_TOKENIZER_EOT_ID, vocab.special_eot_id },
|
||||||
|
{ LLM_KV_TOKENIZER_EOM_ID, vocab.special_eom_id },
|
||||||
|
};
|
||||||
|
|
||||||
|
for (const auto & it : special_token_types) {
|
||||||
|
const std::string & key = kv(std::get<0>(it));
|
||||||
|
int32_t & id = std::get<1>(it);
|
||||||
|
|
||||||
|
uint32_t new_id;
|
||||||
|
if (!ml.get_key(std::get<0>(it), new_id, false)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (new_id >= vocab.id_to_token.size()) {
|
||||||
|
LLAMA_LOG_WARN("%s: bad special token: '%s' = %ud, using default id %d\n",
|
||||||
|
__func__, key.c_str(), new_id, id);
|
||||||
|
} else {
|
||||||
|
id = new_id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle add_bos_token and add_eos_token
|
||||||
|
{
|
||||||
|
bool temp = true;
|
||||||
|
|
||||||
|
if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
|
||||||
|
vocab.tokenizer_add_bos = temp;
|
||||||
|
}
|
||||||
|
if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
|
||||||
|
vocab.tokenizer_add_eos = temp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// find EOT token: "<|eot_id|>", "<|im_end|>", "<end_of_turn>", etc.
|
||||||
|
//
|
||||||
|
// TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOT_ID
|
||||||
|
// for now, we apply this workaround to find the EOT token based on its text
|
||||||
|
if (vocab.special_eot_id == -1) {
|
||||||
|
for (const auto & t : vocab.token_to_id) {
|
||||||
|
if (
|
||||||
|
// TODO: gemma "<end_of_turn>" is exported as a normal token, so the following check does not work
|
||||||
|
// need to fix convert script
|
||||||
|
//vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL &&
|
||||||
|
(t.first == "<|eot_id|>" ||
|
||||||
|
t.first == "<|im_end|>" ||
|
||||||
|
t.first == "<|end|>" ||
|
||||||
|
t.first == "<end_of_turn>" ||
|
||||||
|
t.first == "<|endoftext|>"
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
vocab.special_eot_id = t.second;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// find EOM token: "<|eom_id|>"
|
||||||
|
//
|
||||||
|
// TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOM_ID
|
||||||
|
// for now, we apply this workaround to find the EOM token based on its text
|
||||||
|
if (vocab.special_eom_id == -1) {
|
||||||
|
const auto & t = vocab.token_to_id.find("<|eom_id|>");
|
||||||
|
if (t != vocab.token_to_id.end()) {
|
||||||
|
vocab.special_eom_id = t->second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
|
// determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
|
||||||
if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
|
if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
|
||||||
// For Fill-In-the-Middle (FIM)/infill models which where converted
|
// For Fill-In-the-Middle (FIM)/infill models which where converted
|
||||||
|
@ -5571,86 +5651,6 @@ static void llm_load_vocab(
|
||||||
vocab.linefeed_id = ids[0];
|
vocab.linefeed_id = ids[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
// special tokens
|
|
||||||
{
|
|
||||||
const std::vector<std::pair<enum llm_kv, int32_t &>> special_token_types = {
|
|
||||||
{ LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id },
|
|
||||||
{ LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id },
|
|
||||||
{ LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id },
|
|
||||||
{ LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id },
|
|
||||||
{ LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id },
|
|
||||||
{ LLM_KV_TOKENIZER_CLS_ID, vocab.special_cls_id },
|
|
||||||
{ LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id },
|
|
||||||
{ LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_prefix_id },
|
|
||||||
{ LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_suffix_id },
|
|
||||||
{ LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id },
|
|
||||||
{ LLM_KV_TOKENIZER_EOT_ID, vocab.special_eot_id },
|
|
||||||
{ LLM_KV_TOKENIZER_EOM_ID, vocab.special_eom_id },
|
|
||||||
};
|
|
||||||
|
|
||||||
for (const auto & it : special_token_types) {
|
|
||||||
const std::string & key = kv(std::get<0>(it));
|
|
||||||
int32_t & id = std::get<1>(it);
|
|
||||||
|
|
||||||
uint32_t new_id;
|
|
||||||
if (!ml.get_key(std::get<0>(it), new_id, false)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (new_id >= vocab.id_to_token.size()) {
|
|
||||||
LLAMA_LOG_WARN("%s: bad special token: '%s' = %ud, using default id %d\n",
|
|
||||||
__func__, key.c_str(), new_id, id);
|
|
||||||
} else {
|
|
||||||
id = new_id;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle add_bos_token and add_eos_token
|
|
||||||
{
|
|
||||||
bool temp = true;
|
|
||||||
|
|
||||||
if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
|
|
||||||
vocab.tokenizer_add_bos = temp;
|
|
||||||
}
|
|
||||||
if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
|
|
||||||
vocab.tokenizer_add_eos = temp;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// find EOT token: "<|eot_id|>", "<|im_end|>", "<end_of_turn>", etc.
|
|
||||||
//
|
|
||||||
// TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOT_ID
|
|
||||||
// for now, we apply this workaround to find the EOT token based on its text
|
|
||||||
if (vocab.special_eot_id == -1) {
|
|
||||||
for (const auto & t : vocab.token_to_id) {
|
|
||||||
if (
|
|
||||||
// TODO: gemma "<end_of_turn>" is exported as a normal token, so the following check does not work
|
|
||||||
// need to fix convert script
|
|
||||||
//vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL &&
|
|
||||||
(t.first == "<|eot_id|>" ||
|
|
||||||
t.first == "<|im_end|>" ||
|
|
||||||
t.first == "<|end|>" ||
|
|
||||||
t.first == "<end_of_turn>" ||
|
|
||||||
t.first == "<|endoftext|>"
|
|
||||||
)
|
|
||||||
) {
|
|
||||||
vocab.special_eot_id = t.second;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// find EOM token: "<|eom_id|>"
|
|
||||||
//
|
|
||||||
// TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOM_ID
|
|
||||||
// for now, we apply this workaround to find the EOM token based on its text
|
|
||||||
if (vocab.special_eom_id == -1) {
|
|
||||||
const auto & t = vocab.token_to_id.find("<|eom_id|>");
|
|
||||||
if (t != vocab.token_to_id.end()) {
|
|
||||||
vocab.special_eom_id = t->second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// build special tokens cache
|
// build special tokens cache
|
||||||
{
|
{
|
||||||
for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
|
for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue