move can_batch_with check

This commit is contained in:
Xuan Son Nguyen 2024-12-27 20:22:49 +01:00
parent 9947b0776f
commit b9b2b6371a
2 changed files with 8 additions and 9 deletions

View file

@ -2588,6 +2588,13 @@ struct server_context {
// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) {
for (auto & slot : slots) {
// check if we can batch this slot with the previous one
if (!slot_batched) {
slot_batched = &slot;
} else if (slot_batched && !slot_batched->can_batch_with(slot)) {
continue;
}
// this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
auto & prompt_tokens = slot.prompt_tokens;
@ -2748,13 +2755,6 @@ struct server_context {
}
}
// check if we can batch this slot with the previous one
if (!slot_batched) {
slot_batched = &slot;
} else if (slot_batched && !slot_batched->can_batch_with(slot)) {
continue;
}
// keep only the common part
if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
// could not partially delete (likely using a non-Transformer model)

View file

@ -68,10 +68,9 @@ def test_lora_per_request():
"temperature": 0.0,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
) for lora, re_test in lora_config]
) for lora, _ in lora_config]
results = parallel_function_calls(tasks)
print(results)
assert all([res.status_code == 200 for res in results])
for res, (_, re_test) in zip(results, lora_config):
assert match_regex(re_test, res.body["content"])