Implement 'rstrip' properly

This commit is contained in:
jaime-m-p 2024-06-01 20:30:42 +02:00
parent 33de247948
commit ada961cec2

View file

@ -13377,16 +13377,25 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
// right // right
if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) { if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
const int64_t right_reminder_offset = match + special_token.length(); int64_t right_reminder_offset = match + special_token.length();
const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length()); int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
if (data.attribs & LLAMA_TOKEN_ATTRIB_RSTRIP) {
while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) {
right_reminder_offset++;
right_reminder_length--;
}
}
if (right_reminder_length > 0) {
buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
it++;
}
#ifdef PRETOKENIZERDEBUG #ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str()); LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
#endif #endif
it++;
if (source == 0) { if (source == 0) {
buffer.erase_after(buffer.before_begin()); buffer.erase_after(buffer.before_begin());
} else { } else {
@ -13432,9 +13441,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
// tokenizer.encode('', add_special_tokens=True) returns [1] // tokenizer.encode('', add_special_tokens=True) returns [1]
// tokenizer.encode('', add_special_tokens=False) returns [] // tokenizer.encode('', add_special_tokens=False) returns []
static const bool rtrim = true; //TODO: as param
bool is_prev_special = false; bool is_prev_special = false;
bool special_token_rtrim = false;
if (add_special && vocab.special_add_bos != 0) { if (add_special && vocab.special_add_bos != 0) {
GGML_ASSERT(vocab.special_bos_id != -1); GGML_ASSERT(vocab.special_bos_id != -1);
@ -13444,25 +13451,8 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
for (const auto & fragment : fragment_buffer) { for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
// without adding this leading whitespace, we do not get the same results as the original tokenizer
// TODO: It's likely possible to get rid of this string copy entirely
// by modifying llm_tokenizer_x to operate with string offsets like pre-tokenizer
// and passing 'add space prefix' as bool argument
//
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
if (special_token_rtrim) {
size_t num_whitespaces = 0;
while (isspace(raw_text[num_whitespaces])) {
num_whitespaces++;
}
if (num_whitespaces == raw_text.size()) {
continue; // skip if all whitespaces
}
raw_text = raw_text.substr(num_whitespaces);
}
if (vocab.add_space_prefix) { if (vocab.add_space_prefix) {
if (!output.size() || is_prev_special) { // prefix with space if first token if (!output.size() || is_prev_special) { // prefix with space if first token
raw_text = " " + raw_text; raw_text = " " + raw_text;
@ -13478,11 +13468,6 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token); output.push_back(fragment.token);
is_prev_special = true; is_prev_special = true;
// phi-3 special tokens without rtrim, works fine for llama-spm too
special_token_rtrim = rtrim
&& fragment.token != vocab.special_bos_id
&& fragment.token != vocab.special_unk_id
&& fragment.token != vocab.special_eos_id;
} }
} }