From 36c86d794dd3981160c9ec786ac6825231435137 Mon Sep 17 00:00:00 2001 From: Randall Fitzgerald Date: Sat, 27 May 2023 16:43:08 -0700 Subject: [PATCH] Automate Context resetting and minor fixes Fixed top_k still not being set. Removed an unnecessary loop. --- examples/server/server.cpp | 56 ++++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 4dddf50d3..d5a1473f1 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -29,36 +29,17 @@ struct llama_server_context std::vector> no_show_words; std::vector tokens_predicted; + std::vector last_prompt_tokens; + llama_context *ctx; gpt_params params; - bool reload_ctx = false; - void rewind() { as_loop = false; params.antiprompt.clear(); no_show_words.clear(); num_tokens_predicted = 0; generated_text = ""; - - if(reload_ctx) - { - if(!processed_tokens.empty()) - { - processed_tokens.erase(processed_tokens.begin() + 1, processed_tokens.end()); - } - - if(!embd_inp.empty()) - { - embd_inp.erase(embd_inp.begin() + 1, embd_inp.end()); - } - - n_remain = 0; - n_past = 0; - n_consumed = 0; - - reload_ctx = false; - } } bool loadModel(gpt_params params_) @@ -82,6 +63,28 @@ struct llama_server_context std::vector prompt_tokens = ::llama_tokenize(ctx, params.prompt, true); // compare the evaluated prompt with the new prompt int new_prompt_len = 0; + if (last_prompt_tokens == prompt_tokens) + { + //fprintf(stdout, "Context matched.\n"); + processed_tokens = last_prompt_tokens; + embd_inp = last_prompt_tokens; + n_past = processed_tokens.size(); + n_consumed = last_prompt_tokens.size() - 2; + last_prompt_tokens = prompt_tokens; + has_next_token = true; + return true; + } + else + { + if (!processed_tokens.empty() && !embd_inp.empty()) + { + //fprintf(stdout, "Resetting context.\n"); + processed_tokens.erase(processed_tokens.begin() + 1, processed_tokens.end()); + embd_inp.erase(embd_inp.begin() + 1, embd_inp.end()); + n_consumed = 0; + n_past = 0; + } + } for (size_t i = 0; i < prompt_tokens.size(); i++) { if (i < processed_tokens.size() && processed_tokens[i] == prompt_tokens[i]) @@ -159,6 +162,7 @@ struct llama_server_context const float temp = params.temp; // const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; const float top_p = params.top_p; + const float top_k = params.top_k; const float tfs_z = params.tfs_z; const float typical_p = params.typical_p; const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n; @@ -229,6 +233,7 @@ struct llama_server_context llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1); llama_sample_typical(ctx, &candidates_p, typical_p, 1); llama_sample_top_p(ctx, &candidates_p, top_p, 1); + llama_sample_top_k(ctx, &candidates_p, top_k, 1); llama_sample_temperature(ctx, &candidates_p, temp); id = llama_sample_token(ctx, &candidates_p); } @@ -253,10 +258,7 @@ struct llama_server_context // add it to the context embd.push_back(id); - for (auto id : embd) - { - result = id; - } + result = id; // decrement remaining sampling budget --n_remain; } @@ -619,10 +621,6 @@ bool parse_options_completion(json body, llama_server_context& llama, Response & { llama.params.interactive = body["interactive"].get(); } - if (!body["reload_ctx"].is_null()) - { - llama.reload_ctx = body["reload_ctx"].get(); - } if (!body["prompt"].is_null()) { llama.params.prompt = body["prompt"].get();