From d6fff56e22455c73efb381fbd03f50c97d3ef2db Mon Sep 17 00:00:00 2001 From: anon Date: Tue, 30 May 2023 19:33:33 -0300 Subject: [PATCH] add streaming via server-sent events Removes /next-token endpoint and adds a "stream" parameter to the /completion one. --- examples/server/server.cpp | 213 ++++++++++++++----------------------- 1 file changed, 80 insertions(+), 133 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 7be83aa92..5e7d1c357 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -13,7 +13,7 @@ struct server_params struct llama_server_context { - bool as_loop = false; + bool stream = false; bool has_next_token = false; std::string generated_text = ""; @@ -35,7 +35,6 @@ struct llama_server_context std::string stopping_word; void rewind() { - as_loop = false; params.antiprompt.clear(); num_tokens_predicted = 0; generated_text = ""; @@ -253,9 +252,6 @@ struct llama_server_context if (token == -1) { return ""; } - if(as_loop) { - generated_text = ""; - } std::string token_text = llama_token_to_str(ctx, token); generated_text += token_text; @@ -270,7 +266,7 @@ struct llama_server_context } } - return generated_text; + return token_text; } std::vector embedding(std::string content, int threads) { @@ -478,9 +474,13 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para bool parse_options_completion(json body, llama_server_context& llama, Response &res) { gpt_params default_params; - if (!body["as_loop"].is_null()) + if (!body["stream"].is_null()) { - llama.as_loop = body["as_loop"].get(); + llama.stream = body["stream"].get(); + } + else + { + llama.stream = false; } if (!body["n_predict"].is_null()) { @@ -671,8 +671,6 @@ int main(int argc, char **argv) llama_server_context llama; params.model = "ggml-model.bin"; - std::string final_text; - if (server_params_parse(argc, argv, sparams, params) == false) { return 1; @@ -689,65 +687,80 @@ int main(int argc, char **argv) svr.Get("/", [](const Request &, Response &res) { res.set_content("

llama.cpp server works

", "text/html"); }); - svr.Post("/completion", [&llama, &final_text](const Request &req, Response &res) - { - if(llama.params.embedding) { - json data = { - {"status", "error"}, - {"reason", "To use completion function, disable embedding mode"}}; - res.set_content(data.dump(), "application/json"); - res.status = 400; - return; + svr.Post("/completion", [&llama](const Request &req, Response &res) { + if (llama.params.embedding) { + json data = { + {"status", "error"}, + {"reason", "To use completion function, disable embedding mode"}}; + res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), + "application/json"); + res.status = 400; + return; + } + + llama.rewind(); + + if (parse_options_completion(json::parse(req.body), llama, res) == false) { + return; + } + + if (!llama.loadPrompt()) { + json data = {{"status", "error"}, {"reason", "Context too long."}}; + res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), + "application/json"); + res.status = 400; + return; + } + + llama.beginCompletion(); + + if (!llama.stream) { + while (llama.has_next_token) { + llama.doCompletion(); + } + + json data = {{"content", llama.generated_text}, + {"stop", true}, + {"model", llama.params.model_alias }, + {"tokens_predicted", llama.num_tokens_predicted}, + {"generation_settings", format_generation_settings(llama)}, + {"prompt", llama.params.prompt}, + {"stopping_word", llama.stopping_word}}; + return res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json"); + } else { + const auto chunked_content_provider = [&](size_t, DataSink &sink) { + while (llama.has_next_token) { + std::string token_text = llama.doCompletion(); + + json data; + if (llama.has_next_token) { + data = {{"content", token_text}, {"stop", false}}; + } else { + // Generation is done, send extra information. + data = { + {"content", token_text}, + {"stop", true}, + {"model", llama.params.model_alias}, + {"tokens_predicted", llama.num_tokens_predicted}, + {"generation_settings", format_generation_settings(llama)}, + {"prompt", llama.params.prompt}, + {"stopping_word", llama.stopping_word}, + {"generated_text", llama.generated_text}}; + } + + std::string str = + "data: " + + data.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; + sink.write(str.data(), str.size()); } - llama.rewind(); - final_text = ""; - - if(parse_options_completion(json::parse(req.body), llama, res) == false){ - return; - } - - if (!llama.loadPrompt()) - { - json data = { - {"status", "error"}, - {"reason", "Context too long."}}; - res.set_content(data.dump(), "application/json"); - res.status = 400; - return; - } - - llama.beginCompletion(); - if(llama.as_loop) { - json data = { - {"status", "done" } }; - return res.set_content(data.dump(), "application/json"); - } else { - // loop inference until finish completion - while (llama.has_next_token) - { - llama.doCompletion(); - } - try - { - json data = { - {"model", llama.params.model_alias }, - {"content", llama.generated_text }, - {"tokens_predicted", llama.num_tokens_predicted}, - {"generation_settings", format_generation_settings(llama)}, - {"prompt", llama.params.prompt}, - {"stopping_word", llama.stopping_word} }; - return res.set_content(data.dump(), "application/json"); - } - catch (const json::exception &e) - { - // Some tokens have bad UTF-8 strings, the json parser is very sensitive - json data = { - {"content", "Bad encoding token"}, - {"tokens_predicted", 0}}; - return res.set_content(data.dump(), "application/json"); - } - } }); + sink.done(); + return true; + }; + res.set_chunked_content_provider("text/event-stream", chunked_content_provider); + } + }); svr.Post("/tokenize", [&llama](const Request &req, Response &res) { @@ -774,72 +787,6 @@ int main(int argc, char **argv) return res.set_content(data.dump(), "application/json"); }); - svr.Get("/next-token", [&llama, &final_text](const Request &req, Response &res) - { - if(llama.params.embedding) { - res.set_content("{}", "application/json"); - return; - } - std::string result = ""; - if (req.has_param("stop")) { - llama.has_next_token = false; - } else { - result = llama.doCompletion(); // inference next token - } - try { - json data; - if (llama.has_next_token) - { - //fprintf(stdout, "Result: %s\n", result); - final_text += result; - data = { - {"content", result }, - {"stop", false } - }; - } - else - { - // Generation is done, send extra information. - data = { - {"content", result }, - {"stop", true }, - {"tokens_predicted", llama.num_tokens_predicted}, - {"generation_settings", format_generation_settings(llama)}, - {"prompt", llama.params.prompt}, - {"stopping_word", llama.stopping_word}, - {"generated_text", final_text} - }; - } - - return res.set_content(data.dump(), "application/json"); - } catch (const json::exception &e) { - // Some tokens have bad UTF-8 strings, the json parser is very sensitive - json data; - if (llama.has_next_token) - { - final_text += u8"\uFFFD"; - data = { - {"content", result }, - {"stop", false } - }; - } - else - { - // Generation is done, send extra information. - data = { - {"content", u8"\uFFFD" }, - {"stop", true }, - {"tokens_predicted", llama.num_tokens_predicted}, - {"generation_settings", format_generation_settings(llama)}, - {"prompt", llama.params.prompt}, - {"stopping_word", llama.stopping_word}, - {"generated_text", final_text} - }; - } - return res.set_content(data.dump(), "application/json"); - } - }); - fprintf(stderr, "%s: http server Listening at http://%s:%i\n", __func__, sparams.hostname.c_str(), sparams.port); if(params.embedding) {