server : do not process more than n_batch tokens per iter

This commit is contained in:
Georgi Gerganov 2024-03-06 21:04:09 +02:00
parent aef02b11ec
commit bfb121fd2e
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -148,8 +148,9 @@ struct server_slot {
int32_t n_prompt_tokens_processed = 0; int32_t n_prompt_tokens_processed = 0;
json prompt; json prompt;
std::vector<llama_token> prompt_tokens;
std::string generated_text; std::string generated_text;
llama_token sampled;
std::vector<llama_token> cache_tokens; std::vector<llama_token> cache_tokens;
std::vector<completion_token_output> generated_token_probs; std::vector<completion_token_output> generated_token_probs;
@ -167,6 +168,7 @@ struct server_slot {
std::string stopping_word; std::string stopping_word;
// sampling // sampling
llama_token sampled;
struct llama_sampling_params sparams; struct llama_sampling_params sparams;
llama_sampling_context * ctx_sampling = nullptr; llama_sampling_context * ctx_sampling = nullptr;
@ -181,7 +183,7 @@ struct server_slot {
size_t n_sent_token_probs = 0; size_t n_sent_token_probs = 0;
int64_t t_start_process_prompt; int64_t t_start_process_prompt;
int64_t t_start_genereration; int64_t t_start_generation;
double t_prompt_processing; // ms double t_prompt_processing; // ms
double t_token_generation; // ms double t_token_generation; // ms
@ -232,13 +234,12 @@ struct server_slot {
if (command == RELEASE) { if (command == RELEASE) {
return; return;
} }
cache_tokens.push_back(token.tok);
generated_token_probs.push_back(token); generated_token_probs.push_back(token);
} }
void release() { void release() {
if (state == PROCESSING) { if (state == PROCESSING) {
t_token_generation = (ggml_time_us() - t_start_genereration) / 1e3; t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
command = RELEASE; command = RELEASE;
} }
} }
@ -968,6 +969,7 @@ struct llama_server_context {
} }
slot.command = LOAD_PROMPT; slot.command = LOAD_PROMPT;
slot.prompt_tokens.clear();
LOG_INFO("slot is processing task", { LOG_INFO("slot is processing task", {
{"id_slot", slot.id}, {"id_slot", slot.id},
@ -1426,9 +1428,7 @@ struct llama_server_context {
if (task.data.contains("system_prompt")) { if (task.data.contains("system_prompt")) {
system_prompt_set(task.data["system_prompt"]); system_prompt_set(task.data["system_prompt"]);
// reset cache_tokens for all slots
for (server_slot & slot : slots) { for (server_slot & slot : slots) {
slot.cache_tokens.clear();
slot.n_past = 0; slot.n_past = 0;
slot.n_past_se = 0; slot.n_past_se = 0;
} }
@ -1614,7 +1614,7 @@ struct llama_server_context {
for (server_slot & slot : slots) { for (server_slot & slot : slots) {
if (slot.ga_n == 1) { if (slot.ga_n == 1) {
if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx) { if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
// Shift context // Shift context
const int n_keep = slot.params.n_keep + add_bos_token; const int n_keep = slot.params.n_keep + add_bos_token;
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep; const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
@ -1635,11 +1635,13 @@ struct llama_server_context {
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { if (slot.params.cache_prompt) {
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
} slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
}
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
}
slot.n_past -= n_discard; slot.n_past -= n_discard;
@ -1663,8 +1665,13 @@ struct llama_server_context {
// TODO: we always have to take into account the "system_tokens" // TODO: we always have to take into account the "system_tokens"
// this is not great and needs to be improved somehow // this is not great and needs to be improved somehow
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true); llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
slot.n_past += 1; slot.n_past += 1;
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(slot.sampled);
}
LOG_VERBOSE("slot decode token", { LOG_VERBOSE("slot decode token", {
{"id_slot", slot.id}, {"id_slot", slot.id},
{"id_task", slot.id_task}, {"id_task", slot.id_task},
@ -1681,6 +1688,8 @@ struct llama_server_context {
// assign workload to the slots // assign workload to the slots
if (params.cont_batching || batch.n_tokens == 0) { if (params.cont_batching || batch.n_tokens == 0) {
int n_available = n_batch;
for (auto & slot : slots) { for (auto & slot : slots) {
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty()); const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty());
@ -1697,122 +1706,132 @@ struct llama_server_context {
// need process the prompt // need process the prompt
if (slot.state == IDLE && slot.command == LOAD_PROMPT) { if (slot.state == IDLE && slot.command == LOAD_PROMPT) {
slot.state = PROCESSING; auto & prompt_tokens = slot.prompt_tokens;
slot.command = NONE;
std::vector<llama_token> prompt_tokens; if (prompt_tokens.empty()) {
slot.t_start_process_prompt = ggml_time_us(); LOG_VERBOSE("tokenizing prompt", {
slot.t_start_genereration = 0; {"id_slot", slot.id},
{"id_task", slot.id_task}
if (slot.infill) {
bool suff_rm_leading_spc = true;
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
const int space_token = 29871; // TODO: this should not be hardcoded
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
suffix_tokens.erase(suffix_tokens.begin());
}
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.push_back(llama_token_middle(model));
prompt_tokens = prefix_tokens;
} else {
prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
}
slot.n_prompt_tokens = prompt_tokens.size();
if (slot.params.n_keep < 0) {
slot.params.n_keep = slot.n_prompt_tokens;
}
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
// if input prompt is too big, truncate it, if group attention self-extend is disabled
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_block_size = n_left / 2;
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
std::vector<llama_token> new_tokens(
prompt_tokens.begin(),
prompt_tokens.begin() + slot.params.n_keep);
new_tokens.insert(
new_tokens.end(),
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
prompt_tokens.end());
LOG_VERBOSE("input truncated", {
{"n_ctx", slot.n_ctx},
{"n_keep", slot.params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
}); });
slot.truncated = true; slot.t_start_process_prompt = ggml_time_us();
prompt_tokens = new_tokens; slot.t_start_generation = 0;
if (slot.infill) {
bool suff_rm_leading_spc = true;
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
const int space_token = 29871; // TODO: this should not be hardcoded
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
suffix_tokens.erase(suffix_tokens.begin());
}
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.push_back(llama_token_middle(model));
prompt_tokens = prefix_tokens;
} else {
prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
}
slot.n_past = 0;
slot.n_prompt_tokens = prompt_tokens.size(); slot.n_prompt_tokens = prompt_tokens.size();
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
}
if (!slot.params.cache_prompt) { if (slot.embedding) {
llama_sampling_reset(slot.ctx_sampling); // this prompt is too large to process - discard it
if (slot.n_prompt_tokens > n_batch) {
slot.state = PROCESSING;
slot.command = NONE;
slot.release();
slot.print_timings();
send_final_response(slot);
continue;
}
} else {
if (slot.params.n_keep < 0) {
slot.params.n_keep = slot.n_prompt_tokens;
}
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
slot.n_past = 0; // if input prompt is too big, truncate it (if group attention self-extend is disabled)
slot.n_past_se = 0; if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
slot.ga_i = 0; const int n_left = slot.n_ctx - slot.params.n_keep;
slot.n_prompt_tokens_processed = slot.n_prompt_tokens; const int n_block_size = n_left / 2;
} else { const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
GGML_ASSERT(slot.ga_n == 1);
// push the prompt into the sampling context (do not apply grammar) std::vector<llama_token> new_tokens(
for (auto & token : prompt_tokens) { prompt_tokens.begin(),
llama_sampling_accept(slot.ctx_sampling, ctx, token, false); prompt_tokens.begin() + slot.params.n_keep);
new_tokens.insert(
new_tokens.end(),
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
prompt_tokens.end());
prompt_tokens = std::move(new_tokens);
slot.truncated = true;
slot.n_prompt_tokens = prompt_tokens.size();
LOG_VERBOSE("input truncated", {
{"n_ctx", slot.n_ctx},
{"n_keep", slot.params.n_keep},
{"n_left", n_left},
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
});
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
}
llama_sampling_reset(slot.ctx_sampling);
if (!slot.params.cache_prompt) {
slot.n_past_se = 0;
slot.ga_i = 0;
} else {
GGML_ASSERT(slot.ga_n == 1);
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
// remove the non-common part from the cache
slot.cache_tokens.resize(slot.n_past);
// push the prompt into the sampling context (do not apply grammar)
for (int i = 0; i < slot.n_past; ++i) {
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
}
}
} }
slot.n_past = common_part(slot.cache_tokens, prompt_tokens); if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
// we have to evaluate at least 1 token to generate logits.
LOG_INFO("we have to evaluate at least 1 token to generate logits", {
{ "id_slot", slot.id },
{ "id_task", slot.id_task }
});
// the last token of the cache is not in the KV cache until the next call to llama_decode slot.n_past--;
// (it was sampled, pushed into the "cache_tokens", but not yet put in the context) if (slot.ga_i > 0) {
if (slot.n_past > 0 && slot.n_past == (int32_t) slot.cache_tokens.size()) { slot.n_past_se--;
slot.n_past -= 1; }
} }
slot.n_prompt_tokens_processed = slot.n_prompt_tokens - slot.n_past; slot.n_prompt_tokens_processed = 0;
} }
LOG_INFO("slot progression", { if (slot.embedding) {
{ "id_slot", slot.id }, // cannot fit the prompt in the current batch - will try next iter
{ "id_task", slot.id_task }, if (slot.n_prompt_tokens > n_available) {
{ "n_past", slot.n_past }, continue;
{ "n_past_se", slot.n_past_se },
{ "ga_i", slot.ga_i },
{ "n_prompt_tokens_processed", slot.n_prompt_tokens_processed }
});
slot.cache_tokens = prompt_tokens;
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
// we have to evaluate at least 1 token to generate logits.
LOG_INFO("we have to evaluate at least 1 token to generate logits", {
{ "id_slot", slot.id },
{ "id_task", slot.id_task }
});
slot.n_past--;
if (slot.ga_i > 0) {
slot.n_past_se--;
} }
} }
@ -1825,19 +1844,13 @@ struct llama_server_context {
{ "p0", p0 } { "p0", p0 }
}); });
LOG_VERBOSE("prompt ingested", {
{"n_past", slot.n_past},
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},
{"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())},
});
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
int32_t ga_i = slot.ga_i; int32_t ga_i = slot.ga_i;
int32_t ga_n = slot.ga_n; int32_t ga_n = slot.ga_n;
int32_t ga_w = slot.ga_w; int32_t ga_w = slot.ga_w;
for (; slot.n_past < (int) prompt_tokens.size(); ++slot.n_past) { for (; slot.n_past < slot.n_prompt_tokens && n_available > 0; ++slot.n_past, --n_available) {
if (slot.ga_n != 1) { if (slot.ga_n != 1) {
while (slot_npast >= ga_i + ga_w) { while (slot_npast >= ga_i + ga_w) {
const int bd = (ga_w/ga_n)*(ga_n - 1); const int bd = (ga_w/ga_n)*(ga_n - 1);
@ -1848,16 +1861,45 @@ struct llama_server_context {
llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false); llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
}
slot.n_prompt_tokens_processed++;
slot_npast++; slot_npast++;
} }
// extract the logits only for the last token // entire prompt has been processed
if (batch.n_tokens > 0) { if (slot.n_past == slot.n_prompt_tokens) {
batch.logits[batch.n_tokens - 1] = true; slot.state = PROCESSING;
} slot.command = NONE;
slot.n_decoded = 0; GGML_ASSERT(batch.n_tokens > 0);
slot.i_batch = batch.n_tokens - 1;
// extract the logits only for the last token
batch.logits[batch.n_tokens - 1] = true;
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
LOG_VERBOSE("prompt processed", {
{"id_slot", slot.id},
{"n_past", slot.n_past},
{"n_ctx", n_ctx},
{"n_tokens", batch.n_tokens},
});
} else {
LOG_VERBOSE("prompt still processing", {
{"id_slot", slot.id},
{"n_past", slot.n_past},
{"n_ctx", n_ctx},
{"n_tokens", batch.n_tokens},
});
}
}
if (n_available == 0) {
break;
} }
} }
} }
@ -1868,6 +1910,10 @@ struct llama_server_context {
return true; return true;
} }
LOG_VERBOSE("decoding batch", {
{"n_tokens", batch.n_tokens},
});
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
@ -1948,8 +1994,8 @@ struct llama_server_context {
slot.n_decoded += 1; slot.n_decoded += 1;
if (slot.n_decoded == 1) { if (slot.n_decoded == 1) {
slot.t_start_genereration = ggml_time_us(); slot.t_start_generation = ggml_time_us();
slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3; slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot); metrics.on_prompt_eval(slot);
} }