llama : add new llama_decode() API that works with llama_batch

This commit is contained in:
Georgi Gerganov 2023-09-18 14:23:52 +03:00
parent 58bb5110ca
commit 9f42e75489
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
13 changed files with 146 additions and 75 deletions

View file

@ -77,7 +77,7 @@ int main(int argc, char ** argv) {
while (!embd_inp.empty()) {
int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
if (llama_eval(ctx, embd_inp.data(), n_tokens, n_past, params.n_threads)) {
if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0), params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}