diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 30ff3b149..055e2c5b8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1090,6 +1090,10 @@ struct server_slot { return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK; } + bool can_batch_with(server_slot & other_slot) { + return is_non_causal() == other_slot.is_non_causal(); + } + bool has_budget(const common_params & global_params) { if (params.n_predict == -1 && global_params.n_predict == -1) { return true; // limitless @@ -2564,11 +2568,8 @@ struct server_context { int32_t n_batch = llama_n_batch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx); - // track if this is an embedding or non-embedding batch - // if we've added sampled tokens above, we are in non-embedding mode - // -1: none, 0: non-embedding, 1: embedding - // TODO: make enum - int32_t batch_type = batch.n_tokens > 0 ? 0 : -1; + // track if given slot can be batched with slots already in the batch + server_slot * slot_batched = nullptr; // next, batch any pending prompts without exceeding n_batch if (params_base.cont_batching || batch.n_tokens == 0) { @@ -2733,11 +2734,10 @@ struct server_context { } } - // check that we are in the right batch_type, if not defer the slot - int slot_type = slot.is_non_causal(); - if (batch_type == -1) { - batch_type = slot_type; - } else if (batch_type != slot_type) { + // check if we can batch this slot with the previous one + if (!slot_batched) { + slot_batched = &slot; + } else if (slot_batched && !slot_batched->can_batch_with(slot)) { continue; } @@ -2809,7 +2809,7 @@ struct server_context { SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); // make sure we're in the right embedding mode - llama_set_embeddings(ctx, batch_type == 1); + llama_set_embeddings(ctx, slot_batched && slot_batched->is_non_causal()); // process the created batch of tokens for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {