Add RWKV tokenization
This commit is contained in:
parent
8d2eca3507
commit
dc0767f4b3
3 changed files with 147 additions and 0 deletions
|
@ -66,6 +66,7 @@ extern "C" {
|
||||||
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
|
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
|
||||||
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
|
||||||
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
|
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
|
||||||
|
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
|
||||||
};
|
};
|
||||||
|
|
||||||
// pre-tokenization types
|
// pre-tokenization types
|
||||||
|
|
|
@ -1097,6 +1097,104 @@ private:
|
||||||
struct naive_trie token_matcher;
|
struct naive_trie token_matcher;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// RWKV tokenizer
|
||||||
|
//
|
||||||
|
|
||||||
|
static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escaped) {
|
||||||
|
std::vector<uint8_t> output;
|
||||||
|
|
||||||
|
// Parser state
|
||||||
|
bool escaping = false;
|
||||||
|
uint8_t hex_remaining = 0;
|
||||||
|
uint8_t hex_acc = 0;
|
||||||
|
|
||||||
|
// Step through characters, performing parsing
|
||||||
|
for (const char & c : escaped) {
|
||||||
|
// If we're parsing a hex code, interpret the next character
|
||||||
|
if (hex_remaining != 0) {
|
||||||
|
uint8_t value = (c >= 'a') ? (c - 'a' + 10) : (c - '0');
|
||||||
|
hex_acc = (hex_acc << 4) + value;
|
||||||
|
|
||||||
|
hex_remaining -= 1;
|
||||||
|
if (hex_remaining == 0) {
|
||||||
|
output.push_back(hex_acc);
|
||||||
|
hex_acc = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we got an escape character, interpret it
|
||||||
|
if (escaping) {
|
||||||
|
if (c == 't') {
|
||||||
|
output.push_back('\t');
|
||||||
|
} else if (c == 'n') {
|
||||||
|
output.push_back('\n');
|
||||||
|
} else if (c == 'r') {
|
||||||
|
output.push_back('\r');
|
||||||
|
} else if (c == 'x') {
|
||||||
|
hex_remaining = 2;
|
||||||
|
} else {
|
||||||
|
output.push_back(c);
|
||||||
|
}
|
||||||
|
|
||||||
|
escaping = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (c == '\\') {
|
||||||
|
escaping = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
output.push_back(c);
|
||||||
|
}
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct llm_tokenizer_rwkv {
|
||||||
|
llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) {
|
||||||
|
// RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
|
||||||
|
// For now, we decode the vocab here into the lookup we'll use for tokenization.
|
||||||
|
for (const auto & token : vocab.id_to_token) {
|
||||||
|
auto data = llama_unescape_rwkv_token(token.text);
|
||||||
|
tokens.push_back(data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||||
|
uint32_t position = 0;
|
||||||
|
|
||||||
|
while (position < text.size()) {
|
||||||
|
// Iterate through possible tokens backwards, starting with the largest
|
||||||
|
for (int32_t i = (int32_t)tokens.size() - 1; i >= 0; i--) {
|
||||||
|
uint32_t token_size = tokens[i].size();
|
||||||
|
|
||||||
|
// If there's not enough left for this token
|
||||||
|
if (text.size() - position < token_size) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the token doesn't match the data
|
||||||
|
if (std::memcmp(text.data() + position, tokens[i].data(), token_size) != 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the token and advance
|
||||||
|
output.push_back(i);
|
||||||
|
position += token_size;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_vocab & vocab;
|
||||||
|
|
||||||
|
std::vector<std::vector<uint8_t>> tokens;
|
||||||
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
// (de-) tokenize
|
// (de-) tokenize
|
||||||
//
|
//
|
||||||
|
@ -1401,6 +1499,23 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
|
||||||
output.push_back(vocab.special_eos_id);
|
output.push_back(vocab.special_eos_id);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLAMA_VOCAB_TYPE_RWKV:
|
||||||
|
{
|
||||||
|
for (const auto & fragment : fragment_buffer) {
|
||||||
|
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||||
|
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
||||||
|
|
||||||
|
#ifdef PRETOKENIZERDEBUG
|
||||||
|
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
||||||
|
#endif
|
||||||
|
|
||||||
|
llm_tokenizer_rwkv tokenizer(vocab);
|
||||||
|
tokenizer.tokenize(raw_text, output);
|
||||||
|
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||||
|
output.push_back(fragment.token);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case LLAMA_VOCAB_TYPE_NONE:
|
case LLAMA_VOCAB_TYPE_NONE:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
@ -1616,6 +1731,17 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case LLAMA_VOCAB_TYPE_RWKV: {
|
||||||
|
std::vector<uint8_t> result = llama_unescape_rwkv_token(token_text);
|
||||||
|
|
||||||
|
// If we don't have enough space, return an error
|
||||||
|
if (result.size() > (size_t)length) {
|
||||||
|
return -(int)result.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
memcpy(buf, result.data(), result.size());
|
||||||
|
return (int)result.size();
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
|
@ -212,6 +212,7 @@ enum llm_arch {
|
||||||
LLM_ARCH_JAIS,
|
LLM_ARCH_JAIS,
|
||||||
LLM_ARCH_NEMOTRON,
|
LLM_ARCH_NEMOTRON,
|
||||||
LLM_ARCH_EXAONE,
|
LLM_ARCH_EXAONE,
|
||||||
|
LLM_ARCH_RWKV,
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -259,6 +260,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||||
{ LLM_ARCH_JAIS, "jais" },
|
{ LLM_ARCH_JAIS, "jais" },
|
||||||
{ LLM_ARCH_NEMOTRON, "nemotron" },
|
{ LLM_ARCH_NEMOTRON, "nemotron" },
|
||||||
{ LLM_ARCH_EXAONE, "exaone" },
|
{ LLM_ARCH_EXAONE, "exaone" },
|
||||||
|
{ LLM_ARCH_RWKV, "rwkv" },
|
||||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1339,6 +1341,12 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_RWKV,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
{
|
{
|
||||||
|
@ -5919,6 +5927,16 @@ static void llm_load_vocab(
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
} else if (tokenizer_name == "rwkv") {
|
||||||
|
vocab.type = LLAMA_VOCAB_TYPE_RWKV;
|
||||||
|
|
||||||
|
// default special tokens
|
||||||
|
vocab.special_bos_id = 0;
|
||||||
|
vocab.special_eos_id = 0;
|
||||||
|
vocab.special_unk_id = -1;
|
||||||
|
vocab.special_sep_id = -1;
|
||||||
|
vocab.special_pad_id = -1;
|
||||||
|
vocab.add_space_prefix = false;
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
|
throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
|
||||||
}
|
}
|
||||||
|
@ -17955,6 +17973,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||||
case LLM_ARCH_T5:
|
case LLM_ARCH_T5:
|
||||||
case LLM_ARCH_T5ENCODER:
|
case LLM_ARCH_T5ENCODER:
|
||||||
case LLM_ARCH_JAIS:
|
case LLM_ARCH_JAIS:
|
||||||
|
case LLM_ARCH_RWKV:
|
||||||
return LLAMA_ROPE_TYPE_NONE;
|
return LLAMA_ROPE_TYPE_NONE;
|
||||||
|
|
||||||
// use what we call a normal RoPE, operating on pairs of consecutive head values
|
// use what we call a normal RoPE, operating on pairs of consecutive head values
|
||||||
|
@ -18123,6 +18142,7 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) {
|
||||||
bool llama_model_is_recurrent(const struct llama_model * model) {
|
bool llama_model_is_recurrent(const struct llama_model * model) {
|
||||||
switch (model->arch) {
|
switch (model->arch) {
|
||||||
case LLM_ARCH_MAMBA: return true;
|
case LLM_ARCH_MAMBA: return true;
|
||||||
|
case LLM_ARCH_RWKV: return true;
|
||||||
default: return false;
|
default: return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue