fix infinite generation loop
This commit is contained in:
parent
60d4194bfe
commit
b550011be3
1 changed files with 120 additions and 116 deletions
|
@ -68,6 +68,7 @@ enum stop_type {
|
|||
// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
|
||||
enum slot_state {
|
||||
SLOT_STATE_IDLE,
|
||||
SLOT_STATE_STARTED,
|
||||
SLOT_STATE_PROCESSING_PROMPT,
|
||||
SLOT_STATE_DONE_PROMPT,
|
||||
SLOT_STATE_GENERATING,
|
||||
|
@ -950,7 +951,7 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
||||
slot.state = SLOT_STATE_STARTED;
|
||||
|
||||
SLT_INF(slot, "%s", "processing task\n");
|
||||
|
||||
|
@ -1867,149 +1868,152 @@ struct server_context {
|
|||
if (params.cont_batching || batch.n_tokens == 0) {
|
||||
for (auto & slot : slots) {
|
||||
// this slot still has a prompt to be processed
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT) {
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
||||
auto & prompt_tokens = slot.prompt_tokens;
|
||||
|
||||
slot.t_start_process_prompt = ggml_time_us();
|
||||
slot.t_start_generation = 0;
|
||||
slot.n_past = 0;
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
// TODO: maybe move branch to outside of this loop in the future
|
||||
if (slot.state == SLOT_STATE_STARTED) {
|
||||
slot.t_start_process_prompt = ggml_time_us();
|
||||
slot.t_start_generation = 0;
|
||||
slot.n_past = 0;
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
||||
|
||||
// empty prompt passed -> release the slot and send empty response
|
||||
if (prompt_tokens.empty()) {
|
||||
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
|
||||
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
|
||||
|
||||
slot.release();
|
||||
slot.print_timings();
|
||||
send_final_response(slot);
|
||||
continue;
|
||||
}
|
||||
|
||||
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
|
||||
|
||||
// print prompt tokens (for debugging)
|
||||
if (1) {
|
||||
// first 16 tokens (avoid flooding logs)
|
||||
for (int i = 0; i < std::min<int>(16, prompt_tokens.size()); i++) {
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
// print prompt tokens (for debugging)
|
||||
if (1) {
|
||||
// first 16 tokens (avoid flooding logs)
|
||||
for (int i = 0; i < std::min<int>(16, prompt_tokens.size()); i++) {
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
}
|
||||
} else {
|
||||
// all
|
||||
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// all
|
||||
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
|
||||
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
// this prompt is too large to process - discard it
|
||||
if (slot.n_prompt_tokens > n_ubatch) {
|
||||
// empty prompt passed -> release the slot and send empty response
|
||||
if (prompt_tokens.empty()) {
|
||||
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
|
||||
|
||||
slot.release();
|
||||
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
|
||||
slot.print_timings();
|
||||
send_final_response(slot);
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (!params.ctx_shift) {
|
||||
// if context shift is disabled, we make sure prompt size is smaller than KV size
|
||||
// TODO: there should be a separate parameter that control prompt truncation
|
||||
// context shift should be applied only during the generation phase
|
||||
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
||||
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
// this prompt is too large to process - discard it
|
||||
if (slot.n_prompt_tokens > n_ubatch) {
|
||||
slot.release();
|
||||
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
|
||||
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (slot.params.n_keep < 0) {
|
||||
slot.params.n_keep = slot.n_prompt_tokens;
|
||||
}
|
||||
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
||||
} else {
|
||||
if (!params.ctx_shift) {
|
||||
// if context shift is disabled, we make sure prompt size is smaller than KV size
|
||||
// TODO: there should be a separate parameter that control prompt truncation
|
||||
// context shift should be applied only during the generation phase
|
||||
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
||||
slot.release();
|
||||
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (slot.params.n_keep < 0) {
|
||||
slot.params.n_keep = slot.n_prompt_tokens;
|
||||
}
|
||||
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
||||
|
||||
// if input prompt is too big, truncate it
|
||||
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||
// if input prompt is too big, truncate it
|
||||
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||
|
||||
const int n_block_size = n_left / 2;
|
||||
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||
const int n_block_size = n_left / 2;
|
||||
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
||||
|
||||
std::vector<llama_token> new_tokens(
|
||||
prompt_tokens.begin(),
|
||||
prompt_tokens.begin() + slot.params.n_keep);
|
||||
std::vector<llama_token> new_tokens(
|
||||
prompt_tokens.begin(),
|
||||
prompt_tokens.begin() + slot.params.n_keep);
|
||||
|
||||
new_tokens.insert(
|
||||
new_tokens.end(),
|
||||
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
|
||||
prompt_tokens.end());
|
||||
new_tokens.insert(
|
||||
new_tokens.end(),
|
||||
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
|
||||
prompt_tokens.end());
|
||||
|
||||
prompt_tokens = std::move(new_tokens);
|
||||
prompt_tokens = std::move(new_tokens);
|
||||
|
||||
slot.truncated = true;
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
slot.truncated = true;
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens);
|
||||
SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens);
|
||||
|
||||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||
}
|
||||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||
}
|
||||
|
||||
common_sampler_reset(slot.smpl);
|
||||
if (slot.params.cache_prompt) {
|
||||
// reuse any previously computed tokens that are common with the new prompt
|
||||
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
// reuse any previously computed tokens that are common with the new prompt
|
||||
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
|
||||
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
||||
if (params.n_cache_reuse > 0) {
|
||||
size_t head_c = slot.n_past; // cache
|
||||
size_t head_p = slot.n_past; // current prompt
|
||||
|
||||
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
||||
if (params.n_cache_reuse > 0) {
|
||||
size_t head_c = slot.n_past; // cache
|
||||
size_t head_p = slot.n_past; // current prompt
|
||||
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params.n_cache_reuse, slot.n_past);
|
||||
|
||||
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params.n_cache_reuse, slot.n_past);
|
||||
while (head_c < slot.cache_tokens.size() &&
|
||||
head_p < prompt_tokens.size()) {
|
||||
|
||||
while (head_c < slot.cache_tokens.size() &&
|
||||
head_p < prompt_tokens.size()) {
|
||||
size_t n_match = 0;
|
||||
while (head_c + n_match < slot.cache_tokens.size() &&
|
||||
head_p + n_match < prompt_tokens.size() &&
|
||||
slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
|
||||
|
||||
size_t n_match = 0;
|
||||
while (head_c + n_match < slot.cache_tokens.size() &&
|
||||
head_p + n_match < prompt_tokens.size() &&
|
||||
slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
|
||||
|
||||
n_match++;
|
||||
}
|
||||
|
||||
if (n_match >= (size_t) params.n_cache_reuse) {
|
||||
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
|
||||
//for (size_t i = head_p; i < head_p + n_match; i++) {
|
||||
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
//}
|
||||
|
||||
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
|
||||
|
||||
llama_kv_cache_seq_rm (ctx, slot.id + 1, head_p, head_c);
|
||||
llama_kv_cache_seq_add(ctx, slot.id + 1, head_c, -1, kv_shift);
|
||||
|
||||
for (size_t i = 0; i < n_match; i++) {
|
||||
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
|
||||
slot.n_past++;
|
||||
n_match++;
|
||||
}
|
||||
|
||||
head_c += n_match;
|
||||
head_p += n_match;
|
||||
} else {
|
||||
head_c += 1;
|
||||
}
|
||||
}
|
||||
if (n_match >= (size_t) params.n_cache_reuse) {
|
||||
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
|
||||
//for (size_t i = head_p; i < head_p + n_match; i++) {
|
||||
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
||||
//}
|
||||
|
||||
SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
|
||||
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
|
||||
|
||||
llama_kv_cache_seq_rm (ctx, slot.id + 1, head_p, head_c);
|
||||
llama_kv_cache_seq_add(ctx, slot.id + 1, head_c, -1, kv_shift);
|
||||
|
||||
for (size_t i = 0; i < n_match; i++) {
|
||||
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
|
||||
|
||||
slot.n_past++;
|
||||
}
|
||||
|
||||
head_c += n_match;
|
||||
head_p += n_match;
|
||||
} else {
|
||||
head_c += 1;
|
||||
}
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
|
||||
// we have to evaluate at least 1 token to generate logits.
|
||||
SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
|
||||
|
||||
slot.n_past--;
|
||||
}
|
||||
|
||||
slot.n_prompt_tokens_processed = 0;
|
||||
}
|
||||
|
||||
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
|
||||
// we have to evaluate at least 1 token to generate logits.
|
||||
SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
|
||||
|
||||
slot.n_past--;
|
||||
}
|
||||
|
||||
slot.n_prompt_tokens_processed = 0;
|
||||
|
||||
// non-causal tasks require to fit the entire prompt in the physical batch
|
||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||
// cannot fit the prompt in the current batch - will try next iter
|
||||
|
@ -2036,8 +2040,6 @@ struct server_context {
|
|||
|
||||
// there is no common part left
|
||||
slot.n_past = 0;
|
||||
|
||||
common_sampler_reset(slot.smpl);
|
||||
}
|
||||
|
||||
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
|
||||
|
@ -2047,10 +2049,10 @@ struct server_context {
|
|||
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
||||
common_batch_add(batch, slot.prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false);
|
||||
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
slot.cache_tokens.push_back(slot.prompt_tokens[slot.n_past]);
|
||||
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
||||
}
|
||||
|
||||
slot.n_prompt_tokens_processed++;
|
||||
|
@ -2065,9 +2067,11 @@ struct server_context {
|
|||
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
|
||||
common_sampler_reset(slot.smpl);
|
||||
|
||||
// Process all prompt tokens through sampler system
|
||||
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
|
||||
common_sampler_accept(slot.smpl, slot.prompt_tokens[i], false);
|
||||
common_sampler_accept(slot.smpl, prompt_tokens[i], false);
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue