diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 567dc4d2d..c52db9c05 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -258,14 +258,14 @@ struct llama_server_context return true; } - std::vector tokenizePrompt(void) + std::vector tokenizePrompt(json json_prompt) { std::vector prompt_tokens; - if (prompt.is_array()) + if (json_prompt.is_array()) { bool first = true; - for (const auto& p : prompt) + for (const auto& p : json_prompt) { if (p.is_string()) { @@ -295,7 +295,7 @@ struct llama_server_context } else { - auto s = prompt.template get(); + auto s = json_prompt.template get(); s.insert(0, 1, ' '); // always add a first space prompt_tokens = ::llama_tokenize(ctx, s, true); } @@ -305,7 +305,7 @@ struct llama_server_context void loadPrompt() { - auto prompt_tokens = tokenizePrompt(); + auto prompt_tokens = tokenizePrompt(prompt); num_prompt_tokens = prompt_tokens.size(); @@ -1062,7 +1062,15 @@ static void parse_options_completion(const json &body, llama_server_context &lla llama.params.n_keep = body.value("n_keep", default_params.n_keep); llama.params.seed = body.value("seed", default_params.seed); llama.params.n_probs = body.value("n_probs", default_params.n_probs); - llama.prompt = body["prompt"]; + + if (body.count("content") != 0) + { + llama.prompt = body["prompt"]; + } + else + { + llama.prompt = ""; + } llama.params.logit_bias.clear(); if (body.value("ignore_eos", false)) @@ -1304,8 +1312,11 @@ int main(int argc, char **argv) auto lock = llama.lock(); const json body = json::parse(req.body); - llama.prompt = body["content"]; - const std::vector tokens = llama.tokenizePrompt(); + std::vector tokens; + if (body.count("content") != 0) + { + tokens = llama.tokenizePrompt(body["content"]); + } const json data = format_tokenizer_response(tokens); return res.set_content(data.dump(), "application/json"); }); @@ -1317,7 +1328,14 @@ int main(int argc, char **argv) llama.rewind(); llama_reset_timings(llama.ctx); - llama.prompt = body["content"]; + if (body.count("content") != 0) + { + llama.prompt = body["content"]; + } + else + { + llama.prompt = ""; + } llama.params.n_predict = 0; llama.loadPrompt(); llama.beginCompletion();