llama : add new llama_decode() API that works with llama_batch

This commit is contained in:
Georgi Gerganov 2023-09-18 14:23:52 +03:00
parent 58bb5110ca
commit 9f42e75489
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
13 changed files with 146 additions and 75 deletions

View file

@ -79,7 +79,8 @@ bool eval_float(void * model, float * input, int N){
if (n_eval > n_batch) {
n_eval = n_batch;
}
if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) {
llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, false };
if (llama_decode(ctx, batch, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
@ -100,7 +101,7 @@ bool eval_tokens(void * model, std::vector<llama_token> tokens) {
if (n_eval > params.n_batch) {
n_eval = params.n_batch;
}
if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) {
if (llama_decode(ctx, llama_batch_get_one(&tokens[i], n_eval, n_past, 0), params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}