refactor: respect special token from metadata

when vocab.type is SPM, we will confirm the linefeed_id by searching the
char, and use special_pad_id instead if not found.

the special_*_id are usually record in metadata, to ensure the
special_pad_id can be used correctly, we need to obtain it from metadata
first and then perform the above confirmation logic.

Signed-off-by: thxCode <thxcode0824@gmail.com>
This commit is contained in:
thxCode 2024-08-06 17:10:21 +08:00
parent bb55b19c04
commit 0b90345749

View file

@ -5517,6 +5517,86 @@ static void llm_load_vocab(
} }
GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size()); GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size());
// special tokens
{
const std::vector<std::pair<enum llm_kv, int32_t &>> special_token_types = {
{ LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id },
{ LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id },
{ LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id },
{ LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id },
{ LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id },
{ LLM_KV_TOKENIZER_CLS_ID, vocab.special_cls_id },
{ LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id },
{ LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_prefix_id },
{ LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_suffix_id },
{ LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id },
{ LLM_KV_TOKENIZER_EOT_ID, vocab.special_eot_id },
{ LLM_KV_TOKENIZER_EOM_ID, vocab.special_eom_id },
};
for (const auto & it : special_token_types) {
const std::string & key = kv(std::get<0>(it));
int32_t & id = std::get<1>(it);
uint32_t new_id;
if (!ml.get_key(std::get<0>(it), new_id, false)) {
continue;
}
if (new_id >= vocab.id_to_token.size()) {
LLAMA_LOG_WARN("%s: bad special token: '%s' = %ud, using default id %d\n",
__func__, key.c_str(), new_id, id);
} else {
id = new_id;
}
}
// Handle add_bos_token and add_eos_token
{
bool temp = true;
if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
vocab.tokenizer_add_bos = temp;
}
if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
vocab.tokenizer_add_eos = temp;
}
}
// find EOT token: "<|eot_id|>", "<|im_end|>", "<end_of_turn>", etc.
//
// TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOT_ID
// for now, we apply this workaround to find the EOT token based on its text
if (vocab.special_eot_id == -1) {
for (const auto & t : vocab.token_to_id) {
if (
// TODO: gemma "<end_of_turn>" is exported as a normal token, so the following check does not work
// need to fix convert script
//vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL &&
(t.first == "<|eot_id|>" ||
t.first == "<|im_end|>" ||
t.first == "<|end|>" ||
t.first == "<end_of_turn>" ||
t.first == "<|endoftext|>"
)
) {
vocab.special_eot_id = t.second;
break;
}
}
}
// find EOM token: "<|eom_id|>"
//
// TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOM_ID
// for now, we apply this workaround to find the EOM token based on its text
if (vocab.special_eom_id == -1) {
const auto & t = vocab.token_to_id.find("<|eom_id|>");
if (t != vocab.token_to_id.end()) {
vocab.special_eom_id = t->second;
}
}
}
// determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
// For Fill-In-the-Middle (FIM)/infill models which where converted // For Fill-In-the-Middle (FIM)/infill models which where converted
@ -5571,86 +5651,6 @@ static void llm_load_vocab(
vocab.linefeed_id = ids[0]; vocab.linefeed_id = ids[0];
} }
// special tokens
{
const std::vector<std::pair<enum llm_kv, int32_t &>> special_token_types = {
{ LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id },
{ LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id },
{ LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id },
{ LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id },
{ LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id },
{ LLM_KV_TOKENIZER_CLS_ID, vocab.special_cls_id },
{ LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id },
{ LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_prefix_id },
{ LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_suffix_id },
{ LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id },
{ LLM_KV_TOKENIZER_EOT_ID, vocab.special_eot_id },
{ LLM_KV_TOKENIZER_EOM_ID, vocab.special_eom_id },
};
for (const auto & it : special_token_types) {
const std::string & key = kv(std::get<0>(it));
int32_t & id = std::get<1>(it);
uint32_t new_id;
if (!ml.get_key(std::get<0>(it), new_id, false)) {
continue;
}
if (new_id >= vocab.id_to_token.size()) {
LLAMA_LOG_WARN("%s: bad special token: '%s' = %ud, using default id %d\n",
__func__, key.c_str(), new_id, id);
} else {
id = new_id;
}
}
// Handle add_bos_token and add_eos_token
{
bool temp = true;
if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
vocab.tokenizer_add_bos = temp;
}
if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
vocab.tokenizer_add_eos = temp;
}
}
// find EOT token: "<|eot_id|>", "<|im_end|>", "<end_of_turn>", etc.
//
// TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOT_ID
// for now, we apply this workaround to find the EOT token based on its text
if (vocab.special_eot_id == -1) {
for (const auto & t : vocab.token_to_id) {
if (
// TODO: gemma "<end_of_turn>" is exported as a normal token, so the following check does not work
// need to fix convert script
//vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL &&
(t.first == "<|eot_id|>" ||
t.first == "<|im_end|>" ||
t.first == "<|end|>" ||
t.first == "<end_of_turn>" ||
t.first == "<|endoftext|>"
)
) {
vocab.special_eot_id = t.second;
break;
}
}
}
// find EOM token: "<|eom_id|>"
//
// TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOM_ID
// for now, we apply this workaround to find the EOM token based on its text
if (vocab.special_eom_id == -1) {
const auto & t = vocab.token_to_id.find("<|eom_id|>");
if (t != vocab.token_to_id.end()) {
vocab.special_eom_id = t->second;
}
}
}
// build special tokens cache // build special tokens cache
{ {
for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) { for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {