From b550011be3e181690583cc20be4cfa44f1b7befb Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 24 Oct 2024 00:03:00 +0200 Subject: [PATCH] fix infinite generation loop --- examples/server/server.cpp | 236 +++++++++++++++++++------------------ 1 file changed, 120 insertions(+), 116 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 734bde5c2..cdb97419f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -68,6 +68,7 @@ enum stop_type { // state diagram: https://github.com/ggerganov/llama.cpp/pull/9283 enum slot_state { SLOT_STATE_IDLE, + SLOT_STATE_STARTED, SLOT_STATE_PROCESSING_PROMPT, SLOT_STATE_DONE_PROMPT, SLOT_STATE_GENERATING, @@ -950,7 +951,7 @@ struct server_context { } } - slot.state = SLOT_STATE_PROCESSING_PROMPT; + slot.state = SLOT_STATE_STARTED; SLT_INF(slot, "%s", "processing task\n"); @@ -1867,149 +1868,152 @@ struct server_context { if (params.cont_batching || batch.n_tokens == 0) { for (auto & slot : slots) { // this slot still has a prompt to be processed - if (slot.state == SLOT_STATE_PROCESSING_PROMPT) { + if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { auto & prompt_tokens = slot.prompt_tokens; - slot.t_start_process_prompt = ggml_time_us(); - slot.t_start_generation = 0; - slot.n_past = 0; - slot.n_prompt_tokens = prompt_tokens.size(); + // TODO: maybe move branch to outside of this loop in the future + if (slot.state == SLOT_STATE_STARTED) { + slot.t_start_process_prompt = ggml_time_us(); + slot.t_start_generation = 0; + slot.n_past = 0; + slot.n_prompt_tokens = prompt_tokens.size(); + slot.state = SLOT_STATE_PROCESSING_PROMPT; - // empty prompt passed -> release the slot and send empty response - if (prompt_tokens.empty()) { - SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); + SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); - slot.release(); - slot.print_timings(); - send_final_response(slot); - continue; - } - - SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); - - // print prompt tokens (for debugging) - if (1) { - // first 16 tokens (avoid flooding logs) - for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + // print prompt tokens (for debugging) + if (1) { + // first 16 tokens (avoid flooding logs) + for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + } + } else { + // all + for (int i = 0; i < (int) prompt_tokens.size(); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + } } - } else { - // all - for (int i = 0; i < (int) prompt_tokens.size(); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); - } - } - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { - // this prompt is too large to process - discard it - if (slot.n_prompt_tokens > n_ubatch) { + // empty prompt passed -> release the slot and send empty response + if (prompt_tokens.empty()) { + SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); + slot.release(); - send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); + slot.print_timings(); + send_final_response(slot); continue; } - } else { - if (!params.ctx_shift) { - // if context shift is disabled, we make sure prompt size is smaller than KV size - // TODO: there should be a separate parameter that control prompt truncation - // context shift should be applied only during the generation phase - if (slot.n_prompt_tokens >= slot.n_ctx) { + + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { + // this prompt is too large to process - discard it + if (slot.n_prompt_tokens > n_ubatch) { slot.release(); - send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST); + send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); continue; } - } - if (slot.params.n_keep < 0) { - slot.params.n_keep = slot.n_prompt_tokens; - } - slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); + } else { + if (!params.ctx_shift) { + // if context shift is disabled, we make sure prompt size is smaller than KV size + // TODO: there should be a separate parameter that control prompt truncation + // context shift should be applied only during the generation phase + if (slot.n_prompt_tokens >= slot.n_ctx) { + slot.release(); + send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST); + continue; + } + } + if (slot.params.n_keep < 0) { + slot.params.n_keep = slot.n_prompt_tokens; + } + slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); - // if input prompt is too big, truncate it - if (slot.n_prompt_tokens >= slot.n_ctx) { - const int n_left = slot.n_ctx - slot.params.n_keep; + // if input prompt is too big, truncate it + if (slot.n_prompt_tokens >= slot.n_ctx) { + const int n_left = slot.n_ctx - slot.params.n_keep; - const int n_block_size = n_left / 2; - const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; + const int n_block_size = n_left / 2; + const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; - std::vector new_tokens( - prompt_tokens.begin(), - prompt_tokens.begin() + slot.params.n_keep); + std::vector new_tokens( + prompt_tokens.begin(), + prompt_tokens.begin() + slot.params.n_keep); - new_tokens.insert( - new_tokens.end(), - prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, - prompt_tokens.end()); + new_tokens.insert( + new_tokens.end(), + prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, + prompt_tokens.end()); - prompt_tokens = std::move(new_tokens); + prompt_tokens = std::move(new_tokens); - slot.truncated = true; - slot.n_prompt_tokens = prompt_tokens.size(); + slot.truncated = true; + slot.n_prompt_tokens = prompt_tokens.size(); - SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens); + SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens); - GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); - } + GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); + } - common_sampler_reset(slot.smpl); + if (slot.params.cache_prompt) { + // reuse any previously computed tokens that are common with the new prompt + slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens); - if (slot.params.cache_prompt) { - // reuse any previously computed tokens that are common with the new prompt - slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens); + // reuse chunks from the cached prompt by shifting their KV cache in the new position + if (params.n_cache_reuse > 0) { + size_t head_c = slot.n_past; // cache + size_t head_p = slot.n_past; // current prompt - // reuse chunks from the cached prompt by shifting their KV cache in the new position - if (params.n_cache_reuse > 0) { - size_t head_c = slot.n_past; // cache - size_t head_p = slot.n_past; // current prompt + SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params.n_cache_reuse, slot.n_past); - SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params.n_cache_reuse, slot.n_past); + while (head_c < slot.cache_tokens.size() && + head_p < prompt_tokens.size()) { - while (head_c < slot.cache_tokens.size() && - head_p < prompt_tokens.size()) { + size_t n_match = 0; + while (head_c + n_match < slot.cache_tokens.size() && + head_p + n_match < prompt_tokens.size() && + slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { - size_t n_match = 0; - while (head_c + n_match < slot.cache_tokens.size() && - head_p + n_match < prompt_tokens.size() && - slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { - - n_match++; - } - - if (n_match >= (size_t) params.n_cache_reuse) { - SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); - //for (size_t i = head_p; i < head_p + n_match; i++) { - // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); - //} - - const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; - - llama_kv_cache_seq_rm (ctx, slot.id + 1, head_p, head_c); - llama_kv_cache_seq_add(ctx, slot.id + 1, head_c, -1, kv_shift); - - for (size_t i = 0; i < n_match; i++) { - slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i]; - slot.n_past++; + n_match++; } - head_c += n_match; - head_p += n_match; - } else { - head_c += 1; - } - } + if (n_match >= (size_t) params.n_cache_reuse) { + SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); + //for (size_t i = head_p; i < head_p + n_match; i++) { + // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + //} - SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); + const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; + + llama_kv_cache_seq_rm (ctx, slot.id + 1, head_p, head_c); + llama_kv_cache_seq_add(ctx, slot.id + 1, head_c, -1, kv_shift); + + for (size_t i = 0; i < n_match; i++) { + slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i]; + + slot.n_past++; + } + + head_c += n_match; + head_p += n_match; + } else { + head_c += 1; + } + } + + SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); + } } } + + if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { + // we have to evaluate at least 1 token to generate logits. + SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens); + + slot.n_past--; + } + + slot.n_prompt_tokens_processed = 0; } - if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { - // we have to evaluate at least 1 token to generate logits. - SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens); - - slot.n_past--; - } - - slot.n_prompt_tokens_processed = 0; - // non-causal tasks require to fit the entire prompt in the physical batch if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { // cannot fit the prompt in the current batch - will try next iter @@ -2036,8 +2040,6 @@ struct server_context { // there is no common part left slot.n_past = 0; - - common_sampler_reset(slot.smpl); } SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); @@ -2047,10 +2049,10 @@ struct server_context { // add prompt tokens for processing in the current batch while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { - common_batch_add(batch, slot.prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false); + common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false); if (slot.params.cache_prompt) { - slot.cache_tokens.push_back(slot.prompt_tokens[slot.n_past]); + slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); } slot.n_prompt_tokens_processed++; @@ -2065,9 +2067,11 @@ struct server_context { GGML_ASSERT(batch.n_tokens > 0); + common_sampler_reset(slot.smpl); + // Process all prompt tokens through sampler system for (int i = 0; i < slot.n_prompt_tokens; ++i) { - common_sampler_accept(slot.smpl, slot.prompt_tokens[i], false); + common_sampler_accept(slot.smpl, prompt_tokens[i], false); } // extract the logits only for the last token