Fix issues revealed by CI

This commit is contained in:
Igor Pissolati 2023-06-20 01:27:36 -03:00
parent e468e75515
commit ca1fc20508
2 changed files with 25 additions and 24 deletions

View file

@ -16,6 +16,7 @@
#include <vector> #include <vector>
#include <map> #include <map>
#include <unordered_map> #include <unordered_map>
#include <memory>
#include <stdexcept> #include <stdexcept>
#ifdef __has_include #ifdef __has_include
@ -546,7 +547,7 @@ typedef llama_buffer llama_ctx_buffer;
struct llama_trie_node { struct llama_trie_node {
llama_trie_node(): is_terminator(false) {} llama_trie_node(): is_terminator(false) {}
std::unordered_map<char, llama_trie_node*> children; std::unordered_map<char, std::unique_ptr<llama_trie_node>> children;
bool is_terminator; bool is_terminator;
}; };
@ -561,24 +562,24 @@ public:
return; return;
} }
llama_trie_node *ref = root_; llama_trie_node *ref = root_.get();
for (char c : word) { for (char c : word) {
if (ref->children.find(c) == ref->children.end()) { if (ref->children.find(c) == ref->children.end()) {
ref->children[c] = new llama_trie_node(); ref->children[c].reset(new llama_trie_node());
} }
ref = ref->children[c]; ref = ref->children[c].get();
} }
ref->is_terminator = true; ref->is_terminator = true;
} }
// Will look for the words added to the trie within `text`. Output is the boundaries of the words found. // Will look for the words added to the trie within `text`. Output is the boundaries of the words found.
// Note that this trie will match the longest possible word first! // Note that this trie will match the longest possible word first!
std::vector<int> split(const std::string & text) const { std::vector<size_t> split(const std::string & text) const {
std::map<int, llama_trie_node*> states; std::map<size_t, llama_trie_node*> states;
std::vector<int> offsets{0}; std::vector<size_t> offsets{0};
int skip = 0; size_t skip = 0;
for (int current = 0; current < text.size(); current++) { for (size_t current = 0; current < text.size(); current++) {
char current_char = text[current]; char current_char = text[current];
if (skip > 0 && current < skip) { if (skip > 0 && current < skip) {
// Prevents the lookahead for matching twice // Prevents the lookahead for matching twice
@ -592,7 +593,7 @@ public:
// In this case, we already have partial matches (But unfinished) // In this case, we already have partial matches (But unfinished)
for (auto state = states.begin(); state != states.end(); ) { for (auto state = states.begin(); state != states.end(); ) {
int start = state->first; size_t start = state->first;
llama_trie_node *trie_pointer = state->second; llama_trie_node *trie_pointer = state->second;
if (trie_pointer->is_terminator) { if (trie_pointer->is_terminator) {
// This is a final match, we need to reset and // This is a final match, we need to reset and
@ -603,11 +604,11 @@ public:
// Here we are also actively looking for other earlier partial // Here we are also actively looking for other earlier partial
// matches // matches
// "[CLS]", "L", we need to match CLS even if L is special // "[CLS]", "L", we need to match CLS even if L is special
int end = 0; size_t end = 0;
for (const auto & look : states) { for (const auto & look : states) {
int lookstart = look.first; size_t lookstart = look.first;
llama_trie_node *looktrie_pointer = look.second; llama_trie_node *looktrie_pointer = look.second;
int lookahead_index = 0; size_t lookahead_index = 0;
if (lookstart > start) { if (lookstart > start) {
// This partial match is later, we can stop looking // This partial match is later, we can stop looking
break; break;
@ -633,7 +634,7 @@ public:
auto looktrie_pointer_it = looktrie_pointer->children.find(next_char); auto looktrie_pointer_it = looktrie_pointer->children.find(next_char);
while (looktrie_pointer_it != looktrie_pointer->children.end()) { while (looktrie_pointer_it != looktrie_pointer->children.end()) {
looktrie_pointer = looktrie_pointer_it->second; looktrie_pointer = looktrie_pointer_it->second.get();
lookahead_index++; lookahead_index++;
if (looktrie_pointer->is_terminator) { if (looktrie_pointer->is_terminator) {
start = lookstart; start = lookstart;
@ -660,7 +661,7 @@ public:
if (trie_pointer_it != trie_pointer->children.end()) { if (trie_pointer_it != trie_pointer->children.end()) {
// The current character being looked at has a match within the trie // The current character being looked at has a match within the trie
// update the pointer (it will be stored back into states later). // update the pointer (it will be stored back into states later).
trie_pointer = trie_pointer_it->second; trie_pointer = trie_pointer_it->second.get();
states[start] = trie_pointer; states[start] = trie_pointer;
++state; ++state;
} else { } else {
@ -679,18 +680,18 @@ public:
// start keeping track of this partial match. // start keeping track of this partial match.
auto children_it = root_->children.find(current_char); auto children_it = root_->children.find(current_char);
if (current >= skip && children_it != root_->children.end()) { if (current >= skip && children_it != root_->children.end()) {
states[current] = children_it->second; states[current] = children_it->second.get();
} }
} }
// We have a cut at the end with states. // We have a cut at the end with states.
for (const auto & state : states) { for (const auto & state : states) {
int start = state.first; size_t start = state.first;
llama_trie_node *trie_pointer = state.second; llama_trie_node *trie_pointer = state.second;
if (trie_pointer->is_terminator) { if (trie_pointer->is_terminator) {
// This is a final match, we need to reset and // This is a final match, we need to reset and
// store the results in `offsets`. // store the results in `offsets`.
int end = text.size(); size_t end = text.size();
offsets.push_back(start); offsets.push_back(start);
offsets.push_back(end); offsets.push_back(end);
break; break;
@ -702,7 +703,7 @@ public:
} }
private: private:
llama_trie_node *root_; std::unique_ptr<llama_trie_node> root_;
}; };
#endif #endif

View file

@ -281,7 +281,7 @@ struct llama_vocab {
llama_trie special_token_trie; llama_trie special_token_trie;
std::unordered_map<token, id> special_token_to_id; std::unordered_map<token, id> special_token_to_id;
size_t max_special_token_length; size_t max_special_token_length = 0;
}; };
struct llama_model { struct llama_model {
@ -578,7 +578,7 @@ struct llama_file_loader {
vocab.special_token_to_id.reserve(hparams.n_vocab_sp); vocab.special_token_to_id.reserve(hparams.n_vocab_sp);
for (uint32_t i = 0; i < hparams.n_vocab_sp; i++) { for (uint32_t i = 0; i < hparams.n_vocab_sp; i++) {
uint32_t token_id = file.read_u32(); llama_vocab::id token_id = file.read_u32();
const auto & word = vocab.id_to_token[token_id].tok; const auto & word = vocab.id_to_token[token_id].tok;
vocab.special_token_trie.add(word); vocab.special_token_trie.add(word);
@ -2108,9 +2108,9 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
return output; return output;
} }
std::vector<int> offsets = vocab.special_token_trie.split(text); std::vector<size_t> offsets = vocab.special_token_trie.split(text);
int start = 0; size_t start = 0;
for (int end : offsets) { for (size_t end : offsets) {
if (start >= end) { if (start >= end) {
continue; continue;
} }