slot.can_batch_with

This commit is contained in:
Xuan Son Nguyen 2024-12-27 11:28:25 +01:00
parent d79d8f39b4
commit 2ba6efc561

View file

@ -1090,6 +1090,10 @@ struct server_slot {
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK; return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
} }
bool can_batch_with(server_slot & other_slot) {
return is_non_causal() == other_slot.is_non_causal();
}
bool has_budget(const common_params & global_params) { bool has_budget(const common_params & global_params) {
if (params.n_predict == -1 && global_params.n_predict == -1) { if (params.n_predict == -1 && global_params.n_predict == -1) {
return true; // limitless return true; // limitless
@ -2564,11 +2568,8 @@ struct server_context {
int32_t n_batch = llama_n_batch(ctx); int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx);
// track if this is an embedding or non-embedding batch // track if given slot can be batched with slots already in the batch
// if we've added sampled tokens above, we are in non-embedding mode server_slot * slot_batched = nullptr;
// -1: none, 0: non-embedding, 1: embedding
// TODO: make enum
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
// next, batch any pending prompts without exceeding n_batch // next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) { if (params_base.cont_batching || batch.n_tokens == 0) {
@ -2733,11 +2734,10 @@ struct server_context {
} }
} }
// check that we are in the right batch_type, if not defer the slot // check if we can batch this slot with the previous one
int slot_type = slot.is_non_causal(); if (!slot_batched) {
if (batch_type == -1) { slot_batched = &slot;
batch_type = slot_type; } else if (slot_batched && !slot_batched->can_batch_with(slot)) {
} else if (batch_type != slot_type) {
continue; continue;
} }
@ -2809,7 +2809,7 @@ struct server_context {
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
// make sure we're in the right embedding mode // make sure we're in the right embedding mode
llama_set_embeddings(ctx, batch_type == 1); llama_set_embeddings(ctx, slot_batched && slot_batched->is_non_causal());
// process the created batch of tokens // process the created batch of tokens
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {