Add RWKV tokenization
This commit is contained in:
parent
8d2eca3507
commit
dc0767f4b3
3 changed files with 147 additions and 0 deletions
|
@ -66,6 +66,7 @@ extern "C" {
|
|||
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
|
||||
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
||||
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
|
||||
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
|
||||
};
|
||||
|
||||
// pre-tokenization types
|
||||
|
|
|
@ -1097,6 +1097,104 @@ private:
|
|||
struct naive_trie token_matcher;
|
||||
};
|
||||
|
||||
//
|
||||
// RWKV tokenizer
|
||||
//
|
||||
|
||||
static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escaped) {
|
||||
std::vector<uint8_t> output;
|
||||
|
||||
// Parser state
|
||||
bool escaping = false;
|
||||
uint8_t hex_remaining = 0;
|
||||
uint8_t hex_acc = 0;
|
||||
|
||||
// Step through characters, performing parsing
|
||||
for (const char & c : escaped) {
|
||||
// If we're parsing a hex code, interpret the next character
|
||||
if (hex_remaining != 0) {
|
||||
uint8_t value = (c >= 'a') ? (c - 'a' + 10) : (c - '0');
|
||||
hex_acc = (hex_acc << 4) + value;
|
||||
|
||||
hex_remaining -= 1;
|
||||
if (hex_remaining == 0) {
|
||||
output.push_back(hex_acc);
|
||||
hex_acc = 0;
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// If we got an escape character, interpret it
|
||||
if (escaping) {
|
||||
if (c == 't') {
|
||||
output.push_back('\t');
|
||||
} else if (c == 'n') {
|
||||
output.push_back('\n');
|
||||
} else if (c == 'r') {
|
||||
output.push_back('\r');
|
||||
} else if (c == 'x') {
|
||||
hex_remaining = 2;
|
||||
} else {
|
||||
output.push_back(c);
|
||||
}
|
||||
|
||||
escaping = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (c == '\\') {
|
||||
escaping = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
output.push_back(c);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
struct llm_tokenizer_rwkv {
|
||||
llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) {
|
||||
// RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
|
||||
// For now, we decode the vocab here into the lookup we'll use for tokenization.
|
||||
for (const auto & token : vocab.id_to_token) {
|
||||
auto data = llama_unescape_rwkv_token(token.text);
|
||||
tokens.push_back(data);
|
||||
}
|
||||
}
|
||||
|
||||
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||
uint32_t position = 0;
|
||||
|
||||
while (position < text.size()) {
|
||||
// Iterate through possible tokens backwards, starting with the largest
|
||||
for (int32_t i = (int32_t)tokens.size() - 1; i >= 0; i--) {
|
||||
uint32_t token_size = tokens[i].size();
|
||||
|
||||
// If there's not enough left for this token
|
||||
if (text.size() - position < token_size) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If the token doesn't match the data
|
||||
if (std::memcmp(text.data() + position, tokens[i].data(), token_size) != 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Add the token and advance
|
||||
output.push_back(i);
|
||||
position += token_size;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const llama_vocab & vocab;
|
||||
|
||||
std::vector<std::vector<uint8_t>> tokens;
|
||||
};
|
||||
|
||||
//
|
||||
// (de-) tokenize
|
||||
//
|
||||
|
@ -1401,6 +1499,23 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
|
|||
output.push_back(vocab.special_eos_id);
|
||||
}
|
||||
} break;
|
||||
case LLAMA_VOCAB_TYPE_RWKV:
|
||||
{
|
||||
for (const auto & fragment : fragment_buffer) {
|
||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
||||
|
||||
#ifdef PRETOKENIZERDEBUG
|
||||
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
||||
#endif
|
||||
|
||||
llm_tokenizer_rwkv tokenizer(vocab);
|
||||
tokenizer.tokenize(raw_text, output);
|
||||
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||
output.push_back(fragment.token);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLAMA_VOCAB_TYPE_NONE:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
@ -1616,6 +1731,17 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
|
|||
}
|
||||
break;
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_RWKV: {
|
||||
std::vector<uint8_t> result = llama_unescape_rwkv_token(token_text);
|
||||
|
||||
// If we don't have enough space, return an error
|
||||
if (result.size() > (size_t)length) {
|
||||
return -(int)result.size();
|
||||
}
|
||||
|
||||
memcpy(buf, result.data(), result.size());
|
||||
return (int)result.size();
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
|
|
@ -212,6 +212,7 @@ enum llm_arch {
|
|||
LLM_ARCH_JAIS,
|
||||
LLM_ARCH_NEMOTRON,
|
||||
LLM_ARCH_EXAONE,
|
||||
LLM_ARCH_RWKV,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
@ -259,6 +260,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_JAIS, "jais" },
|
||||
{ LLM_ARCH_NEMOTRON, "nemotron" },
|
||||
{ LLM_ARCH_EXAONE, "exaone" },
|
||||
{ LLM_ARCH_RWKV, "rwkv" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
|
@ -1339,6 +1341,12 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
|||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_RWKV,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
|
@ -5919,6 +5927,16 @@ static void llm_load_vocab(
|
|||
}
|
||||
#endif
|
||||
}
|
||||
} else if (tokenizer_name == "rwkv") {
|
||||
vocab.type = LLAMA_VOCAB_TYPE_RWKV;
|
||||
|
||||
// default special tokens
|
||||
vocab.special_bos_id = 0;
|
||||
vocab.special_eos_id = 0;
|
||||
vocab.special_unk_id = -1;
|
||||
vocab.special_sep_id = -1;
|
||||
vocab.special_pad_id = -1;
|
||||
vocab.add_space_prefix = false;
|
||||
} else {
|
||||
throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
|
||||
}
|
||||
|
@ -17955,6 +17973,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
|||
case LLM_ARCH_T5:
|
||||
case LLM_ARCH_T5ENCODER:
|
||||
case LLM_ARCH_JAIS:
|
||||
case LLM_ARCH_RWKV:
|
||||
return LLAMA_ROPE_TYPE_NONE;
|
||||
|
||||
// use what we call a normal RoPE, operating on pairs of consecutive head values
|
||||
|
@ -18123,6 +18142,7 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) {
|
|||
bool llama_model_is_recurrent(const struct llama_model * model) {
|
||||
switch (model->arch) {
|
||||
case LLM_ARCH_MAMBA: return true;
|
||||
case LLM_ARCH_RWKV: return true;
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue