mamba : adapt perplexity, batched, and batched-bench examples
* perplexity : limit the max number of sequences This adapts to what the loaded model can provide. * llama : add llama_n_max_seq to get the upper limit for seq_ids Used by the perplexity example. * batched : pass n_parallel to the model's context params This should have been there already, but it wasn't. * batched-bench : reserve sequences to support Mamba * batched-bench : fix tokens being put in wrong sequences Generation quality isn't what's measured in there anyway, but at least using the correct sequences avoids using non-consecutive token positions.
This commit is contained in:
parent
79d636cc7e
commit
8f605cfe0d
5 changed files with 21 additions and 9 deletions
|
@ -105,6 +105,9 @@ int main(int argc, char ** argv) {
|
||||||
ctx_params.n_threads = params.n_threads;
|
ctx_params.n_threads = params.n_threads;
|
||||||
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
||||||
|
|
||||||
|
// ensure enough sequences are available
|
||||||
|
ctx_params.n_parallel = *std::max_element(n_pl.begin(), n_pl.end());
|
||||||
|
|
||||||
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
||||||
|
|
||||||
if (ctx == NULL) {
|
if (ctx == NULL) {
|
||||||
|
@ -174,10 +177,10 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
llama_batch_clear(batch);
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
const int n_tokens = is_pp_shared ? pp : pl*pp;
|
for (int i = 0; i < pp; ++i) {
|
||||||
|
for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
llama_batch_add(batch, 0, i, { j }, false);
|
||||||
llama_batch_add(batch, 0, i, { 0 }, false);
|
}
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.logits[batch.n_tokens - 1] = true;
|
||||||
|
|
||||||
|
@ -192,7 +195,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
if (is_pp_shared) {
|
if (is_pp_shared) {
|
||||||
for (int32_t i = 1; i < pl; ++i) {
|
for (int32_t i = 1; i < pl; ++i) {
|
||||||
llama_kv_cache_seq_cp(ctx, 0, i, 0, pp);
|
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -80,6 +80,7 @@ int main(int argc, char ** argv) {
|
||||||
ctx_params.seed = 1234;
|
ctx_params.seed = 1234;
|
||||||
ctx_params.n_ctx = n_kv_req;
|
ctx_params.n_ctx = n_kv_req;
|
||||||
ctx_params.n_batch = std::max(n_len, n_parallel);
|
ctx_params.n_batch = std::max(n_len, n_parallel);
|
||||||
|
ctx_params.n_parallel = n_parallel;
|
||||||
ctx_params.n_threads = params.n_threads;
|
ctx_params.n_threads = params.n_threads;
|
||||||
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
|
||||||
|
|
||||||
|
@ -132,7 +133,7 @@ int main(int argc, char ** argv) {
|
||||||
// assign the system KV cache to all parallel sequences
|
// assign the system KV cache to all parallel sequences
|
||||||
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
|
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
|
||||||
for (int32_t i = 1; i < n_parallel; ++i) {
|
for (int32_t i = 1; i < n_parallel; ++i) {
|
||||||
llama_kv_cache_seq_cp(ctx, 0, i, 0, batch.n_tokens);
|
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_parallel > 1) {
|
if (n_parallel > 1) {
|
||||||
|
|
|
@ -809,7 +809,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||||
const int n_batch = params.n_batch;
|
const int n_batch = params.n_batch;
|
||||||
|
|
||||||
const int max_tasks_per_batch = 32;
|
const int max_tasks_per_batch = 32;
|
||||||
const int max_seq = 4*max_tasks_per_batch;
|
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_max_seq(ctx));
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
||||||
|
|
||||||
|
@ -1086,7 +1086,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
||||||
const int n_batch = params.n_batch;
|
const int n_batch = params.n_batch;
|
||||||
|
|
||||||
const int max_tasks_per_batch = 128;
|
const int max_tasks_per_batch = 128;
|
||||||
const int max_seq = 2*max_tasks_per_batch;
|
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_max_seq(ctx));
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
||||||
|
|
||||||
|
@ -1438,7 +1438,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
|
||||||
const int n_batch = params.n_batch;
|
const int n_batch = params.n_batch;
|
||||||
|
|
||||||
const int max_tasks_per_batch = 32;
|
const int max_tasks_per_batch = 32;
|
||||||
const int max_seq = 4*max_tasks_per_batch;
|
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_max_seq(ctx));
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
||||||
|
|
||||||
|
@ -1815,6 +1815,9 @@ int main(int argc, char ** argv) {
|
||||||
llama_model * model;
|
llama_model * model;
|
||||||
llama_context * ctx;
|
llama_context * ctx;
|
||||||
|
|
||||||
|
// ensure there's at least enough seq_ids for HellaSwag
|
||||||
|
params.n_parallel = std::max(4, params.n_parallel);
|
||||||
|
|
||||||
// load the model and apply lora adapter, if any
|
// load the model and apply lora adapter, if any
|
||||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||||
if (model == NULL) {
|
if (model == NULL) {
|
||||||
|
|
|
@ -12844,6 +12844,10 @@ uint32_t llama_n_batch(const struct llama_context * ctx) {
|
||||||
return ctx->cparams.n_batch;
|
return ctx->cparams.n_batch;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t llama_n_max_seq(const struct llama_context * ctx) {
|
||||||
|
return ctx->kv_self.size;
|
||||||
|
}
|
||||||
|
|
||||||
enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
|
enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
|
||||||
return model->vocab.type;
|
return model->vocab.type;
|
||||||
}
|
}
|
||||||
|
|
1
llama.h
1
llama.h
|
@ -377,6 +377,7 @@ extern "C" {
|
||||||
|
|
||||||
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
|
||||||
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
|
||||||
|
LLAMA_API uint32_t llama_n_max_seq (const struct llama_context * ctx);
|
||||||
|
|
||||||
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
|
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
|
||||||
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue