parallel : process system prompt once + configurable paramters + llama API

This commit is contained in:
Georgi Gerganov 2023-09-19 17:00:42 +03:00
parent 82e20e9ba0
commit 4b5f3cd6bf
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
9 changed files with 187 additions and 93 deletions

View file

@ -172,7 +172,7 @@ int main(int argc, char ** argv) {
LOG("out of drafted tokens\n");
}
llama_kv_cache_rm_seq(ctx_dft, 0, n_past_dft, n_ctx);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx);
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads);
++n_past_dft;
@ -257,7 +257,7 @@ int main(int argc, char ** argv) {
}
// evaluate the drafted token on the draft model
llama_kv_cache_rm_seq(ctx_dft, 0, n_past_cur, n_ctx);
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx);
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads);
++n_past_cur;
@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
}
// evaluate the target model on the drafted tokens
llama_kv_cache_rm_seq(ctx_tgt, 0, n_past_tgt, n_ctx);
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx);
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads);
++n_past_tgt;