json parsing improvements
This commit is contained in:
parent
4148b9bd03
commit
dff11a14d2
1 changed files with 23 additions and 97 deletions
|
@ -607,98 +607,33 @@ json format_generation_settings(llama_server_context & llama) {
|
|||
};
|
||||
}
|
||||
|
||||
bool parse_options_completion(json body, llama_server_context & llama, Response & res) {
|
||||
bool parse_options_completion(json body, llama_server_context & llama) {
|
||||
gpt_params default_params;
|
||||
if (!body["stream"].is_null()) {
|
||||
llama.stream = body["stream"].get<bool>();
|
||||
} else {
|
||||
llama.stream = false;
|
||||
}
|
||||
if (!body["n_predict"].is_null()) {
|
||||
llama.params.n_predict = body["n_predict"].get<int32_t>();
|
||||
} else {
|
||||
llama.params.n_predict = default_params.n_predict;
|
||||
}
|
||||
if (!body["top_k"].is_null()) {
|
||||
llama.params.top_k = body["top_k"].get<int32_t>();
|
||||
} else {
|
||||
llama.params.top_k = default_params.top_k;
|
||||
}
|
||||
if (!body["top_p"].is_null()) {
|
||||
llama.params.top_p = body["top_p"].get<float>();
|
||||
} else {
|
||||
llama.params.top_p = default_params.top_p;
|
||||
}
|
||||
if (!body["tfs_z"].is_null()) {
|
||||
llama.params.tfs_z = body["tfs_z"].get<float>();
|
||||
} else {
|
||||
llama.params.tfs_z = default_params.tfs_z;
|
||||
}
|
||||
if (!body["typical_p"].is_null()) {
|
||||
llama.params.typical_p = body["typical_p"].get<float>();
|
||||
} else {
|
||||
llama.params.typical_p = default_params.typical_p;
|
||||
}
|
||||
if (!body["repeat_last_n"].is_null()) {
|
||||
llama.params.repeat_last_n = body["repeat_last_n"].get<int32_t>();
|
||||
} else {
|
||||
llama.params.repeat_last_n = default_params.repeat_last_n;
|
||||
}
|
||||
if (!body["temperature"].is_null()) {
|
||||
llama.params.temp = body["temperature"].get<float>();
|
||||
} else {
|
||||
llama.params.temp = default_params.temp;
|
||||
}
|
||||
if (!body["repeat_penalty"].is_null()) {
|
||||
llama.params.repeat_penalty = body["repeat_penalty"].get<float>();
|
||||
} else {
|
||||
llama.params.repeat_penalty = default_params.repeat_penalty;
|
||||
}
|
||||
if (!body["presence_penalty"].is_null()) {
|
||||
llama.params.presence_penalty = body["presence_penalty"].get<float>();
|
||||
} else {
|
||||
llama.params.presence_penalty = default_params.presence_penalty;
|
||||
}
|
||||
if (!body["frequency_penalty"].is_null()) {
|
||||
llama.params.frequency_penalty = body["frequency_penalty"].get<float>();
|
||||
} else {
|
||||
llama.params.frequency_penalty = default_params.frequency_penalty;
|
||||
}
|
||||
if (!body["mirostat"].is_null()) {
|
||||
llama.params.mirostat = body["mirostat"].get<int>();
|
||||
} else {
|
||||
llama.params.mirostat = default_params.mirostat;
|
||||
}
|
||||
if (!body["mirostat_tau"].is_null()) {
|
||||
llama.params.mirostat_tau = body["mirostat_tau"].get<float>();
|
||||
} else {
|
||||
llama.params.mirostat_tau = default_params.mirostat_tau;
|
||||
}
|
||||
if (!body["mirostat_eta"].is_null()) {
|
||||
llama.params.mirostat_eta = body["mirostat_eta"].get<float>();
|
||||
} else {
|
||||
llama.params.mirostat_eta = default_params.mirostat_eta;
|
||||
}
|
||||
if (!body["penalize_nl"].is_null()) {
|
||||
llama.params.penalize_nl = body["penalize_nl"].get<bool>();
|
||||
} else {
|
||||
llama.params.penalize_nl = default_params.penalize_nl;
|
||||
}
|
||||
if (!body["n_keep"].is_null()) {
|
||||
llama.params.n_keep = body["n_keep"].get<int32_t>();
|
||||
} else {
|
||||
llama.params.n_keep = default_params.n_keep;
|
||||
}
|
||||
if (!body["seed"].is_null()) {
|
||||
llama.params.seed = body["seed"].get<int32_t>();
|
||||
} else {
|
||||
llama.params.seed = time(NULL);
|
||||
}
|
||||
|
||||
llama.stream = body.value("stream", false);
|
||||
llama.params.n_predict = body.value("n_predict", default_params.n_predict);
|
||||
llama.params.top_k = body.value("top_k", default_params.top_k);
|
||||
llama.params.top_p = body.value("top_p", default_params.top_p);
|
||||
llama.params.tfs_z = body.value("tfs_z", default_params.tfs_z);
|
||||
llama.params.typical_p = body.value("typical_p", default_params.typical_p);
|
||||
llama.params.repeat_last_n = body.value("repeat_last_n", default_params.repeat_last_n);
|
||||
llama.params.temp = body.value("temperature", default_params.temp);
|
||||
llama.params.repeat_penalty = body.value("repeat_penalty", default_params.repeat_penalty);
|
||||
llama.params.presence_penalty = body.value("presence_penalty", default_params.presence_penalty);
|
||||
llama.params.frequency_penalty = body.value("frequency_penalty", default_params.frequency_penalty);
|
||||
llama.params.mirostat = body.value("mirostat", default_params.mirostat);
|
||||
llama.params.mirostat_tau = body.value("mirostat_tau", default_params.mirostat_tau);
|
||||
llama.params.mirostat_eta = body.value("mirostat_eta", default_params.mirostat_eta);
|
||||
llama.params.penalize_nl = body.value("penalize_nl", default_params.penalize_nl);
|
||||
llama.params.n_keep = body.value("n_keep", default_params.n_keep);
|
||||
llama.params.seed = body.value("seed", default_params.seed);
|
||||
llama.params.prompt = body.value("prompt", default_params.prompt);
|
||||
|
||||
llama.params.logit_bias.clear();
|
||||
if (!body["ignore_eos"].is_null() && body["ignore_eos"].get<bool>()) {
|
||||
if (body.value("ignore_eos", false)) {
|
||||
llama.params.logit_bias[llama_token_eos()] = -INFINITY;
|
||||
}
|
||||
|
||||
if (body["logit_bias"].is_array()) {
|
||||
int n_vocab = llama_n_vocab(llama.ctx);
|
||||
for (const auto & el : body["logit_bias"]) {
|
||||
|
@ -715,15 +650,6 @@ bool parse_options_completion(json body, llama_server_context & llama, Response
|
|||
}
|
||||
}
|
||||
|
||||
if (!body["prompt"].is_null()) {
|
||||
llama.params.prompt = body["prompt"].get<std::string>();
|
||||
} else {
|
||||
json data = { {"status", "error"}, {"reason", "You need to provide a prompt"} };
|
||||
res.set_content(data.dump(llama.json_indent), "application/json");
|
||||
res.status = 400;
|
||||
return false;
|
||||
}
|
||||
|
||||
llama.params.antiprompt.clear();
|
||||
if (!body["stop"].is_null()) {
|
||||
const auto stop = body["stop"].get<std::vector<std::string>>();
|
||||
|
@ -788,7 +714,7 @@ int main(int argc, char ** argv) {
|
|||
llama.rewind();
|
||||
llama_reset_timings(llama.ctx);
|
||||
|
||||
if (!parse_options_completion(json::parse(req.body), llama, res)) {
|
||||
if (!parse_options_completion(json::parse(req.body), llama)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue