diff --git a/llama-util.h b/llama-util.h index 9c38bddd0..30a6c0eb5 100644 --- a/llama-util.h +++ b/llama-util.h @@ -16,6 +16,7 @@ #include #include #include +#include #include #ifdef __has_include @@ -546,7 +547,7 @@ typedef llama_buffer llama_ctx_buffer; struct llama_trie_node { llama_trie_node(): is_terminator(false) {} - std::unordered_map children; + std::unordered_map> children; bool is_terminator; }; @@ -561,24 +562,24 @@ public: return; } - llama_trie_node *ref = root_; + llama_trie_node *ref = root_.get(); for (char c : word) { if (ref->children.find(c) == ref->children.end()) { - ref->children[c] = new llama_trie_node(); + ref->children[c].reset(new llama_trie_node()); } - ref = ref->children[c]; + ref = ref->children[c].get(); } ref->is_terminator = true; } // Will look for the words added to the trie within `text`. Output is the boundaries of the words found. // Note that this trie will match the longest possible word first! - std::vector split(const std::string & text) const { - std::map states; - std::vector offsets{0}; + std::vector split(const std::string & text) const { + std::map states; + std::vector offsets{0}; - int skip = 0; - for (int current = 0; current < text.size(); current++) { + size_t skip = 0; + for (size_t current = 0; current < text.size(); current++) { char current_char = text[current]; if (skip > 0 && current < skip) { // Prevents the lookahead for matching twice @@ -592,7 +593,7 @@ public: // In this case, we already have partial matches (But unfinished) for (auto state = states.begin(); state != states.end(); ) { - int start = state->first; + size_t start = state->first; llama_trie_node *trie_pointer = state->second; if (trie_pointer->is_terminator) { // This is a final match, we need to reset and @@ -603,11 +604,11 @@ public: // Here we are also actively looking for other earlier partial // matches // "[CLS]", "L", we need to match CLS even if L is special - int end = 0; + size_t end = 0; for (const auto & look : states) { - int lookstart = look.first; + size_t lookstart = look.first; llama_trie_node *looktrie_pointer = look.second; - int lookahead_index = 0; + size_t lookahead_index = 0; if (lookstart > start) { // This partial match is later, we can stop looking break; @@ -633,7 +634,7 @@ public: auto looktrie_pointer_it = looktrie_pointer->children.find(next_char); while (looktrie_pointer_it != looktrie_pointer->children.end()) { - looktrie_pointer = looktrie_pointer_it->second; + looktrie_pointer = looktrie_pointer_it->second.get(); lookahead_index++; if (looktrie_pointer->is_terminator) { start = lookstart; @@ -660,7 +661,7 @@ public: if (trie_pointer_it != trie_pointer->children.end()) { // The current character being looked at has a match within the trie // update the pointer (it will be stored back into states later). - trie_pointer = trie_pointer_it->second; + trie_pointer = trie_pointer_it->second.get(); states[start] = trie_pointer; ++state; } else { @@ -679,18 +680,18 @@ public: // start keeping track of this partial match. auto children_it = root_->children.find(current_char); if (current >= skip && children_it != root_->children.end()) { - states[current] = children_it->second; + states[current] = children_it->second.get(); } } // We have a cut at the end with states. for (const auto & state : states) { - int start = state.first; + size_t start = state.first; llama_trie_node *trie_pointer = state.second; if (trie_pointer->is_terminator) { // This is a final match, we need to reset and // store the results in `offsets`. - int end = text.size(); + size_t end = text.size(); offsets.push_back(start); offsets.push_back(end); break; @@ -702,7 +703,7 @@ public: } private: - llama_trie_node *root_; + std::unique_ptr root_; }; #endif diff --git a/llama.cpp b/llama.cpp index d7e0b3174..af12931e0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -281,7 +281,7 @@ struct llama_vocab { llama_trie special_token_trie; std::unordered_map special_token_to_id; - size_t max_special_token_length; + size_t max_special_token_length = 0; }; struct llama_model { @@ -578,7 +578,7 @@ struct llama_file_loader { vocab.special_token_to_id.reserve(hparams.n_vocab_sp); for (uint32_t i = 0; i < hparams.n_vocab_sp; i++) { - uint32_t token_id = file.read_u32(); + llama_vocab::id token_id = file.read_u32(); const auto & word = vocab.id_to_token[token_id].tok; vocab.special_token_trie.add(word); @@ -2108,9 +2108,9 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co return output; } - std::vector offsets = vocab.special_token_trie.split(text); - int start = 0; - for (int end : offsets) { + std::vector offsets = vocab.special_token_trie.split(text); + size_t start = 0; + for (size_t end : offsets) { if (start >= end) { continue; }