Automate Context resetting and minor fixes
Fixed top_k still not being set. Removed an unnecessary loop.
This commit is contained in:
parent
66ed19d01f
commit
36c86d794d
1 changed files with 27 additions and 29 deletions
|
@ -29,36 +29,17 @@ struct llama_server_context
|
||||||
std::vector<std::vector<llama_token>> no_show_words;
|
std::vector<std::vector<llama_token>> no_show_words;
|
||||||
std::vector<llama_token> tokens_predicted;
|
std::vector<llama_token> tokens_predicted;
|
||||||
|
|
||||||
|
std::vector<llama_token> last_prompt_tokens;
|
||||||
|
|
||||||
llama_context *ctx;
|
llama_context *ctx;
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
bool reload_ctx = false;
|
|
||||||
|
|
||||||
void rewind() {
|
void rewind() {
|
||||||
as_loop = false;
|
as_loop = false;
|
||||||
params.antiprompt.clear();
|
params.antiprompt.clear();
|
||||||
no_show_words.clear();
|
no_show_words.clear();
|
||||||
num_tokens_predicted = 0;
|
num_tokens_predicted = 0;
|
||||||
generated_text = "";
|
generated_text = "";
|
||||||
|
|
||||||
if(reload_ctx)
|
|
||||||
{
|
|
||||||
if(!processed_tokens.empty())
|
|
||||||
{
|
|
||||||
processed_tokens.erase(processed_tokens.begin() + 1, processed_tokens.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
if(!embd_inp.empty())
|
|
||||||
{
|
|
||||||
embd_inp.erase(embd_inp.begin() + 1, embd_inp.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
n_remain = 0;
|
|
||||||
n_past = 0;
|
|
||||||
n_consumed = 0;
|
|
||||||
|
|
||||||
reload_ctx = false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool loadModel(gpt_params params_)
|
bool loadModel(gpt_params params_)
|
||||||
|
@ -82,6 +63,28 @@ struct llama_server_context
|
||||||
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
|
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
|
||||||
// compare the evaluated prompt with the new prompt
|
// compare the evaluated prompt with the new prompt
|
||||||
int new_prompt_len = 0;
|
int new_prompt_len = 0;
|
||||||
|
if (last_prompt_tokens == prompt_tokens)
|
||||||
|
{
|
||||||
|
//fprintf(stdout, "Context matched.\n");
|
||||||
|
processed_tokens = last_prompt_tokens;
|
||||||
|
embd_inp = last_prompt_tokens;
|
||||||
|
n_past = processed_tokens.size();
|
||||||
|
n_consumed = last_prompt_tokens.size() - 2;
|
||||||
|
last_prompt_tokens = prompt_tokens;
|
||||||
|
has_next_token = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if (!processed_tokens.empty() && !embd_inp.empty())
|
||||||
|
{
|
||||||
|
//fprintf(stdout, "Resetting context.\n");
|
||||||
|
processed_tokens.erase(processed_tokens.begin() + 1, processed_tokens.end());
|
||||||
|
embd_inp.erase(embd_inp.begin() + 1, embd_inp.end());
|
||||||
|
n_consumed = 0;
|
||||||
|
n_past = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
for (size_t i = 0; i < prompt_tokens.size(); i++) {
|
for (size_t i = 0; i < prompt_tokens.size(); i++) {
|
||||||
if (i < processed_tokens.size() &&
|
if (i < processed_tokens.size() &&
|
||||||
processed_tokens[i] == prompt_tokens[i])
|
processed_tokens[i] == prompt_tokens[i])
|
||||||
|
@ -159,6 +162,7 @@ struct llama_server_context
|
||||||
const float temp = params.temp;
|
const float temp = params.temp;
|
||||||
// const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
|
// const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
|
||||||
const float top_p = params.top_p;
|
const float top_p = params.top_p;
|
||||||
|
const float top_k = params.top_k;
|
||||||
const float tfs_z = params.tfs_z;
|
const float tfs_z = params.tfs_z;
|
||||||
const float typical_p = params.typical_p;
|
const float typical_p = params.typical_p;
|
||||||
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n;
|
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n;
|
||||||
|
@ -229,6 +233,7 @@ struct llama_server_context
|
||||||
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
|
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
|
||||||
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
|
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
|
||||||
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
|
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
|
||||||
|
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
|
||||||
llama_sample_temperature(ctx, &candidates_p, temp);
|
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||||
id = llama_sample_token(ctx, &candidates_p);
|
id = llama_sample_token(ctx, &candidates_p);
|
||||||
}
|
}
|
||||||
|
@ -253,10 +258,7 @@ struct llama_server_context
|
||||||
|
|
||||||
// add it to the context
|
// add it to the context
|
||||||
embd.push_back(id);
|
embd.push_back(id);
|
||||||
for (auto id : embd)
|
|
||||||
{
|
|
||||||
result = id;
|
result = id;
|
||||||
}
|
|
||||||
// decrement remaining sampling budget
|
// decrement remaining sampling budget
|
||||||
--n_remain;
|
--n_remain;
|
||||||
}
|
}
|
||||||
|
@ -619,10 +621,6 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
|
||||||
{
|
{
|
||||||
llama.params.interactive = body["interactive"].get<bool>();
|
llama.params.interactive = body["interactive"].get<bool>();
|
||||||
}
|
}
|
||||||
if (!body["reload_ctx"].is_null())
|
|
||||||
{
|
|
||||||
llama.reload_ctx = body["reload_ctx"].get<int>();
|
|
||||||
}
|
|
||||||
if (!body["prompt"].is_null())
|
if (!body["prompt"].is_null())
|
||||||
{
|
{
|
||||||
llama.params.prompt = body["prompt"].get<std::string>();
|
llama.params.prompt = body["prompt"].get<std::string>();
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue