This commit is contained in:
compilade 2024-09-02 17:12:24 +01:00 committed by GitHub
commit aab435a5a7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 2608 additions and 802 deletions

View file

@ -2541,7 +2541,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
if (llama_model_has_decoder(model)) { if (llama_model_has_decoder(model)) {
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
} }
llama_kv_cache_clear(lctx); llama_past_clear(lctx);
llama_synchronize(lctx); llama_synchronize(lctx);
llama_reset_timings(lctx); llama_reset_timings(lctx);
} }

View file

@ -2874,6 +2874,120 @@ class MambaModel(Model):
return [(new_name, data_torch)] return [(new_name, data_torch)]
@Model.register("JambaForCausalLM")
class JambaModel(Model):
model_arch = gguf.MODEL_ARCH.JAMBA
def get_vocab_base_pre(self, tokenizer) -> str:
del tokenizer # unused
return "gpt-2"
def set_vocab(self):
if (self.dir_model / "tokenizer.model").is_file():
# Using Jamba's tokenizer.json causes errors on model load
# (something about "byte not found in vocab"),
# but there's a working tokenizer.model
self._set_vocab_sentencepiece()
else:
# Some Jamba models only have a tokenizer.json, which works.
self._set_vocab_gpt2()
def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "mamba_d_model"])
d_conv = self.find_hparam(["mamba_d_conv"], optional=True) or 4
d_inner = self.hparams["mamba_expand"] * d_model
d_state = self.find_hparam(["mamba_d_state"], optional=True) or 16
# ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
dt_rank = self.find_hparam(["mamba_dt_rank"], optional=True) or -(d_model // -16)
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-6
n_kv_head = self.hparams["num_key_value_heads"]
attn_offset = self.hparams["attn_layer_offset"]
attn_period = self.hparams["attn_layer_period"]
n_kv_vec = [0 for _ in range(attn_offset)] + [
n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count)
]
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"]))
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(n_kv_vec)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
self.gguf_writer.add_file_type(self.ftype)
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Mini-Jamba
name = name.replace(".moe.", ".feed_forward.")
if bid is not None:
moe_offset = self.hparams["expert_layer_offset"]
moe_period = self.hparams["expert_layer_period"]
if not (bid >= moe_offset and (bid - moe_offset) % moe_period == 0):
name = name.replace(".experts.0.", ".")
# process the experts separately
if ".feed_forward.experts." in name:
n_experts = self.hparams["num_experts"]
assert bid is not None
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]
self._experts[bid][name] = data_torch
if len(self._experts[bid]) >= n_experts * 3:
# merge the experts into a single 3d tensor
for wid in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{wid}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]
data_torch = torch.stack(datas, dim=0)
# using the same merged name as qwen2moe
merged_name = f"model.layers.{bid}.mlp.experts.{wid}.weight"
new_name = self.map_tensor_name(merged_name)
yield new_name, data_torch
return
new_name = self.map_tensor_name(name)
if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)
yield new_name, data_torch
def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
@Model.register("CohereForCausalLM") @Model.register("CohereForCausalLM")
class CommandR2Model(Model): class CommandR2Model(Model):
model_arch = gguf.MODEL_ARCH.COMMAND_R model_arch = gguf.MODEL_ARCH.COMMAND_R

View file

@ -153,7 +153,7 @@ int main(int argc, char ** argv) {
const auto t_pp_start = ggml_time_us(); const auto t_pp_start = ggml_time_us();
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
if (!decode_helper(ctx, batch, ctx_params.n_batch)) { if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_TEE("%s: llama_decode() failed\n", __func__); LOG_TEE("%s: llama_decode() failed\n", __func__);
@ -162,7 +162,7 @@ int main(int argc, char ** argv) {
if (is_pp_shared) { if (is_pp_shared) {
for (int32_t i = 1; i < pl; ++i) { for (int32_t i = 1; i < pl; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); llama_past_seq_cp(ctx, 0, i, -1, -1);
} }
} }

View file

@ -98,7 +98,7 @@ if llama_decode(context, batch) != 0 {
} }
for i in 1 ..< n_parallel { for i in 1 ..< n_parallel {
llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens) llama_past_seq_cp(context, 0, Int32(i), -1, -1)
} }
if n_parallel > 1 { if n_parallel > 1 {

View file

@ -132,7 +132,7 @@ int main(int argc, char ** argv) {
//// assign the system KV cache to all parallel sequences //// assign the system KV cache to all parallel sequences
//// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them //// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
//for (int32_t i = 1; i < n_parallel; ++i) { //for (int32_t i = 1; i < n_parallel; ++i) {
// llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); // llama_past_seq_cp(ctx, 0, i, -1, -1);
//} //}
if (n_parallel > 1) { if (n_parallel > 1) {

View file

@ -338,7 +338,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
} }
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) { static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) { if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return false; return false;

View file

@ -35,7 +35,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
const struct llama_model * model = llama_get_model(ctx); const struct llama_model * model = llama_get_model(ctx);
// clear previous kv_cache values (irrelevant for embeddings) // clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
// run model // run model
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);

View file

@ -43,7 +43,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
} }
// clear previous kv_cache values (irrelevant for embeddings) // clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
llama_set_embeddings(ctx, true); llama_set_embeddings(ctx, true);
llama_set_causal_attn(ctx, false); llama_set_causal_attn(ctx, false);
@ -98,7 +98,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
const llama_model * mdl = llama_get_model(ctx); const llama_model * mdl = llama_get_model(ctx);
llama_token eos_token = llama_token_eos(mdl); llama_token eos_token = llama_token_eos(mdl);
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
llama_set_embeddings(ctx, false); llama_set_embeddings(ctx, false);
llama_set_causal_attn(ctx, true); llama_set_causal_attn(ctx, true);

View file

@ -499,7 +499,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache // clear the KV cache
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
for (int j = 0; j < num_batches; ++j) { for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch; const int batch_start = start + j * n_batch;

View file

@ -385,8 +385,8 @@ int main(int argc, char ** argv) {
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard); n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); llama_past_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); llama_past_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
n_past -= n_discard; n_past -= n_discard;

View file

@ -1515,7 +1515,7 @@ int main(int argc, char ** argv) {
test t(inst, lmodel, ctx); test t(inst, lmodel, ctx);
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
// cool off before the test // cool off before the test
if (params.delay) { if (params.delay) {
@ -1549,7 +1549,7 @@ int main(int argc, char ** argv) {
} }
for (int i = 0; i < params.reps; i++) { for (int i = 0; i < params.reps; i++) {
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
uint64_t t_start = get_time_ns(); uint64_t t_start = get_time_ns();

View file

@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
} }
batch->logits[batch->n_tokens - 1] = true; batch->logits[batch->n_tokens - 1] = true;
llama_kv_cache_clear(context); llama_past_clear(context);
const auto t_pp_start = ggml_time_us(); const auto t_pp_start = ggml_time_us();
if (llama_decode(context, *batch) != 0) { if (llama_decode(context, *batch) != 0) {
@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
LOGi("Benchmark text generation (tg)"); LOGi("Benchmark text generation (tg)");
llama_kv_cache_clear(context); llama_past_clear(context);
const auto t_tg_start = ggml_time_us(); const auto t_tg_start = ggml_time_us();
for (i = 0; i < tg; i++) { for (i = 0; i < tg; i++) {
@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
const auto t_tg_end = ggml_time_us(); const auto t_tg_end = ggml_time_us();
llama_kv_cache_clear(context); llama_past_clear(context);
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
@ -439,5 +439,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
extern "C" extern "C"
JNIEXPORT void JNICALL JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
llama_kv_cache_clear(reinterpret_cast<llama_context *>(context)); llama_past_clear(reinterpret_cast<llama_context *>(context));
} }

View file

@ -216,7 +216,7 @@ actor LlamaContext {
} }
batch.logits[Int(batch.n_tokens) - 1] = 1 // true batch.logits[Int(batch.n_tokens) - 1] = 1 // true
llama_kv_cache_clear(context) llama_past_clear(context)
let t_pp_start = ggml_time_us() let t_pp_start = ggml_time_us()
@ -229,7 +229,7 @@ actor LlamaContext {
// bench text generation // bench text generation
llama_kv_cache_clear(context) llama_past_clear(context)
let t_tg_start = ggml_time_us() let t_tg_start = ggml_time_us()
@ -248,7 +248,7 @@ actor LlamaContext {
let t_tg_end = ggml_time_us() let t_tg_end = ggml_time_us()
llama_kv_cache_clear(context) llama_past_clear(context)
let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0 let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0 let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
@ -298,7 +298,7 @@ actor LlamaContext {
func clear() { func clear() {
tokens_list.removeAll() tokens_list.removeAll()
temporary_invalid_cchars.removeAll() temporary_invalid_cchars.removeAll()
llama_kv_cache_clear(context) llama_past_clear(context)
} }
private func tokenize(text: String, add_bos: Bool) -> [llama_token] { private func tokenize(text: String, add_bos: Bool) -> [llama_token] {

View file

@ -96,7 +96,7 @@ int main(int argc, char ** argv) {
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
for (int s = 1; s < W + G + 1; ++s) { for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); llama_past_seq_cp(ctx, 0, s, -1, -1);
} }
const auto t_enc_end = ggml_time_us(); const auto t_enc_end = ggml_time_us();
@ -438,17 +438,18 @@ int main(int argc, char ** argv) {
// KV cache management // KV cache management
// if no verification token matched, we simply remove all cells from this batch -> no fragmentation // if no verification token matched, we simply remove all cells from this batch -> no fragmentation
llama_kv_cache_seq_rm(ctx, -1, n_past, -1); // FIXME: recurrent and hybrid models
llama_past_seq_rm(ctx, -1, n_past, -1);
if (seq_id_best != 0) { if (seq_id_best != 0) {
// if a verification token matched, we keep the best sequence and remove the rest // if a verification token matched, we keep the best sequence and remove the rest
// this leads to some KV cache fragmentation // this leads to some KV cache fragmentation
llama_kv_cache_seq_keep(ctx, seq_id_best); llama_past_seq_keep(ctx, seq_id_best);
llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1); llama_past_seq_cp (ctx, seq_id_best, 0, -1, -1);
llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1); llama_past_seq_rm (ctx, seq_id_best, -1, -1);
for (int s = 1; s < W + G + 1; ++s) { for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); llama_past_seq_cp(ctx, 0, s, -1, -1);
} }
} }
} }

View file

@ -194,7 +194,8 @@ int main(int argc, char ** argv){
// KV cache management // KV cache management
// clean the cache of draft tokens that weren't accepted // clean the cache of draft tokens that weren't accepted
llama_kv_cache_seq_rm(ctx, 0, n_past, -1); // FIXME: recurrent and hybrid models
llama_past_seq_rm(ctx, 0, n_past, -1);
llama_batch_clear(batch_tgt); llama_batch_clear(batch_tgt);
llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);

View file

@ -369,6 +369,10 @@ int main(int argc, char ** argv) {
} }
n_matching_session_tokens++; n_matching_session_tokens++;
} }
// remove any "future" tokens that we might have inherited from the previous session
n_matching_session_tokens = llama_past_seq_rm(ctx, -1, n_matching_session_tokens, -1);
if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) { if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) {
LOG_TEE("%s: using full prompt from session file\n", __func__); LOG_TEE("%s: using full prompt from session file\n", __func__);
} else if (n_matching_session_tokens >= embd_inp.size()) { } else if (n_matching_session_tokens >= embd_inp.size()) {
@ -380,9 +384,6 @@ int main(int argc, char ** argv) {
LOG_TEE("%s: session file matches %zu / %zu tokens of prompt\n", LOG_TEE("%s: session file matches %zu / %zu tokens of prompt\n",
__func__, n_matching_session_tokens, embd_inp.size()); __func__, n_matching_session_tokens, embd_inp.size());
} }
// remove any "future" tokens that we might have inherited from the previous session
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
} }
LOGLN( LOGLN(
@ -395,6 +396,8 @@ int main(int argc, char ** argv) {
LOGLN("recalculate the cached logits (do): session_tokens.resize( %zu )", embd_inp.size() - 1); LOGLN("recalculate the cached logits (do): session_tokens.resize( %zu )", embd_inp.size() - 1);
session_tokens.resize(embd_inp.size() - 1); session_tokens.resize(embd_inp.size() - 1);
} else {
session_tokens.resize(n_matching_session_tokens);
} }
// number of tokens to keep when resetting context // number of tokens to keep when resetting context
@ -624,8 +627,8 @@ int main(int argc, char ** argv) {
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard); n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); llama_past_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); llama_past_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
n_past -= n_discard; n_past -= n_discard;
@ -652,9 +655,9 @@ int main(int argc, char ** argv) {
LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n); LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd); LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd); llama_past_seq_add(ctx, 0, ga_i, n_past, ib*bd);
llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); llama_past_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); llama_past_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
n_past -= bd; n_past -= bd;
@ -668,6 +671,8 @@ int main(int argc, char ** argv) {
if (n_session_consumed < (int) session_tokens.size()) { if (n_session_consumed < (int) session_tokens.size()) {
size_t i = 0; size_t i = 0;
for ( ; i < embd.size(); i++) { for ( ; i < embd.size(); i++) {
// TODO: are the session tokens guaranteed to all be matching here?
// Should n_matching_session_tokens be re-used instead?
if (embd[i] != session_tokens[n_session_consumed]) { if (embd[i] != session_tokens[n_session_consumed]) {
session_tokens.resize(n_session_consumed); session_tokens.resize(n_session_consumed);
break; break;

View file

@ -200,7 +200,7 @@ int main(int argc, char ** argv) {
// assign the system KV cache to all parallel sequences // assign the system KV cache to all parallel sequences
for (int32_t i = 1; i <= n_clients; ++i) { for (int32_t i = 1; i <= n_clients; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); llama_past_seq_cp(ctx, 0, i, -1, -1);
} }
LOG_TEE("\n"); LOG_TEE("\n");
@ -232,9 +232,9 @@ int main(int argc, char ** argv) {
if (batch.n_tokens == 0) { if (batch.n_tokens == 0) {
// all sequences have ended - clear the entire KV cache // all sequences have ended - clear the entire KV cache
for (int i = 1; i <= n_clients; ++i) { for (int i = 1; i <= n_clients; ++i) {
llama_kv_cache_seq_rm(ctx, i, -1, -1); llama_past_seq_rm(ctx, i, -1, -1);
// but keep the system prompt // but keep the system prompt
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); llama_past_seq_cp(ctx, 0, i, -1, -1);
} }
LOG_TEE("%s: clearing the KV cache\n", __func__); LOG_TEE("%s: clearing the KV cache\n", __func__);
@ -371,8 +371,8 @@ int main(int argc, char ** argv) {
} }
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache // delete only the generated part of the sequence, i.e. keep the system prompt in the cache
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1); llama_past_seq_rm(ctx, client.id + 1, -1, -1);
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1); llama_past_seq_cp(ctx, 0, client.id + 1, -1, -1);
const auto t_main_end = ggml_time_us(); const auto t_main_end = ggml_time_us();

View file

@ -126,11 +126,11 @@ int main(int argc, char ** argv) {
const int ib = i/n_batch - 1; const int ib = i/n_batch - 1;
const int bd = n_batch_grp*(n_grp - 1); const int bd = n_batch_grp*(n_grp - 1);
llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd); llama_past_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); llama_past_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
llama_kv_cache_update(ctx); llama_kv_cache_update(ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; n_past = llama_past_seq_pos_max(ctx, 0) + 1;
} }
llama_batch_clear(batch); llama_batch_clear(batch);
@ -160,12 +160,12 @@ int main(int argc, char ** argv) {
LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard); LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard);
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
//llama_kv_cache_defrag(ctx); //llama_kv_cache_defrag(ctx);
llama_kv_cache_update(ctx); llama_kv_cache_update(ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; n_past = llama_past_seq_pos_max(ctx, 0) + 1;
llama_batch_clear(batch); llama_batch_clear(batch);
@ -191,12 +191,12 @@ int main(int argc, char ** argv) {
if (n_discard > 0) { if (n_discard > 0) {
LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard); LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
//llama_kv_cache_defrag(ctx); //llama_kv_cache_defrag(ctx);
llama_kv_cache_update(ctx); llama_kv_cache_update(ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; n_past = llama_past_seq_pos_max(ctx, 0) + 1;
} }
} }

View file

@ -400,7 +400,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache // clear the KV cache
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
for (int j = 0; j < num_batches; ++j) { for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch; const int batch_start = start + j * n_batch;
@ -575,7 +575,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache // clear the KV cache
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
for (int j = 0; j < num_batches; ++j) { for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch; const int batch_start = start + j * n_batch;
@ -944,7 +944,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
return; return;
} }
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
// decode all tasks [i0, i1) // decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@ -1221,7 +1221,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
return; return;
} }
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
// decode all tasks [i0, i1) // decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@ -1594,7 +1594,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
return; return;
} }
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
// decode all tasks [i0, i1) // decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@ -1780,7 +1780,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
} }
// clear the KV cache // clear the KV cache
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
for (int j = 0; j < num_batches; ++j) { for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch; const int batch_start = start + j * n_batch;

View file

@ -82,7 +82,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
// clear previous kv_cache values (irrelevant for embeddings) // clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
// run model // run model
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);

View file

@ -199,7 +199,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy); fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
// erase whole kv // erase whole kv
llama_kv_cache_clear(ctx3); llama_past_clear(ctx3);
fprintf(stderr, "%s : kv cache cleared\n", __func__); fprintf(stderr, "%s : kv cache cleared\n", __func__);
// restore kv into seq 1 // restore kv into seq 1

View file

@ -1092,7 +1092,7 @@ struct server_context {
LOG_VERBOSE("clearing KV cache", {}); LOG_VERBOSE("clearing KV cache", {});
// clear the entire KV cache // clear the entire KV cache
llama_kv_cache_clear(ctx); llama_past_clear(ctx);
clean_kv_cache = false; clean_kv_cache = false;
} }
@ -1127,7 +1127,7 @@ struct server_context {
// assign the system KV cache to all parallel sequences // assign the system KV cache to all parallel sequences
for (int32_t i = 1; i <= params.n_parallel; ++i) { for (int32_t i = 1; i <= params.n_parallel; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); llama_past_seq_cp(ctx, 0, i, -1, -1);
} }
} }
@ -1844,7 +1844,7 @@ struct server_context {
// Erase token cache // Erase token cache
const size_t n_erased = slot->cache_tokens.size(); const size_t n_erased = slot->cache_tokens.size();
llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); llama_past_seq_rm(ctx, slot->id + 1, -1, -1);
slot->cache_tokens.clear(); slot->cache_tokens.clear();
server_task_result result; server_task_result result;
@ -1949,8 +1949,8 @@ struct server_context {
{"n_cache_tokens", slot.cache_tokens.size()} {"n_cache_tokens", slot.cache_tokens.size()}
}); });
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard); llama_past_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); llama_past_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
if (slot.params.cache_prompt) { if (slot.params.cache_prompt) {
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
@ -2189,23 +2189,28 @@ struct server_context {
} }
// keep only the common part // keep only the common part
int p0 = (int) system_tokens.size() + slot.n_past; llama_pos p0 = (llama_pos) system_tokens.size() + slot.n_past;
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
// could not partially delete (likely using a non-Transformer model)
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
p0 = (int) system_tokens.size(); // for recurrent and hybrid models, sometimes it goes back further than asked
if (p0 != 0) { llama_pos new_p0 = llama_past_seq_rm(ctx, slot.id + 1, p0, -1);
// copy over the system prompt when there is one
llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1); if (new_p0 < p0) {
GGML_ASSERT(new_p0 >= (llama_pos) system_tokens.size());
slot.n_past -= p0 - new_p0;
if (slot.ga_i > 0) {
// TODO: test with an hybrid model (e.g. Jamba)
slot.n_past_se -= p0 - new_p0;
} }
// there is no common part left (except for the system prompt) // TODO: find a way to avoid rolling back the sampling context twice
slot.n_past = 0;
slot.n_past_se = 0;
slot.ga_i = 0;
// TODO: is the system prompt ever in the sampling context?
llama_sampling_reset(slot.ctx_sampling); llama_sampling_reset(slot.ctx_sampling);
// push the prompt into the sampling context (do not apply grammar)
for (int i = 0; i < slot.n_past; ++i) {
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
}
p0 = new_p0;
} }
// remove the non-common part from the cache // remove the non-common part from the cache
@ -2310,9 +2315,9 @@ struct server_context {
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd); llama_past_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); llama_past_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd); llama_past_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
slot.n_past_se -= bd; slot.n_past_se -= bd;

View file

@ -399,14 +399,15 @@ int main(int argc, char ** argv) {
{ {
LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft); LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
llama_kv_cache_seq_keep(ctx_dft, s_keep); llama_past_seq_keep(ctx_dft, s_keep);
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1); llama_past_seq_cp (ctx_dft, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_dft, 0); llama_past_seq_keep(ctx_dft, 0);
llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1); // FIXME: recurrent and hybrid models
llama_kv_cache_seq_keep(ctx_tgt, s_keep); llama_past_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1); llama_past_seq_keep(ctx_tgt, s_keep);
llama_kv_cache_seq_keep(ctx_tgt, 0); llama_past_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_past_seq_keep(ctx_tgt, 0);
} }
for (int s = 0; s < n_seq_dft; ++s) { for (int s = 0; s < n_seq_dft; ++s) {
@ -423,7 +424,8 @@ int main(int argc, char ** argv) {
llama_batch_clear(batch_dft); llama_batch_clear(batch_dft);
llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); // FIXME: recurrent and hybrid models
llama_past_seq_rm(ctx_dft, 0, n_past_dft, -1);
// LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); // LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
llama_decode(ctx_dft, batch_dft); llama_decode(ctx_dft, batch_dft);
@ -479,8 +481,8 @@ int main(int argc, char ** argv) {
if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) { if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
LOG("splitting seq %3d into %3d\n", s, n_seq_cur); LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); llama_past_seq_rm(ctx_dft, n_seq_cur, -1, -1);
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); llama_past_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
// all previous tokens from this branch are now also part of the new branch // all previous tokens from this branch are now also part of the new branch
for (int t = 0; t < batch_tgt.n_tokens; ++t) { for (int t = 0; t < batch_tgt.n_tokens; ++t) {
@ -558,9 +560,9 @@ int main(int argc, char ** argv) {
// evaluate the target model on the drafted tokens // evaluate the target model on the drafted tokens
{ {
llama_kv_cache_seq_keep(ctx_tgt, 0); llama_past_seq_keep(ctx_tgt, 0);
for (int s = 1; s < n_seq_dft; ++s) { for (int s = 1; s < n_seq_dft; ++s) {
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1); llama_past_seq_cp(ctx_tgt, 0, s, -1, -1);
} }
// LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); // LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());

View file

@ -19820,7 +19820,6 @@ struct ggml_cplan ggml_graph_plan(
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
} }
} break; } break;
case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS:
{ {
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);

View file

@ -215,6 +215,7 @@ class MODEL_ARCH(IntEnum):
STARCODER2 = auto() STARCODER2 = auto()
RWKV6 = auto() RWKV6 = auto()
MAMBA = auto() MAMBA = auto()
JAMBA = auto()
XVERSE = auto() XVERSE = auto()
COMMAND_R = auto() COMMAND_R = auto()
DBRX = auto() DBRX = auto()
@ -274,7 +275,10 @@ class MODEL_TENSOR(IntEnum):
SSM_CONV1D = auto() SSM_CONV1D = auto()
SSM_X = auto() SSM_X = auto()
SSM_DT = auto() SSM_DT = auto()
SSM_DT_NORM = auto()
SSM_A = auto() SSM_A = auto()
SSM_B_NORM = auto()
SSM_C_NORM = auto()
SSM_D = auto() SSM_D = auto()
SSM_OUT = auto() SSM_OUT = auto()
TIME_MIX_W1 = auto() TIME_MIX_W1 = auto()
@ -369,6 +373,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6", MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.JAMBA: "jamba",
MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.DBRX: "dbrx", MODEL_ARCH.DBRX: "dbrx",
@ -428,7 +433,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
MODEL_TENSOR.SSM_DT_NORM: "blk.{bid}.ssm_dt_norm",
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
MODEL_TENSOR.SSM_B_NORM: "blk.{bid}.ssm_b_norm",
MODEL_TENSOR.SSM_C_NORM: "blk.{bid}.ssm_c_norm",
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
@ -954,6 +962,34 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_OUT, MODEL_TENSOR.SSM_OUT,
], ],
MODEL_ARCH.JAMBA: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.SSM_IN,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_X,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_DT_NORM,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_B_NORM,
MODEL_TENSOR.SSM_C_NORM,
MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_OUT,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.XVERSE: [ MODEL_ARCH.XVERSE: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,

View file

@ -238,6 +238,8 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok "transformer.decoder_layer.{bid}.rms_norm_2", # Grok
"encoder.layers.{bid}.post_attention_layernorm", # chatglm "encoder.layers.{bid}.post_attention_layernorm", # chatglm
"transformer.layers.{bid}.ffn_norm", # openelm "transformer.layers.{bid}.ffn_norm", # openelm
"model.layers.{bid}.pre_ff_layernorm", # jamba
"model.layers.{bid}.pre_moe_layernorm", # mini-jamba
), ),
# Post feed-forward norm # Post feed-forward norm
@ -256,6 +258,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.gate", # qwen2moe "model.layers.{bid}.mlp.gate", # qwen2moe
"transformer.decoder_layer.{bid}.router", # Grok "transformer.decoder_layer.{bid}.router", # Grok
"transformer.blocks.{bid}.ffn.router.layer", # dbrx "transformer.blocks.{bid}.ffn.router.layer", # dbrx
"model.layers.{bid}.feed_forward.router", # jamba
), ),
MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@ -287,6 +290,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.c_fc", # starcoder2 "model.layers.{bid}.mlp.c_fc", # starcoder2
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
"model.layers.{bid}.residual_mlp.w3", # arctic "model.layers.{bid}.residual_mlp.w3", # arctic
"model.layers.{bid}.feed_forward.up_proj", # jamba
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
"transformer.h.{bid}.mlp.c_fc_1", # exaone "transformer.h.{bid}.mlp.c_fc_1", # exaone
), ),
@ -320,6 +324,7 @@ class TensorNameMap:
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2
"transformer.h.{bid}.mlp.linear_1", # refact "transformer.h.{bid}.mlp.linear_1", # refact
"model.layers.{bid}.residual_mlp.w1", # arctic "model.layers.{bid}.residual_mlp.w1", # arctic
"model.layers.{bid}.feed_forward.gate_proj", # jamba
"transformer.h.{bid}.mlp.c_fc_0", # exaone "transformer.h.{bid}.mlp.c_fc_0", # exaone
), ),
@ -359,6 +364,7 @@ class TensorNameMap:
"transformer.layers.{bid}.ffn.proj_2", # openelm "transformer.layers.{bid}.ffn.proj_2", # openelm
"model.layers.{bid}.residual_mlp.w2", # arctic "model.layers.{bid}.residual_mlp.w2", # arctic
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
"model.layers.{bid}.feed_forward.down_proj", # jamba
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
"model.layers.h.{bid}.mlp.c_proj", # exaone "model.layers.h.{bid}.mlp.c_proj", # exaone
), ),
@ -406,38 +412,59 @@ class TensorNameMap:
), ),
MODEL_TENSOR.SSM_IN: ( MODEL_TENSOR.SSM_IN: (
"model.layers.{bid}.in_proj", "model.layers.{bid}.in_proj", # mamba-hf
"backbone.layers.{bid}.mixer.in_proj", "backbone.layers.{bid}.mixer.in_proj", # mamba
"model.layers.{bid}.mamba.in_proj", # jamba
), ),
MODEL_TENSOR.SSM_CONV1D: ( MODEL_TENSOR.SSM_CONV1D: (
"model.layers.{bid}.conv1d", "model.layers.{bid}.conv1d", # mamba-hf
"backbone.layers.{bid}.mixer.conv1d", "backbone.layers.{bid}.mixer.conv1d", # mamba
"model.layers.{bid}.mamba.conv1d", # jamba
), ),
MODEL_TENSOR.SSM_X: ( MODEL_TENSOR.SSM_X: (
"model.layers.{bid}.x_proj", "model.layers.{bid}.x_proj", # mamba-hf
"backbone.layers.{bid}.mixer.x_proj", "backbone.layers.{bid}.mixer.x_proj", # mamba
"model.layers.{bid}.mamba.x_proj", # jamba
), ),
MODEL_TENSOR.SSM_DT: ( MODEL_TENSOR.SSM_DT: (
"model.layers.{bid}.dt_proj", "model.layers.{bid}.dt_proj", # mamba-hf
"backbone.layers.{bid}.mixer.dt_proj", "backbone.layers.{bid}.mixer.dt_proj", # mamba
"model.layers.{bid}.mamba.dt_proj", # jamba
),
MODEL_TENSOR.SSM_DT_NORM: (
"model.layers.{bid}.mamba.dt_layernorm", # jamba
), ),
MODEL_TENSOR.SSM_A: ( MODEL_TENSOR.SSM_A: (
"model.layers.{bid}.A_log", "model.layers.{bid}.A_log", # mamba-hf
"backbone.layers.{bid}.mixer.A_log", "backbone.layers.{bid}.mixer.A_log", # mamba
"model.layers.{bid}.mamba.A_log", # jamba
),
MODEL_TENSOR.SSM_B_NORM: (
"model.layers.{bid}.mamba.b_layernorm", # jamba
"model.layers.{bid}.mamba.B_layernorm", # mini-jamba
),
MODEL_TENSOR.SSM_C_NORM: (
"model.layers.{bid}.mamba.c_layernorm", # jamba
"model.layers.{bid}.mamba.C_layernorm", # mini-jamba
), ),
MODEL_TENSOR.SSM_D: ( MODEL_TENSOR.SSM_D: (
"model.layers.{bid}.D", "model.layers.{bid}.D", # mamba-hf
"backbone.layers.{bid}.mixer.D", "backbone.layers.{bid}.mixer.D", # mamba
"model.layers.{bid}.mamba.D", # jamba
), ),
MODEL_TENSOR.SSM_OUT: ( MODEL_TENSOR.SSM_OUT: (
"model.layers.{bid}.out_proj", "model.layers.{bid}.out_proj", # mamba-hf
"backbone.layers.{bid}.mixer.out_proj", "backbone.layers.{bid}.mixer.out_proj", # mamba
"model.layers.{bid}.mamba.out_proj", # jamba
), ),
MODEL_TENSOR.TIME_MIX_W1: ( MODEL_TENSOR.TIME_MIX_W1: (

View file

@ -38,10 +38,10 @@
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 8 #define LLAMA_SESSION_VERSION 9
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
#define LLAMA_STATE_SEQ_VERSION 2 #define LLAMA_STATE_SEQ_VERSION 3
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
@ -621,6 +621,12 @@ extern "C" {
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view); LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
// Rebuild and check the validity of the recurrent state cache's tree of sequences.
// (slow, use only for debugging purposes)
// Returns whether or not the rs cache was valid.
// The errors are always corrected, but only logged when debug is true.
LLAMA_API bool llama_rs_cache_rebuild(struct llama_context * ctx, bool debug);
// Returns the number of tokens in the KV cache (slow, use only for debug) // Returns the number of tokens in the KV cache (slow, use only for debug)
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx); LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
@ -628,36 +634,62 @@ extern "C" {
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them) // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx); LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
// Clear the KV cache - both cell info is erased and KV data is zeroed // Returns the number of used recurrent state cells (i.e. have at least one sequence assigned to them)
LLAMA_API void llama_kv_cache_clear( LLAMA_API int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx);
// Clear the KV cache and recurrent states - both cell info is erased and KV data is zeroed
LLAMA_API void llama_past_clear(
struct llama_context * ctx); struct llama_context * ctx);
LLAMA_API DEPRECATED(void llama_kv_cache_clear(
struct llama_context * ctx),
"use llama_past_clear instead");
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1) // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
// seq_id < 0 : match any sequence // seq_id < 0 : match any sequence
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
LLAMA_API bool llama_kv_cache_seq_rm( // Returns n_past (one more than the largest remaining pos in the seq_id)
// which is only meaningful to handle for partial removals.
LLAMA_API llama_pos llama_past_seq_rm(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1); llama_pos p1);
LLAMA_API DEPRECATED(bool llama_kv_cache_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1),
"use llama_past_seq_rm instead, and handle its return value for partial removals");
// Copy all tokens that belong to the specified sequence to another sequence // Copy all tokens that belong to the specified sequence to another sequence
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence // Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_cp( // Returns n_past (one more than the largest remaining pos in the destination seq_id)
// which is only meaningful to handle when partially copying.
LLAMA_API llama_pos llama_past_seq_cp(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id_src, llama_seq_id seq_id_src,
llama_seq_id seq_id_dst, llama_seq_id seq_id_dst,
llama_pos p0, llama_pos p0,
llama_pos p1); llama_pos p1);
LLAMA_API DEPRECATED(void llama_kv_cache_seq_cp(
struct llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1),
"use llama_past_seq_cp instead, and handle its return value for partial copies");
// Removes all tokens that do not belong to the specified sequence // Removes all tokens that do not belong to the specified sequence
LLAMA_API void llama_kv_cache_seq_keep( LLAMA_API void llama_past_seq_keep(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id); llama_seq_id seq_id);
LLAMA_API DEPRECATED(void llama_kv_cache_seq_keep(
struct llama_context * ctx,
llama_seq_id seq_id),
"use llama_past_seq_keep instead");
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// If the KV cache is RoPEd, the KV data is updated accordingly: // If the KV cache is RoPEd, the KV data is updated accordingly:
@ -665,12 +697,19 @@ extern "C" {
// - explicitly with llama_kv_cache_update() // - explicitly with llama_kv_cache_update()
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_add( LLAMA_API void llama_past_seq_add(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
llama_pos delta); llama_pos delta);
LLAMA_API DEPRECATED(void llama_kv_cache_seq_add(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta),
"use llama_past_seq_add instead");
// Integer division of the positions by factor of `d > 1` // Integer division of the positions by factor of `d > 1`
// If the KV cache is RoPEd, the KV data is updated accordingly: // If the KV cache is RoPEd, the KV data is updated accordingly:
@ -678,17 +717,28 @@ extern "C" {
// - explicitly with llama_kv_cache_update() // - explicitly with llama_kv_cache_update()
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_div( LLAMA_API void llama_past_seq_div(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
int d); int d);
LLAMA_API DEPRECATED(void llama_kv_cache_seq_div(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d),
"use llama_past_seq_div instead");
// Returns the largest position present in the KV cache for the specified sequence // Returns the largest position present in the KV and/or RS cache for the specified sequence
LLAMA_API llama_pos llama_kv_cache_seq_pos_max( LLAMA_API llama_pos llama_past_seq_pos_max(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id); llama_seq_id seq_id);
LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max(
struct llama_context * ctx,
llama_seq_id seq_id),
"use llama_past_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells");
// Defragment the KV cache // Defragment the KV cache
// This will be applied: // This will be applied:

File diff suppressed because it is too large Load diff