Fix issues revealed by CI
This commit is contained in:
parent
e468e75515
commit
ca1fc20508
2 changed files with 25 additions and 24 deletions
39
llama-util.h
39
llama-util.h
|
@ -16,6 +16,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <memory>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
#ifdef __has_include
|
#ifdef __has_include
|
||||||
|
@ -546,7 +547,7 @@ typedef llama_buffer llama_ctx_buffer;
|
||||||
struct llama_trie_node {
|
struct llama_trie_node {
|
||||||
llama_trie_node(): is_terminator(false) {}
|
llama_trie_node(): is_terminator(false) {}
|
||||||
|
|
||||||
std::unordered_map<char, llama_trie_node*> children;
|
std::unordered_map<char, std::unique_ptr<llama_trie_node>> children;
|
||||||
bool is_terminator;
|
bool is_terminator;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -561,24 +562,24 @@ public:
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_trie_node *ref = root_;
|
llama_trie_node *ref = root_.get();
|
||||||
for (char c : word) {
|
for (char c : word) {
|
||||||
if (ref->children.find(c) == ref->children.end()) {
|
if (ref->children.find(c) == ref->children.end()) {
|
||||||
ref->children[c] = new llama_trie_node();
|
ref->children[c].reset(new llama_trie_node());
|
||||||
}
|
}
|
||||||
ref = ref->children[c];
|
ref = ref->children[c].get();
|
||||||
}
|
}
|
||||||
ref->is_terminator = true;
|
ref->is_terminator = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Will look for the words added to the trie within `text`. Output is the boundaries of the words found.
|
// Will look for the words added to the trie within `text`. Output is the boundaries of the words found.
|
||||||
// Note that this trie will match the longest possible word first!
|
// Note that this trie will match the longest possible word first!
|
||||||
std::vector<int> split(const std::string & text) const {
|
std::vector<size_t> split(const std::string & text) const {
|
||||||
std::map<int, llama_trie_node*> states;
|
std::map<size_t, llama_trie_node*> states;
|
||||||
std::vector<int> offsets{0};
|
std::vector<size_t> offsets{0};
|
||||||
|
|
||||||
int skip = 0;
|
size_t skip = 0;
|
||||||
for (int current = 0; current < text.size(); current++) {
|
for (size_t current = 0; current < text.size(); current++) {
|
||||||
char current_char = text[current];
|
char current_char = text[current];
|
||||||
if (skip > 0 && current < skip) {
|
if (skip > 0 && current < skip) {
|
||||||
// Prevents the lookahead for matching twice
|
// Prevents the lookahead for matching twice
|
||||||
|
@ -592,7 +593,7 @@ public:
|
||||||
|
|
||||||
// In this case, we already have partial matches (But unfinished)
|
// In this case, we already have partial matches (But unfinished)
|
||||||
for (auto state = states.begin(); state != states.end(); ) {
|
for (auto state = states.begin(); state != states.end(); ) {
|
||||||
int start = state->first;
|
size_t start = state->first;
|
||||||
llama_trie_node *trie_pointer = state->second;
|
llama_trie_node *trie_pointer = state->second;
|
||||||
if (trie_pointer->is_terminator) {
|
if (trie_pointer->is_terminator) {
|
||||||
// This is a final match, we need to reset and
|
// This is a final match, we need to reset and
|
||||||
|
@ -603,11 +604,11 @@ public:
|
||||||
// Here we are also actively looking for other earlier partial
|
// Here we are also actively looking for other earlier partial
|
||||||
// matches
|
// matches
|
||||||
// "[CLS]", "L", we need to match CLS even if L is special
|
// "[CLS]", "L", we need to match CLS even if L is special
|
||||||
int end = 0;
|
size_t end = 0;
|
||||||
for (const auto & look : states) {
|
for (const auto & look : states) {
|
||||||
int lookstart = look.first;
|
size_t lookstart = look.first;
|
||||||
llama_trie_node *looktrie_pointer = look.second;
|
llama_trie_node *looktrie_pointer = look.second;
|
||||||
int lookahead_index = 0;
|
size_t lookahead_index = 0;
|
||||||
if (lookstart > start) {
|
if (lookstart > start) {
|
||||||
// This partial match is later, we can stop looking
|
// This partial match is later, we can stop looking
|
||||||
break;
|
break;
|
||||||
|
@ -633,7 +634,7 @@ public:
|
||||||
|
|
||||||
auto looktrie_pointer_it = looktrie_pointer->children.find(next_char);
|
auto looktrie_pointer_it = looktrie_pointer->children.find(next_char);
|
||||||
while (looktrie_pointer_it != looktrie_pointer->children.end()) {
|
while (looktrie_pointer_it != looktrie_pointer->children.end()) {
|
||||||
looktrie_pointer = looktrie_pointer_it->second;
|
looktrie_pointer = looktrie_pointer_it->second.get();
|
||||||
lookahead_index++;
|
lookahead_index++;
|
||||||
if (looktrie_pointer->is_terminator) {
|
if (looktrie_pointer->is_terminator) {
|
||||||
start = lookstart;
|
start = lookstart;
|
||||||
|
@ -660,7 +661,7 @@ public:
|
||||||
if (trie_pointer_it != trie_pointer->children.end()) {
|
if (trie_pointer_it != trie_pointer->children.end()) {
|
||||||
// The current character being looked at has a match within the trie
|
// The current character being looked at has a match within the trie
|
||||||
// update the pointer (it will be stored back into states later).
|
// update the pointer (it will be stored back into states later).
|
||||||
trie_pointer = trie_pointer_it->second;
|
trie_pointer = trie_pointer_it->second.get();
|
||||||
states[start] = trie_pointer;
|
states[start] = trie_pointer;
|
||||||
++state;
|
++state;
|
||||||
} else {
|
} else {
|
||||||
|
@ -679,18 +680,18 @@ public:
|
||||||
// start keeping track of this partial match.
|
// start keeping track of this partial match.
|
||||||
auto children_it = root_->children.find(current_char);
|
auto children_it = root_->children.find(current_char);
|
||||||
if (current >= skip && children_it != root_->children.end()) {
|
if (current >= skip && children_it != root_->children.end()) {
|
||||||
states[current] = children_it->second;
|
states[current] = children_it->second.get();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// We have a cut at the end with states.
|
// We have a cut at the end with states.
|
||||||
for (const auto & state : states) {
|
for (const auto & state : states) {
|
||||||
int start = state.first;
|
size_t start = state.first;
|
||||||
llama_trie_node *trie_pointer = state.second;
|
llama_trie_node *trie_pointer = state.second;
|
||||||
if (trie_pointer->is_terminator) {
|
if (trie_pointer->is_terminator) {
|
||||||
// This is a final match, we need to reset and
|
// This is a final match, we need to reset and
|
||||||
// store the results in `offsets`.
|
// store the results in `offsets`.
|
||||||
int end = text.size();
|
size_t end = text.size();
|
||||||
offsets.push_back(start);
|
offsets.push_back(start);
|
||||||
offsets.push_back(end);
|
offsets.push_back(end);
|
||||||
break;
|
break;
|
||||||
|
@ -702,7 +703,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
llama_trie_node *root_;
|
std::unique_ptr<llama_trie_node> root_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
10
llama.cpp
10
llama.cpp
|
@ -281,7 +281,7 @@ struct llama_vocab {
|
||||||
|
|
||||||
llama_trie special_token_trie;
|
llama_trie special_token_trie;
|
||||||
std::unordered_map<token, id> special_token_to_id;
|
std::unordered_map<token, id> special_token_to_id;
|
||||||
size_t max_special_token_length;
|
size_t max_special_token_length = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_model {
|
struct llama_model {
|
||||||
|
@ -578,7 +578,7 @@ struct llama_file_loader {
|
||||||
vocab.special_token_to_id.reserve(hparams.n_vocab_sp);
|
vocab.special_token_to_id.reserve(hparams.n_vocab_sp);
|
||||||
|
|
||||||
for (uint32_t i = 0; i < hparams.n_vocab_sp; i++) {
|
for (uint32_t i = 0; i < hparams.n_vocab_sp; i++) {
|
||||||
uint32_t token_id = file.read_u32();
|
llama_vocab::id token_id = file.read_u32();
|
||||||
const auto & word = vocab.id_to_token[token_id].tok;
|
const auto & word = vocab.id_to_token[token_id].tok;
|
||||||
|
|
||||||
vocab.special_token_trie.add(word);
|
vocab.special_token_trie.add(word);
|
||||||
|
@ -2108,9 +2108,9 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> offsets = vocab.special_token_trie.split(text);
|
std::vector<size_t> offsets = vocab.special_token_trie.split(text);
|
||||||
int start = 0;
|
size_t start = 0;
|
||||||
for (int end : offsets) {
|
for (size_t end : offsets) {
|
||||||
if (start >= end) {
|
if (start >= end) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue