llama : add new llama_decode() API that works with llama_batch

This commit is contained in:
Georgi Gerganov 2023-09-18 14:23:52 +03:00
parent 58bb5110ca
commit 9f42e75489
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
13 changed files with 146 additions and 75 deletions

View file

@ -891,7 +891,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
int n_processed = 0;
while (n_processed < n_prompt) {
int n_tokens = std::min(n_prompt - n_processed, n_batch);
llama_eval(ctx, tokens.data(), n_tokens, n_past + n_processed, n_threads);
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0), n_threads);
n_processed += n_tokens;
}
}
@ -899,7 +899,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
llama_token token = llama_token_bos(ctx);
for (int i = 0; i < n_gen; i++) {
llama_eval(ctx, &token, 1, n_past + i, n_threads);
llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0), n_threads);
}
}