fix llama_bench
This commit is contained in:
parent
92769503dc
commit
59fd6b6119
2 changed files with 7 additions and 7 deletions
|
@ -1428,7 +1428,7 @@ struct sql_printer : public printer {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
|
static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
|
||||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||||
|
|
||||||
const llama_model * model = llama_get_model(ctx);
|
const llama_model * model = llama_get_model(ctx);
|
||||||
|
@ -1451,7 +1451,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
|
static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
|
||||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||||
|
|
||||||
const llama_model * model = llama_get_model(ctx);
|
const llama_model * model = llama_get_model(ctx);
|
||||||
|
@ -1596,13 +1596,13 @@ int main(int argc, char ** argv) {
|
||||||
fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup prompt run\n", params_idx, params_count);
|
fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup prompt run\n", params_idx, params_count);
|
||||||
}
|
}
|
||||||
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
|
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
|
||||||
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
|
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
|
||||||
}
|
}
|
||||||
if (t.n_gen > 0) {
|
if (t.n_gen > 0) {
|
||||||
if (params.progress) {
|
if (params.progress) {
|
||||||
fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup generation run\n", params_idx, params_count);
|
fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup generation run\n", params_idx, params_count);
|
||||||
}
|
}
|
||||||
test_gen(ctx, 1, 0, t.n_threads);
|
test_gen(ctx, 1, t.n_threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < params.reps; i++) {
|
for (int i = 0; i < params.reps; i++) {
|
||||||
|
@ -1614,13 +1614,13 @@ int main(int argc, char ** argv) {
|
||||||
if (params.progress) {
|
if (params.progress) {
|
||||||
fprintf(stderr, "llama-bench: benchmark %d/%ld: prompt run %d/%d\n", params_idx, params_count, i + 1, params.reps);
|
fprintf(stderr, "llama-bench: benchmark %d/%ld: prompt run %d/%d\n", params_idx, params_count, i + 1, params.reps);
|
||||||
}
|
}
|
||||||
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
|
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
|
||||||
}
|
}
|
||||||
if (t.n_gen > 0) {
|
if (t.n_gen > 0) {
|
||||||
if (params.progress) {
|
if (params.progress) {
|
||||||
fprintf(stderr, "llama-bench: benchmark %d/%ld: generation run %d/%d\n", params_idx, params_count, i + 1, params.reps);
|
fprintf(stderr, "llama-bench: benchmark %d/%ld: generation run %d/%d\n", params_idx, params_count, i + 1, params.reps);
|
||||||
}
|
}
|
||||||
test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
|
test_gen(ctx, t.n_gen, t.n_threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t t_ns = get_time_ns() - t_start;
|
uint64_t t_ns = get_time_ns() - t_start;
|
||||||
|
|
|
@ -21093,7 +21093,7 @@ struct llama_batch_allocr {
|
||||||
struct llama_batch batch = in_batch;
|
struct llama_batch batch = in_batch;
|
||||||
if (!batch.pos) {
|
if (!batch.pos) {
|
||||||
// determine the last position in KV cache
|
// determine the last position in KV cache
|
||||||
llama_pos last_pos;
|
llama_pos last_pos = 0;
|
||||||
for (const auto & cell : ctx->kv_self.cells) {
|
for (const auto & cell : ctx->kv_self.cells) {
|
||||||
if (cell.seq_id.find(default_seq_id) != cell.seq_id.end()) {
|
if (cell.seq_id.find(default_seq_id) != cell.seq_id.end()) {
|
||||||
last_pos = std::max(last_pos, cell.pos);
|
last_pos = std::max(last_pos, cell.pos);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue