Move 'add_special_bos/eos' logic to llm_tokenizer_bpe

This commit is contained in:
jaime-m-p 2024-05-25 04:32:22 +02:00
parent 6168399112
commit 0794b77714

View file

@ -12274,8 +12274,11 @@ struct llm_bigram_bpe {
struct llm_tokenizer_bpe {
llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {
GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
special_add_bos = vocab.special_add_bos == 1;
special_add_eos = vocab.special_add_eos == 1;
switch (vocab.type_pre) {
case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
special_add_bos = true;
ignore_merges = true;
regex_exprs = {
// original regex from tokenizer.json
@ -12362,6 +12365,39 @@ struct llm_tokenizer_bpe {
}
}
bool add_special_bos(std::vector<llama_vocab::id> & output) const {
if (special_add_bos) {
GGML_ASSERT(vocab.special_bos_id != -1);
output.push_back(vocab.special_bos_id);
return true;
}
return false;
}
bool add_special_eos(std::vector<llama_vocab::id> & output) const {
if (special_add_eos) {
GGML_ASSERT(vocab.special_eos_id != -1);
output.push_back(vocab.special_eos_id);
return true;
}
return false;
}
void check_add_special(const std::vector<llama_vocab::id> & output) const {
if (special_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
LLAMA_LOG_WARN(
"%s: Added a BOS token to the prompt as specified by the model but the prompt "
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
"Are you sure this is what you want?\n", __FUNCTION__);
}
if (special_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) {
LLAMA_LOG_WARN(
"%s: Added a EOS token to the prompt as specified by the model but the prompt "
"also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. "
"Are you sure this is what you want?\n", __FUNCTION__);
}
}
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
int final_prev_index = -1;
@ -12499,8 +12535,10 @@ private:
const llama_vocab & vocab;
bool ignore_merges = false;
std::vector<std::string> regex_exprs;
bool special_add_bos = false;
bool special_add_eos = false;
bool ignore_merges = false;
std::vector<llm_symbol> symbols;
std::vector<llm_symbol> symbols_final;
@ -12851,9 +12889,8 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
{
llm_tokenizer_bpe tokenizer(vocab);
if (add_special && vocab.special_add_bos != 0) {
GGML_ASSERT(vocab.special_bos_id != -1);
output.push_back(vocab.special_bos_id);
if (add_special) {
tokenizer.add_special_bos(output);
}
for (const auto & fragment : fragment_buffer) {
@ -12869,16 +12906,9 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
}
}
if (add_special && vocab.special_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
LLAMA_LOG_WARN(
"%s: Added a BOS token to the prompt as specified by the model but the prompt "
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
"Are you sure this is what you want?\n", __FUNCTION__);
}
if (add_special && vocab.special_add_eos == 1) {
GGML_ASSERT(vocab.special_add_eos != -1);
output.push_back(vocab.special_eos_id);
if (add_special) {
tokenizer.add_special_eos(output);
tokenizer.check_add_special(output);
}
} break;
case LLAMA_VOCAB_TYPE_WPM: