server : do not process more than n_batch tokens per iter
This commit is contained in:
parent
aef02b11ec
commit
bfb121fd2e
1 changed files with 172 additions and 126 deletions
|
@ -148,8 +148,9 @@ struct server_slot {
|
||||||
int32_t n_prompt_tokens_processed = 0;
|
int32_t n_prompt_tokens_processed = 0;
|
||||||
|
|
||||||
json prompt;
|
json prompt;
|
||||||
|
std::vector<llama_token> prompt_tokens;
|
||||||
|
|
||||||
std::string generated_text;
|
std::string generated_text;
|
||||||
llama_token sampled;
|
|
||||||
std::vector<llama_token> cache_tokens;
|
std::vector<llama_token> cache_tokens;
|
||||||
std::vector<completion_token_output> generated_token_probs;
|
std::vector<completion_token_output> generated_token_probs;
|
||||||
|
|
||||||
|
@ -167,6 +168,7 @@ struct server_slot {
|
||||||
std::string stopping_word;
|
std::string stopping_word;
|
||||||
|
|
||||||
// sampling
|
// sampling
|
||||||
|
llama_token sampled;
|
||||||
struct llama_sampling_params sparams;
|
struct llama_sampling_params sparams;
|
||||||
llama_sampling_context * ctx_sampling = nullptr;
|
llama_sampling_context * ctx_sampling = nullptr;
|
||||||
|
|
||||||
|
@ -181,7 +183,7 @@ struct server_slot {
|
||||||
size_t n_sent_token_probs = 0;
|
size_t n_sent_token_probs = 0;
|
||||||
|
|
||||||
int64_t t_start_process_prompt;
|
int64_t t_start_process_prompt;
|
||||||
int64_t t_start_genereration;
|
int64_t t_start_generation;
|
||||||
|
|
||||||
double t_prompt_processing; // ms
|
double t_prompt_processing; // ms
|
||||||
double t_token_generation; // ms
|
double t_token_generation; // ms
|
||||||
|
@ -232,13 +234,12 @@ struct server_slot {
|
||||||
if (command == RELEASE) {
|
if (command == RELEASE) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
cache_tokens.push_back(token.tok);
|
|
||||||
generated_token_probs.push_back(token);
|
generated_token_probs.push_back(token);
|
||||||
}
|
}
|
||||||
|
|
||||||
void release() {
|
void release() {
|
||||||
if (state == PROCESSING) {
|
if (state == PROCESSING) {
|
||||||
t_token_generation = (ggml_time_us() - t_start_genereration) / 1e3;
|
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
|
||||||
command = RELEASE;
|
command = RELEASE;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -968,6 +969,7 @@ struct llama_server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.command = LOAD_PROMPT;
|
slot.command = LOAD_PROMPT;
|
||||||
|
slot.prompt_tokens.clear();
|
||||||
|
|
||||||
LOG_INFO("slot is processing task", {
|
LOG_INFO("slot is processing task", {
|
||||||
{"id_slot", slot.id},
|
{"id_slot", slot.id},
|
||||||
|
@ -1426,9 +1428,7 @@ struct llama_server_context {
|
||||||
if (task.data.contains("system_prompt")) {
|
if (task.data.contains("system_prompt")) {
|
||||||
system_prompt_set(task.data["system_prompt"]);
|
system_prompt_set(task.data["system_prompt"]);
|
||||||
|
|
||||||
// reset cache_tokens for all slots
|
|
||||||
for (server_slot & slot : slots) {
|
for (server_slot & slot : slots) {
|
||||||
slot.cache_tokens.clear();
|
|
||||||
slot.n_past = 0;
|
slot.n_past = 0;
|
||||||
slot.n_past_se = 0;
|
slot.n_past_se = 0;
|
||||||
}
|
}
|
||||||
|
@ -1614,7 +1614,7 @@ struct llama_server_context {
|
||||||
|
|
||||||
for (server_slot & slot : slots) {
|
for (server_slot & slot : slots) {
|
||||||
if (slot.ga_n == 1) {
|
if (slot.ga_n == 1) {
|
||||||
if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx) {
|
if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
|
||||||
// Shift context
|
// Shift context
|
||||||
const int n_keep = slot.params.n_keep + add_bos_token;
|
const int n_keep = slot.params.n_keep + add_bos_token;
|
||||||
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
|
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
|
||||||
|
@ -1635,11 +1635,13 @@ struct llama_server_context {
|
||||||
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
|
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
|
||||||
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
||||||
|
|
||||||
|
if (slot.params.cache_prompt) {
|
||||||
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
|
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
|
||||||
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
|
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
|
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
|
||||||
|
}
|
||||||
|
|
||||||
slot.n_past -= n_discard;
|
slot.n_past -= n_discard;
|
||||||
|
|
||||||
|
@ -1663,8 +1665,13 @@ struct llama_server_context {
|
||||||
// TODO: we always have to take into account the "system_tokens"
|
// TODO: we always have to take into account the "system_tokens"
|
||||||
// this is not great and needs to be improved somehow
|
// this is not great and needs to be improved somehow
|
||||||
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
|
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
|
||||||
|
|
||||||
slot.n_past += 1;
|
slot.n_past += 1;
|
||||||
|
|
||||||
|
if (slot.params.cache_prompt) {
|
||||||
|
slot.cache_tokens.push_back(slot.sampled);
|
||||||
|
}
|
||||||
|
|
||||||
LOG_VERBOSE("slot decode token", {
|
LOG_VERBOSE("slot decode token", {
|
||||||
{"id_slot", slot.id},
|
{"id_slot", slot.id},
|
||||||
{"id_task", slot.id_task},
|
{"id_task", slot.id_task},
|
||||||
|
@ -1681,6 +1688,8 @@ struct llama_server_context {
|
||||||
|
|
||||||
// assign workload to the slots
|
// assign workload to the slots
|
||||||
if (params.cont_batching || batch.n_tokens == 0) {
|
if (params.cont_batching || batch.n_tokens == 0) {
|
||||||
|
int n_available = n_batch;
|
||||||
|
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty());
|
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty());
|
||||||
|
|
||||||
|
@ -1697,12 +1706,16 @@ struct llama_server_context {
|
||||||
|
|
||||||
// need process the prompt
|
// need process the prompt
|
||||||
if (slot.state == IDLE && slot.command == LOAD_PROMPT) {
|
if (slot.state == IDLE && slot.command == LOAD_PROMPT) {
|
||||||
slot.state = PROCESSING;
|
auto & prompt_tokens = slot.prompt_tokens;
|
||||||
slot.command = NONE;
|
|
||||||
|
if (prompt_tokens.empty()) {
|
||||||
|
LOG_VERBOSE("tokenizing prompt", {
|
||||||
|
{"id_slot", slot.id},
|
||||||
|
{"id_task", slot.id_task}
|
||||||
|
});
|
||||||
|
|
||||||
std::vector<llama_token> prompt_tokens;
|
|
||||||
slot.t_start_process_prompt = ggml_time_us();
|
slot.t_start_process_prompt = ggml_time_us();
|
||||||
slot.t_start_genereration = 0;
|
slot.t_start_generation = 0;
|
||||||
|
|
||||||
if (slot.infill) {
|
if (slot.infill) {
|
||||||
bool suff_rm_leading_spc = true;
|
bool suff_rm_leading_spc = true;
|
||||||
|
@ -1729,16 +1742,29 @@ struct llama_server_context {
|
||||||
prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
|
prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slot.n_past = 0;
|
||||||
slot.n_prompt_tokens = prompt_tokens.size();
|
slot.n_prompt_tokens = prompt_tokens.size();
|
||||||
|
|
||||||
|
if (slot.embedding) {
|
||||||
|
// this prompt is too large to process - discard it
|
||||||
|
if (slot.n_prompt_tokens > n_batch) {
|
||||||
|
slot.state = PROCESSING;
|
||||||
|
slot.command = NONE;
|
||||||
|
slot.release();
|
||||||
|
slot.print_timings();
|
||||||
|
send_final_response(slot);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if (slot.params.n_keep < 0) {
|
if (slot.params.n_keep < 0) {
|
||||||
slot.params.n_keep = slot.n_prompt_tokens;
|
slot.params.n_keep = slot.n_prompt_tokens;
|
||||||
}
|
}
|
||||||
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
||||||
|
|
||||||
// if input prompt is too big, truncate it, if group attention self-extend is disabled
|
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
|
||||||
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
|
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
|
||||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||||
|
|
||||||
const int n_block_size = n_left / 2;
|
const int n_block_size = n_left / 2;
|
||||||
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||||
|
|
||||||
|
@ -1751,57 +1777,40 @@ struct llama_server_context {
|
||||||
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
|
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
|
||||||
prompt_tokens.end());
|
prompt_tokens.end());
|
||||||
|
|
||||||
|
prompt_tokens = std::move(new_tokens);
|
||||||
|
|
||||||
|
slot.truncated = true;
|
||||||
|
slot.n_prompt_tokens = prompt_tokens.size();
|
||||||
|
|
||||||
LOG_VERBOSE("input truncated", {
|
LOG_VERBOSE("input truncated", {
|
||||||
{"n_ctx", slot.n_ctx},
|
{"n_ctx", slot.n_ctx},
|
||||||
{"n_keep", slot.params.n_keep},
|
{"n_keep", slot.params.n_keep},
|
||||||
{"n_left", n_left},
|
{"n_left", n_left},
|
||||||
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
|
||||||
});
|
});
|
||||||
|
|
||||||
slot.truncated = true;
|
|
||||||
prompt_tokens = new_tokens;
|
|
||||||
|
|
||||||
slot.n_prompt_tokens = prompt_tokens.size();
|
|
||||||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!slot.params.cache_prompt) {
|
|
||||||
llama_sampling_reset(slot.ctx_sampling);
|
llama_sampling_reset(slot.ctx_sampling);
|
||||||
|
|
||||||
slot.n_past = 0;
|
if (!slot.params.cache_prompt) {
|
||||||
slot.n_past_se = 0;
|
slot.n_past_se = 0;
|
||||||
slot.ga_i = 0;
|
slot.ga_i = 0;
|
||||||
|
|
||||||
slot.n_prompt_tokens_processed = slot.n_prompt_tokens;
|
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(slot.ga_n == 1);
|
GGML_ASSERT(slot.ga_n == 1);
|
||||||
|
|
||||||
// push the prompt into the sampling context (do not apply grammar)
|
|
||||||
for (auto & token : prompt_tokens) {
|
|
||||||
llama_sampling_accept(slot.ctx_sampling, ctx, token, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
||||||
|
|
||||||
// the last token of the cache is not in the KV cache until the next call to llama_decode
|
// remove the non-common part from the cache
|
||||||
// (it was sampled, pushed into the "cache_tokens", but not yet put in the context)
|
slot.cache_tokens.resize(slot.n_past);
|
||||||
if (slot.n_past > 0 && slot.n_past == (int32_t) slot.cache_tokens.size()) {
|
|
||||||
slot.n_past -= 1;
|
// push the prompt into the sampling context (do not apply grammar)
|
||||||
|
for (int i = 0; i < slot.n_past; ++i) {
|
||||||
|
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.n_prompt_tokens_processed = slot.n_prompt_tokens - slot.n_past;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_INFO("slot progression", {
|
|
||||||
{ "id_slot", slot.id },
|
|
||||||
{ "id_task", slot.id_task },
|
|
||||||
{ "n_past", slot.n_past },
|
|
||||||
{ "n_past_se", slot.n_past_se },
|
|
||||||
{ "ga_i", slot.ga_i },
|
|
||||||
{ "n_prompt_tokens_processed", slot.n_prompt_tokens_processed }
|
|
||||||
});
|
|
||||||
|
|
||||||
slot.cache_tokens = prompt_tokens;
|
|
||||||
|
|
||||||
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
|
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
|
||||||
// we have to evaluate at least 1 token to generate logits.
|
// we have to evaluate at least 1 token to generate logits.
|
||||||
|
@ -1816,6 +1825,16 @@ struct llama_server_context {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slot.n_prompt_tokens_processed = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (slot.embedding) {
|
||||||
|
// cannot fit the prompt in the current batch - will try next iter
|
||||||
|
if (slot.n_prompt_tokens > n_available) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const int p0 = (int) system_tokens.size() + slot.n_past;
|
const int p0 = (int) system_tokens.size() + slot.n_past;
|
||||||
llama_kv_cache_seq_rm(ctx, slot.id, p0, -1);
|
llama_kv_cache_seq_rm(ctx, slot.id, p0, -1);
|
||||||
|
|
||||||
|
@ -1825,19 +1844,13 @@ struct llama_server_context {
|
||||||
{ "p0", p0 }
|
{ "p0", p0 }
|
||||||
});
|
});
|
||||||
|
|
||||||
LOG_VERBOSE("prompt ingested", {
|
|
||||||
{"n_past", slot.n_past},
|
|
||||||
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},
|
|
||||||
{"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())},
|
|
||||||
});
|
|
||||||
|
|
||||||
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
||||||
|
|
||||||
int32_t ga_i = slot.ga_i;
|
int32_t ga_i = slot.ga_i;
|
||||||
int32_t ga_n = slot.ga_n;
|
int32_t ga_n = slot.ga_n;
|
||||||
int32_t ga_w = slot.ga_w;
|
int32_t ga_w = slot.ga_w;
|
||||||
|
|
||||||
for (; slot.n_past < (int) prompt_tokens.size(); ++slot.n_past) {
|
for (; slot.n_past < slot.n_prompt_tokens && n_available > 0; ++slot.n_past, --n_available) {
|
||||||
if (slot.ga_n != 1) {
|
if (slot.ga_n != 1) {
|
||||||
while (slot_npast >= ga_i + ga_w) {
|
while (slot_npast >= ga_i + ga_w) {
|
||||||
const int bd = (ga_w/ga_n)*(ga_n - 1);
|
const int bd = (ga_w/ga_n)*(ga_n - 1);
|
||||||
|
@ -1848,16 +1861,45 @@ struct llama_server_context {
|
||||||
|
|
||||||
llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
|
llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
|
||||||
|
|
||||||
|
if (slot.params.cache_prompt) {
|
||||||
|
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
||||||
|
}
|
||||||
|
|
||||||
|
slot.n_prompt_tokens_processed++;
|
||||||
slot_npast++;
|
slot_npast++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// entire prompt has been processed
|
||||||
|
if (slot.n_past == slot.n_prompt_tokens) {
|
||||||
|
slot.state = PROCESSING;
|
||||||
|
slot.command = NONE;
|
||||||
|
|
||||||
|
GGML_ASSERT(batch.n_tokens > 0);
|
||||||
|
|
||||||
// extract the logits only for the last token
|
// extract the logits only for the last token
|
||||||
if (batch.n_tokens > 0) {
|
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.logits[batch.n_tokens - 1] = true;
|
||||||
}
|
|
||||||
|
|
||||||
slot.n_decoded = 0;
|
slot.n_decoded = 0;
|
||||||
slot.i_batch = batch.n_tokens - 1;
|
slot.i_batch = batch.n_tokens - 1;
|
||||||
|
|
||||||
|
LOG_VERBOSE("prompt processed", {
|
||||||
|
{"id_slot", slot.id},
|
||||||
|
{"n_past", slot.n_past},
|
||||||
|
{"n_ctx", n_ctx},
|
||||||
|
{"n_tokens", batch.n_tokens},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
LOG_VERBOSE("prompt still processing", {
|
||||||
|
{"id_slot", slot.id},
|
||||||
|
{"n_past", slot.n_past},
|
||||||
|
{"n_ctx", n_ctx},
|
||||||
|
{"n_tokens", batch.n_tokens},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n_available == 0) {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1868,6 +1910,10 @@ struct llama_server_context {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LOG_VERBOSE("decoding batch", {
|
||||||
|
{"n_tokens", batch.n_tokens},
|
||||||
|
});
|
||||||
|
|
||||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||||
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
||||||
|
|
||||||
|
@ -1948,8 +1994,8 @@ struct llama_server_context {
|
||||||
|
|
||||||
slot.n_decoded += 1;
|
slot.n_decoded += 1;
|
||||||
if (slot.n_decoded == 1) {
|
if (slot.n_decoded == 1) {
|
||||||
slot.t_start_genereration = ggml_time_us();
|
slot.t_start_generation = ggml_time_us();
|
||||||
slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3;
|
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
||||||
metrics.on_prompt_eval(slot);
|
metrics.on_prompt_eval(slot);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue