llama : unified KV cache + batch inference API

This commit is contained in:
Georgi Gerganov 2023-09-18 10:08:22 +03:00
parent fad56936d4
commit d29e76937c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
10 changed files with 315 additions and 236 deletions

View file

@ -73,10 +73,12 @@ int main(int argc, char ** argv) {
const int n_gen = std::min(32, max_context_size);
while (llama_get_kv_cache_token_count(ctx) < n_gen) {
int n_cur = 0;
while (n_cur < n_gen) {
// evaluate the transformer
if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx), params.n_threads)) {
if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), n_cur, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}