rwkv : speed-up tokenization using trie
This commit is contained in:
parent
7f2ef56639
commit
7004323ecd
1 changed files with 35 additions and 33 deletions
|
@ -58,17 +58,17 @@ struct naive_trie {
|
||||||
auto res = children.find(c);
|
auto res = children.find(c);
|
||||||
if (res != children.end()) {
|
if (res != children.end()) {
|
||||||
return res->second.get_longest_prefix(key, len, offset + 1);
|
return res->second.get_longest_prefix(key, len, offset + 1);
|
||||||
} else {
|
|
||||||
return std::make_pair(key, offset);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return std::make_pair(key, offset);
|
||||||
}
|
}
|
||||||
struct naive_trie * traverse(const char c) {
|
const struct naive_trie * traverse(const char c) const {
|
||||||
auto res = children.find(c);
|
auto res = children.find(c);
|
||||||
if (res != children.end()) {
|
if (res != children.end()) {
|
||||||
return &res->second;
|
return &res->second;
|
||||||
} else {
|
|
||||||
return NULL;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return NULL;
|
||||||
}
|
}
|
||||||
std::map<char, struct naive_trie> children;
|
std::map<char, struct naive_trie> children;
|
||||||
bool has_value;
|
bool has_value;
|
||||||
|
@ -843,7 +843,7 @@ struct llm_tokenizer_ugm {
|
||||||
// traverse the token matcher trie to find a matching token
|
// traverse the token matcher trie to find a matching token
|
||||||
bool single_codepoint_token_found = false;
|
bool single_codepoint_token_found = false;
|
||||||
const struct best_tokenization & current_best = tokenization_results[input_offset];
|
const struct best_tokenization & current_best = tokenization_results[input_offset];
|
||||||
struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
|
const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
|
||||||
|
|
||||||
while (prefix_offset <= input_len && node != NULL) {
|
while (prefix_offset <= input_len && node != NULL) {
|
||||||
// check if we found valid token in prefix
|
// check if we found valid token in prefix
|
||||||
|
@ -1103,6 +1103,7 @@ private:
|
||||||
|
|
||||||
static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escaped) {
|
static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escaped) {
|
||||||
std::vector<uint8_t> output;
|
std::vector<uint8_t> output;
|
||||||
|
output.reserve(escaped.size());
|
||||||
|
|
||||||
// Parser state
|
// Parser state
|
||||||
bool escaping = false;
|
bool escaping = false;
|
||||||
|
@ -1158,9 +1159,12 @@ struct llm_tokenizer_rwkv {
|
||||||
llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) {
|
llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) {
|
||||||
// RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
|
// RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
|
||||||
// For now, we decode the vocab here into the lookup we'll use for tokenization.
|
// For now, we decode the vocab here into the lookup we'll use for tokenization.
|
||||||
for (const auto & token : vocab.id_to_token) {
|
|
||||||
auto data = llama_unescape_rwkv_token(token.text);
|
// build trie
|
||||||
tokens.push_back(data);
|
for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
|
||||||
|
const auto & token = vocab.id_to_token[id];
|
||||||
|
const auto data = llama_unescape_rwkv_token(token.text);
|
||||||
|
token_matcher.insert((const char *) data.data(), data.size(), id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1168,36 +1172,34 @@ struct llm_tokenizer_rwkv {
|
||||||
uint32_t position = 0;
|
uint32_t position = 0;
|
||||||
|
|
||||||
while (position < text.size()) {
|
while (position < text.size()) {
|
||||||
// Iterate through possible tokens backwards, starting with the largest
|
const struct naive_trie * node = token_matcher.traverse(text[position]);
|
||||||
for (int32_t i = (int32_t)tokens.size() - 1; i >= 0; i--) {
|
if (node == NULL) {
|
||||||
// Skip tokens that aren't normal type, we can't match on those
|
// no matching token found, add unknown token
|
||||||
if (!(vocab.id_to_token[i].attr & LLAMA_TOKEN_ATTR_NORMAL)) {
|
output.push_back(vocab.special_unk_id);
|
||||||
continue;
|
position += 1;
|
||||||
}
|
continue;
|
||||||
|
|
||||||
uint32_t token_size = tokens[i].size();
|
|
||||||
|
|
||||||
// If there's not enough left for this token
|
|
||||||
if (text.size() - position < token_size) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the token doesn't match the data
|
|
||||||
if (std::memcmp(text.data() + position, tokens[i].data(), token_size) != 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the token and advance
|
|
||||||
output.push_back(i);
|
|
||||||
position += token_size;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// traverse the trie to find the longest matching token
|
||||||
|
uint32_t token_id = 0;
|
||||||
|
uint32_t token_length = 0;
|
||||||
|
while (node != NULL) {
|
||||||
|
if (node->has_value) {
|
||||||
|
token_id = node->value;
|
||||||
|
token_length = position + 1;
|
||||||
|
}
|
||||||
|
node = node->traverse(text[++position]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// add the longest matching token
|
||||||
|
output.push_back(token_id);
|
||||||
|
position = token_length;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_vocab & vocab;
|
const llama_vocab & vocab;
|
||||||
|
|
||||||
std::vector<std::vector<uint8_t>> tokens;
|
struct naive_trie token_matcher;
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue