llama : support RWKV v6 models (#8980)
* convert_hf_to_gguf: Add support for RWKV v6 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Add RWKV tokenization * Fix build Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Do not use special tokens when matching in RWKV tokenizer * Fix model loading * Add (broken) placeholder graph builder for RWKV * Add workaround for kv cache * Add logits conversion to rwkv5 * Add rwkv5 layer norms * Add time mix KVRG & correct merge mistake * Add remaining time mix parameters * Add time mix output loading * Add placeholder llm_build_time_mix * Fix build Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Load more tensors for rwkv v6 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Fix rwkv tokenizer Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * ggml: Add unary operator Exp Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * RWKV v6 graph building Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Add ``rescale_every_n_layers`` parameter Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Add ``wkv.head_size`` key for RWKV so it doesn't reuse Mamba ssm parameters Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Fix offloading layers to CUDA Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Fix parallel inferencing for RWKV Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Remove trailing whitespaces Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * build_rwkv: Avoid using inplace operations Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * convert_hf_to_gguf: rwkv: Avoid using ``eval`` Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * convert_hf_to_gguf: rwkv tokenizer: Don't escape sequences manually Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Update convert_hf_to_gguf.py Co-authored-by: compilade <git@compilade.net> * ggml: Add backward computation for unary op ``exp`` Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Update convert_hf_to_gguf.py Co-authored-by: compilade <git@compilade.net> * Update convert_hf_to_gguf.py Co-authored-by: compilade <git@compilade.net> * Use MODEL_ARCH.RWKV6 instead of MODEL_ARCH.RWKV Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * build_rwkv6: Simplify graph Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Detect model.type Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Fix tensor loading for 7B/14B models Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Fix group_norm assertion failure with Metal Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Clean up Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Add quantization tensor exclusion Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Use the new advanced batch splits Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Update src/llama.cpp Co-authored-by: compilade <git@compilade.net> * llama: rwkv6: Use ``ggml_norm`` instead of ``ggml_group_norm`` Co-authored-by: compilade <git@compilade.net> * llama: rwkv6: Apply code style and misc changes Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * converter: Use class name ``Rwkv6Model`` Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Make use of key ``feed_forward_length`` Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Add kv ``time_mix_extra_dim`` and ``time_decay_extra_dim`` Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * converter: Match ``new_name`` instead of ``name`` for float32 explicit tensors Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Keep ``time_mix_w1/w2`` as F32 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Remove unused nodes Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Apply code format changes Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: rwkv6: Add lora for some supported tensors Currently att.key/receptance/value/gate/output, ffn.receptance/key/value, as well as head.weight Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * rwkv : speed-up tokenization using trie * minor : style + indentation * llama: rwkv6: Avoid division by zero Co-authored-by: compilade <git@compilade.net> * ggml: rwkv_wkv: Avoid copying the state Signed-off-by: Molly Sophia <mollysophia379@gmail.com> --------- Signed-off-by: Molly Sophia <mollysophia379@gmail.com> Co-authored-by: Layl Bongers <3094382+LaylBongers@users.noreply.github.com> Co-authored-by: compilade <git@compilade.net> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
a47667cff4
commit
8f1d81a0b6
9 changed files with 1266 additions and 103 deletions
|
@ -58,17 +58,17 @@ struct naive_trie {
|
|||
auto res = children.find(c);
|
||||
if (res != children.end()) {
|
||||
return res->second.get_longest_prefix(key, len, offset + 1);
|
||||
} else {
|
||||
return std::make_pair(key, offset);
|
||||
}
|
||||
|
||||
return std::make_pair(key, offset);
|
||||
}
|
||||
struct naive_trie * traverse(const char c) {
|
||||
const struct naive_trie * traverse(const char c) const {
|
||||
auto res = children.find(c);
|
||||
if (res != children.end()) {
|
||||
return &res->second;
|
||||
} else {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
std::map<char, struct naive_trie> children;
|
||||
bool has_value;
|
||||
|
@ -843,7 +843,7 @@ struct llm_tokenizer_ugm {
|
|||
// traverse the token matcher trie to find a matching token
|
||||
bool single_codepoint_token_found = false;
|
||||
const struct best_tokenization & current_best = tokenization_results[input_offset];
|
||||
struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
|
||||
const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
|
||||
|
||||
while (prefix_offset <= input_len && node != NULL) {
|
||||
// check if we found valid token in prefix
|
||||
|
@ -1097,6 +1097,111 @@ private:
|
|||
struct naive_trie token_matcher;
|
||||
};
|
||||
|
||||
//
|
||||
// RWKV tokenizer
|
||||
//
|
||||
|
||||
static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escaped) {
|
||||
std::vector<uint8_t> output;
|
||||
output.reserve(escaped.size());
|
||||
|
||||
// Parser state
|
||||
bool escaping = false;
|
||||
uint8_t hex_remaining = 0;
|
||||
uint8_t hex_acc = 0;
|
||||
|
||||
// Step through characters, performing parsing
|
||||
for (const char & c : escaped) {
|
||||
// If we're parsing a hex code, interpret the next character
|
||||
if (hex_remaining != 0) {
|
||||
uint8_t value = (c >= 'a') ? (c - 'a' + 10) : (c - '0');
|
||||
hex_acc = (hex_acc << 4) + value;
|
||||
|
||||
hex_remaining -= 1;
|
||||
if (hex_remaining == 0) {
|
||||
output.push_back(hex_acc);
|
||||
hex_acc = 0;
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// If we got an escape character, interpret it
|
||||
if (escaping) {
|
||||
if (c == 't') {
|
||||
output.push_back('\t');
|
||||
} else if (c == 'n') {
|
||||
output.push_back('\n');
|
||||
} else if (c == 'r') {
|
||||
output.push_back('\r');
|
||||
} else if (c == 'x') {
|
||||
hex_remaining = 2;
|
||||
} else {
|
||||
output.push_back(c);
|
||||
}
|
||||
|
||||
escaping = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (c == '\\') {
|
||||
escaping = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
output.push_back(c);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
struct llm_tokenizer_rwkv {
|
||||
llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) {
|
||||
// RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
|
||||
// For now, we decode the vocab here into the lookup we'll use for tokenization.
|
||||
|
||||
// build trie
|
||||
for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
|
||||
const auto & token = vocab.id_to_token[id];
|
||||
const auto data = llama_unescape_rwkv_token(token.text);
|
||||
token_matcher.insert((const char *) data.data(), data.size(), id);
|
||||
}
|
||||
}
|
||||
|
||||
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||
uint32_t position = 0;
|
||||
|
||||
while (position < text.size()) {
|
||||
const struct naive_trie * node = token_matcher.traverse(text[position]);
|
||||
if (node == NULL) {
|
||||
// no matching token found, add unknown token
|
||||
output.push_back(vocab.special_unk_id);
|
||||
position += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// traverse the trie to find the longest matching token
|
||||
uint32_t token_id = 0;
|
||||
uint32_t token_length = 0;
|
||||
while (node != NULL) {
|
||||
if (node->has_value) {
|
||||
token_id = node->value;
|
||||
token_length = position + 1;
|
||||
}
|
||||
node = node->traverse(text[++position]);
|
||||
}
|
||||
|
||||
// add the longest matching token
|
||||
output.push_back(token_id);
|
||||
position = token_length;
|
||||
}
|
||||
}
|
||||
|
||||
const llama_vocab & vocab;
|
||||
|
||||
struct naive_trie token_matcher;
|
||||
};
|
||||
|
||||
//
|
||||
// (de-) tokenize
|
||||
//
|
||||
|
@ -1401,6 +1506,23 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
|
|||
output.push_back(vocab.special_eos_id);
|
||||
}
|
||||
} break;
|
||||
case LLAMA_VOCAB_TYPE_RWKV:
|
||||
{
|
||||
for (const auto & fragment : fragment_buffer) {
|
||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
||||
|
||||
#ifdef PRETOKENIZERDEBUG
|
||||
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
||||
#endif
|
||||
|
||||
llm_tokenizer_rwkv tokenizer(vocab);
|
||||
tokenizer.tokenize(raw_text, output);
|
||||
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||
output.push_back(fragment.token);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLAMA_VOCAB_TYPE_NONE:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
@ -1616,6 +1738,17 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
|
|||
}
|
||||
break;
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_RWKV: {
|
||||
std::vector<uint8_t> result = llama_unescape_rwkv_token(token_text);
|
||||
|
||||
// If we don't have enough space, return an error
|
||||
if (result.size() > (size_t)length) {
|
||||
return -(int)result.size();
|
||||
}
|
||||
|
||||
memcpy(buf, result.data(), result.size());
|
||||
return (int)result.size();
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue