server : formatting

This commit is contained in:
Georgi Gerganov 2024-01-27 12:39:33 +02:00 committed by GitHub
parent 9d7b7e686c
commit 26f95fb079
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -190,6 +190,7 @@ struct llama_client_slot
int32_t ga_w = 512; // group-attention width int32_t ga_w = 512; // group-attention width
int32_t n_past_self_extension = 0; int32_t n_past_self_extension = 0;
// multimodal // multimodal
std::vector<slot_image> images; std::vector<slot_image> images;
@ -406,6 +407,7 @@ struct llama_server_context
slot.id = i; slot.id = i;
slot.n_ctx = n_ctx_slot; slot.n_ctx = n_ctx_slot;
LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot); LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot);
const int ga_n = params.grp_attn_n; const int ga_n = params.grp_attn_n;
@ -425,7 +427,6 @@ struct llama_server_context
slot.reset(); slot.reset();
slots.push_back(slot); slots.push_back(slot);
} }
@ -1377,13 +1378,11 @@ struct llama_server_context
{ {
if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx) if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx)
{ {
// Shift context // Shift context
const int n_left = slot.n_past - slot.params.n_keep - 1; const int n_left = slot.n_past - slot.params.n_keep - 1;
const int n_discard = n_left / 2; const int n_discard = n_left / 2;
LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard);
slot.params.n_keep, n_left, n_discard);
llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1); llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard); llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard);
@ -1429,8 +1428,10 @@ struct llama_server_context
} }
slot.i_batch = batch.n_tokens; slot.i_batch = batch.n_tokens;
int32_t slot_npast = slot.n_past_self_extension > 0 ? slot.n_past_self_extension : slot.n_past;
const int32_t slot_npast = slot.n_past_self_extension > 0 ? slot.n_past_self_extension : slot.n_past;
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true); llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
slot.n_past += 1; slot.n_past += 1;
} }
@ -1559,7 +1560,6 @@ struct llama_server_context
} }
slot.n_past_self_extension = slot_npast; slot.n_past_self_extension = slot_npast;
slot.ga_i = ga_i; slot.ga_i = ga_i;
} }
LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
@ -1642,7 +1642,6 @@ struct llama_server_context
{ {
if (slot.ga_n != 1) if (slot.ga_n != 1)
{ {
// context extension via Self-Extend // context extension via Self-Extend
while (slot.n_past_self_extension >= slot.ga_i + slot.ga_w) while (slot.n_past_self_extension >= slot.ga_i + slot.ga_w)
{ {
@ -1752,7 +1751,6 @@ struct llama_server_context
slot.i_batch = -1; slot.i_batch = -1;
} }
} }
return true; return true;
} }
@ -2015,14 +2013,16 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
} }
params.n_threads = std::stoi(argv[i]); params.n_threads = std::stoi(argv[i]);
} }
else if (arg == "--grp-attn-n" || arg == "-gan") { else if (arg == "--grp-attn-n" || arg == "-gan")
{
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.grp_attn_n = std::stoi(argv[i]); params.grp_attn_n = std::stoi(argv[i]);
} else if (arg == "--grp-attn-w" || arg == "-gaw") }
else if (arg == "--grp-attn-w" || arg == "-gaw")
{ {
if (++i >= argc) if (++i >= argc)
{ {