llama : remove all_pos_0, all_pos_1, all_seq_id from llama_batch (#9745)

* refactor llama_batch_get_one

* adapt all examples

* fix simple.cpp

* fix llama_bench

* fix

* fix context shifting

* free batch before return

* use common_batch_add, reuse llama_batch in loop

* null terminated seq_id list

* fix save-load-state example

* fix perplexity

* correct token pos in llama_batch_allocr
This commit is contained in:
Xuan Son Nguyen 2024-10-18 23:18:01 +02:00 committed by GitHub
parent afd9909a64
commit cda0e4b648
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 205 additions and 118 deletions

View file

@ -1428,7 +1428,7 @@ struct sql_printer : public printer {
}
};
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads);
const llama_model * model = llama_get_model(ctx);
@ -1444,14 +1444,14 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
for (int i = 1; i < n_tokens; i++) {
tokens[i] = std::rand() % n_vocab;
}
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0));
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
n_processed += n_tokens;
}
llama_synchronize(ctx);
}
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads);
const llama_model * model = llama_get_model(ctx);
@ -1460,7 +1460,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads)
llama_token token = llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab;
for (int i = 0; i < n_gen; i++) {
llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0));
llama_decode(ctx, llama_batch_get_one(&token, 1));
llama_synchronize(ctx);
token = std::rand() % n_vocab;
}
@ -1596,13 +1596,13 @@ int main(int argc, char ** argv) {
fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup prompt run\n", params_idx, params_count);
}
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
}
if (t.n_gen > 0) {
if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup generation run\n", params_idx, params_count);
}
test_gen(ctx, 1, 0, t.n_threads);
test_gen(ctx, 1, t.n_threads);
}
for (int i = 0; i < params.reps; i++) {
@ -1614,13 +1614,13 @@ int main(int argc, char ** argv) {
if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%ld: prompt run %d/%d\n", params_idx, params_count, i + 1, params.reps);
}
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
}
if (t.n_gen > 0) {
if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%ld: generation run %d/%d\n", params_idx, params_count, i + 1, params.reps);
}
test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
test_gen(ctx, t.n_gen, t.n_threads);
}
uint64_t t_ns = get_time_ns() - t_start;