Discard all tokens when no matching found

This commit is contained in:
jaime-m-p 2024-05-27 20:17:01 +02:00
parent 117b091069
commit f3f6c0a930

View file

@ -12660,7 +12660,7 @@ struct llm_tokenizer_wpm {
llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {} llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) { void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
auto * token_map = &vocab.token_to_id; const auto & token_map = vocab.token_to_id;
// normalize and split by whitespace // normalize and split by whitespace
std::vector<std::string> words = preprocess(text); std::vector<std::string> words = preprocess(text);
@ -12675,36 +12675,34 @@ struct llm_tokenizer_wpm {
} }
// prepend phantom space // prepend phantom space
std::string word1 = "\xe2\x96\x81" + word; const std::string word1 = "\xe2\x96\x81" + word;
int n = word1.size(); const int n = word1.size();
const size_t current_tokens = output.size();
// we're at the start of a new word // we're at the start of a new word
int i = 0;
bool match_any = false;
// move through character position in word // move through character position in word
while (i < n) { for (int i = 0; i < n; ++i) {
// loop through possible match length // loop through possible match length
bool match = false; bool match = false;
for (int j = n; j > i; j--) { for (int j = n; j > i; j--) {
auto it = token_map->find(word1.substr(i, j - i)); auto it = token_map.find(word1.substr(i, j - i));
if (it != token_map->end()) { if (it != token_map.end()) {
output.push_back(it->second); output.push_back(it->second);
match = true; match = true;
match_any = true; i = j - 1;
i = j;
break; break;
} }
} }
// must be an unknown character if (!match) { // discard all
if (!match) { output.resize(current_tokens);
i++; break; // and discard next tokens
} }
} }
// we didn't find any matches for this word // we didn't find any matches for this word
if (!match_any) { if (current_tokens == output.size()) {
output.push_back(vocab.special_unk_id); output.push_back(vocab.special_unk_id);
} }
} }