Compare commits

...
Sign in to create a new pull request.

41 commits

Author SHA1 Message Date
Francis Couture-Harpin
375de5b1f8 llama : use unused n_embd_k_gqa in k_shift
This also slightly reduces the diff from the master branch
2024-09-01 21:59:24 -04:00
Francis Couture-Harpin
5f62db790b llama : fix mixed signedness comparison 2024-09-01 21:50:27 -04:00
Francis Couture-Harpin
9d3f44dad4 convert_hf : fix Jamba conversion 2024-09-01 21:46:35 -04:00
Francis Couture-Harpin
a03e32a3c9 Merge branch 'master' into compilade/refactor-kv-cache 2024-09-01 21:16:32 -04:00
Francis Couture-Harpin
fcb889cf7f llama : session saving and reloading for hybrid models 2024-09-01 20:31:30 -04:00
Francis Couture-Harpin
bc320ef66d Merge branch 'master' into compilade/refactor-kv-cache 2024-08-31 21:17:11 -04:00
Francis Couture-Harpin
9b38f8bf65 Merge branch 'master' into compilade/refactor-kv-cache 2024-07-04 17:33:52 -04:00
Francis Couture-Harpin
10c3c419e9 Merge branch 'master' into compilade/refactor-kv-cache 2024-06-30 16:04:57 -04:00
Francis Couture-Harpin
33425a7e1e mamba : fix non-contiguous usage of ggml_silu 2024-06-12 12:57:02 -04:00
Francis Couture-Harpin
ff794f5535 Merge branch 'master' into compilade/refactor-kv-cache 2024-06-12 12:10:29 -04:00
Francis Couture-Harpin
43d8d4bf9e examples : replace llama_kv_cache_seq_* with llama_past_seq_* 2024-06-11 23:27:04 -04:00
Francis Couture-Harpin
372482dffe llama : rename llama_cache to llama_past
This can be changed back later if the name change is wrong.
I was renaming the functions anyway to generalize kv-cache-related
functions to hybrid and recurrent model architectures.
I think llama_past is a better name than llama_cache for a combined
kv cache and recurrent state cache, because the states it contains
pretty much always come before the newly-added ones for any particular
sequence. Also 'llama_past_clear' sounds more obvious in what it does
than 'llama_kv_cache_clear'. The future is what the models generate.
(For embeddings, the kv cache isn't really used anyway)

Still, I'm open to better suggestions.
2024-06-08 17:58:40 -04:00
Francis Couture-Harpin
6840ac0bca Merge branch 'master' into compilade/refactor-kv-cache 2024-06-08 17:30:49 -04:00
Francis Couture-Harpin
fee3c1d740 llama : allow doing the equivalent of SSM_CONV with SUM_ROWS and MUL
* ggml : allow GGML_OP_CONCAT to work on non-contiguous tensors

The implementation already supported it,
and this makes Mamba's conv step slightly faster.
2024-06-03 13:54:39 -04:00
Francis Couture-Harpin
17f6c1ef3b llama : fix .base() compilation error on Windows 2024-06-03 00:41:15 -04:00
Francis Couture-Harpin
8fb57ac0fb llama : use im2col and mul_mat to perform convolution for Mamba
This removes the need for ggml_ssm_conv!!!
But performance seems slighly worse on my system,
especially for prompt processing.
Maybe ggml_mul_mat isn't optimized for small row sizes?
More performance testing is necessary until GGML_OP_SSM_CONV is removed.

* ggml : make ggml_ssm_scan not modify its source tensors

* llama : fix shared recurrent tail cell count for small ubatch sizes

Otherwise it was impossible to run the 'parallel' example with '-ub 1'
with a Mamba or Jamba model.
2024-06-03 00:01:41 -04:00
Francis Couture-Harpin
eb589d5e36 llama : avoid copies for simple batch splits 2024-06-02 00:18:56 -04:00
Francis Couture-Harpin
61200ef29f llama : fix edge case finding batch seq_id of split recurrent cell
This otherwise was a problem when running the HellaSwag benchmark
with small batch sizes, making it crash.
2024-06-01 16:44:43 -04:00
Francis Couture-Harpin
18d1c14047 llama : minimize swaps when reordering logits
This reduces overhead when running hellaswag
on thousands of sequences with very small 100k params Mamba models.
2024-06-01 15:06:59 -04:00
Francis Couture-Harpin
72eea49224 llama : fix batch split output count for embeddings 2024-06-01 12:24:19 -04:00
Francis Couture-Harpin
5d3c7b9585 Merge branch 'master' into compilade/refactor-kv-cache 2024-06-01 11:51:41 -04:00
Francis Couture-Harpin
3587a94987 llama : use equal-sequence-length sub-batches for recurrent models
* ggml : simplify SSM-related operators

* llama : make recurrent state slot allocation contiguous

* llama : adapt internal uses of batches to llama_ubatch
2024-06-01 11:49:17 -04:00
Francis Couture-Harpin
4e4c41e553 Merge branch 'master' into compilade/refactor-kv-cache 2024-05-28 15:15:18 -04:00
Francis Couture-Harpin
3a414b0be2 llama : sequence-length-aware batch splitting 2024-05-28 15:07:32 -04:00
Francis Couture-Harpin
181dadf294 llama : fix Jamba quantization sanity checks 2024-05-28 15:07:32 -04:00
Francis Couture-Harpin
fc59407efe convert-hf : support Mini-Jamba conversion 2024-05-25 13:56:21 -04:00
Francis Couture-Harpin
ea2e63e9d2 convert-hf : check for unprocessed Jamba experts 2024-05-25 12:54:30 -04:00
Francis Couture-Harpin
61a88a1da3 llama : fix BERT inference without KV cache 2024-05-24 22:41:38 -04:00
Francis Couture-Harpin
0fd13e9473 Merge branch 'master' into compilade/refactor-kv-cache 2024-05-24 19:35:16 -04:00
Francis Couture-Harpin
cbc743e600 llama : support Jamba 2024-05-24 19:27:27 -04:00
Francis Couture-Harpin
7e13f19fb5 llama : rethink recurrent state cell counts
* llama : begin work on support for variable GQA

This will also be useful for Jamba if we consider the Mamba layers
to have 0 KV heads.

* llama : gracefully fail when not finding hybrid slot
2024-05-24 16:19:25 -04:00
Francis Couture-Harpin
3b57b55c6f Merge branch 'master' into compilade/refactor-kv-cache 2024-05-22 15:34:24 -04:00
Francis Couture-Harpin
b7ec12ebf7 Merge branch 'master' into compilade/refactor-kv-cache 2024-05-12 17:13:31 -04:00
Francis Couture-Harpin
b6fafd1747 llama : remove useless return value for some llama_cache_* functions 2024-04-29 12:59:43 -04:00
Francis Couture-Harpin
c460ff1a1c Merge branch 'master' into compilade/refactor-kv-cache 2024-04-29 10:31:39 -04:00
Francis Couture-Harpin
a09db95eab llama : rename many llama_kv_cache_* functions 2024-04-29 10:24:45 -04:00
Francis Couture-Harpin
d66849f628 Merge branch 'master' into compilade/refactor-kv-cache 2024-04-09 20:33:38 -04:00
Francis Couture-Harpin
0c8b3b2095 llama : correctly handle more edge cases for the rs cache 2024-04-09 17:35:52 -04:00
Francis Couture-Harpin
0028010d01 llama : state checkpoints for recurrent models 2024-04-08 09:54:35 -04:00
Francis Couture-Harpin
8db1e4d45f llama : use std::find for seq_nodes in llama_rs_cache 2024-04-04 10:46:43 -04:00
Francis Couture-Harpin
271104c65c wip: llama : separate recurrent states from the KV cache
This will be necessary to support Jamba
(and other recurrent models mixed with Attention).

Doesn't compile yet, and finding a slot isn't yet done correctly for recurrent states.
2024-04-03 20:47:34 -04:00
28 changed files with 2608 additions and 802 deletions

View file

@ -2541,7 +2541,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
if (llama_model_has_decoder(model)) {
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
}
llama_kv_cache_clear(lctx);
llama_past_clear(lctx);
llama_synchronize(lctx);
llama_reset_timings(lctx);
}

View file

@ -2874,6 +2874,120 @@ class MambaModel(Model):
return [(new_name, data_torch)]
@Model.register("JambaForCausalLM")
class JambaModel(Model):
model_arch = gguf.MODEL_ARCH.JAMBA
def get_vocab_base_pre(self, tokenizer) -> str:
del tokenizer # unused
return "gpt-2"
def set_vocab(self):
if (self.dir_model / "tokenizer.model").is_file():
# Using Jamba's tokenizer.json causes errors on model load
# (something about "byte not found in vocab"),
# but there's a working tokenizer.model
self._set_vocab_sentencepiece()
else:
# Some Jamba models only have a tokenizer.json, which works.
self._set_vocab_gpt2()
def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "mamba_d_model"])
d_conv = self.find_hparam(["mamba_d_conv"], optional=True) or 4
d_inner = self.hparams["mamba_expand"] * d_model
d_state = self.find_hparam(["mamba_d_state"], optional=True) or 16
# ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
dt_rank = self.find_hparam(["mamba_dt_rank"], optional=True) or -(d_model // -16)
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-6
n_kv_head = self.hparams["num_key_value_heads"]
attn_offset = self.hparams["attn_layer_offset"]
attn_period = self.hparams["attn_layer_period"]
n_kv_vec = [0 for _ in range(attn_offset)] + [
n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count)
]
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"]))
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(n_kv_vec)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
self.gguf_writer.add_file_type(self.ftype)
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Mini-Jamba
name = name.replace(".moe.", ".feed_forward.")
if bid is not None:
moe_offset = self.hparams["expert_layer_offset"]
moe_period = self.hparams["expert_layer_period"]
if not (bid >= moe_offset and (bid - moe_offset) % moe_period == 0):
name = name.replace(".experts.0.", ".")
# process the experts separately
if ".feed_forward.experts." in name:
n_experts = self.hparams["num_experts"]
assert bid is not None
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]
self._experts[bid][name] = data_torch
if len(self._experts[bid]) >= n_experts * 3:
# merge the experts into a single 3d tensor
for wid in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{wid}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]
data_torch = torch.stack(datas, dim=0)
# using the same merged name as qwen2moe
merged_name = f"model.layers.{bid}.mlp.experts.{wid}.weight"
new_name = self.map_tensor_name(merged_name)
yield new_name, data_torch
return
new_name = self.map_tensor_name(name)
if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)
yield new_name, data_torch
def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
@Model.register("CohereForCausalLM")
class CommandR2Model(Model):
model_arch = gguf.MODEL_ARCH.COMMAND_R

View file

@ -153,7 +153,7 @@ int main(int argc, char ** argv) {
const auto t_pp_start = ggml_time_us();
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
@ -162,7 +162,7 @@ int main(int argc, char ** argv) {
if (is_pp_shared) {
for (int32_t i = 1; i < pl; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
llama_past_seq_cp(ctx, 0, i, -1, -1);
}
}

View file

@ -98,7 +98,7 @@ if llama_decode(context, batch) != 0 {
}
for i in 1 ..< n_parallel {
llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
llama_past_seq_cp(context, 0, Int32(i), -1, -1)
}
if n_parallel > 1 {

View file

@ -132,7 +132,7 @@ int main(int argc, char ** argv) {
//// assign the system KV cache to all parallel sequences
//// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
//for (int32_t i = 1; i < n_parallel; ++i) {
// llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
// llama_past_seq_cp(ctx, 0, i, -1, -1);
//}
if (n_parallel > 1) {

View file

@ -338,7 +338,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
}
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;

View file

@ -35,7 +35,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
const struct llama_model * model = llama_get_model(ctx);
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);
// run model
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);

View file

@ -43,7 +43,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
}
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);
llama_set_embeddings(ctx, true);
llama_set_causal_attn(ctx, false);
@ -98,7 +98,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
const llama_model * mdl = llama_get_model(ctx);
llama_token eos_token = llama_token_eos(mdl);
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);
llama_set_embeddings(ctx, false);
llama_set_causal_attn(ctx, true);

View file

@ -499,7 +499,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;

View file

@ -385,8 +385,8 @@ int main(int argc, char ** argv) {
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
llama_past_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_past_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
n_past -= n_discard;

View file

@ -1515,7 +1515,7 @@ int main(int argc, char ** argv) {
test t(inst, lmodel, ctx);
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);
// cool off before the test
if (params.delay) {
@ -1549,7 +1549,7 @@ int main(int argc, char ** argv) {
}
for (int i = 0; i < params.reps; i++) {
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);
uint64_t t_start = get_time_ns();

View file

@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
}
batch->logits[batch->n_tokens - 1] = true;
llama_kv_cache_clear(context);
llama_past_clear(context);
const auto t_pp_start = ggml_time_us();
if (llama_decode(context, *batch) != 0) {
@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
LOGi("Benchmark text generation (tg)");
llama_kv_cache_clear(context);
llama_past_clear(context);
const auto t_tg_start = ggml_time_us();
for (i = 0; i < tg; i++) {
@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
const auto t_tg_end = ggml_time_us();
llama_kv_cache_clear(context);
llama_past_clear(context);
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
@ -439,5 +439,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
llama_kv_cache_clear(reinterpret_cast<llama_context *>(context));
llama_past_clear(reinterpret_cast<llama_context *>(context));
}

View file

@ -216,7 +216,7 @@ actor LlamaContext {
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
llama_kv_cache_clear(context)
llama_past_clear(context)
let t_pp_start = ggml_time_us()
@ -229,7 +229,7 @@ actor LlamaContext {
// bench text generation
llama_kv_cache_clear(context)
llama_past_clear(context)
let t_tg_start = ggml_time_us()
@ -248,7 +248,7 @@ actor LlamaContext {
let t_tg_end = ggml_time_us()
llama_kv_cache_clear(context)
llama_past_clear(context)
let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
@ -298,7 +298,7 @@ actor LlamaContext {
func clear() {
tokens_list.removeAll()
temporary_invalid_cchars.removeAll()
llama_kv_cache_clear(context)
llama_past_clear(context)
}
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {

View file

@ -96,7 +96,7 @@ int main(int argc, char ** argv) {
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
llama_past_seq_cp(ctx, 0, s, -1, -1);
}
const auto t_enc_end = ggml_time_us();
@ -438,17 +438,18 @@ int main(int argc, char ** argv) {
// KV cache management
// if no verification token matched, we simply remove all cells from this batch -> no fragmentation
llama_kv_cache_seq_rm(ctx, -1, n_past, -1);
// FIXME: recurrent and hybrid models
llama_past_seq_rm(ctx, -1, n_past, -1);
if (seq_id_best != 0) {
// if a verification token matched, we keep the best sequence and remove the rest
// this leads to some KV cache fragmentation
llama_kv_cache_seq_keep(ctx, seq_id_best);
llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1);
llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1);
llama_past_seq_keep(ctx, seq_id_best);
llama_past_seq_cp (ctx, seq_id_best, 0, -1, -1);
llama_past_seq_rm (ctx, seq_id_best, -1, -1);
for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
llama_past_seq_cp(ctx, 0, s, -1, -1);
}
}
}

View file

@ -194,7 +194,8 @@ int main(int argc, char ** argv){
// KV cache management
// clean the cache of draft tokens that weren't accepted
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
// FIXME: recurrent and hybrid models
llama_past_seq_rm(ctx, 0, n_past, -1);
llama_batch_clear(batch_tgt);
llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);

View file

@ -369,6 +369,10 @@ int main(int argc, char ** argv) {
}
n_matching_session_tokens++;
}
// remove any "future" tokens that we might have inherited from the previous session
n_matching_session_tokens = llama_past_seq_rm(ctx, -1, n_matching_session_tokens, -1);
if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) {
LOG_TEE("%s: using full prompt from session file\n", __func__);
} else if (n_matching_session_tokens >= embd_inp.size()) {
@ -380,9 +384,6 @@ int main(int argc, char ** argv) {
LOG_TEE("%s: session file matches %zu / %zu tokens of prompt\n",
__func__, n_matching_session_tokens, embd_inp.size());
}
// remove any "future" tokens that we might have inherited from the previous session
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
}
LOGLN(
@ -395,6 +396,8 @@ int main(int argc, char ** argv) {
LOGLN("recalculate the cached logits (do): session_tokens.resize( %zu )", embd_inp.size() - 1);
session_tokens.resize(embd_inp.size() - 1);
} else {
session_tokens.resize(n_matching_session_tokens);
}
// number of tokens to keep when resetting context
@ -624,8 +627,8 @@ int main(int argc, char ** argv) {
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
llama_past_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_past_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
n_past -= n_discard;
@ -652,9 +655,9 @@ int main(int argc, char ** argv) {
LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd);
llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
llama_past_seq_add(ctx, 0, ga_i, n_past, ib*bd);
llama_past_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
llama_past_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
n_past -= bd;
@ -668,6 +671,8 @@ int main(int argc, char ** argv) {
if (n_session_consumed < (int) session_tokens.size()) {
size_t i = 0;
for ( ; i < embd.size(); i++) {
// TODO: are the session tokens guaranteed to all be matching here?
// Should n_matching_session_tokens be re-used instead?
if (embd[i] != session_tokens[n_session_consumed]) {
session_tokens.resize(n_session_consumed);
break;

View file

@ -200,7 +200,7 @@ int main(int argc, char ** argv) {
// assign the system KV cache to all parallel sequences
for (int32_t i = 1; i <= n_clients; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
llama_past_seq_cp(ctx, 0, i, -1, -1);
}
LOG_TEE("\n");
@ -232,9 +232,9 @@ int main(int argc, char ** argv) {
if (batch.n_tokens == 0) {
// all sequences have ended - clear the entire KV cache
for (int i = 1; i <= n_clients; ++i) {
llama_kv_cache_seq_rm(ctx, i, -1, -1);
llama_past_seq_rm(ctx, i, -1, -1);
// but keep the system prompt
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
llama_past_seq_cp(ctx, 0, i, -1, -1);
}
LOG_TEE("%s: clearing the KV cache\n", __func__);
@ -371,8 +371,8 @@ int main(int argc, char ** argv) {
}
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
llama_past_seq_rm(ctx, client.id + 1, -1, -1);
llama_past_seq_cp(ctx, 0, client.id + 1, -1, -1);
const auto t_main_end = ggml_time_us();

View file

@ -126,11 +126,11 @@ int main(int argc, char ** argv) {
const int ib = i/n_batch - 1;
const int bd = n_batch_grp*(n_grp - 1);
llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
llama_past_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
llama_past_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
llama_kv_cache_update(ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
n_past = llama_past_seq_pos_max(ctx, 0) + 1;
}
llama_batch_clear(batch);
@ -160,12 +160,12 @@ int main(int argc, char ** argv) {
LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard);
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
//llama_kv_cache_defrag(ctx);
llama_kv_cache_update(ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
n_past = llama_past_seq_pos_max(ctx, 0) + 1;
llama_batch_clear(batch);
@ -191,12 +191,12 @@ int main(int argc, char ** argv) {
if (n_discard > 0) {
LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
//llama_kv_cache_defrag(ctx);
llama_kv_cache_update(ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
n_past = llama_past_seq_pos_max(ctx, 0) + 1;
}
}

View file

@ -400,7 +400,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
@ -575,7 +575,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
@ -944,7 +944,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
return;
}
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);
// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@ -1221,7 +1221,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
return;
}
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);
// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@ -1594,7 +1594,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
return;
}
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);
// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@ -1780,7 +1780,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
}
// clear the KV cache
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;

View file

@ -82,7 +82,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);
// run model
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);

View file

@ -199,7 +199,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
// erase whole kv
llama_kv_cache_clear(ctx3);
llama_past_clear(ctx3);
fprintf(stderr, "%s : kv cache cleared\n", __func__);
// restore kv into seq 1

View file

@ -1123,7 +1123,7 @@ struct server_context {
LOG_VERBOSE("clearing KV cache", {});
// clear the entire KV cache
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);
clean_kv_cache = false;
}
@ -1158,7 +1158,7 @@ struct server_context {
// assign the system KV cache to all parallel sequences
for (int32_t i = 1; i <= params.n_parallel; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
llama_past_seq_cp(ctx, 0, i, -1, -1);
}
}
@ -1835,7 +1835,7 @@ struct server_context {
// Erase token cache
const size_t n_erased = slot->cache_tokens.size();
llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
llama_past_seq_rm(ctx, slot->id + 1, -1, -1);
slot->cache_tokens.clear();
server_task_result result;
@ -1960,8 +1960,8 @@ struct server_context {
{"n_cache_tokens", slot.cache_tokens.size()}
});
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
llama_past_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
llama_past_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
if (slot.params.cache_prompt) {
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
@ -2200,23 +2200,28 @@ struct server_context {
}
// keep only the common part
int p0 = (int) system_tokens.size() + slot.n_past;
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
// could not partially delete (likely using a non-Transformer model)
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
llama_pos p0 = (llama_pos) system_tokens.size() + slot.n_past;
p0 = (int) system_tokens.size();
if (p0 != 0) {
// copy over the system prompt when there is one
llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
// for recurrent and hybrid models, sometimes it goes back further than asked
llama_pos new_p0 = llama_past_seq_rm(ctx, slot.id + 1, p0, -1);
if (new_p0 < p0) {
GGML_ASSERT(new_p0 >= (llama_pos) system_tokens.size());
slot.n_past -= p0 - new_p0;
if (slot.ga_i > 0) {
// TODO: test with an hybrid model (e.g. Jamba)
slot.n_past_se -= p0 - new_p0;
}
// there is no common part left (except for the system prompt)
slot.n_past = 0;
slot.n_past_se = 0;
slot.ga_i = 0;
// TODO: is the system prompt ever in the sampling context?
// TODO: find a way to avoid rolling back the sampling context twice
llama_sampling_reset(slot.ctx_sampling);
// push the prompt into the sampling context (do not apply grammar)
for (int i = 0; i < slot.n_past; ++i) {
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
}
p0 = new_p0;
}
// remove the non-common part from the cache
@ -2321,9 +2326,9 @@ struct server_context {
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
llama_past_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
llama_past_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
llama_past_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
slot.n_past_se -= bd;

View file

@ -399,14 +399,15 @@ int main(int argc, char ** argv) {
{
LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
llama_kv_cache_seq_keep(ctx_dft, s_keep);
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_dft, 0);
llama_past_seq_keep(ctx_dft, s_keep);
llama_past_seq_cp (ctx_dft, s_keep, 0, -1, -1);
llama_past_seq_keep(ctx_dft, 0);
llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
llama_kv_cache_seq_keep(ctx_tgt, s_keep);
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_tgt, 0);
// FIXME: recurrent and hybrid models
llama_past_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
llama_past_seq_keep(ctx_tgt, s_keep);
llama_past_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_past_seq_keep(ctx_tgt, 0);
}
for (int s = 0; s < n_seq_dft; ++s) {
@ -423,7 +424,8 @@ int main(int argc, char ** argv) {
llama_batch_clear(batch_dft);
llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
// FIXME: recurrent and hybrid models
llama_past_seq_rm(ctx_dft, 0, n_past_dft, -1);
// LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
llama_decode(ctx_dft, batch_dft);
@ -479,8 +481,8 @@ int main(int argc, char ** argv) {
if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
llama_past_seq_rm(ctx_dft, n_seq_cur, -1, -1);
llama_past_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
// all previous tokens from this branch are now also part of the new branch
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
@ -558,9 +560,9 @@ int main(int argc, char ** argv) {
// evaluate the target model on the drafted tokens
{
llama_kv_cache_seq_keep(ctx_tgt, 0);
llama_past_seq_keep(ctx_tgt, 0);
for (int s = 1; s < n_seq_dft; ++s) {
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
llama_past_seq_cp(ctx_tgt, 0, s, -1, -1);
}
// LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());

View file

@ -19817,7 +19817,6 @@ struct ggml_cplan ggml_graph_plan(
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
}
} break;
case GGML_OP_CROSS_ENTROPY_LOSS:
{
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);

View file

@ -215,6 +215,7 @@ class MODEL_ARCH(IntEnum):
STARCODER2 = auto()
RWKV6 = auto()
MAMBA = auto()
JAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()
DBRX = auto()
@ -274,7 +275,10 @@ class MODEL_TENSOR(IntEnum):
SSM_CONV1D = auto()
SSM_X = auto()
SSM_DT = auto()
SSM_DT_NORM = auto()
SSM_A = auto()
SSM_B_NORM = auto()
SSM_C_NORM = auto()
SSM_D = auto()
SSM_OUT = auto()
TIME_MIX_W1 = auto()
@ -369,6 +373,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.JAMBA: "jamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.DBRX: "dbrx",
@ -428,7 +433,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
MODEL_TENSOR.SSM_DT_NORM: "blk.{bid}.ssm_dt_norm",
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
MODEL_TENSOR.SSM_B_NORM: "blk.{bid}.ssm_b_norm",
MODEL_TENSOR.SSM_C_NORM: "blk.{bid}.ssm_c_norm",
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
@ -954,6 +962,34 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_OUT,
],
MODEL_ARCH.JAMBA: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.SSM_IN,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_X,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_DT_NORM,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_B_NORM,
MODEL_TENSOR.SSM_C_NORM,
MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_OUT,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.XVERSE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,

View file

@ -238,6 +238,8 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
"transformer.layers.{bid}.ffn_norm", # openelm
"model.layers.{bid}.pre_ff_layernorm", # jamba
"model.layers.{bid}.pre_moe_layernorm", # mini-jamba
),
# Post feed-forward norm
@ -256,6 +258,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.gate", # qwen2moe
"transformer.decoder_layer.{bid}.router", # Grok
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
"model.layers.{bid}.feed_forward.router", # jamba
),
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@ -287,6 +290,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.c_fc", # starcoder2
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
"model.layers.{bid}.residual_mlp.w3", # arctic
"model.layers.{bid}.feed_forward.up_proj", # jamba
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
"transformer.h.{bid}.mlp.c_fc_1", # exaone
),
@ -320,6 +324,7 @@ class TensorNameMap:
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2
"transformer.h.{bid}.mlp.linear_1", # refact
"model.layers.{bid}.residual_mlp.w1", # arctic
"model.layers.{bid}.feed_forward.gate_proj", # jamba
"transformer.h.{bid}.mlp.c_fc_0", # exaone
),
@ -359,6 +364,7 @@ class TensorNameMap:
"transformer.layers.{bid}.ffn.proj_2", # openelm
"model.layers.{bid}.residual_mlp.w2", # arctic
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
"model.layers.{bid}.feed_forward.down_proj", # jamba
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
"model.layers.h.{bid}.mlp.c_proj", # exaone
),
@ -406,38 +412,59 @@ class TensorNameMap:
),
MODEL_TENSOR.SSM_IN: (
"model.layers.{bid}.in_proj",
"backbone.layers.{bid}.mixer.in_proj",
"model.layers.{bid}.in_proj", # mamba-hf
"backbone.layers.{bid}.mixer.in_proj", # mamba
"model.layers.{bid}.mamba.in_proj", # jamba
),
MODEL_TENSOR.SSM_CONV1D: (
"model.layers.{bid}.conv1d",
"backbone.layers.{bid}.mixer.conv1d",
"model.layers.{bid}.conv1d", # mamba-hf
"backbone.layers.{bid}.mixer.conv1d", # mamba
"model.layers.{bid}.mamba.conv1d", # jamba
),
MODEL_TENSOR.SSM_X: (
"model.layers.{bid}.x_proj",
"backbone.layers.{bid}.mixer.x_proj",
"model.layers.{bid}.x_proj", # mamba-hf
"backbone.layers.{bid}.mixer.x_proj", # mamba
"model.layers.{bid}.mamba.x_proj", # jamba
),
MODEL_TENSOR.SSM_DT: (
"model.layers.{bid}.dt_proj",
"backbone.layers.{bid}.mixer.dt_proj",
"model.layers.{bid}.dt_proj", # mamba-hf
"backbone.layers.{bid}.mixer.dt_proj", # mamba
"model.layers.{bid}.mamba.dt_proj", # jamba
),
MODEL_TENSOR.SSM_DT_NORM: (
"model.layers.{bid}.mamba.dt_layernorm", # jamba
),
MODEL_TENSOR.SSM_A: (
"model.layers.{bid}.A_log",
"backbone.layers.{bid}.mixer.A_log",
"model.layers.{bid}.A_log", # mamba-hf
"backbone.layers.{bid}.mixer.A_log", # mamba
"model.layers.{bid}.mamba.A_log", # jamba
),
MODEL_TENSOR.SSM_B_NORM: (
"model.layers.{bid}.mamba.b_layernorm", # jamba
"model.layers.{bid}.mamba.B_layernorm", # mini-jamba
),
MODEL_TENSOR.SSM_C_NORM: (
"model.layers.{bid}.mamba.c_layernorm", # jamba
"model.layers.{bid}.mamba.C_layernorm", # mini-jamba
),
MODEL_TENSOR.SSM_D: (
"model.layers.{bid}.D",
"backbone.layers.{bid}.mixer.D",
"model.layers.{bid}.D", # mamba-hf
"backbone.layers.{bid}.mixer.D", # mamba
"model.layers.{bid}.mamba.D", # jamba
),
MODEL_TENSOR.SSM_OUT: (
"model.layers.{bid}.out_proj",
"backbone.layers.{bid}.mixer.out_proj",
"model.layers.{bid}.out_proj", # mamba-hf
"backbone.layers.{bid}.mixer.out_proj", # mamba
"model.layers.{bid}.mamba.out_proj", # jamba
),
MODEL_TENSOR.TIME_MIX_W1: (

View file

@ -38,10 +38,10 @@
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 8
#define LLAMA_SESSION_VERSION 9
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
#define LLAMA_STATE_SEQ_VERSION 2
#define LLAMA_STATE_SEQ_VERSION 3
#ifdef __cplusplus
extern "C" {
@ -621,6 +621,12 @@ extern "C" {
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
// Rebuild and check the validity of the recurrent state cache's tree of sequences.
// (slow, use only for debugging purposes)
// Returns whether or not the rs cache was valid.
// The errors are always corrected, but only logged when debug is true.
LLAMA_API bool llama_rs_cache_rebuild(struct llama_context * ctx, bool debug);
// Returns the number of tokens in the KV cache (slow, use only for debug)
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
@ -628,36 +634,62 @@ extern "C" {
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
// Clear the KV cache - both cell info is erased and KV data is zeroed
LLAMA_API void llama_kv_cache_clear(
// Returns the number of used recurrent state cells (i.e. have at least one sequence assigned to them)
LLAMA_API int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx);
// Clear the KV cache and recurrent states - both cell info is erased and KV data is zeroed
LLAMA_API void llama_past_clear(
struct llama_context * ctx);
LLAMA_API DEPRECATED(void llama_kv_cache_clear(
struct llama_context * ctx),
"use llama_past_clear instead");
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
// seq_id < 0 : match any sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API bool llama_kv_cache_seq_rm(
// Returns n_past (one more than the largest remaining pos in the seq_id)
// which is only meaningful to handle for partial removals.
LLAMA_API llama_pos llama_past_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1);
LLAMA_API DEPRECATED(bool llama_kv_cache_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1),
"use llama_past_seq_rm instead, and handle its return value for partial removals");
// Copy all tokens that belong to the specified sequence to another sequence
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
// Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_cp(
// Returns n_past (one more than the largest remaining pos in the destination seq_id)
// which is only meaningful to handle when partially copying.
LLAMA_API llama_pos llama_past_seq_cp(
struct llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1);
LLAMA_API DEPRECATED(void llama_kv_cache_seq_cp(
struct llama_context * ctx,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1),
"use llama_past_seq_cp instead, and handle its return value for partial copies");
// Removes all tokens that do not belong to the specified sequence
LLAMA_API void llama_kv_cache_seq_keep(
LLAMA_API void llama_past_seq_keep(
struct llama_context * ctx,
llama_seq_id seq_id);
LLAMA_API DEPRECATED(void llama_kv_cache_seq_keep(
struct llama_context * ctx,
llama_seq_id seq_id),
"use llama_past_seq_keep instead");
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// If the KV cache is RoPEd, the KV data is updated accordingly:
@ -665,12 +697,19 @@ extern "C" {
// - explicitly with llama_kv_cache_update()
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_add(
LLAMA_API void llama_past_seq_add(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta);
LLAMA_API DEPRECATED(void llama_kv_cache_seq_add(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta),
"use llama_past_seq_add instead");
// Integer division of the positions by factor of `d > 1`
// If the KV cache is RoPEd, the KV data is updated accordingly:
@ -678,17 +717,28 @@ extern "C" {
// - explicitly with llama_kv_cache_update()
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_div(
LLAMA_API void llama_past_seq_div(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d);
LLAMA_API DEPRECATED(void llama_kv_cache_seq_div(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d),
"use llama_past_seq_div instead");
// Returns the largest position present in the KV cache for the specified sequence
LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
// Returns the largest position present in the KV and/or RS cache for the specified sequence
LLAMA_API llama_pos llama_past_seq_pos_max(
struct llama_context * ctx,
llama_seq_id seq_id);
LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max(
struct llama_context * ctx,
llama_seq_id seq_id),
"use llama_past_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells");
// Defragment the KV cache
// This will be applied:

File diff suppressed because it is too large Load diff