diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 095ae4bc3..bcb7cd4b9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -13,7 +13,7 @@ struct server_params struct llama_server_context { - bool streaming = false; + bool stream = false; bool has_next_token = false; std::string generated_text = ""; @@ -35,7 +35,6 @@ struct llama_server_context std::string stopping_word; void rewind() { - streaming = false; params.antiprompt.clear(); num_tokens_predicted = 0; generated_text = ""; @@ -253,9 +252,6 @@ struct llama_server_context if (token == -1) { return ""; } - if(streaming) { - generated_text = ""; - } std::string token_text = llama_token_to_str(ctx, token); generated_text += token_text; @@ -270,7 +266,7 @@ struct llama_server_context } } - return generated_text; + return token_text; } std::vector embedding(std::string content, int threads) { @@ -478,13 +474,13 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para bool parse_options_completion(json body, llama_server_context& llama, Response &res) { gpt_params default_params; - if (!body["streaming"].is_null()) + if (!body["stream"].is_null()) { - llama.streaming = body["streaming"].get(); + llama.stream = body["stream"].get(); } else { - llama.streaming = false; + llama.stream = false; } if (!body["n_predict"].is_null()) { @@ -675,8 +671,6 @@ int main(int argc, char **argv) llama_server_context llama; params.model = "ggml-model.bin"; - std::string final_text; - if (server_params_parse(argc, argv, sparams, params) == false) { return 1; @@ -693,98 +687,81 @@ int main(int argc, char **argv) svr.Get("/", [](const Request &, Response &res) { res.set_content("

llama.cpp server works

", "text/html"); }); - svr.Post("/completion", [&llama, &final_text](const Request &req, Response &res) - { - if(llama.params.embedding) { - json data = { - {"status", "error"}, - {"reason", "To use completion function, disable embedding mode"}}; - res.set_content(data.dump(), "application/json"); - res.status = 400; - return; - } + svr.Post("/completion", [&llama](const Request &req, Response &res) { + if (llama.params.embedding) { + json data = { + {"status", "error"}, + {"reason", "To use completion function, disable embedding mode"}}; + res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), + "application/json"); + res.status = 400; + return; + } - llama.rewind(); - final_text = ""; + llama.rewind(); - if(parse_options_completion(json::parse(req.body), llama, res) == false){ - return; - } + if (parse_options_completion(json::parse(req.body), llama, res) == false) { + return; + } - if (!llama.loadPrompt()) - { - json data = { - {"status", "error"}, - {"reason", "Context too long."}}; - res.set_content(data.dump(), "application/json"); - res.status = 400; - return; - } + if (!llama.loadPrompt()) { + json data = {{"status", "error"}, {"reason", "Context too long."}}; + res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), + "application/json"); + res.status = 400; + return; + } + + llama.beginCompletion(); + + if (!llama.stream) { + while (llama.has_next_token) { + llama.doCompletion(); + } + + json data = {{"content", llama.generated_text}, + {"stop", true}, + {"model", llama.params.model_alias }, + {"tokens_predicted", llama.num_tokens_predicted}, + {"generation_settings", format_generation_settings(llama)}, + {"prompt", llama.params.prompt}, + {"stopping_word", llama.stopping_word}}; + return res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json"); + } else { + const auto chunked_content_provider = [&](size_t, DataSink &sink) { + while (llama.has_next_token) { + std::string token_text = llama.doCompletion(); - llama.beginCompletion(); - if(llama.streaming) - { - res.set_chunked_content_provider("text/event-stream", [&](size_t /*offset*/, - DataSink& sink) { - std::string final_text = ""; - // loop inference until finish completion - while (llama.has_next_token) { - std::string result = llama.doCompletion(); json data; - final_text += result; - if (llama.has_next_token) - { - data = { {"content", result}, {"stop", false} }; - } - else - { - // Generation is done, send extra information. - data = { {"content", result}, - {"stop", true}, - {"tokens_predicted", llama.num_tokens_predicted}, - {"generation_settings", format_generation_settings(llama)}, - {"prompt", llama.params.prompt}, - {"stopping_word", llama.stopping_word}, - {"generated_text", final_text} }; + if (llama.has_next_token) { + data = {{"content", token_text}, {"stop", false}}; + } else { + // Generation is done, send extra information. + data = { + {"content", token_text}, + {"stop", true}, + {"model", llama.params.model_alias}, + {"tokens_predicted", llama.num_tokens_predicted}, + {"generation_settings", format_generation_settings(llama)}, + {"prompt", llama.params.prompt}, + {"stopping_word", llama.stopping_word}, + {"generated_text", llama.generated_text}}; } std::string str = - "data: " + data.dump(4, ' ', false, json::error_handler_t::replace) + - "\n\n"; + "data: " + + data.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; sink.write(str.data(), str.size()); - } - - sink.done(); - return true; - }); } - else - { - // loop inference until finish completion - while (llama.has_next_token) - { - llama.doCompletion(); - } - try - { - json data = { - {"model", llama.params.model_alias }, - {"content", llama.generated_text }, - {"tokens_predicted", llama.num_tokens_predicted}, - {"generation_settings", format_generation_settings(llama)}, - {"prompt", llama.params.prompt}, - {"stopping_word", llama.stopping_word} }; - return res.set_content(data.dump(), "application/json"); - } - catch (const json::exception &e) - { - // Some tokens have bad UTF-8 strings, the json parser is very sensitive - json data = { - {"content", "Bad encoding token"}, - {"tokens_predicted", 0}}; - return res.set_content(data.dump(), "application/json"); - } - } }); + + sink.done(); + return true; + }; + res.set_chunked_content_provider("text/event-stream", chunked_content_provider); + } + }); + svr.Post("/tokenize", [&llama](const Request &req, Response &res) { @@ -811,7 +788,6 @@ int main(int argc, char **argv) return res.set_content(data.dump(), "application/json"); }); - fprintf(stderr, "%s: http server Listening at http://%s:%i\n", __func__, sparams.hostname.c_str(), sparams.port); if(params.embedding) {