From dc0767f4b3eaf23eae4a44248ac01b1697f64334 Mon Sep 17 00:00:00 2001 From: Layl Bongers <3094382+LaylBongers@users.noreply.github.com> Date: Thu, 4 Apr 2024 13:59:37 +0200 Subject: [PATCH] Add RWKV tokenization --- include/llama.h | 1 + src/llama-vocab.cpp | 126 ++++++++++++++++++++++++++++++++++++++++++++ src/llama.cpp | 20 +++++++ 3 files changed, 147 insertions(+) diff --git a/include/llama.h b/include/llama.h index 6cca6320b..53ff6e535 100644 --- a/include/llama.h +++ b/include/llama.h @@ -66,6 +66,7 @@ extern "C" { LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram + LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization }; // pre-tokenization types diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 323660ef5..30db1a042 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1097,6 +1097,104 @@ private: struct naive_trie token_matcher; }; +// +// RWKV tokenizer +// + +static std::vector llama_unescape_rwkv_token(const std::string & escaped) { + std::vector output; + + // Parser state + bool escaping = false; + uint8_t hex_remaining = 0; + uint8_t hex_acc = 0; + + // Step through characters, performing parsing + for (const char & c : escaped) { + // If we're parsing a hex code, interpret the next character + if (hex_remaining != 0) { + uint8_t value = (c >= 'a') ? (c - 'a' + 10) : (c - '0'); + hex_acc = (hex_acc << 4) + value; + + hex_remaining -= 1; + if (hex_remaining == 0) { + output.push_back(hex_acc); + hex_acc = 0; + } + + continue; + } + + // If we got an escape character, interpret it + if (escaping) { + if (c == 't') { + output.push_back('\t'); + } else if (c == 'n') { + output.push_back('\n'); + } else if (c == 'r') { + output.push_back('\r'); + } else if (c == 'x') { + hex_remaining = 2; + } else { + output.push_back(c); + } + + escaping = false; + continue; + } + + if (c == '\\') { + escaping = true; + continue; + } + + output.push_back(c); + } + + return output; +} + +struct llm_tokenizer_rwkv { + llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) { + // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens. + // For now, we decode the vocab here into the lookup we'll use for tokenization. + for (const auto & token : vocab.id_to_token) { + auto data = llama_unescape_rwkv_token(token.text); + tokens.push_back(data); + } + } + + void tokenize(const std::string & text, std::vector & output) { + uint32_t position = 0; + + while (position < text.size()) { + // Iterate through possible tokens backwards, starting with the largest + for (int32_t i = (int32_t)tokens.size() - 1; i >= 0; i--) { + uint32_t token_size = tokens[i].size(); + + // If there's not enough left for this token + if (text.size() - position < token_size) { + continue; + } + + // If the token doesn't match the data + if (std::memcmp(text.data() + position, tokens[i].data(), token_size) != 0) { + continue; + } + + // Add the token and advance + output.push_back(i); + position += token_size; + break; + } + } + } + + const llama_vocab & vocab; + + std::vector> tokens; +}; + // // (de-) tokenize // @@ -1401,6 +1499,23 @@ std::vector llama_tokenize_internal(const llama_vocab & vocab, output.push_back(vocab.special_eos_id); } } break; + case LLAMA_VOCAB_TYPE_RWKV: + { + for (const auto & fragment : fragment_buffer) { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { + auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); + +#ifdef PRETOKENIZERDEBUG + LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); +#endif + + llm_tokenizer_rwkv tokenizer(vocab); + tokenizer.tokenize(raw_text, output); + } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + output.push_back(fragment.token); + } + } + } break; case LLAMA_VOCAB_TYPE_NONE: GGML_ABORT("fatal error"); } @@ -1616,6 +1731,17 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token } break; } + case LLAMA_VOCAB_TYPE_RWKV: { + std::vector result = llama_unescape_rwkv_token(token_text); + + // If we don't have enough space, return an error + if (result.size() > (size_t)length) { + return -(int)result.size(); + } + + memcpy(buf, result.data(), result.size()); + return (int)result.size(); + } default: GGML_ABORT("fatal error"); } diff --git a/src/llama.cpp b/src/llama.cpp index 8d5f24783..799c6059d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -212,6 +212,7 @@ enum llm_arch { LLM_ARCH_JAIS, LLM_ARCH_NEMOTRON, LLM_ARCH_EXAONE, + LLM_ARCH_RWKV, LLM_ARCH_UNKNOWN, }; @@ -259,6 +260,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_EXAONE, "exaone" }, + { LLM_ARCH_RWKV, "rwkv" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1339,6 +1341,12 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_RWKV, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -5919,6 +5927,16 @@ static void llm_load_vocab( } #endif } + } else if (tokenizer_name == "rwkv") { + vocab.type = LLAMA_VOCAB_TYPE_RWKV; + + // default special tokens + vocab.special_bos_id = 0; + vocab.special_eos_id = 0; + vocab.special_unk_id = -1; + vocab.special_sep_id = -1; + vocab.special_pad_id = -1; + vocab.add_space_prefix = false; } else { throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } @@ -17955,6 +17973,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_T5: case LLM_ARCH_T5ENCODER: case LLM_ARCH_JAIS: + case LLM_ARCH_RWKV: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values @@ -18123,6 +18142,7 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) { bool llama_model_is_recurrent(const struct llama_model * model) { switch (model->arch) { case LLM_ARCH_MAMBA: return true; + case LLM_ARCH_RWKV: return true; default: return false; } }