diff --git a/examples/server/server.cpp b/examples/server/server.cpp index bd67c04b6..50123975f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -184,6 +184,12 @@ struct llama_client_slot struct llama_sampling_params sparams; llama_sampling_context *ctx_sampling = nullptr; + int32_t ga_i = 0; // group-attention state + int32_t ga_n = 1;// group-attention factor + int32_t ga_w = 512; // group-attention width + + int32_t n_past_se = 0; // self-extend + // multimodal std::vector images; @@ -212,7 +218,8 @@ struct llama_client_slot sent_count = 0; sent_token_probs_index = 0; infill = false; - + ga_i = 0; + n_past_se = 0; generated_token_probs.clear(); for (slot_image & img : images) @@ -399,9 +406,26 @@ struct llama_server_context slot.id = i; slot.n_ctx = n_ctx_slot; - slot.reset(); LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot); + + const int ga_n = params.grp_attn_n; + const int ga_w = params.grp_attn_w; + + if (ga_n != 1) { + GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT + GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT + //GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT + //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT + LOG_TEE(" -> Slot %i - self-extend: ga_n = %d, ga_w = %d\n", slot.id, ga_n, ga_w); + } + + slot.ga_i = 0; + slot.ga_n = ga_n; + slot.ga_w = ga_w; + + slot.reset(); + slots.push_back(slot); } @@ -1202,7 +1226,8 @@ struct llama_server_context (json)(slot.images[image_idx].prefix_prompt); std::vector append_tokens = tokenize(json_prompt, false); // has next image - for (int append_token : append_tokens) { + for (int append_token : append_tokens) + { llama_batch_add(batch, append_token, slot.n_past, { slot.id }, true); slot.n_past += 1; } @@ -1221,12 +1246,12 @@ struct llama_server_context void split_multiprompt_task(int multitask_id, task_server& multiprompt_task) { - std::size_t prompt_count = multiprompt_task.data.at("prompt").size(); + int prompt_count = int(multiprompt_task.data.at("prompt").size()); assert(prompt_count > 1); // generate all the ID for subtask std::vector subtask_ids(prompt_count); - for (std::size_t i = 0; i < prompt_count; i++) + for (int i = 0; i < prompt_count; i++) { subtask_ids[i] = queue_tasks.get_new_id(); } @@ -1235,7 +1260,7 @@ struct llama_server_context queue_tasks.add_multitask(multitask_id, subtask_ids); // add subtasks - for (std::size_t i = 0; i < prompt_count; i++) + for (int i = 0; i < prompt_count; i++) { json subtask_data = multiprompt_task.data; subtask_data["prompt"] = subtask_data["prompt"][i]; @@ -1349,32 +1374,35 @@ struct llama_server_context for (llama_client_slot &slot : slots) { - if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx) + if (slot.ga_n == 1) { - // Shift context - const int n_left = slot.n_past - slot.params.n_keep - 1; - const int n_discard = n_left / 2; - - LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard); - llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1); - llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard); - - for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++) + if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx) { - slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; + // Shift context + const int n_left = slot.n_past - slot.params.n_keep - 1; + const int n_discard = n_left / 2; + + LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard); + llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard); + + for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++) + { + slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; + } + + slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); + + slot.n_past -= n_discard; + + slot.truncated = true; + + LOG_VERBOSE("context shift", { + { "n_ctx", n_ctx }, + { "n_keep", params.n_keep }, + { "n_left", n_left }, + }); } - - slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); - - slot.n_past -= n_discard; - - slot.truncated = true; - - LOG_VERBOSE("context shift", { - {"n_ctx", n_ctx}, - {"n_keep", params.n_keep}, - {"n_left", n_left}, - }); } } @@ -1401,7 +1429,8 @@ struct llama_server_context slot.i_batch = batch.n_tokens; - llama_batch_add(batch, slot.sampled, llama_pos(system_tokens.size() + slot.n_past), { slot.id }, true); + const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; + llama_batch_add(batch, slot.sampled, int(system_tokens.size() + slot_npast), {slot.id}, true); slot.n_past += 1; } @@ -1499,6 +1528,8 @@ struct llama_server_context llama_sampling_reset(slot.ctx_sampling); slot.n_past = 0; + slot.n_past_se = 0; + slot.ga_i = 0; slot.num_prompt_tokens_processed = slot.num_prompt_tokens; } else @@ -1512,6 +1543,25 @@ struct llama_server_context slot.n_past = int32_t(common_part(slot.cache_tokens, prompt_tokens)); slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past; + if (slot.ga_n != 1) + { + int ga_i = 0; + int32_t ga_n = slot.ga_n; + int32_t ga_w = slot.ga_w; + int32_t slot_npast = 0; + for (int k = 0; k < slot.n_past; ++k) + { + while (slot_npast >= ga_i + ga_w) { + const int bd = (ga_w/ga_n)*(ga_n - 1); + slot_npast -= bd; + ga_i += ga_w/ga_n; + } + slot_npast++; + } + slot.n_past_se = slot_npast; + slot.ga_i = ga_i; + } + LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); } @@ -1526,6 +1576,10 @@ struct llama_server_context // we have to evaluate at least 1 token to generate logits. LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id); slot.n_past--; + if (slot.ga_i > 0) + { + slot.n_past_se--; + } } LOG_VERBOSE("prompt ingested", { @@ -1538,9 +1592,22 @@ struct llama_server_context // process the prefix of first image std::vector prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens; + int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; + int ga_i = slot.ga_i; + int32_t ga_n = slot.ga_n; + int32_t ga_w = slot.ga_w; for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) { - llama_batch_add(batch, prefix_tokens[slot.n_past], llama_pos(system_tokens.size() + slot.n_past), { slot.id }, false); + if (slot.ga_n != 1) + { + while (slot_npast >= ga_i + ga_w) { + const int bd = (ga_w/ga_n)*(ga_n - 1); + slot_npast -= bd; + ga_i += ga_w/ga_n; + } + } + llama_batch_add(batch, prefix_tokens[slot.n_past], llama_pos(system_tokens.size() + slot_npast), {slot.id }, false); + slot_npast += 1; } if (has_images && !ingest_images(slot, n_batch)) @@ -1570,6 +1637,36 @@ struct llama_server_context for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + + for (auto & slot : slots) + { + if (slot.ga_n != 1) + { + // context extension via Self-Extend + while (slot.n_past_se >= slot.ga_i + slot.ga_w) + { + const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; + const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); + const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; + + LOG_TEE("\n"); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd); + LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); + + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd); + llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n); + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd); + + slot.n_past_se -= bd; + + slot.ga_i += slot.ga_w / slot.ga_n; + + LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i); + } + slot.n_past_se += n_tokens; + } + } llama_batch batch_view = { n_tokens, @@ -1583,6 +1680,7 @@ struct llama_server_context }; const int ret = llama_decode(ctx, batch_view); + if (ret != 0) { if (n_batch == 1 || ret < 0) @@ -1728,6 +1826,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" advanced option to override model metadata by key. may be specified multiple times.\n"); printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); + printf(" -gan N, --grp-attn-n N Set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`"); + printf(" -gaw N, --grp-attn-w N Set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`"); printf("\n"); } @@ -1913,6 +2013,25 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.n_threads = std::stoi(argv[i]); } + else if (arg == "--grp-attn-n" || arg == "-gan") + { + if (++i >= argc) { + invalid_param = true; + break; + } + + params.grp_attn_n = std::stoi(argv[i]); + } + else if (arg == "--grp-attn-w" || arg == "-gaw") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + + params.grp_attn_w = std::stoi(argv[i]); + } else if (arg == "--threads-batch" || arg == "-tb") { if (++i >= argc) @@ -2033,7 +2152,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, invalid_param = true; break; } - params.lora_adapter.emplace_back(argv[i], 1.0f); + params.lora_adapter.push_back(std::make_tuple(argv[i], 1.0f)); params.use_mmap = false; } else if (arg == "--lora-scaled") @@ -2049,7 +2168,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, invalid_param = true; break; } - params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i])); + params.lora_adapter.push_back(std::make_tuple(lora_adapter, std::stof(argv[i]))); params.use_mmap = false; } else if (arg == "--lora-base") @@ -2191,7 +2310,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } } if (!params.kv_overrides.empty()) { - params.kv_overrides.emplace_back(); + params.kv_overrides.emplace_back(llama_model_kv_override()); params.kv_overrides.back().key[0] = 0; } @@ -2625,11 +2744,12 @@ int main(int argc, char **argv) if (!llama_result.error) { std::vector result_array = format_partial_response_oaicompat( llama_result); - for (auto& it : result_array) { - if (!it.empty()) { + for (auto it = result_array.begin(); it != result_array.end(); ++it) + { + if (!it->empty()) { const std::string str = "data: " + - it.dump(-1, ' ', false, json::error_handler_t::replace) + + it->dump(-1, ' ', false, json::error_handler_t::replace) + "\n\n"; LOG_VERBOSE("data stream", {{"to_send", str}}); if (!sink.write(str.c_str(), str.size())) {