llama : unified KV cache + batch inference API

This commit is contained in:
Georgi Gerganov 2023-09-18 10:08:22 +03:00
parent fad56936d4
commit d29e76937c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
10 changed files with 315 additions and 236 deletions

View file

@ -158,7 +158,8 @@ int main(int argc, char ** argv)
}
std::cout << std::flush;
int n_past = llama_get_kv_cache_token_count(ctx);
int n_past = 0;
if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads))
{
fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );

View file

@ -198,15 +198,6 @@ int main(int argc, char ** argv) {
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
}
// export the cgraph and exit
if (params.export_cgraph) {
llama_eval_export(ctx, "llama.ggml");
llama_free(ctx);
llama_free_model(model);
return 0;
}
std::string path_session = params.path_prompt_cache;
std::vector<llama_token> session_tokens;

View file

@ -400,7 +400,7 @@ results_perplexity perplexity(llama_context * ctx, const gpt_params & params) {
return {tokens, ppl, logit_history, prob_history};
}
std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch,
std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vector<int> & tokens, int n_past, int n_batch,
int n_vocab, int n_thread) {
std::vector<float> result;
result.reserve(tokens.size() * n_vocab);

View file

@ -73,10 +73,12 @@ int main(int argc, char ** argv) {
const int n_gen = std::min(32, max_context_size);
while (llama_get_kv_cache_token_count(ctx) < n_gen) {
int n_cur = 0;
while (n_cur < n_gen) {
// evaluate the transformer
if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx), params.n_threads)) {
if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), n_cur, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}