slot.can_batch_with
This commit is contained in:
parent
d79d8f39b4
commit
2ba6efc561
1 changed files with 11 additions and 11 deletions
|
@ -1090,6 +1090,10 @@ struct server_slot {
|
||||||
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
|
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool can_batch_with(server_slot & other_slot) {
|
||||||
|
return is_non_causal() == other_slot.is_non_causal();
|
||||||
|
}
|
||||||
|
|
||||||
bool has_budget(const common_params & global_params) {
|
bool has_budget(const common_params & global_params) {
|
||||||
if (params.n_predict == -1 && global_params.n_predict == -1) {
|
if (params.n_predict == -1 && global_params.n_predict == -1) {
|
||||||
return true; // limitless
|
return true; // limitless
|
||||||
|
@ -2564,11 +2568,8 @@ struct server_context {
|
||||||
int32_t n_batch = llama_n_batch(ctx);
|
int32_t n_batch = llama_n_batch(ctx);
|
||||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||||
|
|
||||||
// track if this is an embedding or non-embedding batch
|
// track if given slot can be batched with slots already in the batch
|
||||||
// if we've added sampled tokens above, we are in non-embedding mode
|
server_slot * slot_batched = nullptr;
|
||||||
// -1: none, 0: non-embedding, 1: embedding
|
|
||||||
// TODO: make enum
|
|
||||||
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
|
||||||
|
|
||||||
// next, batch any pending prompts without exceeding n_batch
|
// next, batch any pending prompts without exceeding n_batch
|
||||||
if (params_base.cont_batching || batch.n_tokens == 0) {
|
if (params_base.cont_batching || batch.n_tokens == 0) {
|
||||||
|
@ -2733,11 +2734,10 @@ struct server_context {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// check that we are in the right batch_type, if not defer the slot
|
// check if we can batch this slot with the previous one
|
||||||
int slot_type = slot.is_non_causal();
|
if (!slot_batched) {
|
||||||
if (batch_type == -1) {
|
slot_batched = &slot;
|
||||||
batch_type = slot_type;
|
} else if (slot_batched && !slot_batched->can_batch_with(slot)) {
|
||||||
} else if (batch_type != slot_type) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2809,7 +2809,7 @@ struct server_context {
|
||||||
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
|
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
|
||||||
|
|
||||||
// make sure we're in the right embedding mode
|
// make sure we're in the right embedding mode
|
||||||
llama_set_embeddings(ctx, batch_type == 1);
|
llama_set_embeddings(ctx, slot_batched && slot_batched->is_non_causal());
|
||||||
|
|
||||||
// process the created batch of tokens
|
// process the created batch of tokens
|
||||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue