cont : move kv_self update to llama_context
ggml-ci
This commit is contained in:
parent
f2524c0e41
commit
b4ec1d4429
3 changed files with 157 additions and 154 deletions
|
@ -32,6 +32,38 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
|
|||
return relative_bucket;
|
||||
}
|
||||
|
||||
enum ggml_status llama_context::compute_graph(
|
||||
ggml_cgraph * graph,
|
||||
bool batched) {
|
||||
int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads;
|
||||
ggml_threadpool_t tp = batched ? threadpool_batch : threadpool;
|
||||
|
||||
if (backend_cpu != nullptr) {
|
||||
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));
|
||||
auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool");
|
||||
set_threadpool_fn(backend_cpu, tp);
|
||||
}
|
||||
|
||||
// set the number of threads for all the backends
|
||||
for (const auto & set_n_threads_fn : set_n_threads_fns) {
|
||||
set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
|
||||
}
|
||||
|
||||
auto status = ggml_backend_sched_graph_compute_async(sched.get(), graph);
|
||||
if (status != GGML_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
|
||||
}
|
||||
|
||||
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched));
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
|
||||
llama_pos llama_context::pos_max() const {
|
||||
return kv_self.pos_max();
|
||||
}
|
||||
|
||||
// TODO: improve
|
||||
void llama_context::reset() {
|
||||
inp_tokens = nullptr;
|
||||
|
@ -540,6 +572,93 @@ ggml_tensor * llama_context::build_lora_mm_id(
|
|||
return res;
|
||||
}
|
||||
|
||||
bool llama_context::kv_self_update() {
|
||||
bool need_reserve = false;
|
||||
|
||||
auto & kv = kv_self;
|
||||
|
||||
if (kv.has_shift) {
|
||||
if (!kv.can_shift) {
|
||||
GGML_ABORT("The current context does not support K-shift");
|
||||
}
|
||||
|
||||
// apply K-shift if needed
|
||||
if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
||||
prepare_k_shift();
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_compute_meta.size(),
|
||||
/*.mem_buffer =*/ buf_compute_meta.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
reset();
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
|
||||
|
||||
build_k_shift(ctx0, gf);
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
|
||||
set_inputs({});
|
||||
|
||||
compute_graph(gf, false);
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
need_reserve = true;
|
||||
}
|
||||
|
||||
{
|
||||
kv.has_shift = false;
|
||||
|
||||
for (uint32_t i = 0; i < kv.size; ++i) {
|
||||
kv.cells[i].delta = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// defragment the KV cache if needed
|
||||
if (kv.do_defrag) {
|
||||
prepare_defrag();
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_compute_meta.size(),
|
||||
/*.mem_buffer =*/ buf_compute_meta.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
reset();
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
|
||||
|
||||
build_defrag(ctx0, gf);
|
||||
|
||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
||||
|
||||
// no input
|
||||
//set_inputs({});
|
||||
|
||||
compute_graph(gf, false);
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
need_reserve = true;
|
||||
|
||||
kv.do_defrag = false;
|
||||
}
|
||||
|
||||
return need_reserve;
|
||||
}
|
||||
|
||||
void llama_context::build_attn_inp(
|
||||
ggml_context * ctx0,
|
||||
int32_t n_tokens,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue