llama : limit max batch size to n_batch
This commit is contained in:
parent
937966d75e
commit
00a415d19b
2 changed files with 3 additions and 7 deletions
|
@ -1609,7 +1609,6 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
|
|||
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
|
||||
ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
|
||||
} else {
|
||||
//printf("%s: sync %s\n", __func__, ggml_backend_name(split_backend));
|
||||
ggml_backend_synchronize(split_backend);
|
||||
}
|
||||
ggml_backend_tensor_copy(input, input_cpy);
|
||||
|
@ -1617,12 +1616,10 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
|
|||
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
|
||||
ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]);
|
||||
} else {
|
||||
//printf("%s: sync %s %s\n", __func__, ggml_backend_name(split_backend), ggml_backend_name(input_backend));
|
||||
ggml_backend_synchronize(split_backend);
|
||||
ggml_backend_synchronize(input_backend);
|
||||
}
|
||||
|
||||
// split_backend waits on input_backend and then copies the data
|
||||
ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8770,9 +8770,8 @@ static int llama_decode_internal(
|
|||
|
||||
GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT
|
||||
|
||||
GGML_ASSERT(n_tokens_all <= cparams.n_ctx);
|
||||
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
||||
|
||||
//const int64_t t_start_us = ggml_time_us();
|
||||
if (lctx.t_compute_start_us == 0) {
|
||||
lctx.t_compute_start_us = ggml_time_us();
|
||||
}
|
||||
|
@ -12959,8 +12958,8 @@ struct llama_context * llama_new_context_with_model(
|
|||
// graph outputs buffer
|
||||
{
|
||||
// resized during inference, reserve maximum
|
||||
ctx->logits_size = hparams.n_vocab*cparams.n_ctx;
|
||||
ctx->embd_size = params.embeddings ? hparams.n_embd*cparams.n_ctx : 0;
|
||||
ctx->logits_size = hparams.n_vocab*cparams.n_batch;
|
||||
ctx->embd_size = params.embeddings ? hparams.n_embd*cparams.n_batch : 0;
|
||||
|
||||
const size_t buf_output_size = (ctx->logits_size + ctx->embd_size)*sizeof(float);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue