server: use tokenizePrompt(json) and default "" if empty prompt

This commit is contained in:
Xiao-Yong Jin 2023-07-24 21:39:35 -05:00
parent 1a61c1a5e1
commit 97deb25398

View file

@ -258,14 +258,14 @@ struct llama_server_context
return true; return true;
} }
std::vector<llama_token> tokenizePrompt(void) std::vector<llama_token> tokenizePrompt(json json_prompt)
{ {
std::vector<llama_token> prompt_tokens; std::vector<llama_token> prompt_tokens;
if (prompt.is_array()) if (json_prompt.is_array())
{ {
bool first = true; bool first = true;
for (const auto& p : prompt) for (const auto& p : json_prompt)
{ {
if (p.is_string()) if (p.is_string())
{ {
@ -295,7 +295,7 @@ struct llama_server_context
} }
else else
{ {
auto s = prompt.template get<std::string>(); auto s = json_prompt.template get<std::string>();
s.insert(0, 1, ' '); // always add a first space s.insert(0, 1, ' '); // always add a first space
prompt_tokens = ::llama_tokenize(ctx, s, true); prompt_tokens = ::llama_tokenize(ctx, s, true);
} }
@ -305,7 +305,7 @@ struct llama_server_context
void loadPrompt() void loadPrompt()
{ {
auto prompt_tokens = tokenizePrompt(); auto prompt_tokens = tokenizePrompt(prompt);
num_prompt_tokens = prompt_tokens.size(); num_prompt_tokens = prompt_tokens.size();
@ -1062,7 +1062,15 @@ static void parse_options_completion(const json &body, llama_server_context &lla
llama.params.n_keep = body.value("n_keep", default_params.n_keep); llama.params.n_keep = body.value("n_keep", default_params.n_keep);
llama.params.seed = body.value("seed", default_params.seed); llama.params.seed = body.value("seed", default_params.seed);
llama.params.n_probs = body.value("n_probs", default_params.n_probs); llama.params.n_probs = body.value("n_probs", default_params.n_probs);
llama.prompt = body["prompt"];
if (body.count("content") != 0)
{
llama.prompt = body["prompt"];
}
else
{
llama.prompt = "";
}
llama.params.logit_bias.clear(); llama.params.logit_bias.clear();
if (body.value("ignore_eos", false)) if (body.value("ignore_eos", false))
@ -1304,8 +1312,11 @@ int main(int argc, char **argv)
auto lock = llama.lock(); auto lock = llama.lock();
const json body = json::parse(req.body); const json body = json::parse(req.body);
llama.prompt = body["content"]; std::vector<llama_token> tokens;
const std::vector<llama_token> tokens = llama.tokenizePrompt(); if (body.count("content") != 0)
{
tokens = llama.tokenizePrompt(body["content"]);
}
const json data = format_tokenizer_response(tokens); const json data = format_tokenizer_response(tokens);
return res.set_content(data.dump(), "application/json"); }); return res.set_content(data.dump(), "application/json"); });
@ -1317,7 +1328,14 @@ int main(int argc, char **argv)
llama.rewind(); llama.rewind();
llama_reset_timings(llama.ctx); llama_reset_timings(llama.ctx);
llama.prompt = body["content"]; if (body.count("content") != 0)
{
llama.prompt = body["content"];
}
else
{
llama.prompt = "";
}
llama.params.n_predict = 0; llama.params.n_predict = 0;
llama.loadPrompt(); llama.loadPrompt();
llama.beginCompletion(); llama.beginCompletion();