llama : cont
ggml-ci
This commit is contained in:
parent
f78b396ee7
commit
e4550fbafc
19 changed files with 128 additions and 79 deletions
|
@ -90,6 +90,9 @@ int main(int argc, char ** argv) {
|
|||
model_dft = llama_init_dft.model.get();
|
||||
ctx_dft = llama_init_dft.context.get();
|
||||
|
||||
llama_kv_cache * kv_tgt = llama_get_kv_cache(ctx_tgt);
|
||||
llama_kv_cache * kv_dft = llama_get_kv_cache(ctx_dft);
|
||||
|
||||
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
|
||||
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
|
||||
|
||||
|
@ -420,14 +423,14 @@ int main(int argc, char ** argv) {
|
|||
{
|
||||
LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
|
||||
|
||||
llama_kv_cache_seq_keep(ctx_dft, s_keep);
|
||||
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
|
||||
llama_kv_cache_seq_keep(ctx_dft, 0);
|
||||
llama_kv_cache_seq_keep(kv_dft, s_keep);
|
||||
llama_kv_cache_seq_cp (kv_dft, s_keep, 0, -1, -1);
|
||||
llama_kv_cache_seq_keep(kv_dft, 0);
|
||||
|
||||
llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
|
||||
llama_kv_cache_seq_keep(ctx_tgt, s_keep);
|
||||
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
|
||||
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||
llama_kv_cache_seq_rm (kv_tgt, s_keep, n_past_tgt, -1);
|
||||
llama_kv_cache_seq_keep(kv_tgt, s_keep);
|
||||
llama_kv_cache_seq_cp (kv_tgt, s_keep, 0, -1, -1);
|
||||
llama_kv_cache_seq_keep(kv_tgt, 0);
|
||||
}
|
||||
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
|
@ -444,8 +447,8 @@ int main(int argc, char ** argv) {
|
|||
common_batch_clear(batch_dft);
|
||||
common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
|
||||
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
||||
llama_kv_cache_seq_rm(kv_dft, 0, n_past_dft, -1);
|
||||
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(kv_dft, batch_dft).c_str());
|
||||
llama_decode(ctx_dft, batch_dft);
|
||||
|
||||
++n_past_dft;
|
||||
|
@ -503,8 +506,8 @@ int main(int argc, char ** argv) {
|
|||
if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_draft_split) {
|
||||
LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
||||
|
||||
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
||||
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
|
||||
llama_kv_cache_seq_rm(kv_dft, n_seq_cur, -1, -1);
|
||||
llama_kv_cache_seq_cp(kv_dft, s, n_seq_cur, -1, -1);
|
||||
|
||||
// all previous tokens from this branch are now also part of the new branch
|
||||
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
|
||||
|
@ -585,9 +588,9 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// evaluate the target model on the drafted tokens
|
||||
{
|
||||
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||
llama_kv_cache_seq_keep(kv_tgt, 0);
|
||||
for (int s = 1; s < n_seq_dft; ++s) {
|
||||
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
|
||||
llama_kv_cache_seq_cp(kv_tgt, 0, s, -1, -1);
|
||||
}
|
||||
|
||||
// LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue