Update server.cpp

This commit is contained in:
Maximilian Winter 2024-01-27 14:13:34 +01:00
parent d56cecc2cd
commit 826e6dcad9

View file

@ -219,7 +219,7 @@ struct llama_client_slot
sent_token_probs_index = 0;
infill = false;
ga_i = 0;
n_past_self_extension = 0;
n_past_se = 0;
generated_token_probs.clear();
for (slot_image & img : images)
@ -1428,7 +1428,7 @@ struct llama_server_context
slot.i_batch = batch.n_tokens;
const int32_t slot_npast = slot.n_past_self_extension > 0 ? slot.n_past_self_extension : slot.n_past;
const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
slot.n_past += 1;
@ -1527,7 +1527,7 @@ struct llama_server_context
llama_sampling_reset(slot.ctx_sampling);
slot.n_past = 0;
slot.n_past_self_extension = 0;
slot.n_past_se = 0;
slot.ga_i = 0;
slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
}
@ -1557,7 +1557,7 @@ struct llama_server_context
}
slot_npast++;
}
slot.n_past_self_extension = slot_npast;
slot.n_past_se = slot_npast;
slot.ga_i = ga_i;
}
@ -1577,7 +1577,7 @@ struct llama_server_context
slot.n_past--;
if (slot.ga_i > 0)
{
slot.n_past_self_extension--;
slot.n_past_se--;
}
}
@ -1591,7 +1591,7 @@ struct llama_server_context
// process the prefix of first image
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;
int32_t slot_npast = slot.n_past_self_extension > 0 ? slot.n_past_self_extension : slot.n_past;
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
int ga_i = slot.ga_i;
int32_t ga_n = slot.ga_n;
int32_t ga_w = slot.ga_w;
@ -1642,28 +1642,28 @@ struct llama_server_context
if (slot.ga_n != 1)
{
// context extension via Self-Extend
while (slot.n_past_self_extension >= slot.ga_i + slot.ga_w)
while (slot.n_past_se >= slot.ga_i + slot.ga_w)
{
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
LOG_TEE("\n");
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_self_extension, ib * bd, slot.ga_i + ib * bd, slot.n_past_self_extension + ib * bd);
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_self_extension + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_self_extension + ib * bd + dd);
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_self_extension, ib * bd);
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_self_extension + ib * bd, dd);
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
slot.n_past_self_extension -= bd;
slot.n_past_se -= bd;
slot.ga_i += slot.ga_w / slot.ga_n;
LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_self_extension + bd, slot.n_past_self_extension, slot.ga_i);
LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
}
slot.n_past_self_extension += n_tokens;
slot.n_past_se += n_tokens;
}
}
llama_batch batch_view =