diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 19aff18ae..dff6c68ec 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -105,6 +105,9 @@ int main(int argc, char ** argv) { ctx_params.n_threads = params.n_threads; ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + // ensure enough sequences are available + ctx_params.n_parallel = *std::max_element(n_pl.begin(), n_pl.end()); + llama_context * ctx = llama_new_context_with_model(model, ctx_params); if (ctx == NULL) { @@ -174,10 +177,10 @@ int main(int argc, char ** argv) { llama_batch_clear(batch); - const int n_tokens = is_pp_shared ? pp : pl*pp; - - for (int i = 0; i < n_tokens; ++i) { - llama_batch_add(batch, 0, i, { 0 }, false); + for (int i = 0; i < pp; ++i) { + for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) { + llama_batch_add(batch, 0, i, { j }, false); + } } batch.logits[batch.n_tokens - 1] = true; @@ -192,7 +195,7 @@ int main(int argc, char ** argv) { if (is_pp_shared) { for (int32_t i = 1; i < pl; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, 0, pp); + llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); } } diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 9be7eb56b..dde4d5a06 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -80,6 +80,7 @@ int main(int argc, char ** argv) { ctx_params.seed = 1234; ctx_params.n_ctx = n_kv_req; ctx_params.n_batch = std::max(n_len, n_parallel); + ctx_params.n_parallel = n_parallel; ctx_params.n_threads = params.n_threads; ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; @@ -132,7 +133,7 @@ int main(int argc, char ** argv) { // assign the system KV cache to all parallel sequences // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them for (int32_t i = 1; i < n_parallel; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, 0, batch.n_tokens); + llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); } if (n_parallel > 1) { diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 9ec989389..52789ee63 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -809,7 +809,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { const int n_batch = params.n_batch; const int max_tasks_per_batch = 32; - const int max_seq = 4*max_tasks_per_batch; + const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_max_seq(ctx)); llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); @@ -1086,7 +1086,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { const int n_batch = params.n_batch; const int max_tasks_per_batch = 128; - const int max_seq = 2*max_tasks_per_batch; + const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_max_seq(ctx)); llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); @@ -1438,7 +1438,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params const int n_batch = params.n_batch; const int max_tasks_per_batch = 32; - const int max_seq = 4*max_tasks_per_batch; + const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_max_seq(ctx)); llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); @@ -1815,6 +1815,9 @@ int main(int argc, char ** argv) { llama_model * model; llama_context * ctx; + // ensure there's at least enough seq_ids for HellaSwag + params.n_parallel = std::max(4, params.n_parallel); + // load the model and apply lora adapter, if any std::tie(model, ctx) = llama_init_from_gpt_params(params); if (model == NULL) { diff --git a/llama.cpp b/llama.cpp index b3964810c..f437059b2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12844,6 +12844,10 @@ uint32_t llama_n_batch(const struct llama_context * ctx) { return ctx->cparams.n_batch; } +uint32_t llama_n_max_seq(const struct llama_context * ctx) { + return ctx->kv_self.size; +} + enum llama_vocab_type llama_vocab_type(const struct llama_model * model) { return model->vocab.type; } diff --git a/llama.h b/llama.h index a4675d4c3..f0aca1fe5 100644 --- a/llama.h +++ b/llama.h @@ -377,6 +377,7 @@ extern "C" { LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); + LLAMA_API uint32_t llama_n_max_seq (const struct llama_context * ctx); LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model); LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);