Automate Context resetting and minor fixes

Fixed top_k still not being set.
Removed an unnecessary loop.
This commit is contained in:
Randall Fitzgerald 2023-05-27 16:43:08 -07:00 committed by GitHub
parent 66ed19d01f
commit 36c86d794d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -29,36 +29,17 @@ struct llama_server_context
std::vector<std::vector<llama_token>> no_show_words;
std::vector<llama_token> tokens_predicted;
std::vector<llama_token> last_prompt_tokens;
llama_context *ctx;
gpt_params params;
bool reload_ctx = false;
void rewind() {
as_loop = false;
params.antiprompt.clear();
no_show_words.clear();
num_tokens_predicted = 0;
generated_text = "";
if(reload_ctx)
{
if(!processed_tokens.empty())
{
processed_tokens.erase(processed_tokens.begin() + 1, processed_tokens.end());
}
if(!embd_inp.empty())
{
embd_inp.erase(embd_inp.begin() + 1, embd_inp.end());
}
n_remain = 0;
n_past = 0;
n_consumed = 0;
reload_ctx = false;
}
}
bool loadModel(gpt_params params_)
@ -82,6 +63,28 @@ struct llama_server_context
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
// compare the evaluated prompt with the new prompt
int new_prompt_len = 0;
if (last_prompt_tokens == prompt_tokens)
{
//fprintf(stdout, "Context matched.\n");
processed_tokens = last_prompt_tokens;
embd_inp = last_prompt_tokens;
n_past = processed_tokens.size();
n_consumed = last_prompt_tokens.size() - 2;
last_prompt_tokens = prompt_tokens;
has_next_token = true;
return true;
}
else
{
if (!processed_tokens.empty() && !embd_inp.empty())
{
//fprintf(stdout, "Resetting context.\n");
processed_tokens.erase(processed_tokens.begin() + 1, processed_tokens.end());
embd_inp.erase(embd_inp.begin() + 1, embd_inp.end());
n_consumed = 0;
n_past = 0;
}
}
for (size_t i = 0; i < prompt_tokens.size(); i++) {
if (i < processed_tokens.size() &&
processed_tokens[i] == prompt_tokens[i])
@ -159,6 +162,7 @@ struct llama_server_context
const float temp = params.temp;
// const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
const float top_p = params.top_p;
const float top_k = params.top_k;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n;
@ -229,6 +233,7 @@ struct llama_server_context
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token(ctx, &candidates_p);
}
@ -253,10 +258,7 @@ struct llama_server_context
// add it to the context
embd.push_back(id);
for (auto id : embd)
{
result = id;
}
result = id;
// decrement remaining sampling budget
--n_remain;
}
@ -619,10 +621,6 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
{
llama.params.interactive = body["interactive"].get<bool>();
}
if (!body["reload_ctx"].is_null())
{
llama.reload_ctx = body["reload_ctx"].get<int>();
}
if (!body["prompt"].is_null())
{
llama.params.prompt = body["prompt"].get<std::string>();