llama : add new llama_decode() API that works with llama_batch
This commit is contained in:
parent
58bb5110ca
commit
9f42e75489
13 changed files with 146 additions and 75 deletions
|
@ -199,7 +199,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
|||
const int batch_size = std::min(end - batch_start, n_batch);
|
||||
|
||||
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
||||
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) {
|
||||
//fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return {tokens, -1, logit_history, prob_history};
|
||||
}
|
||||
|
@ -331,7 +331,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
tokens[batch_start] = llama_token_bos(ctx);
|
||||
}
|
||||
|
||||
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return {tokens, -1, logit_history, prob_history};
|
||||
}
|
||||
|
@ -409,7 +409,7 @@ static std::vector<float> hellaswag_evaluate_tokens(
|
|||
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
|
||||
size_t n_tokens = tokens.size() - i_chunk * n_batch;
|
||||
n_tokens = std::min(n_tokens, size_t(n_batch));
|
||||
if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) {
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0), n_thread)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return {};
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue