cont : move kv_self update to llama_context

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-16 21:55:12 +02:00
parent f2524c0e41
commit b4ec1d4429
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 157 additions and 154 deletions

View file

@ -32,6 +32,38 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
return relative_bucket;
}
enum ggml_status llama_context::compute_graph(
ggml_cgraph * graph,
bool batched) {
int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads;
ggml_threadpool_t tp = batched ? threadpool_batch : threadpool;
if (backend_cpu != nullptr) {
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));
auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool");
set_threadpool_fn(backend_cpu, tp);
}
// set the number of threads for all the backends
for (const auto & set_n_threads_fn : set_n_threads_fns) {
set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
}
auto status = ggml_backend_sched_graph_compute_async(sched.get(), graph);
if (status != GGML_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
}
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched));
return status;
}
llama_pos llama_context::pos_max() const {
return kv_self.pos_max();
}
// TODO: improve
void llama_context::reset() {
inp_tokens = nullptr;
@ -540,6 +572,93 @@ ggml_tensor * llama_context::build_lora_mm_id(
return res;
}
bool llama_context::kv_self_update() {
bool need_reserve = false;
auto & kv = kv_self;
if (kv.has_shift) {
if (!kv.can_shift) {
GGML_ABORT("The current context does not support K-shift");
}
// apply K-shift if needed
if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
prepare_k_shift();
ggml_backend_sched_reset(sched.get());
struct ggml_init_params params = {
/*.mem_size =*/ buf_compute_meta.size(),
/*.mem_buffer =*/ buf_compute_meta.data(),
/*.no_alloc =*/ true,
};
ggml_context * ctx0 = ggml_init(params);
reset();
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
build_k_shift(ctx0, gf);
ggml_backend_sched_alloc_graph(sched.get(), gf);
set_inputs({});
compute_graph(gf, false);
ggml_free(ctx0);
need_reserve = true;
}
{
kv.has_shift = false;
for (uint32_t i = 0; i < kv.size; ++i) {
kv.cells[i].delta = 0;
}
}
}
// defragment the KV cache if needed
if (kv.do_defrag) {
prepare_defrag();
ggml_backend_sched_reset(sched.get());
struct ggml_init_params params = {
/*.mem_size =*/ buf_compute_meta.size(),
/*.mem_buffer =*/ buf_compute_meta.data(),
/*.no_alloc =*/ true,
};
ggml_context * ctx0 = ggml_init(params);
reset();
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
build_defrag(ctx0, gf);
ggml_backend_sched_alloc_graph(sched.get(), gf);
// no input
//set_inputs({});
compute_graph(gf, false);
ggml_free(ctx0);
need_reserve = true;
kv.do_defrag = false;
}
return need_reserve;
}
void llama_context::build_attn_inp(
ggml_context * ctx0,
int32_t n_tokens,