diff --git a/examples/server/server.cpp b/examples/server/server.cpp index bf0220cff..c10708392 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -184,12 +184,13 @@ struct llama_client_slot struct llama_sampling_params sparams; llama_sampling_context *ctx_sampling = nullptr; - int ga_i = 0; // group-attention state + int ga_i = 0; // group-attention state - int32_t ga_n = 1; // group-attention factor - int32_t ga_w = 512; // group-attention width + int32_t ga_n = 1; // group-attention factor + int32_t ga_w = 512; // group-attention width int32_t n_past_self_extension = 0; + // multimodal std::vector images; @@ -218,8 +219,8 @@ struct llama_client_slot sent_count = 0; sent_token_probs_index = 0; infill = false; - ga_i = 0; - n_past_self_extension = 0; + ga_i = 0; + n_past_self_extension = 0; generated_token_probs.clear(); for (slot_image & img : images) @@ -406,6 +407,7 @@ struct llama_server_context slot.id = i; slot.n_ctx = n_ctx_slot; + LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot); const int ga_n = params.grp_attn_n; @@ -425,7 +427,6 @@ struct llama_server_context slot.reset(); - slots.push_back(slot); } @@ -1377,14 +1378,12 @@ struct llama_server_context { if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx) { - // Shift context - const int n_left = slot.n_past - slot.params.n_keep - 1; + const int n_left = slot.n_past - slot.params.n_keep - 1; const int n_discard = n_left / 2; - LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, - slot.params.n_keep, n_left, n_discard); - llama_kv_cache_seq_rm(ctx, slot.id, slot.params.n_keep + 1, slot.params.n_keep + n_discard + 1); + LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard); + llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1); llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard); for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++) @@ -1429,8 +1428,10 @@ struct llama_server_context } slot.i_batch = batch.n_tokens; - int32_t slot_npast = slot.n_past_self_extension > 0 ? slot.n_past_self_extension : slot.n_past; + + const int32_t slot_npast = slot.n_past_self_extension > 0 ? slot.n_past_self_extension : slot.n_past; llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true); + slot.n_past += 1; } @@ -1542,7 +1543,7 @@ struct llama_server_context slot.n_past = common_part(slot.cache_tokens, prompt_tokens); slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past; - if(slot.ga_n != 1) + if (slot.ga_n != 1) { int ga_i = 0; int32_t ga_n = slot.ga_n; @@ -1559,7 +1560,6 @@ struct llama_server_context } slot.n_past_self_extension = slot_npast; slot.ga_i = ga_i; - } LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); @@ -1576,7 +1576,7 @@ struct llama_server_context // we have to evaluate at least 1 token to generate logits. LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id); slot.n_past--; - if(slot.ga_i > 0) + if (slot.ga_i > 0) { slot.n_past_self_extension--; } @@ -1598,7 +1598,7 @@ struct llama_server_context int32_t ga_w = slot.ga_w; for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) { - if(slot.ga_n != 1) + if (slot.ga_n != 1) { while (slot_npast >= ga_i + ga_w) { const int bd = (ga_w/ga_n)*(ga_n - 1); @@ -1642,7 +1642,6 @@ struct llama_server_context { if (slot.ga_n != 1) { - // context extension via Self-Extend while (slot.n_past_self_extension >= slot.ga_i + slot.ga_w) { @@ -1752,7 +1751,6 @@ struct llama_server_context slot.i_batch = -1; } } - return true; } @@ -2015,14 +2013,16 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.n_threads = std::stoi(argv[i]); } - else if (arg == "--grp-attn-n" || arg == "-gan") { + else if (arg == "--grp-attn-n" || arg == "-gan") + { if (++i >= argc) { invalid_param = true; break; } params.grp_attn_n = std::stoi(argv[i]); - } else if (arg == "--grp-attn-w" || arg == "-gaw") + } + else if (arg == "--grp-attn-w" || arg == "-gaw") { if (++i >= argc) {