fix incorrect if branch
This commit is contained in:
parent
3abc33962e
commit
60d4194bfe
1 changed files with 53 additions and 53 deletions
|
@ -2009,75 +2009,75 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.n_prompt_tokens_processed = 0;
|
slot.n_prompt_tokens_processed = 0;
|
||||||
}
|
|
||||||
|
|
||||||
// non-causal tasks require to fit the entire prompt in the physical batch
|
// non-causal tasks require to fit the entire prompt in the physical batch
|
||||||
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
||||||
// cannot fit the prompt in the current batch - will try next iter
|
// cannot fit the prompt in the current batch - will try next iter
|
||||||
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check that we are in the right batch_type, if not defer the slot
|
||||||
|
const bool slot_type =
|
||||||
|
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
|
||||||
|
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
|
||||||
|
|
||||||
|
if (batch_type == -1) {
|
||||||
|
batch_type = slot_type;
|
||||||
|
} else if (batch_type != slot_type) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// check that we are in the right batch_type, if not defer the slot
|
// keep only the common part
|
||||||
const bool slot_type =
|
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1)) {
|
||||||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
|
// could not partially delete (likely using a non-Transformer model)
|
||||||
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
|
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
|
||||||
|
|
||||||
if (batch_type == -1) {
|
// there is no common part left
|
||||||
batch_type = slot_type;
|
slot.n_past = 0;
|
||||||
} else if (batch_type != slot_type) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// keep only the common part
|
common_sampler_reset(slot.smpl);
|
||||||
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1)) {
|
|
||||||
// could not partially delete (likely using a non-Transformer model)
|
|
||||||
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
|
|
||||||
|
|
||||||
// there is no common part left
|
|
||||||
slot.n_past = 0;
|
|
||||||
|
|
||||||
common_sampler_reset(slot.smpl);
|
|
||||||
}
|
|
||||||
|
|
||||||
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
|
|
||||||
|
|
||||||
// remove the non-common part from the cache
|
|
||||||
slot.cache_tokens.resize(slot.n_past);
|
|
||||||
|
|
||||||
// add prompt tokens for processing in the current batch
|
|
||||||
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
|
||||||
common_batch_add(batch, slot.prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false);
|
|
||||||
|
|
||||||
if (slot.params.cache_prompt) {
|
|
||||||
slot.cache_tokens.push_back(slot.prompt_tokens[slot.n_past]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.n_prompt_tokens_processed++;
|
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
|
||||||
slot.n_past++;
|
|
||||||
}
|
|
||||||
|
|
||||||
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
// remove the non-common part from the cache
|
||||||
|
slot.cache_tokens.resize(slot.n_past);
|
||||||
|
|
||||||
// entire prompt has been processed
|
// add prompt tokens for processing in the current batch
|
||||||
if (slot.n_past == slot.n_prompt_tokens) {
|
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
||||||
slot.state = SLOT_STATE_DONE_PROMPT;
|
common_batch_add(batch, slot.prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false);
|
||||||
|
|
||||||
GGML_ASSERT(batch.n_tokens > 0);
|
if (slot.params.cache_prompt) {
|
||||||
|
slot.cache_tokens.push_back(slot.prompt_tokens[slot.n_past]);
|
||||||
|
}
|
||||||
|
|
||||||
// Process all prompt tokens through sampler system
|
slot.n_prompt_tokens_processed++;
|
||||||
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
|
slot.n_past++;
|
||||||
common_sampler_accept(slot.smpl, slot.prompt_tokens[i], false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// extract the logits only for the last token
|
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
|
||||||
|
|
||||||
slot.n_decoded = 0;
|
// entire prompt has been processed
|
||||||
slot.i_batch = batch.n_tokens - 1;
|
if (slot.n_past == slot.n_prompt_tokens) {
|
||||||
|
slot.state = SLOT_STATE_DONE_PROMPT;
|
||||||
|
|
||||||
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
|
GGML_ASSERT(batch.n_tokens > 0);
|
||||||
|
|
||||||
|
// Process all prompt tokens through sampler system
|
||||||
|
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
|
||||||
|
common_sampler_accept(slot.smpl, slot.prompt_tokens[i], false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract the logits only for the last token
|
||||||
|
batch.logits[batch.n_tokens - 1] = true;
|
||||||
|
|
||||||
|
slot.n_decoded = 0;
|
||||||
|
slot.i_batch = batch.n_tokens - 1;
|
||||||
|
|
||||||
|
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (batch.n_tokens >= n_batch) {
|
if (batch.n_tokens >= n_batch) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue