Add all generation parameters to server.cpp and allow resetting context

sever.cpp left out a few generation parameters and also seems built to assume un-editable chatting with no regens or swipes. I added a simple "reload_ctx" flag that can be passed on generation that will cause the prompt to be reloaded.
This commit is contained in:
Randall Fitzgerald 2023-05-23 06:16:54 -07:00 committed by Henri Vasserman
parent 51e09944ce
commit f93fe36c5b
No known key found for this signature in database
GPG key ID: 2995FC0F58B1A986

View file

@ -13,6 +13,8 @@ struct llama_server_context
{
bool as_loop = false;
bool has_next_token = false;
std::string generated_text = "";
size_t num_tokens_predicted = 0;
@ -31,6 +33,8 @@ struct llama_server_context
llama_context *ctx;
gpt_params params;
bool reload_ctx = false;
void rewind() {
as_loop = false;
params.antiprompt.clear();
@ -61,6 +65,21 @@ struct llama_server_context
bool loadPrompt() {
params.prompt.insert(0, 1, ' '); // always add a first space
if(processed_tokens.size() != 0)
{
processed_tokens.erase(processed_tokens.begin() + 1, processed_tokens.end());
}
if(embd_inp.size() != 0)
{
embd_inp.erase(embd_inp.begin() + 1, embd_inp.end());
}
n_remain = 0;
n_past = 0;
n_consumed = 0;
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
// compare the evaluated prompt with the new prompt
for (n_past = 0; n_past < prompt_tokens.size() - 1 && n_past < processed_tokens.size(); n_past++) {
@ -98,6 +117,7 @@ struct llama_server_context
// Reset context
const int n_left = n_past - params.n_keep;
n_past = std::max(1, params.n_keep);
last_n_tokens.erase(last_n_tokens.begin() + n_past, last_n_tokens.end());
processed_tokens.erase(processed_tokens.begin() + n_past, processed_tokens.end());
embd.insert(embd.begin(), last_n_tokens.begin() + params.n_ctx - n_left / 2 - embd.size(), last_n_tokens.end() - embd.size());
}
@ -455,7 +475,7 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
}
if (!body["repeat_last_n"].is_null())
{
llama.params.repeat_last_n = body["repeat_last_n"].get<int32_t>();
llama.params.repeat_last_n = body["repeat_last_n"].get<float>();
}
if (!body["temperature"].is_null())
{
@ -475,7 +495,7 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
}
if (!body["mirostat"].is_null())
{
llama.params.mirostat = body["mirostat"].get<int>();
llama.params.mirostat = body["mirostat"].get<float>();
}
if (!body["mirostat_tau"].is_null())
{
@ -487,7 +507,7 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
}
if (!body["penalize_nl"].is_null())
{
llama.params.penalize_nl = body["penalize_nl"].get<bool>();
llama.params.penalize_nl = body["penalize_nl"].get<float>();
}
if (!body["batch_size"].is_null())
{