slot.can_batch_with

This commit is contained in:
Xuan Son Nguyen 2024-12-27 11:28:25 +01:00
parent d79d8f39b4
commit 2ba6efc561

View file

@ -1090,6 +1090,10 @@ struct server_slot {
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
}
bool can_batch_with(server_slot & other_slot) {
return is_non_causal() == other_slot.is_non_causal();
}
bool has_budget(const common_params & global_params) {
if (params.n_predict == -1 && global_params.n_predict == -1) {
return true; // limitless
@ -2564,11 +2568,8 @@ struct server_context {
int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);
// track if this is an embedding or non-embedding batch
// if we've added sampled tokens above, we are in non-embedding mode
// -1: none, 0: non-embedding, 1: embedding
// TODO: make enum
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
// track if given slot can be batched with slots already in the batch
server_slot * slot_batched = nullptr;
// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) {
@ -2733,11 +2734,10 @@ struct server_context {
}
}
// check that we are in the right batch_type, if not defer the slot
int slot_type = slot.is_non_causal();
if (batch_type == -1) {
batch_type = slot_type;
} else if (batch_type != slot_type) {
// check if we can batch this slot with the previous one
if (!slot_batched) {
slot_batched = &slot;
} else if (slot_batched && !slot_batched->can_batch_with(slot)) {
continue;
}
@ -2809,7 +2809,7 @@ struct server_context {
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
// make sure we're in the right embedding mode
llama_set_embeddings(ctx, batch_type == 1);
llama_set_embeddings(ctx, slot_batched && slot_batched->is_non_causal());
// process the created batch of tokens
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {