llama : limit max batch size to n_batch

This commit is contained in:
slaren 2024-03-12 20:44:40 +01:00
parent 937966d75e
commit 00a415d19b
2 changed files with 3 additions and 7 deletions

View file

@ -1609,7 +1609,6 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
} else {
//printf("%s: sync %s\n", __func__, ggml_backend_name(split_backend));
ggml_backend_synchronize(split_backend);
}
ggml_backend_tensor_copy(input, input_cpy);
@ -1617,12 +1616,10 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]);
} else {
//printf("%s: sync %s %s\n", __func__, ggml_backend_name(split_backend), ggml_backend_name(input_backend));
ggml_backend_synchronize(split_backend);
ggml_backend_synchronize(input_backend);
}
// split_backend waits on input_backend and then copies the data
ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy);
}
}

View file

@ -8770,9 +8770,8 @@ static int llama_decode_internal(
GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT
GGML_ASSERT(n_tokens_all <= cparams.n_ctx);
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
//const int64_t t_start_us = ggml_time_us();
if (lctx.t_compute_start_us == 0) {
lctx.t_compute_start_us = ggml_time_us();
}
@ -12959,8 +12958,8 @@ struct llama_context * llama_new_context_with_model(
// graph outputs buffer
{
// resized during inference, reserve maximum
ctx->logits_size = hparams.n_vocab*cparams.n_ctx;
ctx->embd_size = params.embeddings ? hparams.n_embd*cparams.n_ctx : 0;
ctx->logits_size = hparams.n_vocab*cparams.n_batch;
ctx->embd_size = params.embeddings ? hparams.n_embd*cparams.n_batch : 0;
const size_t buf_output_size = (ctx->logits_size + ctx->embd_size)*sizeof(float);