Fix special tokens rtrim

This commit is contained in:
jaime-m-p 2024-05-19 00:08:35 +02:00
parent 5b61c04223
commit dd0d1590f6

View file

@ -4582,7 +4582,8 @@ static void llm_load_vocab(
(t.first == "<|eot_id|>" ||
t.first == "<|im_end|>" ||
t.first == "<|end|>" ||
t.first == "<end_of_turn>"
t.first == "<end_of_turn>" ||
t.first == "<|endoftext|>"
)
) {
vocab.special_eot_id = t.second;
@ -12800,6 +12801,10 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
output.push_back(vocab.special_bos_id);
}
static const bool rtrim = true; //TODO: as param
bool is_prev_special = false;
bool special_token_rtrim = false;
for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
// without adding this leading whitespace, we do not get the same results as the original tokenizer
@ -12809,9 +12814,21 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
// and passing 'add space prefix' as bool argument
//
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
if (&fragment == &fragment_buffer.front()) {
if (special_token_rtrim) {
uint num_whitespaces = 0;
while (isspace(raw_text[num_whitespaces])) {
num_whitespaces++;
}
if(num_whitespaces == raw_text.size()) {
continue; // skip if all whitespaces
}
raw_text = raw_text.substr(num_whitespaces);
}
if(vocab.add_space_prefix) {
raw_text = " " + raw_text; // prefix with space if the first token is not special
if (!output.size() || is_prev_special) { // prefix with space if first token
raw_text = " " + raw_text;
}
}
@ -12823,6 +12840,12 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
tokenizer.tokenize(raw_text, output);
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token);
is_prev_special = true;
// phi-3 special tokens without rtrim, works fine for llama-spm too
special_token_rtrim = rtrim
&& fragment.token != vocab.special_bos_id
&& fragment.token != vocab.special_unk_id
&& fragment.token != vocab.special_eos_id;
}
}