server : do not process more than n_batch tokens per iter
This commit is contained in:
parent
aef02b11ec
commit
bfb121fd2e
1 changed files with 172 additions and 126 deletions
|
@ -148,8 +148,9 @@ struct server_slot {
|
|||
int32_t n_prompt_tokens_processed = 0;
|
||||
|
||||
json prompt;
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
|
||||
std::string generated_text;
|
||||
llama_token sampled;
|
||||
std::vector<llama_token> cache_tokens;
|
||||
std::vector<completion_token_output> generated_token_probs;
|
||||
|
||||
|
@ -167,6 +168,7 @@ struct server_slot {
|
|||
std::string stopping_word;
|
||||
|
||||
// sampling
|
||||
llama_token sampled;
|
||||
struct llama_sampling_params sparams;
|
||||
llama_sampling_context * ctx_sampling = nullptr;
|
||||
|
||||
|
@ -181,7 +183,7 @@ struct server_slot {
|
|||
size_t n_sent_token_probs = 0;
|
||||
|
||||
int64_t t_start_process_prompt;
|
||||
int64_t t_start_genereration;
|
||||
int64_t t_start_generation;
|
||||
|
||||
double t_prompt_processing; // ms
|
||||
double t_token_generation; // ms
|
||||
|
@ -232,13 +234,12 @@ struct server_slot {
|
|||
if (command == RELEASE) {
|
||||
return;
|
||||
}
|
||||
cache_tokens.push_back(token.tok);
|
||||
generated_token_probs.push_back(token);
|
||||
}
|
||||
|
||||
void release() {
|
||||
if (state == PROCESSING) {
|
||||
t_token_generation = (ggml_time_us() - t_start_genereration) / 1e3;
|
||||
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
|
||||
command = RELEASE;
|
||||
}
|
||||
}
|
||||
|
@ -968,6 +969,7 @@ struct llama_server_context {
|
|||
}
|
||||
|
||||
slot.command = LOAD_PROMPT;
|
||||
slot.prompt_tokens.clear();
|
||||
|
||||
LOG_INFO("slot is processing task", {
|
||||
{"id_slot", slot.id},
|
||||
|
@ -1426,9 +1428,7 @@ struct llama_server_context {
|
|||
if (task.data.contains("system_prompt")) {
|
||||
system_prompt_set(task.data["system_prompt"]);
|
||||
|
||||
// reset cache_tokens for all slots
|
||||
for (server_slot & slot : slots) {
|
||||
slot.cache_tokens.clear();
|
||||
slot.n_past = 0;
|
||||
slot.n_past_se = 0;
|
||||
}
|
||||
|
@ -1614,7 +1614,7 @@ struct llama_server_context {
|
|||
|
||||
for (server_slot & slot : slots) {
|
||||
if (slot.ga_n == 1) {
|
||||
if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx) {
|
||||
if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
|
||||
// Shift context
|
||||
const int n_keep = slot.params.n_keep + add_bos_token;
|
||||
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
|
||||
|
@ -1635,11 +1635,13 @@ struct llama_server_context {
|
|||
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
||||
|
||||
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
|
||||
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
|
||||
}
|
||||
if (slot.params.cache_prompt) {
|
||||
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
|
||||
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
|
||||
}
|
||||
|
||||
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
|
||||
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
|
||||
}
|
||||
|
||||
slot.n_past -= n_discard;
|
||||
|
||||
|
@ -1663,8 +1665,13 @@ struct llama_server_context {
|
|||
// TODO: we always have to take into account the "system_tokens"
|
||||
// this is not great and needs to be improved somehow
|
||||
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
|
||||
|
||||
slot.n_past += 1;
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
slot.cache_tokens.push_back(slot.sampled);
|
||||
}
|
||||
|
||||
LOG_VERBOSE("slot decode token", {
|
||||
{"id_slot", slot.id},
|
||||
{"id_task", slot.id_task},
|
||||
|
@ -1681,6 +1688,8 @@ struct llama_server_context {
|
|||
|
||||
// assign workload to the slots
|
||||
if (params.cont_batching || batch.n_tokens == 0) {
|
||||
int n_available = n_batch;
|
||||
|
||||
for (auto & slot : slots) {
|
||||
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty());
|
||||
|
||||
|
@ -1697,122 +1706,132 @@ struct llama_server_context {
|
|||
|
||||
// need process the prompt
|
||||
if (slot.state == IDLE && slot.command == LOAD_PROMPT) {
|
||||
slot.state = PROCESSING;
|
||||
slot.command = NONE;
|
||||
auto & prompt_tokens = slot.prompt_tokens;
|
||||
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
slot.t_start_process_prompt = ggml_time_us();
|
||||
slot.t_start_genereration = 0;
|
||||
|
||||
if (slot.infill) {
|
||||
bool suff_rm_leading_spc = true;
|
||||
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
|
||||
params.input_suffix.erase(0, 1);
|
||||
suff_rm_leading_spc = false;
|
||||
}
|
||||
|
||||
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
|
||||
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
|
||||
|
||||
const int space_token = 29871; // TODO: this should not be hardcoded
|
||||
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
|
||||
suffix_tokens.erase(suffix_tokens.begin());
|
||||
}
|
||||
|
||||
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
|
||||
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
|
||||
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
|
||||
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
|
||||
prefix_tokens.push_back(llama_token_middle(model));
|
||||
prompt_tokens = prefix_tokens;
|
||||
} else {
|
||||
prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
|
||||
}
|
||||
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
if (slot.params.n_keep < 0) {
|
||||
slot.params.n_keep = slot.n_prompt_tokens;
|
||||
}
|
||||
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
||||
|
||||
// if input prompt is too big, truncate it, if group attention self-extend is disabled
|
||||
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
|
||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||
const int n_block_size = n_left / 2;
|
||||
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||
|
||||
std::vector<llama_token> new_tokens(
|
||||
prompt_tokens.begin(),
|
||||
prompt_tokens.begin() + slot.params.n_keep);
|
||||
|
||||
new_tokens.insert(
|
||||
new_tokens.end(),
|
||||
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
|
||||
prompt_tokens.end());
|
||||
|
||||
LOG_VERBOSE("input truncated", {
|
||||
{"n_ctx", slot.n_ctx},
|
||||
{"n_keep", slot.params.n_keep},
|
||||
{"n_left", n_left},
|
||||
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
||||
if (prompt_tokens.empty()) {
|
||||
LOG_VERBOSE("tokenizing prompt", {
|
||||
{"id_slot", slot.id},
|
||||
{"id_task", slot.id_task}
|
||||
});
|
||||
|
||||
slot.truncated = true;
|
||||
prompt_tokens = new_tokens;
|
||||
slot.t_start_process_prompt = ggml_time_us();
|
||||
slot.t_start_generation = 0;
|
||||
|
||||
if (slot.infill) {
|
||||
bool suff_rm_leading_spc = true;
|
||||
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
|
||||
params.input_suffix.erase(0, 1);
|
||||
suff_rm_leading_spc = false;
|
||||
}
|
||||
|
||||
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
|
||||
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
|
||||
|
||||
const int space_token = 29871; // TODO: this should not be hardcoded
|
||||
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
|
||||
suffix_tokens.erase(suffix_tokens.begin());
|
||||
}
|
||||
|
||||
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
|
||||
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
|
||||
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
|
||||
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
|
||||
prefix_tokens.push_back(llama_token_middle(model));
|
||||
prompt_tokens = prefix_tokens;
|
||||
} else {
|
||||
prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
|
||||
}
|
||||
|
||||
slot.n_past = 0;
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||
}
|
||||
|
||||
if (!slot.params.cache_prompt) {
|
||||
llama_sampling_reset(slot.ctx_sampling);
|
||||
if (slot.embedding) {
|
||||
// this prompt is too large to process - discard it
|
||||
if (slot.n_prompt_tokens > n_batch) {
|
||||
slot.state = PROCESSING;
|
||||
slot.command = NONE;
|
||||
slot.release();
|
||||
slot.print_timings();
|
||||
send_final_response(slot);
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (slot.params.n_keep < 0) {
|
||||
slot.params.n_keep = slot.n_prompt_tokens;
|
||||
}
|
||||
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
||||
|
||||
slot.n_past = 0;
|
||||
slot.n_past_se = 0;
|
||||
slot.ga_i = 0;
|
||||
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
|
||||
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
|
||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||
|
||||
slot.n_prompt_tokens_processed = slot.n_prompt_tokens;
|
||||
} else {
|
||||
GGML_ASSERT(slot.ga_n == 1);
|
||||
const int n_block_size = n_left / 2;
|
||||
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||
|
||||
// push the prompt into the sampling context (do not apply grammar)
|
||||
for (auto & token : prompt_tokens) {
|
||||
llama_sampling_accept(slot.ctx_sampling, ctx, token, false);
|
||||
std::vector<llama_token> new_tokens(
|
||||
prompt_tokens.begin(),
|
||||
prompt_tokens.begin() + slot.params.n_keep);
|
||||
|
||||
new_tokens.insert(
|
||||
new_tokens.end(),
|
||||
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
|
||||
prompt_tokens.end());
|
||||
|
||||
prompt_tokens = std::move(new_tokens);
|
||||
|
||||
slot.truncated = true;
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
LOG_VERBOSE("input truncated", {
|
||||
{"n_ctx", slot.n_ctx},
|
||||
{"n_keep", slot.params.n_keep},
|
||||
{"n_left", n_left},
|
||||
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
|
||||
});
|
||||
|
||||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||
}
|
||||
|
||||
llama_sampling_reset(slot.ctx_sampling);
|
||||
|
||||
if (!slot.params.cache_prompt) {
|
||||
slot.n_past_se = 0;
|
||||
slot.ga_i = 0;
|
||||
} else {
|
||||
GGML_ASSERT(slot.ga_n == 1);
|
||||
|
||||
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
||||
|
||||
// remove the non-common part from the cache
|
||||
slot.cache_tokens.resize(slot.n_past);
|
||||
|
||||
// push the prompt into the sampling context (do not apply grammar)
|
||||
for (int i = 0; i < slot.n_past; ++i) {
|
||||
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
||||
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
|
||||
// we have to evaluate at least 1 token to generate logits.
|
||||
LOG_INFO("we have to evaluate at least 1 token to generate logits", {
|
||||
{ "id_slot", slot.id },
|
||||
{ "id_task", slot.id_task }
|
||||
});
|
||||
|
||||
// the last token of the cache is not in the KV cache until the next call to llama_decode
|
||||
// (it was sampled, pushed into the "cache_tokens", but not yet put in the context)
|
||||
if (slot.n_past > 0 && slot.n_past == (int32_t) slot.cache_tokens.size()) {
|
||||
slot.n_past -= 1;
|
||||
slot.n_past--;
|
||||
if (slot.ga_i > 0) {
|
||||
slot.n_past_se--;
|
||||
}
|
||||
}
|
||||
|
||||
slot.n_prompt_tokens_processed = slot.n_prompt_tokens - slot.n_past;
|
||||
slot.n_prompt_tokens_processed = 0;
|
||||
}
|
||||
|
||||
LOG_INFO("slot progression", {
|
||||
{ "id_slot", slot.id },
|
||||
{ "id_task", slot.id_task },
|
||||
{ "n_past", slot.n_past },
|
||||
{ "n_past_se", slot.n_past_se },
|
||||
{ "ga_i", slot.ga_i },
|
||||
{ "n_prompt_tokens_processed", slot.n_prompt_tokens_processed }
|
||||
});
|
||||
|
||||
slot.cache_tokens = prompt_tokens;
|
||||
|
||||
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
|
||||
// we have to evaluate at least 1 token to generate logits.
|
||||
LOG_INFO("we have to evaluate at least 1 token to generate logits", {
|
||||
{ "id_slot", slot.id },
|
||||
{ "id_task", slot.id_task }
|
||||
});
|
||||
|
||||
slot.n_past--;
|
||||
if (slot.ga_i > 0) {
|
||||
slot.n_past_se--;
|
||||
if (slot.embedding) {
|
||||
// cannot fit the prompt in the current batch - will try next iter
|
||||
if (slot.n_prompt_tokens > n_available) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1825,19 +1844,13 @@ struct llama_server_context {
|
|||
{ "p0", p0 }
|
||||
});
|
||||
|
||||
LOG_VERBOSE("prompt ingested", {
|
||||
{"n_past", slot.n_past},
|
||||
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},
|
||||
{"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())},
|
||||
});
|
||||
|
||||
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
||||
|
||||
int32_t ga_i = slot.ga_i;
|
||||
int32_t ga_n = slot.ga_n;
|
||||
int32_t ga_w = slot.ga_w;
|
||||
|
||||
for (; slot.n_past < (int) prompt_tokens.size(); ++slot.n_past) {
|
||||
for (; slot.n_past < slot.n_prompt_tokens && n_available > 0; ++slot.n_past, --n_available) {
|
||||
if (slot.ga_n != 1) {
|
||||
while (slot_npast >= ga_i + ga_w) {
|
||||
const int bd = (ga_w/ga_n)*(ga_n - 1);
|
||||
|
@ -1848,16 +1861,45 @@ struct llama_server_context {
|
|||
|
||||
llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
||||
}
|
||||
|
||||
slot.n_prompt_tokens_processed++;
|
||||
slot_npast++;
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
if (batch.n_tokens > 0) {
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
}
|
||||
// entire prompt has been processed
|
||||
if (slot.n_past == slot.n_prompt_tokens) {
|
||||
slot.state = PROCESSING;
|
||||
slot.command = NONE;
|
||||
|
||||
slot.n_decoded = 0;
|
||||
slot.i_batch = batch.n_tokens - 1;
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
|
||||
// extract the logits only for the last token
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
slot.n_decoded = 0;
|
||||
slot.i_batch = batch.n_tokens - 1;
|
||||
|
||||
LOG_VERBOSE("prompt processed", {
|
||||
{"id_slot", slot.id},
|
||||
{"n_past", slot.n_past},
|
||||
{"n_ctx", n_ctx},
|
||||
{"n_tokens", batch.n_tokens},
|
||||
});
|
||||
} else {
|
||||
LOG_VERBOSE("prompt still processing", {
|
||||
{"id_slot", slot.id},
|
||||
{"n_past", slot.n_past},
|
||||
{"n_ctx", n_ctx},
|
||||
{"n_tokens", batch.n_tokens},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (n_available == 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1868,6 +1910,10 @@ struct llama_server_context {
|
|||
return true;
|
||||
}
|
||||
|
||||
LOG_VERBOSE("decoding batch", {
|
||||
{"n_tokens", batch.n_tokens},
|
||||
});
|
||||
|
||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
||||
|
||||
|
@ -1948,8 +1994,8 @@ struct llama_server_context {
|
|||
|
||||
slot.n_decoded += 1;
|
||||
if (slot.n_decoded == 1) {
|
||||
slot.t_start_genereration = ggml_time_us();
|
||||
slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3;
|
||||
slot.t_start_generation = ggml_time_us();
|
||||
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
||||
metrics.on_prompt_eval(slot);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue