Compare commits
41 commits
master
...
compilade/
Author | SHA1 | Date | |
---|---|---|---|
|
375de5b1f8 | ||
|
5f62db790b | ||
|
9d3f44dad4 | ||
|
a03e32a3c9 | ||
|
fcb889cf7f | ||
|
bc320ef66d | ||
|
9b38f8bf65 | ||
|
10c3c419e9 | ||
|
33425a7e1e | ||
|
ff794f5535 | ||
|
43d8d4bf9e | ||
|
372482dffe | ||
|
6840ac0bca | ||
|
fee3c1d740 | ||
|
17f6c1ef3b | ||
|
8fb57ac0fb | ||
|
eb589d5e36 | ||
|
61200ef29f | ||
|
18d1c14047 | ||
|
72eea49224 | ||
|
5d3c7b9585 | ||
|
3587a94987 | ||
|
4e4c41e553 | ||
|
3a414b0be2 | ||
|
181dadf294 | ||
|
fc59407efe | ||
|
ea2e63e9d2 | ||
|
61a88a1da3 | ||
|
0fd13e9473 | ||
|
cbc743e600 | ||
|
7e13f19fb5 | ||
|
3b57b55c6f | ||
|
b7ec12ebf7 | ||
|
b6fafd1747 | ||
|
c460ff1a1c | ||
|
a09db95eab | ||
|
d66849f628 | ||
|
0c8b3b2095 | ||
|
0028010d01 | ||
|
8db1e4d45f | ||
|
271104c65c |
28 changed files with 2608 additions and 802 deletions
|
@ -2541,7 +2541,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
||||||
if (llama_model_has_decoder(model)) {
|
if (llama_model_has_decoder(model)) {
|
||||||
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
|
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
|
||||||
}
|
}
|
||||||
llama_kv_cache_clear(lctx);
|
llama_past_clear(lctx);
|
||||||
llama_synchronize(lctx);
|
llama_synchronize(lctx);
|
||||||
llama_reset_timings(lctx);
|
llama_reset_timings(lctx);
|
||||||
}
|
}
|
||||||
|
|
|
@ -2874,6 +2874,120 @@ class MambaModel(Model):
|
||||||
return [(new_name, data_torch)]
|
return [(new_name, data_torch)]
|
||||||
|
|
||||||
|
|
||||||
|
@Model.register("JambaForCausalLM")
|
||||||
|
class JambaModel(Model):
|
||||||
|
model_arch = gguf.MODEL_ARCH.JAMBA
|
||||||
|
|
||||||
|
def get_vocab_base_pre(self, tokenizer) -> str:
|
||||||
|
del tokenizer # unused
|
||||||
|
|
||||||
|
return "gpt-2"
|
||||||
|
|
||||||
|
def set_vocab(self):
|
||||||
|
if (self.dir_model / "tokenizer.model").is_file():
|
||||||
|
# Using Jamba's tokenizer.json causes errors on model load
|
||||||
|
# (something about "byte not found in vocab"),
|
||||||
|
# but there's a working tokenizer.model
|
||||||
|
self._set_vocab_sentencepiece()
|
||||||
|
else:
|
||||||
|
# Some Jamba models only have a tokenizer.json, which works.
|
||||||
|
self._set_vocab_gpt2()
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
d_model = self.find_hparam(["hidden_size", "mamba_d_model"])
|
||||||
|
d_conv = self.find_hparam(["mamba_d_conv"], optional=True) or 4
|
||||||
|
d_inner = self.hparams["mamba_expand"] * d_model
|
||||||
|
d_state = self.find_hparam(["mamba_d_state"], optional=True) or 16
|
||||||
|
# ceiling division
|
||||||
|
# ref: https://stackoverflow.com/a/17511341/22827863
|
||||||
|
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
|
||||||
|
dt_rank = self.find_hparam(["mamba_dt_rank"], optional=True) or -(d_model // -16)
|
||||||
|
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-6
|
||||||
|
n_kv_head = self.hparams["num_key_value_heads"]
|
||||||
|
attn_offset = self.hparams["attn_layer_offset"]
|
||||||
|
attn_period = self.hparams["attn_layer_period"]
|
||||||
|
n_kv_vec = [0 for _ in range(attn_offset)] + [
|
||||||
|
n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.gguf_writer.add_block_count(self.block_count)
|
||||||
|
self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"]))
|
||||||
|
self.gguf_writer.add_embedding_length(d_model)
|
||||||
|
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||||
|
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
|
||||||
|
self.gguf_writer.add_head_count_kv(n_kv_vec)
|
||||||
|
self.gguf_writer.add_ssm_conv_kernel(d_conv)
|
||||||
|
self.gguf_writer.add_ssm_inner_size(d_inner)
|
||||||
|
self.gguf_writer.add_ssm_state_size(d_state)
|
||||||
|
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
|
||||||
|
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||||
|
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
|
||||||
|
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
|
||||||
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
|
||||||
|
_experts: list[dict[str, Tensor]] | None = None
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
|
||||||
|
# Mini-Jamba
|
||||||
|
name = name.replace(".moe.", ".feed_forward.")
|
||||||
|
if bid is not None:
|
||||||
|
moe_offset = self.hparams["expert_layer_offset"]
|
||||||
|
moe_period = self.hparams["expert_layer_period"]
|
||||||
|
|
||||||
|
if not (bid >= moe_offset and (bid - moe_offset) % moe_period == 0):
|
||||||
|
name = name.replace(".experts.0.", ".")
|
||||||
|
|
||||||
|
# process the experts separately
|
||||||
|
if ".feed_forward.experts." in name:
|
||||||
|
n_experts = self.hparams["num_experts"]
|
||||||
|
|
||||||
|
assert bid is not None
|
||||||
|
|
||||||
|
if self._experts is None:
|
||||||
|
self._experts = [{} for _ in range(self.block_count)]
|
||||||
|
|
||||||
|
self._experts[bid][name] = data_torch
|
||||||
|
|
||||||
|
if len(self._experts[bid]) >= n_experts * 3:
|
||||||
|
|
||||||
|
# merge the experts into a single 3d tensor
|
||||||
|
for wid in ["down_proj", "gate_proj", "up_proj"]:
|
||||||
|
datas: list[Tensor] = []
|
||||||
|
|
||||||
|
for xid in range(n_experts):
|
||||||
|
ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{wid}.weight"
|
||||||
|
datas.append(self._experts[bid][ename])
|
||||||
|
del self._experts[bid][ename]
|
||||||
|
|
||||||
|
data_torch = torch.stack(datas, dim=0)
|
||||||
|
|
||||||
|
# using the same merged name as qwen2moe
|
||||||
|
merged_name = f"model.layers.{bid}.mlp.experts.{wid}.weight"
|
||||||
|
|
||||||
|
new_name = self.map_tensor_name(merged_name)
|
||||||
|
|
||||||
|
yield new_name, data_torch
|
||||||
|
return
|
||||||
|
|
||||||
|
new_name = self.map_tensor_name(name)
|
||||||
|
|
||||||
|
if name.endswith(".A_log"):
|
||||||
|
logger.debug("A_log --> A ==> " + new_name)
|
||||||
|
data_torch = -torch.exp(data_torch)
|
||||||
|
|
||||||
|
yield new_name, data_torch
|
||||||
|
|
||||||
|
def prepare_tensors(self):
|
||||||
|
super().prepare_tensors()
|
||||||
|
|
||||||
|
if self._experts is not None:
|
||||||
|
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||||
|
experts = [k for d in self._experts for k in d.keys()]
|
||||||
|
if len(experts) > 0:
|
||||||
|
raise ValueError(f"Unprocessed experts: {experts}")
|
||||||
|
|
||||||
|
|
||||||
@Model.register("CohereForCausalLM")
|
@Model.register("CohereForCausalLM")
|
||||||
class CommandR2Model(Model):
|
class CommandR2Model(Model):
|
||||||
model_arch = gguf.MODEL_ARCH.COMMAND_R
|
model_arch = gguf.MODEL_ARCH.COMMAND_R
|
||||||
|
|
|
@ -153,7 +153,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
const auto t_pp_start = ggml_time_us();
|
const auto t_pp_start = ggml_time_us();
|
||||||
|
|
||||||
llama_kv_cache_clear(ctx);
|
llama_past_clear(ctx);
|
||||||
|
|
||||||
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
|
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
|
||||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||||
|
@ -162,7 +162,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
if (is_pp_shared) {
|
if (is_pp_shared) {
|
||||||
for (int32_t i = 1; i < pl; ++i) {
|
for (int32_t i = 1; i < pl; ++i) {
|
||||||
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
llama_past_seq_cp(ctx, 0, i, -1, -1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -98,7 +98,7 @@ if llama_decode(context, batch) != 0 {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i in 1 ..< n_parallel {
|
for i in 1 ..< n_parallel {
|
||||||
llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
|
llama_past_seq_cp(context, 0, Int32(i), -1, -1)
|
||||||
}
|
}
|
||||||
|
|
||||||
if n_parallel > 1 {
|
if n_parallel > 1 {
|
||||||
|
|
|
@ -132,7 +132,7 @@ int main(int argc, char ** argv) {
|
||||||
//// assign the system KV cache to all parallel sequences
|
//// assign the system KV cache to all parallel sequences
|
||||||
//// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
|
//// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
|
||||||
//for (int32_t i = 1; i < n_parallel; ++i) {
|
//for (int32_t i = 1; i < n_parallel; ++i) {
|
||||||
// llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
// llama_past_seq_cp(ctx, 0, i, -1, -1);
|
||||||
//}
|
//}
|
||||||
|
|
||||||
if (n_parallel > 1) {
|
if (n_parallel > 1) {
|
||||||
|
|
|
@ -338,7 +338,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
|
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
|
||||||
llama_kv_cache_clear(ctx);
|
llama_past_clear(ctx);
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) {
|
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -35,7 +35,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
||||||
const struct llama_model * model = llama_get_model(ctx);
|
const struct llama_model * model = llama_get_model(ctx);
|
||||||
|
|
||||||
// clear previous kv_cache values (irrelevant for embeddings)
|
// clear previous kv_cache values (irrelevant for embeddings)
|
||||||
llama_kv_cache_clear(ctx);
|
llama_past_clear(ctx);
|
||||||
|
|
||||||
// run model
|
// run model
|
||||||
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
||||||
|
|
|
@ -43,7 +43,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
||||||
}
|
}
|
||||||
|
|
||||||
// clear previous kv_cache values (irrelevant for embeddings)
|
// clear previous kv_cache values (irrelevant for embeddings)
|
||||||
llama_kv_cache_clear(ctx);
|
llama_past_clear(ctx);
|
||||||
llama_set_embeddings(ctx, true);
|
llama_set_embeddings(ctx, true);
|
||||||
llama_set_causal_attn(ctx, false);
|
llama_set_causal_attn(ctx, false);
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
|
||||||
const llama_model * mdl = llama_get_model(ctx);
|
const llama_model * mdl = llama_get_model(ctx);
|
||||||
llama_token eos_token = llama_token_eos(mdl);
|
llama_token eos_token = llama_token_eos(mdl);
|
||||||
|
|
||||||
llama_kv_cache_clear(ctx);
|
llama_past_clear(ctx);
|
||||||
llama_set_embeddings(ctx, false);
|
llama_set_embeddings(ctx, false);
|
||||||
llama_set_causal_attn(ctx, true);
|
llama_set_causal_attn(ctx, true);
|
||||||
|
|
||||||
|
|
|
@ -499,7 +499,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
|
||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
// clear the KV cache
|
// clear the KV cache
|
||||||
llama_kv_cache_clear(ctx);
|
llama_past_clear(ctx);
|
||||||
|
|
||||||
for (int j = 0; j < num_batches; ++j) {
|
for (int j = 0; j < num_batches; ++j) {
|
||||||
const int batch_start = start + j * n_batch;
|
const int batch_start = start + j * n_batch;
|
||||||
|
|
|
@ -385,8 +385,8 @@ int main(int argc, char ** argv) {
|
||||||
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
||||||
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
||||||
|
|
||||||
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
|
llama_past_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
|
||||||
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
|
llama_past_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
|
||||||
|
|
||||||
n_past -= n_discard;
|
n_past -= n_discard;
|
||||||
|
|
||||||
|
|
|
@ -1515,7 +1515,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
test t(inst, lmodel, ctx);
|
test t(inst, lmodel, ctx);
|
||||||
|
|
||||||
llama_kv_cache_clear(ctx);
|
llama_past_clear(ctx);
|
||||||
|
|
||||||
// cool off before the test
|
// cool off before the test
|
||||||
if (params.delay) {
|
if (params.delay) {
|
||||||
|
@ -1549,7 +1549,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < params.reps; i++) {
|
for (int i = 0; i < params.reps; i++) {
|
||||||
llama_kv_cache_clear(ctx);
|
llama_past_clear(ctx);
|
||||||
|
|
||||||
uint64_t t_start = get_time_ns();
|
uint64_t t_start = get_time_ns();
|
||||||
|
|
||||||
|
|
|
@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
||||||
}
|
}
|
||||||
|
|
||||||
batch->logits[batch->n_tokens - 1] = true;
|
batch->logits[batch->n_tokens - 1] = true;
|
||||||
llama_kv_cache_clear(context);
|
llama_past_clear(context);
|
||||||
|
|
||||||
const auto t_pp_start = ggml_time_us();
|
const auto t_pp_start = ggml_time_us();
|
||||||
if (llama_decode(context, *batch) != 0) {
|
if (llama_decode(context, *batch) != 0) {
|
||||||
|
@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
||||||
|
|
||||||
LOGi("Benchmark text generation (tg)");
|
LOGi("Benchmark text generation (tg)");
|
||||||
|
|
||||||
llama_kv_cache_clear(context);
|
llama_past_clear(context);
|
||||||
const auto t_tg_start = ggml_time_us();
|
const auto t_tg_start = ggml_time_us();
|
||||||
for (i = 0; i < tg; i++) {
|
for (i = 0; i < tg; i++) {
|
||||||
|
|
||||||
|
@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
||||||
|
|
||||||
const auto t_tg_end = ggml_time_us();
|
const auto t_tg_end = ggml_time_us();
|
||||||
|
|
||||||
llama_kv_cache_clear(context);
|
llama_past_clear(context);
|
||||||
|
|
||||||
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
|
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
|
||||||
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
|
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
|
||||||
|
@ -439,5 +439,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT void JNICALL
|
JNIEXPORT void JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
|
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
|
||||||
llama_kv_cache_clear(reinterpret_cast<llama_context *>(context));
|
llama_past_clear(reinterpret_cast<llama_context *>(context));
|
||||||
}
|
}
|
||||||
|
|
|
@ -216,7 +216,7 @@ actor LlamaContext {
|
||||||
}
|
}
|
||||||
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
|
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
|
||||||
|
|
||||||
llama_kv_cache_clear(context)
|
llama_past_clear(context)
|
||||||
|
|
||||||
let t_pp_start = ggml_time_us()
|
let t_pp_start = ggml_time_us()
|
||||||
|
|
||||||
|
@ -229,7 +229,7 @@ actor LlamaContext {
|
||||||
|
|
||||||
// bench text generation
|
// bench text generation
|
||||||
|
|
||||||
llama_kv_cache_clear(context)
|
llama_past_clear(context)
|
||||||
|
|
||||||
let t_tg_start = ggml_time_us()
|
let t_tg_start = ggml_time_us()
|
||||||
|
|
||||||
|
@ -248,7 +248,7 @@ actor LlamaContext {
|
||||||
|
|
||||||
let t_tg_end = ggml_time_us()
|
let t_tg_end = ggml_time_us()
|
||||||
|
|
||||||
llama_kv_cache_clear(context)
|
llama_past_clear(context)
|
||||||
|
|
||||||
let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
|
let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
|
||||||
let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
|
let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
|
||||||
|
@ -298,7 +298,7 @@ actor LlamaContext {
|
||||||
func clear() {
|
func clear() {
|
||||||
tokens_list.removeAll()
|
tokens_list.removeAll()
|
||||||
temporary_invalid_cchars.removeAll()
|
temporary_invalid_cchars.removeAll()
|
||||||
llama_kv_cache_clear(context)
|
llama_past_clear(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
|
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
|
||||||
|
|
|
@ -96,7 +96,7 @@ int main(int argc, char ** argv) {
|
||||||
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
|
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
|
||||||
|
|
||||||
for (int s = 1; s < W + G + 1; ++s) {
|
for (int s = 1; s < W + G + 1; ++s) {
|
||||||
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
|
llama_past_seq_cp(ctx, 0, s, -1, -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto t_enc_end = ggml_time_us();
|
const auto t_enc_end = ggml_time_us();
|
||||||
|
@ -438,17 +438,18 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// KV cache management
|
// KV cache management
|
||||||
// if no verification token matched, we simply remove all cells from this batch -> no fragmentation
|
// if no verification token matched, we simply remove all cells from this batch -> no fragmentation
|
||||||
llama_kv_cache_seq_rm(ctx, -1, n_past, -1);
|
// FIXME: recurrent and hybrid models
|
||||||
|
llama_past_seq_rm(ctx, -1, n_past, -1);
|
||||||
|
|
||||||
if (seq_id_best != 0) {
|
if (seq_id_best != 0) {
|
||||||
// if a verification token matched, we keep the best sequence and remove the rest
|
// if a verification token matched, we keep the best sequence and remove the rest
|
||||||
// this leads to some KV cache fragmentation
|
// this leads to some KV cache fragmentation
|
||||||
llama_kv_cache_seq_keep(ctx, seq_id_best);
|
llama_past_seq_keep(ctx, seq_id_best);
|
||||||
llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1);
|
llama_past_seq_cp (ctx, seq_id_best, 0, -1, -1);
|
||||||
llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1);
|
llama_past_seq_rm (ctx, seq_id_best, -1, -1);
|
||||||
|
|
||||||
for (int s = 1; s < W + G + 1; ++s) {
|
for (int s = 1; s < W + G + 1; ++s) {
|
||||||
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
|
llama_past_seq_cp(ctx, 0, s, -1, -1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -194,7 +194,8 @@ int main(int argc, char ** argv){
|
||||||
|
|
||||||
// KV cache management
|
// KV cache management
|
||||||
// clean the cache of draft tokens that weren't accepted
|
// clean the cache of draft tokens that weren't accepted
|
||||||
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
|
// FIXME: recurrent and hybrid models
|
||||||
|
llama_past_seq_rm(ctx, 0, n_past, -1);
|
||||||
|
|
||||||
llama_batch_clear(batch_tgt);
|
llama_batch_clear(batch_tgt);
|
||||||
llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
|
llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
|
||||||
|
|
|
@ -369,6 +369,10 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
n_matching_session_tokens++;
|
n_matching_session_tokens++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// remove any "future" tokens that we might have inherited from the previous session
|
||||||
|
n_matching_session_tokens = llama_past_seq_rm(ctx, -1, n_matching_session_tokens, -1);
|
||||||
|
|
||||||
if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) {
|
if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) {
|
||||||
LOG_TEE("%s: using full prompt from session file\n", __func__);
|
LOG_TEE("%s: using full prompt from session file\n", __func__);
|
||||||
} else if (n_matching_session_tokens >= embd_inp.size()) {
|
} else if (n_matching_session_tokens >= embd_inp.size()) {
|
||||||
|
@ -380,9 +384,6 @@ int main(int argc, char ** argv) {
|
||||||
LOG_TEE("%s: session file matches %zu / %zu tokens of prompt\n",
|
LOG_TEE("%s: session file matches %zu / %zu tokens of prompt\n",
|
||||||
__func__, n_matching_session_tokens, embd_inp.size());
|
__func__, n_matching_session_tokens, embd_inp.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove any "future" tokens that we might have inherited from the previous session
|
|
||||||
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LOGLN(
|
LOGLN(
|
||||||
|
@ -395,6 +396,8 @@ int main(int argc, char ** argv) {
|
||||||
LOGLN("recalculate the cached logits (do): session_tokens.resize( %zu )", embd_inp.size() - 1);
|
LOGLN("recalculate the cached logits (do): session_tokens.resize( %zu )", embd_inp.size() - 1);
|
||||||
|
|
||||||
session_tokens.resize(embd_inp.size() - 1);
|
session_tokens.resize(embd_inp.size() - 1);
|
||||||
|
} else {
|
||||||
|
session_tokens.resize(n_matching_session_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
// number of tokens to keep when resetting context
|
// number of tokens to keep when resetting context
|
||||||
|
@ -624,8 +627,8 @@ int main(int argc, char ** argv) {
|
||||||
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
||||||
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
||||||
|
|
||||||
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
|
llama_past_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
|
||||||
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
|
llama_past_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
|
||||||
|
|
||||||
n_past -= n_discard;
|
n_past -= n_discard;
|
||||||
|
|
||||||
|
@ -652,9 +655,9 @@ int main(int argc, char ** argv) {
|
||||||
LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
|
LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
|
||||||
LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
|
LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
|
||||||
|
|
||||||
llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd);
|
llama_past_seq_add(ctx, 0, ga_i, n_past, ib*bd);
|
||||||
llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
|
llama_past_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
|
||||||
llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
|
llama_past_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
|
||||||
|
|
||||||
n_past -= bd;
|
n_past -= bd;
|
||||||
|
|
||||||
|
@ -668,6 +671,8 @@ int main(int argc, char ** argv) {
|
||||||
if (n_session_consumed < (int) session_tokens.size()) {
|
if (n_session_consumed < (int) session_tokens.size()) {
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
for ( ; i < embd.size(); i++) {
|
for ( ; i < embd.size(); i++) {
|
||||||
|
// TODO: are the session tokens guaranteed to all be matching here?
|
||||||
|
// Should n_matching_session_tokens be re-used instead?
|
||||||
if (embd[i] != session_tokens[n_session_consumed]) {
|
if (embd[i] != session_tokens[n_session_consumed]) {
|
||||||
session_tokens.resize(n_session_consumed);
|
session_tokens.resize(n_session_consumed);
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -200,7 +200,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// assign the system KV cache to all parallel sequences
|
// assign the system KV cache to all parallel sequences
|
||||||
for (int32_t i = 1; i <= n_clients; ++i) {
|
for (int32_t i = 1; i <= n_clients; ++i) {
|
||||||
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
llama_past_seq_cp(ctx, 0, i, -1, -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_TEE("\n");
|
LOG_TEE("\n");
|
||||||
|
@ -232,9 +232,9 @@ int main(int argc, char ** argv) {
|
||||||
if (batch.n_tokens == 0) {
|
if (batch.n_tokens == 0) {
|
||||||
// all sequences have ended - clear the entire KV cache
|
// all sequences have ended - clear the entire KV cache
|
||||||
for (int i = 1; i <= n_clients; ++i) {
|
for (int i = 1; i <= n_clients; ++i) {
|
||||||
llama_kv_cache_seq_rm(ctx, i, -1, -1);
|
llama_past_seq_rm(ctx, i, -1, -1);
|
||||||
// but keep the system prompt
|
// but keep the system prompt
|
||||||
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
llama_past_seq_cp(ctx, 0, i, -1, -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_TEE("%s: clearing the KV cache\n", __func__);
|
LOG_TEE("%s: clearing the KV cache\n", __func__);
|
||||||
|
@ -371,8 +371,8 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
|
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
|
||||||
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
|
llama_past_seq_rm(ctx, client.id + 1, -1, -1);
|
||||||
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
|
llama_past_seq_cp(ctx, 0, client.id + 1, -1, -1);
|
||||||
|
|
||||||
const auto t_main_end = ggml_time_us();
|
const auto t_main_end = ggml_time_us();
|
||||||
|
|
||||||
|
|
|
@ -126,11 +126,11 @@ int main(int argc, char ** argv) {
|
||||||
const int ib = i/n_batch - 1;
|
const int ib = i/n_batch - 1;
|
||||||
const int bd = n_batch_grp*(n_grp - 1);
|
const int bd = n_batch_grp*(n_grp - 1);
|
||||||
|
|
||||||
llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
|
llama_past_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
|
||||||
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
|
llama_past_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
|
||||||
llama_kv_cache_update (ctx);
|
llama_kv_cache_update(ctx);
|
||||||
|
|
||||||
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
n_past = llama_past_seq_pos_max(ctx, 0) + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch_clear(batch);
|
llama_batch_clear(batch);
|
||||||
|
@ -160,12 +160,12 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard);
|
LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard);
|
||||||
|
|
||||||
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
||||||
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||||
//llama_kv_cache_defrag (ctx);
|
//llama_kv_cache_defrag(ctx);
|
||||||
llama_kv_cache_update (ctx);
|
llama_kv_cache_update(ctx);
|
||||||
|
|
||||||
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
n_past = llama_past_seq_pos_max(ctx, 0) + 1;
|
||||||
|
|
||||||
llama_batch_clear(batch);
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
|
@ -191,12 +191,12 @@ int main(int argc, char ** argv) {
|
||||||
if (n_discard > 0) {
|
if (n_discard > 0) {
|
||||||
LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
|
LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
|
||||||
|
|
||||||
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
||||||
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||||
//llama_kv_cache_defrag (ctx);
|
//llama_kv_cache_defrag(ctx);
|
||||||
llama_kv_cache_update (ctx);
|
llama_kv_cache_update(ctx);
|
||||||
|
|
||||||
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
n_past = llama_past_seq_pos_max(ctx, 0) + 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -400,7 +400,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
// clear the KV cache
|
// clear the KV cache
|
||||||
llama_kv_cache_clear(ctx);
|
llama_past_clear(ctx);
|
||||||
|
|
||||||
for (int j = 0; j < num_batches; ++j) {
|
for (int j = 0; j < num_batches; ++j) {
|
||||||
const int batch_start = start + j * n_batch;
|
const int batch_start = start + j * n_batch;
|
||||||
|
@ -575,7 +575,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
// clear the KV cache
|
// clear the KV cache
|
||||||
llama_kv_cache_clear(ctx);
|
llama_past_clear(ctx);
|
||||||
|
|
||||||
for (int j = 0; j < num_batches; ++j) {
|
for (int j = 0; j < num_batches; ++j) {
|
||||||
const int batch_start = start + j * n_batch;
|
const int batch_start = start + j * n_batch;
|
||||||
|
@ -944,7 +944,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_clear(ctx);
|
llama_past_clear(ctx);
|
||||||
|
|
||||||
// decode all tasks [i0, i1)
|
// decode all tasks [i0, i1)
|
||||||
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
|
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
|
||||||
|
@ -1221,7 +1221,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_clear(ctx);
|
llama_past_clear(ctx);
|
||||||
|
|
||||||
// decode all tasks [i0, i1)
|
// decode all tasks [i0, i1)
|
||||||
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
|
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
|
||||||
|
@ -1594,7 +1594,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_clear(ctx);
|
llama_past_clear(ctx);
|
||||||
|
|
||||||
// decode all tasks [i0, i1)
|
// decode all tasks [i0, i1)
|
||||||
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
|
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
|
||||||
|
@ -1780,7 +1780,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// clear the KV cache
|
// clear the KV cache
|
||||||
llama_kv_cache_clear(ctx);
|
llama_past_clear(ctx);
|
||||||
|
|
||||||
for (int j = 0; j < num_batches; ++j) {
|
for (int j = 0; j < num_batches; ++j) {
|
||||||
const int batch_start = start + j * n_batch;
|
const int batch_start = start + j * n_batch;
|
||||||
|
|
|
@ -82,7 +82,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
|
||||||
|
|
||||||
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
|
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
|
||||||
// clear previous kv_cache values (irrelevant for embeddings)
|
// clear previous kv_cache values (irrelevant for embeddings)
|
||||||
llama_kv_cache_clear(ctx);
|
llama_past_clear(ctx);
|
||||||
|
|
||||||
// run model
|
// run model
|
||||||
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
||||||
|
|
|
@ -199,7 +199,7 @@ int main(int argc, char ** argv) {
|
||||||
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
|
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
|
||||||
|
|
||||||
// erase whole kv
|
// erase whole kv
|
||||||
llama_kv_cache_clear(ctx3);
|
llama_past_clear(ctx3);
|
||||||
fprintf(stderr, "%s : kv cache cleared\n", __func__);
|
fprintf(stderr, "%s : kv cache cleared\n", __func__);
|
||||||
|
|
||||||
// restore kv into seq 1
|
// restore kv into seq 1
|
||||||
|
|
|
@ -1123,7 +1123,7 @@ struct server_context {
|
||||||
LOG_VERBOSE("clearing KV cache", {});
|
LOG_VERBOSE("clearing KV cache", {});
|
||||||
|
|
||||||
// clear the entire KV cache
|
// clear the entire KV cache
|
||||||
llama_kv_cache_clear(ctx);
|
llama_past_clear(ctx);
|
||||||
clean_kv_cache = false;
|
clean_kv_cache = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1158,7 +1158,7 @@ struct server_context {
|
||||||
|
|
||||||
// assign the system KV cache to all parallel sequences
|
// assign the system KV cache to all parallel sequences
|
||||||
for (int32_t i = 1; i <= params.n_parallel; ++i) {
|
for (int32_t i = 1; i <= params.n_parallel; ++i) {
|
||||||
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
llama_past_seq_cp(ctx, 0, i, -1, -1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1835,7 +1835,7 @@ struct server_context {
|
||||||
|
|
||||||
// Erase token cache
|
// Erase token cache
|
||||||
const size_t n_erased = slot->cache_tokens.size();
|
const size_t n_erased = slot->cache_tokens.size();
|
||||||
llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
|
llama_past_seq_rm(ctx, slot->id + 1, -1, -1);
|
||||||
slot->cache_tokens.clear();
|
slot->cache_tokens.clear();
|
||||||
|
|
||||||
server_task_result result;
|
server_task_result result;
|
||||||
|
@ -1960,8 +1960,8 @@ struct server_context {
|
||||||
{"n_cache_tokens", slot.cache_tokens.size()}
|
{"n_cache_tokens", slot.cache_tokens.size()}
|
||||||
});
|
});
|
||||||
|
|
||||||
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
|
llama_past_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
|
||||||
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
llama_past_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
||||||
|
|
||||||
if (slot.params.cache_prompt) {
|
if (slot.params.cache_prompt) {
|
||||||
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
|
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
|
||||||
|
@ -2200,23 +2200,28 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
// keep only the common part
|
// keep only the common part
|
||||||
int p0 = (int) system_tokens.size() + slot.n_past;
|
llama_pos p0 = (llama_pos) system_tokens.size() + slot.n_past;
|
||||||
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
|
|
||||||
// could not partially delete (likely using a non-Transformer model)
|
|
||||||
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
|
|
||||||
|
|
||||||
p0 = (int) system_tokens.size();
|
// for recurrent and hybrid models, sometimes it goes back further than asked
|
||||||
if (p0 != 0) {
|
llama_pos new_p0 = llama_past_seq_rm(ctx, slot.id + 1, p0, -1);
|
||||||
// copy over the system prompt when there is one
|
|
||||||
llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
|
if (new_p0 < p0) {
|
||||||
|
GGML_ASSERT(new_p0 >= (llama_pos) system_tokens.size());
|
||||||
|
|
||||||
|
slot.n_past -= p0 - new_p0;
|
||||||
|
if (slot.ga_i > 0) {
|
||||||
|
// TODO: test with an hybrid model (e.g. Jamba)
|
||||||
|
slot.n_past_se -= p0 - new_p0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// there is no common part left (except for the system prompt)
|
// TODO: find a way to avoid rolling back the sampling context twice
|
||||||
slot.n_past = 0;
|
|
||||||
slot.n_past_se = 0;
|
|
||||||
slot.ga_i = 0;
|
|
||||||
// TODO: is the system prompt ever in the sampling context?
|
|
||||||
llama_sampling_reset(slot.ctx_sampling);
|
llama_sampling_reset(slot.ctx_sampling);
|
||||||
|
// push the prompt into the sampling context (do not apply grammar)
|
||||||
|
for (int i = 0; i < slot.n_past; ++i) {
|
||||||
|
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
|
||||||
|
}
|
||||||
|
|
||||||
|
p0 = new_p0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove the non-common part from the cache
|
// remove the non-common part from the cache
|
||||||
|
@ -2321,9 +2326,9 @@ struct server_context {
|
||||||
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
||||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
||||||
|
|
||||||
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
|
llama_past_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
|
||||||
llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
|
llama_past_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
|
||||||
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
|
llama_past_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
|
||||||
|
|
||||||
slot.n_past_se -= bd;
|
slot.n_past_se -= bd;
|
||||||
|
|
||||||
|
|
|
@ -399,14 +399,15 @@ int main(int argc, char ** argv) {
|
||||||
{
|
{
|
||||||
LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
|
LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
|
||||||
|
|
||||||
llama_kv_cache_seq_keep(ctx_dft, s_keep);
|
llama_past_seq_keep(ctx_dft, s_keep);
|
||||||
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
|
llama_past_seq_cp (ctx_dft, s_keep, 0, -1, -1);
|
||||||
llama_kv_cache_seq_keep(ctx_dft, 0);
|
llama_past_seq_keep(ctx_dft, 0);
|
||||||
|
|
||||||
llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
|
// FIXME: recurrent and hybrid models
|
||||||
llama_kv_cache_seq_keep(ctx_tgt, s_keep);
|
llama_past_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
|
||||||
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
|
llama_past_seq_keep(ctx_tgt, s_keep);
|
||||||
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
llama_past_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
|
||||||
|
llama_past_seq_keep(ctx_tgt, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int s = 0; s < n_seq_dft; ++s) {
|
for (int s = 0; s < n_seq_dft; ++s) {
|
||||||
|
@ -423,7 +424,8 @@ int main(int argc, char ** argv) {
|
||||||
llama_batch_clear(batch_dft);
|
llama_batch_clear(batch_dft);
|
||||||
llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
|
llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
|
||||||
|
|
||||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
// FIXME: recurrent and hybrid models
|
||||||
|
llama_past_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||||
// LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
// LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
||||||
llama_decode(ctx_dft, batch_dft);
|
llama_decode(ctx_dft, batch_dft);
|
||||||
|
|
||||||
|
@ -479,8 +481,8 @@ int main(int argc, char ** argv) {
|
||||||
if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
|
if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
|
||||||
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
||||||
|
|
||||||
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
llama_past_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
||||||
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
|
llama_past_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
|
||||||
|
|
||||||
// all previous tokens from this branch are now also part of the new branch
|
// all previous tokens from this branch are now also part of the new branch
|
||||||
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
|
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
|
||||||
|
@ -558,9 +560,9 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// evaluate the target model on the drafted tokens
|
// evaluate the target model on the drafted tokens
|
||||||
{
|
{
|
||||||
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
llama_past_seq_keep(ctx_tgt, 0);
|
||||||
for (int s = 1; s < n_seq_dft; ++s) {
|
for (int s = 1; s < n_seq_dft; ++s) {
|
||||||
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
|
llama_past_seq_cp(ctx_tgt, 0, s, -1, -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
|
// LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
|
||||||
|
|
|
@ -19817,7 +19817,6 @@ struct ggml_cplan ggml_graph_plan(
|
||||||
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
|
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
|
||||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||||
{
|
{
|
||||||
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
|
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
|
||||||
|
|
|
@ -215,6 +215,7 @@ class MODEL_ARCH(IntEnum):
|
||||||
STARCODER2 = auto()
|
STARCODER2 = auto()
|
||||||
RWKV6 = auto()
|
RWKV6 = auto()
|
||||||
MAMBA = auto()
|
MAMBA = auto()
|
||||||
|
JAMBA = auto()
|
||||||
XVERSE = auto()
|
XVERSE = auto()
|
||||||
COMMAND_R = auto()
|
COMMAND_R = auto()
|
||||||
DBRX = auto()
|
DBRX = auto()
|
||||||
|
@ -274,7 +275,10 @@ class MODEL_TENSOR(IntEnum):
|
||||||
SSM_CONV1D = auto()
|
SSM_CONV1D = auto()
|
||||||
SSM_X = auto()
|
SSM_X = auto()
|
||||||
SSM_DT = auto()
|
SSM_DT = auto()
|
||||||
|
SSM_DT_NORM = auto()
|
||||||
SSM_A = auto()
|
SSM_A = auto()
|
||||||
|
SSM_B_NORM = auto()
|
||||||
|
SSM_C_NORM = auto()
|
||||||
SSM_D = auto()
|
SSM_D = auto()
|
||||||
SSM_OUT = auto()
|
SSM_OUT = auto()
|
||||||
TIME_MIX_W1 = auto()
|
TIME_MIX_W1 = auto()
|
||||||
|
@ -369,6 +373,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
MODEL_ARCH.STARCODER2: "starcoder2",
|
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||||
MODEL_ARCH.RWKV6: "rwkv6",
|
MODEL_ARCH.RWKV6: "rwkv6",
|
||||||
MODEL_ARCH.MAMBA: "mamba",
|
MODEL_ARCH.MAMBA: "mamba",
|
||||||
|
MODEL_ARCH.JAMBA: "jamba",
|
||||||
MODEL_ARCH.XVERSE: "xverse",
|
MODEL_ARCH.XVERSE: "xverse",
|
||||||
MODEL_ARCH.COMMAND_R: "command-r",
|
MODEL_ARCH.COMMAND_R: "command-r",
|
||||||
MODEL_ARCH.DBRX: "dbrx",
|
MODEL_ARCH.DBRX: "dbrx",
|
||||||
|
@ -428,7 +433,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
|
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
|
||||||
MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
|
MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
|
||||||
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
|
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
|
||||||
|
MODEL_TENSOR.SSM_DT_NORM: "blk.{bid}.ssm_dt_norm",
|
||||||
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
|
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
|
||||||
|
MODEL_TENSOR.SSM_B_NORM: "blk.{bid}.ssm_b_norm",
|
||||||
|
MODEL_TENSOR.SSM_C_NORM: "blk.{bid}.ssm_c_norm",
|
||||||
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
||||||
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
||||||
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
|
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
|
||||||
|
@ -954,6 +962,34 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.SSM_D,
|
MODEL_TENSOR.SSM_D,
|
||||||
MODEL_TENSOR.SSM_OUT,
|
MODEL_TENSOR.SSM_OUT,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.JAMBA: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_Q,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.SSM_IN,
|
||||||
|
MODEL_TENSOR.SSM_CONV1D,
|
||||||
|
MODEL_TENSOR.SSM_X,
|
||||||
|
MODEL_TENSOR.SSM_DT,
|
||||||
|
MODEL_TENSOR.SSM_DT_NORM,
|
||||||
|
MODEL_TENSOR.SSM_A,
|
||||||
|
MODEL_TENSOR.SSM_B_NORM,
|
||||||
|
MODEL_TENSOR.SSM_C_NORM,
|
||||||
|
MODEL_TENSOR.SSM_D,
|
||||||
|
MODEL_TENSOR.SSM_OUT,
|
||||||
|
MODEL_TENSOR.FFN_GATE_INP,
|
||||||
|
MODEL_TENSOR.FFN_NORM,
|
||||||
|
MODEL_TENSOR.FFN_GATE,
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
MODEL_TENSOR.FFN_GATE_EXP,
|
||||||
|
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
|
MODEL_TENSOR.FFN_UP_EXP,
|
||||||
|
],
|
||||||
MODEL_ARCH.XVERSE: [
|
MODEL_ARCH.XVERSE: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
MODEL_TENSOR.OUTPUT_NORM,
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
|
|
@ -238,6 +238,8 @@ class TensorNameMap:
|
||||||
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
||||||
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
||||||
"transformer.layers.{bid}.ffn_norm", # openelm
|
"transformer.layers.{bid}.ffn_norm", # openelm
|
||||||
|
"model.layers.{bid}.pre_ff_layernorm", # jamba
|
||||||
|
"model.layers.{bid}.pre_moe_layernorm", # mini-jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
# Post feed-forward norm
|
# Post feed-forward norm
|
||||||
|
@ -256,6 +258,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.mlp.gate", # qwen2moe
|
"model.layers.{bid}.mlp.gate", # qwen2moe
|
||||||
"transformer.decoder_layer.{bid}.router", # Grok
|
"transformer.decoder_layer.{bid}.router", # Grok
|
||||||
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
|
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
|
||||||
|
"model.layers.{bid}.feed_forward.router", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||||
|
@ -287,6 +290,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.mlp.c_fc", # starcoder2
|
"model.layers.{bid}.mlp.c_fc", # starcoder2
|
||||||
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
|
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
|
||||||
"model.layers.{bid}.residual_mlp.w3", # arctic
|
"model.layers.{bid}.residual_mlp.w3", # arctic
|
||||||
|
"model.layers.{bid}.feed_forward.up_proj", # jamba
|
||||||
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
|
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
|
||||||
"transformer.h.{bid}.mlp.c_fc_1", # exaone
|
"transformer.h.{bid}.mlp.c_fc_1", # exaone
|
||||||
),
|
),
|
||||||
|
@ -320,6 +324,7 @@ class TensorNameMap:
|
||||||
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2
|
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2
|
||||||
"transformer.h.{bid}.mlp.linear_1", # refact
|
"transformer.h.{bid}.mlp.linear_1", # refact
|
||||||
"model.layers.{bid}.residual_mlp.w1", # arctic
|
"model.layers.{bid}.residual_mlp.w1", # arctic
|
||||||
|
"model.layers.{bid}.feed_forward.gate_proj", # jamba
|
||||||
"transformer.h.{bid}.mlp.c_fc_0", # exaone
|
"transformer.h.{bid}.mlp.c_fc_0", # exaone
|
||||||
),
|
),
|
||||||
|
|
||||||
|
@ -359,6 +364,7 @@ class TensorNameMap:
|
||||||
"transformer.layers.{bid}.ffn.proj_2", # openelm
|
"transformer.layers.{bid}.ffn.proj_2", # openelm
|
||||||
"model.layers.{bid}.residual_mlp.w2", # arctic
|
"model.layers.{bid}.residual_mlp.w2", # arctic
|
||||||
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
|
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
|
||||||
|
"model.layers.{bid}.feed_forward.down_proj", # jamba
|
||||||
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
|
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
|
||||||
"model.layers.h.{bid}.mlp.c_proj", # exaone
|
"model.layers.h.{bid}.mlp.c_proj", # exaone
|
||||||
),
|
),
|
||||||
|
@ -406,38 +412,59 @@ class TensorNameMap:
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_IN: (
|
MODEL_TENSOR.SSM_IN: (
|
||||||
"model.layers.{bid}.in_proj",
|
"model.layers.{bid}.in_proj", # mamba-hf
|
||||||
"backbone.layers.{bid}.mixer.in_proj",
|
"backbone.layers.{bid}.mixer.in_proj", # mamba
|
||||||
|
"model.layers.{bid}.mamba.in_proj", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_CONV1D: (
|
MODEL_TENSOR.SSM_CONV1D: (
|
||||||
"model.layers.{bid}.conv1d",
|
"model.layers.{bid}.conv1d", # mamba-hf
|
||||||
"backbone.layers.{bid}.mixer.conv1d",
|
"backbone.layers.{bid}.mixer.conv1d", # mamba
|
||||||
|
"model.layers.{bid}.mamba.conv1d", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_X: (
|
MODEL_TENSOR.SSM_X: (
|
||||||
"model.layers.{bid}.x_proj",
|
"model.layers.{bid}.x_proj", # mamba-hf
|
||||||
"backbone.layers.{bid}.mixer.x_proj",
|
"backbone.layers.{bid}.mixer.x_proj", # mamba
|
||||||
|
"model.layers.{bid}.mamba.x_proj", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_DT: (
|
MODEL_TENSOR.SSM_DT: (
|
||||||
"model.layers.{bid}.dt_proj",
|
"model.layers.{bid}.dt_proj", # mamba-hf
|
||||||
"backbone.layers.{bid}.mixer.dt_proj",
|
"backbone.layers.{bid}.mixer.dt_proj", # mamba
|
||||||
|
"model.layers.{bid}.mamba.dt_proj", # jamba
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.SSM_DT_NORM: (
|
||||||
|
"model.layers.{bid}.mamba.dt_layernorm", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_A: (
|
MODEL_TENSOR.SSM_A: (
|
||||||
"model.layers.{bid}.A_log",
|
"model.layers.{bid}.A_log", # mamba-hf
|
||||||
"backbone.layers.{bid}.mixer.A_log",
|
"backbone.layers.{bid}.mixer.A_log", # mamba
|
||||||
|
"model.layers.{bid}.mamba.A_log", # jamba
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.SSM_B_NORM: (
|
||||||
|
"model.layers.{bid}.mamba.b_layernorm", # jamba
|
||||||
|
"model.layers.{bid}.mamba.B_layernorm", # mini-jamba
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.SSM_C_NORM: (
|
||||||
|
"model.layers.{bid}.mamba.c_layernorm", # jamba
|
||||||
|
"model.layers.{bid}.mamba.C_layernorm", # mini-jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_D: (
|
MODEL_TENSOR.SSM_D: (
|
||||||
"model.layers.{bid}.D",
|
"model.layers.{bid}.D", # mamba-hf
|
||||||
"backbone.layers.{bid}.mixer.D",
|
"backbone.layers.{bid}.mixer.D", # mamba
|
||||||
|
"model.layers.{bid}.mamba.D", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.SSM_OUT: (
|
MODEL_TENSOR.SSM_OUT: (
|
||||||
"model.layers.{bid}.out_proj",
|
"model.layers.{bid}.out_proj", # mamba-hf
|
||||||
"backbone.layers.{bid}.mixer.out_proj",
|
"backbone.layers.{bid}.mixer.out_proj", # mamba
|
||||||
|
"model.layers.{bid}.mamba.out_proj", # jamba
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.TIME_MIX_W1: (
|
MODEL_TENSOR.TIME_MIX_W1: (
|
||||||
|
|
|
@ -38,10 +38,10 @@
|
||||||
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
||||||
|
|
||||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||||
#define LLAMA_SESSION_VERSION 8
|
#define LLAMA_SESSION_VERSION 9
|
||||||
|
|
||||||
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
||||||
#define LLAMA_STATE_SEQ_VERSION 2
|
#define LLAMA_STATE_SEQ_VERSION 3
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
@ -621,6 +621,12 @@ extern "C" {
|
||||||
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
|
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
|
||||||
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
|
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
|
||||||
|
|
||||||
|
// Rebuild and check the validity of the recurrent state cache's tree of sequences.
|
||||||
|
// (slow, use only for debugging purposes)
|
||||||
|
// Returns whether or not the rs cache was valid.
|
||||||
|
// The errors are always corrected, but only logged when debug is true.
|
||||||
|
LLAMA_API bool llama_rs_cache_rebuild(struct llama_context * ctx, bool debug);
|
||||||
|
|
||||||
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
// Returns the number of tokens in the KV cache (slow, use only for debug)
|
||||||
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
|
||||||
LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
|
LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
|
||||||
|
@ -628,36 +634,62 @@ extern "C" {
|
||||||
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
||||||
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
|
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
|
||||||
|
|
||||||
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
// Returns the number of used recurrent state cells (i.e. have at least one sequence assigned to them)
|
||||||
LLAMA_API void llama_kv_cache_clear(
|
LLAMA_API int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx);
|
||||||
|
|
||||||
|
// Clear the KV cache and recurrent states - both cell info is erased and KV data is zeroed
|
||||||
|
LLAMA_API void llama_past_clear(
|
||||||
struct llama_context * ctx);
|
struct llama_context * ctx);
|
||||||
|
LLAMA_API DEPRECATED(void llama_kv_cache_clear(
|
||||||
|
struct llama_context * ctx),
|
||||||
|
"use llama_past_clear instead");
|
||||||
|
|
||||||
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||||
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
|
||||||
// seq_id < 0 : match any sequence
|
// seq_id < 0 : match any sequence
|
||||||
// p0 < 0 : [0, p1]
|
// p0 < 0 : [0, p1]
|
||||||
// p1 < 0 : [p0, inf)
|
// p1 < 0 : [p0, inf)
|
||||||
LLAMA_API bool llama_kv_cache_seq_rm(
|
// Returns n_past (one more than the largest remaining pos in the seq_id)
|
||||||
|
// which is only meaningful to handle for partial removals.
|
||||||
|
LLAMA_API llama_pos llama_past_seq_rm(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id,
|
llama_seq_id seq_id,
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
llama_pos p1);
|
llama_pos p1);
|
||||||
|
LLAMA_API DEPRECATED(bool llama_kv_cache_seq_rm(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1),
|
||||||
|
"use llama_past_seq_rm instead, and handle its return value for partial removals");
|
||||||
|
|
||||||
// Copy all tokens that belong to the specified sequence to another sequence
|
// Copy all tokens that belong to the specified sequence to another sequence
|
||||||
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
|
// Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence
|
||||||
// p0 < 0 : [0, p1]
|
// p0 < 0 : [0, p1]
|
||||||
// p1 < 0 : [p0, inf)
|
// p1 < 0 : [p0, inf)
|
||||||
LLAMA_API void llama_kv_cache_seq_cp(
|
// Returns n_past (one more than the largest remaining pos in the destination seq_id)
|
||||||
|
// which is only meaningful to handle when partially copying.
|
||||||
|
LLAMA_API llama_pos llama_past_seq_cp(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id_src,
|
llama_seq_id seq_id_src,
|
||||||
llama_seq_id seq_id_dst,
|
llama_seq_id seq_id_dst,
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
llama_pos p1);
|
llama_pos p1);
|
||||||
|
LLAMA_API DEPRECATED(void llama_kv_cache_seq_cp(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id_src,
|
||||||
|
llama_seq_id seq_id_dst,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1),
|
||||||
|
"use llama_past_seq_cp instead, and handle its return value for partial copies");
|
||||||
|
|
||||||
// Removes all tokens that do not belong to the specified sequence
|
// Removes all tokens that do not belong to the specified sequence
|
||||||
LLAMA_API void llama_kv_cache_seq_keep(
|
LLAMA_API void llama_past_seq_keep(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id);
|
llama_seq_id seq_id);
|
||||||
|
LLAMA_API DEPRECATED(void llama_kv_cache_seq_keep(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id),
|
||||||
|
"use llama_past_seq_keep instead");
|
||||||
|
|
||||||
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||||
|
@ -665,12 +697,19 @@ extern "C" {
|
||||||
// - explicitly with llama_kv_cache_update()
|
// - explicitly with llama_kv_cache_update()
|
||||||
// p0 < 0 : [0, p1]
|
// p0 < 0 : [0, p1]
|
||||||
// p1 < 0 : [p0, inf)
|
// p1 < 0 : [p0, inf)
|
||||||
LLAMA_API void llama_kv_cache_seq_add(
|
LLAMA_API void llama_past_seq_add(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id,
|
llama_seq_id seq_id,
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
llama_pos p1,
|
llama_pos p1,
|
||||||
llama_pos delta);
|
llama_pos delta);
|
||||||
|
LLAMA_API DEPRECATED(void llama_kv_cache_seq_add(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1,
|
||||||
|
llama_pos delta),
|
||||||
|
"use llama_past_seq_add instead");
|
||||||
|
|
||||||
// Integer division of the positions by factor of `d > 1`
|
// Integer division of the positions by factor of `d > 1`
|
||||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||||
|
@ -678,17 +717,28 @@ extern "C" {
|
||||||
// - explicitly with llama_kv_cache_update()
|
// - explicitly with llama_kv_cache_update()
|
||||||
// p0 < 0 : [0, p1]
|
// p0 < 0 : [0, p1]
|
||||||
// p1 < 0 : [p0, inf)
|
// p1 < 0 : [p0, inf)
|
||||||
LLAMA_API void llama_kv_cache_seq_div(
|
LLAMA_API void llama_past_seq_div(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id,
|
llama_seq_id seq_id,
|
||||||
llama_pos p0,
|
llama_pos p0,
|
||||||
llama_pos p1,
|
llama_pos p1,
|
||||||
int d);
|
int d);
|
||||||
|
LLAMA_API DEPRECATED(void llama_kv_cache_seq_div(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
llama_pos p0,
|
||||||
|
llama_pos p1,
|
||||||
|
int d),
|
||||||
|
"use llama_past_seq_div instead");
|
||||||
|
|
||||||
// Returns the largest position present in the KV cache for the specified sequence
|
// Returns the largest position present in the KV and/or RS cache for the specified sequence
|
||||||
LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
|
LLAMA_API llama_pos llama_past_seq_pos_max(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id);
|
llama_seq_id seq_id);
|
||||||
|
LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_seq_id seq_id),
|
||||||
|
"use llama_past_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells");
|
||||||
|
|
||||||
// Defragment the KV cache
|
// Defragment the KV cache
|
||||||
// This will be applied:
|
// This will be applied:
|
||||||
|
|
2920
src/llama.cpp
2920
src/llama.cpp
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue