bugfix: assertions, wrong special token list

This commit is contained in:
jaime-m-p 2024-06-01 20:27:32 +02:00
parent 3ead1b9757
commit 33de247948

View file

@ -4903,16 +4903,19 @@ static void llm_load_vocab(
return false; return false;
}; };
auto _set_token_attrib = [&vocab] (const std::string & token, llama_token_attrib attrib, bool value) { auto _set_tokenid_attrib = [&] (const llama_vocab::id id, llama_token_attrib attrib, bool value) {
llama_vocab::id id = vocab.token_to_id.at(token);
uint32_t attribs = vocab.id_to_token[id].attribs; uint32_t attribs = vocab.id_to_token[id].attribs;
attribs = value ? (attribs | attrib) : (attribs & ~attrib); attribs = value ? (attribs | attrib) : (attribs & ~attrib);
vocab.id_to_token[id].attribs = (llama_token_attrib) attribs; vocab.id_to_token[id].attribs = (llama_token_attrib) attribs;
}; };
auto _set_token_attrib = [&] (const std::string & token, llama_token_attrib attrib, bool value) {
_set_tokenid_attrib(vocab.token_to_id.at(token), attrib, value);
};
if (_contains_any({"phi-3", "phi3"})) { if (_contains_any({"phi-3", "phi3"})) {
for (auto token : vocab.cache_token_to_piece_special) { for (auto id : vocab.cache_special_tokens) {
_set_token_attrib(token, LLAMA_TOKEN_ATTRIB_RSTRIP, true); _set_tokenid_attrib(id, LLAMA_TOKEN_ATTRIB_RSTRIP, true);
} }
for (auto token : {"</s>"}) { for (auto token : {"</s>"}) {
_set_token_attrib(token, LLAMA_TOKEN_ATTRIB_RSTRIP, true); _set_token_attrib(token, LLAMA_TOKEN_ATTRIB_RSTRIP, true);
@ -13312,7 +13315,8 @@ struct fragment_buffer_variant {
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) { static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
// for each special token // for each special token
for (const llama_vocab::id special_id : vocab.cache_special_tokens) { for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
const auto & special_token = vocab.id_to_token[special_id].text; const auto & data = vocab.id_to_token[special_id];
const auto & special_token = data.text;
// for each text fragment // for each text fragment
std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin(); std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
@ -13349,13 +13353,22 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
if (match > raw_text_base_offset) { if (match > raw_text_base_offset) {
// left // left
const int64_t left_reminder_offset = raw_text_base_offset + 0; const int64_t left_reminder_offset = raw_text_base_offset + 0;
const int64_t left_reminder_length = match - raw_text_base_offset; int64_t left_reminder_length = match - raw_text_base_offset;
buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
if (data.attribs & LLAMA_TOKEN_ATTRIB_LSTRIP) {
while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) {
left_reminder_length--;
}
}
if (left_reminder_length > 0) {
buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
it++;
}
#ifdef PRETOKENIZERDEBUG #ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str()); LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
#endif #endif
it++;
} }
// special token // special token