diff --git a/examples/server/server.cpp b/examples/server/server.cpp index fc382c68d..734bde5c2 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2009,75 +2009,75 @@ struct server_context { } slot.n_prompt_tokens_processed = 0; - } - // non-causal tasks require to fit the entire prompt in the physical batch - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { - // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { + // non-causal tasks require to fit the entire prompt in the physical batch + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { + // cannot fit the prompt in the current batch - will try next iter + if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { + continue; + } + } + + // check that we are in the right batch_type, if not defer the slot + const bool slot_type = + slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || + slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0; + + if (batch_type == -1) { + batch_type = slot_type; + } else if (batch_type != slot_type) { continue; } - } - // check that we are in the right batch_type, if not defer the slot - const bool slot_type = - slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || - slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0; + // keep only the common part + if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1)) { + // could not partially delete (likely using a non-Transformer model) + llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1); - if (batch_type == -1) { - batch_type = slot_type; - } else if (batch_type != slot_type) { - continue; - } + // there is no common part left + slot.n_past = 0; - // keep only the common part - if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1)) { - // could not partially delete (likely using a non-Transformer model) - llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1); - - // there is no common part left - slot.n_past = 0; - - common_sampler_reset(slot.smpl); - } - - SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); - - // remove the non-common part from the cache - slot.cache_tokens.resize(slot.n_past); - - // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { - common_batch_add(batch, slot.prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false); - - if (slot.params.cache_prompt) { - slot.cache_tokens.push_back(slot.prompt_tokens[slot.n_past]); + common_sampler_reset(slot.smpl); } - slot.n_prompt_tokens_processed++; - slot.n_past++; - } + SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); - SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + // remove the non-common part from the cache + slot.cache_tokens.resize(slot.n_past); - // entire prompt has been processed - if (slot.n_past == slot.n_prompt_tokens) { - slot.state = SLOT_STATE_DONE_PROMPT; + // add prompt tokens for processing in the current batch + while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + common_batch_add(batch, slot.prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false); - GGML_ASSERT(batch.n_tokens > 0); + if (slot.params.cache_prompt) { + slot.cache_tokens.push_back(slot.prompt_tokens[slot.n_past]); + } - // Process all prompt tokens through sampler system - for (int i = 0; i < slot.n_prompt_tokens; ++i) { - common_sampler_accept(slot.smpl, slot.prompt_tokens[i], false); + slot.n_prompt_tokens_processed++; + slot.n_past++; } - // extract the logits only for the last token - batch.logits[batch.n_tokens - 1] = true; + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); - slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; + // entire prompt has been processed + if (slot.n_past == slot.n_prompt_tokens) { + slot.state = SLOT_STATE_DONE_PROMPT; - SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); + GGML_ASSERT(batch.n_tokens > 0); + + // Process all prompt tokens through sampler system + for (int i = 0; i < slot.n_prompt_tokens; ++i) { + common_sampler_accept(slot.smpl, slot.prompt_tokens[i], false); + } + + // extract the logits only for the last token + batch.logits[batch.n_tokens - 1] = true; + + slot.n_decoded = 0; + slot.i_batch = batch.n_tokens - 1; + + SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); + } } if (batch.n_tokens >= n_batch) {