From 3292f057dc2165efde77a813ba65d64526580f4d Mon Sep 17 00:00:00 2001 From: digiwombat Date: Tue, 30 May 2023 19:44:16 -0400 Subject: [PATCH] Changed to single API endpoint for streaming and non. next-token endpoint removed. "as_loop" setting changed to "streaming" --- examples/server/server.cpp | 120 ++++++++++++++----------------------- 1 file changed, 45 insertions(+), 75 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 7be83aa92..5af5fbeaf 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -13,7 +13,7 @@ struct server_params struct llama_server_context { - bool as_loop = false; + bool streaming = false; bool has_next_token = false; std::string generated_text = ""; @@ -35,7 +35,7 @@ struct llama_server_context std::string stopping_word; void rewind() { - as_loop = false; + streaming = false; params.antiprompt.clear(); num_tokens_predicted = 0; generated_text = ""; @@ -253,7 +253,7 @@ struct llama_server_context if (token == -1) { return ""; } - if(as_loop) { + if(streaming) { generated_text = ""; } @@ -478,9 +478,9 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para bool parse_options_completion(json body, llama_server_context& llama, Response &res) { gpt_params default_params; - if (!body["as_loop"].is_null()) + if (!body["streaming"].is_null()) { - llama.as_loop = body["as_loop"].get(); + llama.streaming = body["streaming"].get(); } if (!body["n_predict"].is_null()) { @@ -718,11 +718,46 @@ int main(int argc, char **argv) } llama.beginCompletion(); - if(llama.as_loop) { - json data = { - {"status", "done" } }; - return res.set_content(data.dump(), "application/json"); - } else { + if(llama.streaming) + { + fprintf(stdout, "In streaming\n"); + res.set_chunked_content_provider("text/event-stream", [&](size_t /*offset*/, + DataSink& sink) { + std::string final_text = ""; + // loop inference until finish completion + while (llama.has_next_token) { + std::string result = llama.doCompletion(); + json data; + final_text += result; + fprintf(stdout, "Result: %s\n", result); + if (llama.has_next_token) + { + data = { {"content", result}, {"stop", false} }; + } + else + { + // Generation is done, send extra information. + data = { {"content", result}, + {"stop", true}, + {"tokens_predicted", llama.num_tokens_predicted}, + {"generation_settings", format_generation_settings(llama)}, + {"prompt", llama.params.prompt}, + {"stopping_word", llama.stopping_word}, + {"generated_text", final_text} }; + } + + std::string str = + "data: " + data.dump(4, ' ', false, json::error_handler_t::replace) + + "\n\n"; + sink.write(str.data(), str.size()); + } + + sink.done(); + return true; + }); + } + else + { // loop inference until finish completion while (llama.has_next_token) { @@ -774,71 +809,6 @@ int main(int argc, char **argv) return res.set_content(data.dump(), "application/json"); }); - svr.Get("/next-token", [&llama, &final_text](const Request &req, Response &res) - { - if(llama.params.embedding) { - res.set_content("{}", "application/json"); - return; - } - std::string result = ""; - if (req.has_param("stop")) { - llama.has_next_token = false; - } else { - result = llama.doCompletion(); // inference next token - } - try { - json data; - if (llama.has_next_token) - { - //fprintf(stdout, "Result: %s\n", result); - final_text += result; - data = { - {"content", result }, - {"stop", false } - }; - } - else - { - // Generation is done, send extra information. - data = { - {"content", result }, - {"stop", true }, - {"tokens_predicted", llama.num_tokens_predicted}, - {"generation_settings", format_generation_settings(llama)}, - {"prompt", llama.params.prompt}, - {"stopping_word", llama.stopping_word}, - {"generated_text", final_text} - }; - } - - return res.set_content(data.dump(), "application/json"); - } catch (const json::exception &e) { - // Some tokens have bad UTF-8 strings, the json parser is very sensitive - json data; - if (llama.has_next_token) - { - final_text += u8"\uFFFD"; - data = { - {"content", result }, - {"stop", false } - }; - } - else - { - // Generation is done, send extra information. - data = { - {"content", u8"\uFFFD" }, - {"stop", true }, - {"tokens_predicted", llama.num_tokens_predicted}, - {"generation_settings", format_generation_settings(llama)}, - {"prompt", llama.params.prompt}, - {"stopping_word", llama.stopping_word}, - {"generated_text", final_text} - }; - } - return res.set_content(data.dump(), "application/json"); - } - }); fprintf(stderr, "%s: http server Listening at http://%s:%i\n", __func__, sparams.hostname.c_str(), sparams.port);