llama : remove all_pos_0, all_pos_1, all_seq_id from llama_batch (#9745)
* refactor llama_batch_get_one * adapt all examples * fix simple.cpp * fix llama_bench * fix * fix context shifting * free batch before return * use common_batch_add, reuse llama_batch in loop * null terminated seq_id list * fix save-load-state example * fix perplexity * correct token pos in llama_batch_allocr
This commit is contained in:
parent
afd9909a64
commit
cda0e4b648
22 changed files with 205 additions and 118 deletions
|
@ -1428,7 +1428,7 @@ struct sql_printer : public printer {
|
|||
}
|
||||
};
|
||||
|
||||
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
|
||||
static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
|
||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
|
@ -1444,14 +1444,14 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
|
|||
for (int i = 1; i < n_tokens; i++) {
|
||||
tokens[i] = std::rand() % n_vocab;
|
||||
}
|
||||
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0));
|
||||
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
|
||||
n_processed += n_tokens;
|
||||
}
|
||||
|
||||
llama_synchronize(ctx);
|
||||
}
|
||||
|
||||
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
|
||||
static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
|
||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
|
@ -1460,7 +1460,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads)
|
|||
llama_token token = llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab;
|
||||
|
||||
for (int i = 0; i < n_gen; i++) {
|
||||
llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0));
|
||||
llama_decode(ctx, llama_batch_get_one(&token, 1));
|
||||
llama_synchronize(ctx);
|
||||
token = std::rand() % n_vocab;
|
||||
}
|
||||
|
@ -1596,13 +1596,13 @@ int main(int argc, char ** argv) {
|
|||
fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup prompt run\n", params_idx, params_count);
|
||||
}
|
||||
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
|
||||
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
|
||||
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
|
||||
}
|
||||
if (t.n_gen > 0) {
|
||||
if (params.progress) {
|
||||
fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup generation run\n", params_idx, params_count);
|
||||
}
|
||||
test_gen(ctx, 1, 0, t.n_threads);
|
||||
test_gen(ctx, 1, t.n_threads);
|
||||
}
|
||||
|
||||
for (int i = 0; i < params.reps; i++) {
|
||||
|
@ -1614,13 +1614,13 @@ int main(int argc, char ** argv) {
|
|||
if (params.progress) {
|
||||
fprintf(stderr, "llama-bench: benchmark %d/%ld: prompt run %d/%d\n", params_idx, params_count, i + 1, params.reps);
|
||||
}
|
||||
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
|
||||
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
|
||||
}
|
||||
if (t.n_gen > 0) {
|
||||
if (params.progress) {
|
||||
fprintf(stderr, "llama-bench: benchmark %d/%ld: generation run %d/%d\n", params_idx, params_count, i + 1, params.reps);
|
||||
}
|
||||
test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
|
||||
test_gen(ctx, t.n_gen, t.n_threads);
|
||||
}
|
||||
|
||||
uint64_t t_ns = get_time_ns() - t_start;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue