Add RWKV tokenization

This commit is contained in:
Layl Bongers 2024-04-04 13:59:37 +02:00 committed by Molly Sophia
parent 8d2eca3507
commit dc0767f4b3
3 changed files with 147 additions and 0 deletions

View file

@ -66,6 +66,7 @@ extern "C" {
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
};
// pre-tokenization types

View file

@ -1097,6 +1097,104 @@ private:
struct naive_trie token_matcher;
};
//
// RWKV tokenizer
//
static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escaped) {
std::vector<uint8_t> output;
// Parser state
bool escaping = false;
uint8_t hex_remaining = 0;
uint8_t hex_acc = 0;
// Step through characters, performing parsing
for (const char & c : escaped) {
// If we're parsing a hex code, interpret the next character
if (hex_remaining != 0) {
uint8_t value = (c >= 'a') ? (c - 'a' + 10) : (c - '0');
hex_acc = (hex_acc << 4) + value;
hex_remaining -= 1;
if (hex_remaining == 0) {
output.push_back(hex_acc);
hex_acc = 0;
}
continue;
}
// If we got an escape character, interpret it
if (escaping) {
if (c == 't') {
output.push_back('\t');
} else if (c == 'n') {
output.push_back('\n');
} else if (c == 'r') {
output.push_back('\r');
} else if (c == 'x') {
hex_remaining = 2;
} else {
output.push_back(c);
}
escaping = false;
continue;
}
if (c == '\\') {
escaping = true;
continue;
}
output.push_back(c);
}
return output;
}
struct llm_tokenizer_rwkv {
llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) {
// RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
// For now, we decode the vocab here into the lookup we'll use for tokenization.
for (const auto & token : vocab.id_to_token) {
auto data = llama_unescape_rwkv_token(token.text);
tokens.push_back(data);
}
}
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
uint32_t position = 0;
while (position < text.size()) {
// Iterate through possible tokens backwards, starting with the largest
for (int32_t i = (int32_t)tokens.size() - 1; i >= 0; i--) {
uint32_t token_size = tokens[i].size();
// If there's not enough left for this token
if (text.size() - position < token_size) {
continue;
}
// If the token doesn't match the data
if (std::memcmp(text.data() + position, tokens[i].data(), token_size) != 0) {
continue;
}
// Add the token and advance
output.push_back(i);
position += token_size;
break;
}
}
}
const llama_vocab & vocab;
std::vector<std::vector<uint8_t>> tokens;
};
//
// (de-) tokenize
//
@ -1401,6 +1499,23 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
output.push_back(vocab.special_eos_id);
}
} break;
case LLAMA_VOCAB_TYPE_RWKV:
{
for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
#ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif
llm_tokenizer_rwkv tokenizer(vocab);
tokenizer.tokenize(raw_text, output);
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token);
}
}
} break;
case LLAMA_VOCAB_TYPE_NONE:
GGML_ABORT("fatal error");
}
@ -1616,6 +1731,17 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
}
break;
}
case LLAMA_VOCAB_TYPE_RWKV: {
std::vector<uint8_t> result = llama_unescape_rwkv_token(token_text);
// If we don't have enough space, return an error
if (result.size() > (size_t)length) {
return -(int)result.size();
}
memcpy(buf, result.data(), result.size());
return (int)result.size();
}
default:
GGML_ABORT("fatal error");
}

View file

@ -212,6 +212,7 @@ enum llm_arch {
LLM_ARCH_JAIS,
LLM_ARCH_NEMOTRON,
LLM_ARCH_EXAONE,
LLM_ARCH_RWKV,
LLM_ARCH_UNKNOWN,
};
@ -259,6 +260,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_JAIS, "jais" },
{ LLM_ARCH_NEMOTRON, "nemotron" },
{ LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_RWKV, "rwkv" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@ -1339,6 +1341,12 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_RWKV,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
},
},
{
LLM_ARCH_UNKNOWN,
{
@ -5919,6 +5927,16 @@ static void llm_load_vocab(
}
#endif
}
} else if (tokenizer_name == "rwkv") {
vocab.type = LLAMA_VOCAB_TYPE_RWKV;
// default special tokens
vocab.special_bos_id = 0;
vocab.special_eos_id = 0;
vocab.special_unk_id = -1;
vocab.special_sep_id = -1;
vocab.special_pad_id = -1;
vocab.add_space_prefix = false;
} else {
throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
}
@ -17955,6 +17973,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_T5:
case LLM_ARCH_T5ENCODER:
case LLM_ARCH_JAIS:
case LLM_ARCH_RWKV:
return LLAMA_ROPE_TYPE_NONE;
// use what we call a normal RoPE, operating on pairs of consecutive head values
@ -18123,6 +18142,7 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) {
bool llama_model_is_recurrent(const struct llama_model * model) {
switch (model->arch) {
case LLM_ARCH_MAMBA: return true;
case LLM_ARCH_RWKV: return true;
default: return false;
}
}