fix incorrect if branch

This commit is contained in:
Xuan Son Nguyen 2024-10-23 23:48:08 +02:00
parent 3abc33962e
commit 60d4194bfe

View file

@ -2009,75 +2009,75 @@ struct server_context {
}
slot.n_prompt_tokens_processed = 0;
}
// non-causal tasks require to fit the entire prompt in the physical batch
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
// non-causal tasks require to fit the entire prompt in the physical batch
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
continue;
}
}
// check that we are in the right batch_type, if not defer the slot
const bool slot_type =
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
if (batch_type == -1) {
batch_type = slot_type;
} else if (batch_type != slot_type) {
continue;
}
}
// check that we are in the right batch_type, if not defer the slot
const bool slot_type =
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
// keep only the common part
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1)) {
// could not partially delete (likely using a non-Transformer model)
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
if (batch_type == -1) {
batch_type = slot_type;
} else if (batch_type != slot_type) {
continue;
}
// there is no common part left
slot.n_past = 0;
// keep only the common part
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1)) {
// could not partially delete (likely using a non-Transformer model)
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
// there is no common part left
slot.n_past = 0;
common_sampler_reset(slot.smpl);
}
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
// remove the non-common part from the cache
slot.cache_tokens.resize(slot.n_past);
// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
common_batch_add(batch, slot.prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false);
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(slot.prompt_tokens[slot.n_past]);
common_sampler_reset(slot.smpl);
}
slot.n_prompt_tokens_processed++;
slot.n_past++;
}
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
// remove the non-common part from the cache
slot.cache_tokens.resize(slot.n_past);
// entire prompt has been processed
if (slot.n_past == slot.n_prompt_tokens) {
slot.state = SLOT_STATE_DONE_PROMPT;
// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
common_batch_add(batch, slot.prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false);
GGML_ASSERT(batch.n_tokens > 0);
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(slot.prompt_tokens[slot.n_past]);
}
// Process all prompt tokens through sampler system
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
common_sampler_accept(slot.smpl, slot.prompt_tokens[i], false);
slot.n_prompt_tokens_processed++;
slot.n_past++;
}
// extract the logits only for the last token
batch.logits[batch.n_tokens - 1] = true;
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
// entire prompt has been processed
if (slot.n_past == slot.n_prompt_tokens) {
slot.state = SLOT_STATE_DONE_PROMPT;
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
GGML_ASSERT(batch.n_tokens > 0);
// Process all prompt tokens through sampler system
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
common_sampler_accept(slot.smpl, slot.prompt_tokens[i], false);
}
// extract the logits only for the last token
batch.logits[batch.n_tokens - 1] = true;
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
}
}
if (batch.n_tokens >= n_batch) {