From 7004323ecdd5f4dab77e626ea0e677fcf175542e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 30 Aug 2024 13:19:14 +0300 Subject: [PATCH] rwkv : speed-up tokenization using trie --- src/llama-vocab.cpp | 68 +++++++++++++++++++++++---------------------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 9be52d737..12fbb5971 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -58,17 +58,17 @@ struct naive_trie { auto res = children.find(c); if (res != children.end()) { return res->second.get_longest_prefix(key, len, offset + 1); - } else { - return std::make_pair(key, offset); } + + return std::make_pair(key, offset); } - struct naive_trie * traverse(const char c) { + const struct naive_trie * traverse(const char c) const { auto res = children.find(c); if (res != children.end()) { return &res->second; - } else { - return NULL; } + + return NULL; } std::map children; bool has_value; @@ -843,7 +843,7 @@ struct llm_tokenizer_ugm { // traverse the token matcher trie to find a matching token bool single_codepoint_token_found = false; const struct best_tokenization & current_best = tokenization_results[input_offset]; - struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]); + const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]); while (prefix_offset <= input_len && node != NULL) { // check if we found valid token in prefix @@ -1103,6 +1103,7 @@ private: static std::vector llama_unescape_rwkv_token(const std::string & escaped) { std::vector output; + output.reserve(escaped.size()); // Parser state bool escaping = false; @@ -1158,9 +1159,12 @@ struct llm_tokenizer_rwkv { llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) { // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens. // For now, we decode the vocab here into the lookup we'll use for tokenization. - for (const auto & token : vocab.id_to_token) { - auto data = llama_unescape_rwkv_token(token.text); - tokens.push_back(data); + + // build trie + for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) { + const auto & token = vocab.id_to_token[id]; + const auto data = llama_unescape_rwkv_token(token.text); + token_matcher.insert((const char *) data.data(), data.size(), id); } } @@ -1168,36 +1172,34 @@ struct llm_tokenizer_rwkv { uint32_t position = 0; while (position < text.size()) { - // Iterate through possible tokens backwards, starting with the largest - for (int32_t i = (int32_t)tokens.size() - 1; i >= 0; i--) { - // Skip tokens that aren't normal type, we can't match on those - if (!(vocab.id_to_token[i].attr & LLAMA_TOKEN_ATTR_NORMAL)) { - continue; - } - - uint32_t token_size = tokens[i].size(); - - // If there's not enough left for this token - if (text.size() - position < token_size) { - continue; - } - - // If the token doesn't match the data - if (std::memcmp(text.data() + position, tokens[i].data(), token_size) != 0) { - continue; - } - - // Add the token and advance - output.push_back(i); - position += token_size; - break; + const struct naive_trie * node = token_matcher.traverse(text[position]); + if (node == NULL) { + // no matching token found, add unknown token + output.push_back(vocab.special_unk_id); + position += 1; + continue; } + + // traverse the trie to find the longest matching token + uint32_t token_id = 0; + uint32_t token_length = 0; + while (node != NULL) { + if (node->has_value) { + token_id = node->value; + token_length = position + 1; + } + node = node->traverse(text[++position]); + } + + // add the longest matching token + output.push_back(token_id); + position = token_length; } } const llama_vocab & vocab; - std::vector> tokens; + struct naive_trie token_matcher; }; //