server : do not process more than n_batch tokens per iter

This commit is contained in:
Georgi Gerganov 2024-03-06 21:04:09 +02:00
parent aef02b11ec
commit bfb121fd2e
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -148,8 +148,9 @@ struct server_slot {
int32_t n_prompt_tokens_processed = 0;
json prompt;
std::vector<llama_token> prompt_tokens;
std::string generated_text;
llama_token sampled;
std::vector<llama_token> cache_tokens;
std::vector<completion_token_output> generated_token_probs;
@ -167,6 +168,7 @@ struct server_slot {
std::string stopping_word;
// sampling
llama_token sampled;
struct llama_sampling_params sparams;
llama_sampling_context * ctx_sampling = nullptr;
@ -181,7 +183,7 @@ struct server_slot {
size_t n_sent_token_probs = 0;
int64_t t_start_process_prompt;
int64_t t_start_genereration;
int64_t t_start_generation;
double t_prompt_processing; // ms
double t_token_generation; // ms
@ -232,13 +234,12 @@ struct server_slot {
if (command == RELEASE) {
return;
}
cache_tokens.push_back(token.tok);
generated_token_probs.push_back(token);
}
void release() {
if (state == PROCESSING) {
t_token_generation = (ggml_time_us() - t_start_genereration) / 1e3;
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
command = RELEASE;
}
}
@ -968,6 +969,7 @@ struct llama_server_context {
}
slot.command = LOAD_PROMPT;
slot.prompt_tokens.clear();
LOG_INFO("slot is processing task", {
{"id_slot", slot.id},
@ -1426,9 +1428,7 @@ struct llama_server_context {
if (task.data.contains("system_prompt")) {
system_prompt_set(task.data["system_prompt"]);
// reset cache_tokens for all slots
for (server_slot & slot : slots) {
slot.cache_tokens.clear();
slot.n_past = 0;
slot.n_past_se = 0;
}
@ -1614,7 +1614,7 @@ struct llama_server_context {
for (server_slot & slot : slots) {
if (slot.ga_n == 1) {
if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx) {
if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
// Shift context
const int n_keep = slot.params.n_keep + add_bos_token;
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
@ -1635,11 +1635,13 @@ struct llama_server_context {
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
if (slot.params.cache_prompt) {
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
}
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
}
slot.n_past -= n_discard;
@ -1663,8 +1665,13 @@ struct llama_server_context {
// TODO: we always have to take into account the "system_tokens"
// this is not great and needs to be improved somehow
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
slot.n_past += 1;
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(slot.sampled);
}
LOG_VERBOSE("slot decode token", {
{"id_slot", slot.id},
{"id_task", slot.id_task},
@ -1681,6 +1688,8 @@ struct llama_server_context {
// assign workload to the slots
if (params.cont_batching || batch.n_tokens == 0) {
int n_available = n_batch;
for (auto & slot : slots) {
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty());
@ -1697,12 +1706,16 @@ struct llama_server_context {
// need process the prompt
if (slot.state == IDLE && slot.command == LOAD_PROMPT) {
slot.state = PROCESSING;
slot.command = NONE;
auto & prompt_tokens = slot.prompt_tokens;
if (prompt_tokens.empty()) {
LOG_VERBOSE("tokenizing prompt", {
{"id_slot", slot.id},
{"id_task", slot.id_task}
});
std::vector<llama_token> prompt_tokens;
slot.t_start_process_prompt = ggml_time_us();
slot.t_start_genereration = 0;
slot.t_start_generation = 0;
if (slot.infill) {
bool suff_rm_leading_spc = true;
@ -1729,16 +1742,29 @@ struct llama_server_context {
prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
}
slot.n_past = 0;
slot.n_prompt_tokens = prompt_tokens.size();
if (slot.embedding) {
// this prompt is too large to process - discard it
if (slot.n_prompt_tokens > n_batch) {
slot.state = PROCESSING;
slot.command = NONE;
slot.release();
slot.print_timings();
send_final_response(slot);
continue;
}
} else {
if (slot.params.n_keep < 0) {
slot.params.n_keep = slot.n_prompt_tokens;
}
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
// if input prompt is too big, truncate it, if group attention self-extend is disabled
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_block_size = n_left / 2;
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
@ -1751,57 +1777,40 @@ struct llama_server_context {
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
prompt_tokens.end());
prompt_tokens = std::move(new_tokens);
slot.truncated = true;
slot.n_prompt_tokens = prompt_tokens.size();
LOG_VERBOSE("input truncated", {
{"n_ctx", slot.n_ctx},
{"n_keep", slot.params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
});
slot.truncated = true;
prompt_tokens = new_tokens;
slot.n_prompt_tokens = prompt_tokens.size();
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
}
if (!slot.params.cache_prompt) {
llama_sampling_reset(slot.ctx_sampling);
slot.n_past = 0;
if (!slot.params.cache_prompt) {
slot.n_past_se = 0;
slot.ga_i = 0;
slot.n_prompt_tokens_processed = slot.n_prompt_tokens;
} else {
GGML_ASSERT(slot.ga_n == 1);
// push the prompt into the sampling context (do not apply grammar)
for (auto & token : prompt_tokens) {
llama_sampling_accept(slot.ctx_sampling, ctx, token, false);
}
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
// the last token of the cache is not in the KV cache until the next call to llama_decode
// (it was sampled, pushed into the "cache_tokens", but not yet put in the context)
if (slot.n_past > 0 && slot.n_past == (int32_t) slot.cache_tokens.size()) {
slot.n_past -= 1;
// remove the non-common part from the cache
slot.cache_tokens.resize(slot.n_past);
// push the prompt into the sampling context (do not apply grammar)
for (int i = 0; i < slot.n_past; ++i) {
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
}
}
slot.n_prompt_tokens_processed = slot.n_prompt_tokens - slot.n_past;
}
LOG_INFO("slot progression", {
{ "id_slot", slot.id },
{ "id_task", slot.id_task },
{ "n_past", slot.n_past },
{ "n_past_se", slot.n_past_se },
{ "ga_i", slot.ga_i },
{ "n_prompt_tokens_processed", slot.n_prompt_tokens_processed }
});
slot.cache_tokens = prompt_tokens;
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
// we have to evaluate at least 1 token to generate logits.
@ -1816,6 +1825,16 @@ struct llama_server_context {
}
}
slot.n_prompt_tokens_processed = 0;
}
if (slot.embedding) {
// cannot fit the prompt in the current batch - will try next iter
if (slot.n_prompt_tokens > n_available) {
continue;
}
}
const int p0 = (int) system_tokens.size() + slot.n_past;
llama_kv_cache_seq_rm(ctx, slot.id, p0, -1);
@ -1825,19 +1844,13 @@ struct llama_server_context {
{ "p0", p0 }
});
LOG_VERBOSE("prompt ingested", {
{"n_past", slot.n_past},
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},
{"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())},
});
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
int32_t ga_i = slot.ga_i;
int32_t ga_n = slot.ga_n;
int32_t ga_w = slot.ga_w;
for (; slot.n_past < (int) prompt_tokens.size(); ++slot.n_past) {
for (; slot.n_past < slot.n_prompt_tokens && n_available > 0; ++slot.n_past, --n_available) {
if (slot.ga_n != 1) {
while (slot_npast >= ga_i + ga_w) {
const int bd = (ga_w/ga_n)*(ga_n - 1);
@ -1848,16 +1861,45 @@ struct llama_server_context {
llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
}
slot.n_prompt_tokens_processed++;
slot_npast++;
}
// entire prompt has been processed
if (slot.n_past == slot.n_prompt_tokens) {
slot.state = PROCESSING;
slot.command = NONE;
GGML_ASSERT(batch.n_tokens > 0);
// extract the logits only for the last token
if (batch.n_tokens > 0) {
batch.logits[batch.n_tokens - 1] = true;
}
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
LOG_VERBOSE("prompt processed", {
{"id_slot", slot.id},
{"n_past", slot.n_past},
{"n_ctx", n_ctx},
{"n_tokens", batch.n_tokens},
});
} else {
LOG_VERBOSE("prompt still processing", {
{"id_slot", slot.id},
{"n_past", slot.n_past},
{"n_ctx", n_ctx},
{"n_tokens", batch.n_tokens},
});
}
}
if (n_available == 0) {
break;
}
}
}
@ -1868,6 +1910,10 @@ struct llama_server_context {
return true;
}
LOG_VERBOSE("decoding batch", {
{"n_tokens", batch.n_tokens},
});
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
@ -1948,8 +1994,8 @@ struct llama_server_context {
slot.n_decoded += 1;
if (slot.n_decoded == 1) {
slot.t_start_genereration = ggml_time_us();
slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3;
slot.t_start_generation = ggml_time_us();
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot);
}