Merge pull request #12 from anon998/clear-logit-bias

Clear logit bias between requests.
This commit is contained in:
Randall Fitzgerald 2023-06-01 08:58:35 -04:00 committed by GitHub
commit d29b6d5f55
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -559,6 +559,7 @@ json format_generation_settings(llama_server_context &llama) {
{ "n_keep", llama.params.n_keep }, { "n_keep", llama.params.n_keep },
{ "ignore_eos", ignore_eos }, { "ignore_eos", ignore_eos },
{ "stream", llama.stream }, { "stream", llama.stream },
{ "logit_bias", llama.params.logit_bias },
}; };
} }
@ -638,7 +639,7 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
if (!body["penalize_nl"].is_null()) { if (!body["penalize_nl"].is_null()) {
llama.params.penalize_nl = body["penalize_nl"].get<float>(); llama.params.penalize_nl = body["penalize_nl"].get<float>();
} else { } else {
llama.params.penalize_nl = false; llama.params.penalize_nl = default_params.penalize_nl;
} }
if (!body["n_keep"].is_null()) { if (!body["n_keep"].is_null()) {
llama.params.n_keep = body["n_keep"].get<int>(); llama.params.n_keep = body["n_keep"].get<int>();
@ -650,10 +651,10 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
} else { } else {
llama.params.seed = time(NULL); llama.params.seed = time(NULL);
} }
llama.params.logit_bias.clear();
if (!body["ignore_eos"].is_null() && body["ignore_eos"].get<bool>()) { if (!body["ignore_eos"].is_null() && body["ignore_eos"].get<bool>()) {
llama.params.logit_bias[llama_token_eos()] = -INFINITY; llama.params.logit_bias[llama_token_eos()] = -INFINITY;
} else {
llama.params.logit_bias.erase(llama_token_eos());
} }
if (body["logit_bias"].is_array()) { if (body["logit_bias"].is_array()) {
int n_vocab = llama_n_vocab(llama.ctx); int n_vocab = llama_n_vocab(llama.ctx);
@ -665,6 +666,7 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
} }
} }
} }
if (!body["prompt"].is_null()) { if (!body["prompt"].is_null()) {
llama.params.prompt = body["prompt"].get<std::string>(); llama.params.prompt = body["prompt"].get<std::string>();
} else { } else {
@ -673,6 +675,7 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
res.status = 400; res.status = 400;
return false; return false;
} }
llama.params.antiprompt.clear(); llama.params.antiprompt.clear();
if (!body["stop"].is_null()) { if (!body["stop"].is_null()) {
const auto stop = body["stop"].get<std::vector<std::string>>(); const auto stop = body["stop"].get<std::vector<std::string>>();
@ -888,7 +891,7 @@ int main(int argc, char **argv)
} }
}); });
svr.Options(R"(/.*)", [&llama](const Request &req, Response &res) svr.Options(R"(/.*)", [&llama](const Request &, Response &res)
{ {
return res.set_content("", "application/json"); return res.set_content("", "application/json");
}); });