fix infinite generation loop
This commit is contained in:
parent
60d4194bfe
commit
b550011be3
1 changed files with 120 additions and 116 deletions
|
@ -68,6 +68,7 @@ enum stop_type {
|
||||||
// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
|
// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
|
||||||
enum slot_state {
|
enum slot_state {
|
||||||
SLOT_STATE_IDLE,
|
SLOT_STATE_IDLE,
|
||||||
|
SLOT_STATE_STARTED,
|
||||||
SLOT_STATE_PROCESSING_PROMPT,
|
SLOT_STATE_PROCESSING_PROMPT,
|
||||||
SLOT_STATE_DONE_PROMPT,
|
SLOT_STATE_DONE_PROMPT,
|
||||||
SLOT_STATE_GENERATING,
|
SLOT_STATE_GENERATING,
|
||||||
|
@ -950,7 +951,7 @@ struct server_context {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
slot.state = SLOT_STATE_STARTED;
|
||||||
|
|
||||||
SLT_INF(slot, "%s", "processing task\n");
|
SLT_INF(slot, "%s", "processing task\n");
|
||||||
|
|
||||||
|
@ -1867,23 +1868,16 @@ struct server_context {
|
||||||
if (params.cont_batching || batch.n_tokens == 0) {
|
if (params.cont_batching || batch.n_tokens == 0) {
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
// this slot still has a prompt to be processed
|
// this slot still has a prompt to be processed
|
||||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT) {
|
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
||||||
auto & prompt_tokens = slot.prompt_tokens;
|
auto & prompt_tokens = slot.prompt_tokens;
|
||||||
|
|
||||||
|
// TODO: maybe move branch to outside of this loop in the future
|
||||||
|
if (slot.state == SLOT_STATE_STARTED) {
|
||||||
slot.t_start_process_prompt = ggml_time_us();
|
slot.t_start_process_prompt = ggml_time_us();
|
||||||
slot.t_start_generation = 0;
|
slot.t_start_generation = 0;
|
||||||
slot.n_past = 0;
|
slot.n_past = 0;
|
||||||
slot.n_prompt_tokens = prompt_tokens.size();
|
slot.n_prompt_tokens = prompt_tokens.size();
|
||||||
|
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
||||||
// empty prompt passed -> release the slot and send empty response
|
|
||||||
if (prompt_tokens.empty()) {
|
|
||||||
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
|
|
||||||
|
|
||||||
slot.release();
|
|
||||||
slot.print_timings();
|
|
||||||
send_final_response(slot);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
|
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
|
||||||
|
|
||||||
|
@ -1900,6 +1894,16 @@ struct server_context {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// empty prompt passed -> release the slot and send empty response
|
||||||
|
if (prompt_tokens.empty()) {
|
||||||
|
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
|
||||||
|
|
||||||
|
slot.release();
|
||||||
|
slot.print_timings();
|
||||||
|
send_final_response(slot);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||||
// this prompt is too large to process - discard it
|
// this prompt is too large to process - discard it
|
||||||
if (slot.n_prompt_tokens > n_ubatch) {
|
if (slot.n_prompt_tokens > n_ubatch) {
|
||||||
|
@ -1949,8 +1953,6 @@ struct server_context {
|
||||||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
common_sampler_reset(slot.smpl);
|
|
||||||
|
|
||||||
if (slot.params.cache_prompt) {
|
if (slot.params.cache_prompt) {
|
||||||
// reuse any previously computed tokens that are common with the new prompt
|
// reuse any previously computed tokens that are common with the new prompt
|
||||||
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
|
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
|
||||||
|
@ -1986,6 +1988,7 @@ struct server_context {
|
||||||
|
|
||||||
for (size_t i = 0; i < n_match; i++) {
|
for (size_t i = 0; i < n_match; i++) {
|
||||||
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
|
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
|
||||||
|
|
||||||
slot.n_past++;
|
slot.n_past++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2009,6 +2012,7 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.n_prompt_tokens_processed = 0;
|
slot.n_prompt_tokens_processed = 0;
|
||||||
|
}
|
||||||
|
|
||||||
// non-causal tasks require to fit the entire prompt in the physical batch
|
// non-causal tasks require to fit the entire prompt in the physical batch
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||||
|
@ -2036,8 +2040,6 @@ struct server_context {
|
||||||
|
|
||||||
// there is no common part left
|
// there is no common part left
|
||||||
slot.n_past = 0;
|
slot.n_past = 0;
|
||||||
|
|
||||||
common_sampler_reset(slot.smpl);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
|
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
|
||||||
|
@ -2047,10 +2049,10 @@ struct server_context {
|
||||||
|
|
||||||
// add prompt tokens for processing in the current batch
|
// add prompt tokens for processing in the current batch
|
||||||
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
||||||
common_batch_add(batch, slot.prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false);
|
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false);
|
||||||
|
|
||||||
if (slot.params.cache_prompt) {
|
if (slot.params.cache_prompt) {
|
||||||
slot.cache_tokens.push_back(slot.prompt_tokens[slot.n_past]);
|
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.n_prompt_tokens_processed++;
|
slot.n_prompt_tokens_processed++;
|
||||||
|
@ -2065,9 +2067,11 @@ struct server_context {
|
||||||
|
|
||||||
GGML_ASSERT(batch.n_tokens > 0);
|
GGML_ASSERT(batch.n_tokens > 0);
|
||||||
|
|
||||||
|
common_sampler_reset(slot.smpl);
|
||||||
|
|
||||||
// Process all prompt tokens through sampler system
|
// Process all prompt tokens through sampler system
|
||||||
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
|
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
|
||||||
common_sampler_accept(slot.smpl, slot.prompt_tokens[i], false);
|
common_sampler_accept(slot.smpl, prompt_tokens[i], false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// extract the logits only for the last token
|
// extract the logits only for the last token
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue