diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d9ca2fdd9..0db99c167 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -189,6 +189,7 @@ struct llama_client_slot int32_t ga_n = 1; // group-attention factor int32_t ga_w = 512; // group-attention width + int32_t n_past_self_extension = 0; // multimodal std::vector images; @@ -218,6 +219,7 @@ struct llama_client_slot sent_token_probs_index = 0; infill = false; ga_i = 0; + n_past_self_extension = 0; generated_token_probs.clear(); for (slot_image & img : images) @@ -1427,9 +1429,8 @@ struct llama_server_context } slot.i_batch = batch.n_tokens; - - llama_batch_add(batch, slot.sampled, system_tokens.size() + slot.n_past, { slot.id }, true); - + int32_t slot_npast = slot.n_past_self_extension > 0 ? slot.n_past_self_extension : slot.n_past; + llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true); slot.n_past += 1; } @@ -1526,6 +1527,7 @@ struct llama_server_context llama_sampling_reset(slot.ctx_sampling); slot.n_past = 0; + slot.n_past_self_extension = 0; slot.num_prompt_tokens_processed = slot.num_prompt_tokens; } else @@ -1553,6 +1555,10 @@ struct llama_server_context // we have to evaluate at least 1 token to generate logits. LOG_TEE("slot %d : we have to evaluate at least 1 token to generate logits\n", slot.id); slot.n_past--; + if(slot.n_past_self_extension > 0) + { + slot.n_past_self_extension--; + } } LOG_VERBOSE("prompt ingested", { @@ -1565,10 +1571,10 @@ struct llama_server_context // process the prefix of first image std::vector prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens; - int32_t slot_npast = slot.n_past; + int32_t slot_npast = slot.n_past_self_extension > 0 ? slot.n_past_self_extension : slot.n_past; int ga_i = slot.ga_i; - int ga_n = slot.ga_n; - int ga_w = slot.ga_w; + int32_t ga_n = slot.ga_n; + int32_t ga_w = slot.ga_w; for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) { if(slot.ga_n != 1) @@ -1579,11 +1585,10 @@ struct llama_server_context ga_i += ga_w/ga_n; } } - llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false); + llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false); slot_npast += 1; } - if (has_images && !ingest_images(slot, n_batch)) { LOG_TEE("failed processing images\n"); @@ -1617,38 +1622,37 @@ struct llama_server_context for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); - int slot_id = 0; + for (auto & slot : slots) { - - if(slot.ga_n != 1) + if (slot.ga_n != 1) { + // context extension via Self-Extend - while (slot_npasts[slot_id] >= slot.ga_i + slot.ga_w) { + while (slot.n_past_self_extension >= slot.ga_i + slot.ga_w) + { const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; LOG_TEE("\n"); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot_npasts[slot_id], ib * bd, slot.ga_i + ib * bd, slot_npasts[slot_id] + ib * bd); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_self_extension, ib * bd, slot.ga_i + ib * bd, slot.n_past_self_extension + ib * bd); LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot_npasts[slot_id] + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot_npasts[slot_id] + ib * bd + dd); + LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_self_extension + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_self_extension + ib * bd + dd); - llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot_npasts[slot_id], ib * bd); - llama_kv_cache_seq_div (ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); - llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot_npasts[slot_id] + ib * bd, dd); + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_self_extension, ib * bd); + llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n); + llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_self_extension + ib * bd, dd); - slot_npasts[slot_id] -= bd; + slot.n_past_self_extension -= bd; slot.ga_i += slot.ga_w / slot.ga_n; - LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot_npasts[slot_id] + bd, slot_npasts[slot_id], slot.ga_i); + LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_self_extension + bd, slot.n_past_self_extension, slot.ga_i); } + slot.n_past_self_extension += n_tokens; } - slot_npasts[slot_id] += n_tokens; - slot_id++; } - llama_batch batch_view = { n_tokens, @@ -1682,7 +1686,6 @@ struct llama_server_context for (auto & slot : slots) { - if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { continue;