From 00a415d19b2b8b165bacaf1554043ec103b6086c Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 12 Mar 2024 20:44:40 +0100 Subject: [PATCH] llama : limit max batch size to n_batch --- ggml-backend.c | 3 --- llama.cpp | 7 +++---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/ggml-backend.c b/ggml-backend.c index 625b5c8b5..b9ada70f7 100644 --- a/ggml-backend.c +++ b/ggml-backend.c @@ -1609,7 +1609,6 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s if (sched->events[split_backend_id][sched->cur_copy] != NULL) { ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]); } else { - //printf("%s: sync %s\n", __func__, ggml_backend_name(split_backend)); ggml_backend_synchronize(split_backend); } ggml_backend_tensor_copy(input, input_cpy); @@ -1617,12 +1616,10 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s if (sched->events[split_backend_id][sched->cur_copy] != NULL) { ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]); } else { - //printf("%s: sync %s %s\n", __func__, ggml_backend_name(split_backend), ggml_backend_name(input_backend)); ggml_backend_synchronize(split_backend); ggml_backend_synchronize(input_backend); } - // split_backend waits on input_backend and then copies the data ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy); } } diff --git a/llama.cpp b/llama.cpp index 2e018dad1..468d382e8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8770,9 +8770,8 @@ static int llama_decode_internal( GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT - GGML_ASSERT(n_tokens_all <= cparams.n_ctx); + GGML_ASSERT(n_tokens_all <= cparams.n_batch); - //const int64_t t_start_us = ggml_time_us(); if (lctx.t_compute_start_us == 0) { lctx.t_compute_start_us = ggml_time_us(); } @@ -12959,8 +12958,8 @@ struct llama_context * llama_new_context_with_model( // graph outputs buffer { // resized during inference, reserve maximum - ctx->logits_size = hparams.n_vocab*cparams.n_ctx; - ctx->embd_size = params.embeddings ? hparams.n_embd*cparams.n_ctx : 0; + ctx->logits_size = hparams.n_vocab*cparams.n_batch; + ctx->embd_size = params.embeddings ? hparams.n_embd*cparams.n_batch : 0; const size_t buf_output_size = (ctx->logits_size + ctx->embd_size)*sizeof(float);