Ported self extension to server example
This commit is contained in:
parent
26d607608d
commit
d7ac0d3d06
1 changed files with 116 additions and 26 deletions
|
@ -410,6 +410,11 @@ struct llama_client_slot
|
|||
struct llama_sampling_params sparams;
|
||||
llama_sampling_context *ctx_sampling = nullptr;
|
||||
|
||||
int ga_i = 0; // group-attention state
|
||||
|
||||
int32_t ga_n = 1; // group-attention factor
|
||||
int32_t ga_w = 512; // group-attention width
|
||||
|
||||
// multimodal
|
||||
std::vector<slot_image> images;
|
||||
|
||||
|
@ -438,7 +443,7 @@ struct llama_client_slot
|
|||
sent_count = 0;
|
||||
sent_token_probs_index = 0;
|
||||
infill = false;
|
||||
|
||||
ga_i = 0;
|
||||
generated_token_probs.clear();
|
||||
|
||||
for (slot_image & img : images)
|
||||
|
@ -633,9 +638,26 @@ struct llama_server_context
|
|||
|
||||
slot.id = i;
|
||||
slot.n_ctx = n_ctx_slot;
|
||||
LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot);
|
||||
|
||||
const int ga_n = params.grp_attn_n;
|
||||
const int ga_w = params.grp_attn_w;
|
||||
|
||||
if (ga_n != 1) {
|
||||
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
|
||||
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
|
||||
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
|
||||
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
|
||||
LOG_TEE(" -> Slot %i - self-extend: ga_n = %d, ga_w = %d\n", slot.id, ga_n, ga_w);
|
||||
}
|
||||
|
||||
slot.ga_i = 0;
|
||||
slot.ga_n = ga_n;
|
||||
slot.ga_w = ga_w;
|
||||
|
||||
slot.reset();
|
||||
|
||||
LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, n_ctx_slot);
|
||||
|
||||
slots.push_back(slot);
|
||||
}
|
||||
|
||||
|
@ -1691,14 +1713,18 @@ struct llama_server_context
|
|||
}
|
||||
|
||||
for (llama_client_slot &slot : slots)
|
||||
{
|
||||
if (slot.ga_n == 1)
|
||||
{
|
||||
if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx)
|
||||
{
|
||||
|
||||
// Shift context
|
||||
const int n_left = slot.n_past - slot.params.n_keep - 1;
|
||||
const int n_discard = n_left / 2;
|
||||
|
||||
LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard);
|
||||
LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id,
|
||||
slot.params.n_keep, n_left, n_discard);
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, slot.params.n_keep + 1, slot.params.n_keep + n_discard + 1);
|
||||
llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard);
|
||||
|
||||
|
@ -1720,6 +1746,7 @@ struct llama_server_context
|
|||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// decode any currently ongoing sequences
|
||||
for (auto & slot : slots)
|
||||
|
@ -1880,11 +1907,25 @@ struct llama_server_context
|
|||
|
||||
// process the prefix of first image
|
||||
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;
|
||||
int32_t slot_npast = 0;
|
||||
int ga_i = slot.ga_i;
|
||||
int ga_n = slot.ga_n;
|
||||
int ga_w = slot.ga_w;
|
||||
for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past)
|
||||
{
|
||||
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot.n_past, { slot.id }, false);
|
||||
if(slot.ga_n != 1)
|
||||
{
|
||||
while (slot_npast >= ga_i + ga_w) {
|
||||
const int bd = (ga_w/ga_n)*(ga_n - 1);
|
||||
slot_npast -= bd;
|
||||
ga_i += ga_w/ga_n;
|
||||
}
|
||||
}
|
||||
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
|
||||
slot_npast += 1;
|
||||
}
|
||||
|
||||
slot.n_past = 0;
|
||||
if (has_images && !ingest_images(slot, n_batch))
|
||||
{
|
||||
LOG_TEE("failed processing images\n");
|
||||
|
@ -1912,6 +1953,35 @@ struct llama_server_context
|
|||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
|
||||
{
|
||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||
for (auto & slot : slots)
|
||||
{
|
||||
if(slot.ga_n != 1)
|
||||
{
|
||||
// context extension via Self-Extend
|
||||
while (slot.n_past >= slot.ga_i + slot.ga_w) {
|
||||
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
|
||||
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
|
||||
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
|
||||
|
||||
LOG_TEE("\n");
|
||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past, ib * bd, slot.ga_i + ib * bd, slot.n_past + ib * bd);
|
||||
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past + ib * bd + dd);
|
||||
|
||||
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past, ib * bd);
|
||||
llama_kv_cache_seq_div (ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
|
||||
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot.n_past + ib * bd, dd);
|
||||
|
||||
slot.n_past -= bd;
|
||||
|
||||
slot.ga_i += slot.ga_w / slot.ga_n;
|
||||
|
||||
LOG_TEE("\nslot.n_past_old = %d, slot.n_past = %d, ga_i = %d\n\n", slot.n_past + bd, slot.n_past, slot.ga_i);
|
||||
}
|
||||
}
|
||||
slot.n_past += n_tokens;
|
||||
}
|
||||
|
||||
llama_batch batch_view =
|
||||
{
|
||||
n_tokens,
|
||||
|
@ -1925,6 +1995,7 @@ struct llama_server_context
|
|||
};
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
|
||||
if (ret != 0)
|
||||
{
|
||||
if (n_batch == 1 || ret < 0)
|
||||
|
@ -1944,6 +2015,7 @@ struct llama_server_context
|
|||
|
||||
for (auto & slot : slots)
|
||||
{
|
||||
|
||||
if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens))
|
||||
{
|
||||
continue;
|
||||
|
@ -1995,6 +2067,7 @@ struct llama_server_context
|
|||
slot.i_batch = -1;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
@ -2251,6 +2324,23 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||
}
|
||||
params.n_threads = std::stoi(argv[i]);
|
||||
}
|
||||
else if (arg == "--grp-attn-n" || arg == "-gan") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
|
||||
params.grp_attn_n = std::stoi(argv[i]);
|
||||
} else if (arg == "--grp-attn-w" || arg == "-gaw")
|
||||
{
|
||||
if (++i >= argc)
|
||||
{
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
|
||||
params.grp_attn_w = std::stoi(argv[i]);
|
||||
}
|
||||
else if (arg == "--threads-batch" || arg == "-tb")
|
||||
{
|
||||
if (++i >= argc)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue