add streaming via server-sent events

Removes /next-token endpoint and adds a "stream" parameter to the
/completion one.
This commit is contained in:
anon 2023-05-30 19:33:33 -03:00
parent 03ea8f013a
commit d6fff56e22

View file

@ -13,7 +13,7 @@ struct server_params
struct llama_server_context struct llama_server_context
{ {
bool as_loop = false; bool stream = false;
bool has_next_token = false; bool has_next_token = false;
std::string generated_text = ""; std::string generated_text = "";
@ -35,7 +35,6 @@ struct llama_server_context
std::string stopping_word; std::string stopping_word;
void rewind() { void rewind() {
as_loop = false;
params.antiprompt.clear(); params.antiprompt.clear();
num_tokens_predicted = 0; num_tokens_predicted = 0;
generated_text = ""; generated_text = "";
@ -253,9 +252,6 @@ struct llama_server_context
if (token == -1) { if (token == -1) {
return ""; return "";
} }
if(as_loop) {
generated_text = "";
}
std::string token_text = llama_token_to_str(ctx, token); std::string token_text = llama_token_to_str(ctx, token);
generated_text += token_text; generated_text += token_text;
@ -270,7 +266,7 @@ struct llama_server_context
} }
} }
return generated_text; return token_text;
} }
std::vector<float> embedding(std::string content, int threads) { std::vector<float> embedding(std::string content, int threads) {
@ -478,9 +474,13 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
bool parse_options_completion(json body, llama_server_context& llama, Response &res) { bool parse_options_completion(json body, llama_server_context& llama, Response &res) {
gpt_params default_params; gpt_params default_params;
if (!body["as_loop"].is_null()) if (!body["stream"].is_null())
{ {
llama.as_loop = body["as_loop"].get<bool>(); llama.stream = body["stream"].get<bool>();
}
else
{
llama.stream = false;
} }
if (!body["n_predict"].is_null()) if (!body["n_predict"].is_null())
{ {
@ -671,8 +671,6 @@ int main(int argc, char **argv)
llama_server_context llama; llama_server_context llama;
params.model = "ggml-model.bin"; params.model = "ggml-model.bin";
std::string final_text;
if (server_params_parse(argc, argv, sparams, params) == false) if (server_params_parse(argc, argv, sparams, params) == false)
{ {
return 1; return 1;
@ -689,65 +687,80 @@ int main(int argc, char **argv)
svr.Get("/", [](const Request &, Response &res) svr.Get("/", [](const Request &, Response &res)
{ res.set_content("<h1>llama.cpp server works</h1>", "text/html"); }); { res.set_content("<h1>llama.cpp server works</h1>", "text/html"); });
svr.Post("/completion", [&llama, &final_text](const Request &req, Response &res) svr.Post("/completion", [&llama](const Request &req, Response &res) {
{
if (llama.params.embedding) { if (llama.params.embedding) {
json data = { json data = {
{"status", "error"}, {"status", "error"},
{"reason", "To use completion function, disable embedding mode"}}; {"reason", "To use completion function, disable embedding mode"}};
res.set_content(data.dump(), "application/json"); res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
"application/json");
res.status = 400; res.status = 400;
return; return;
} }
llama.rewind(); llama.rewind();
final_text = "";
if (parse_options_completion(json::parse(req.body), llama, res) == false) { if (parse_options_completion(json::parse(req.body), llama, res) == false) {
return; return;
} }
if (!llama.loadPrompt()) if (!llama.loadPrompt()) {
{ json data = {{"status", "error"}, {"reason", "Context too long."}};
json data = { res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
{"status", "error"}, "application/json");
{"reason", "Context too long."}};
res.set_content(data.dump(), "application/json");
res.status = 400; res.status = 400;
return; return;
} }
llama.beginCompletion(); llama.beginCompletion();
if(llama.as_loop) {
json data = { if (!llama.stream) {
{"status", "done" } }; while (llama.has_next_token) {
return res.set_content(data.dump(), "application/json");
} else {
// loop inference until finish completion
while (llama.has_next_token)
{
llama.doCompletion(); llama.doCompletion();
} }
try
{ json data = {{"content", llama.generated_text},
json data = { {"stop", true},
{"model", llama.params.model_alias }, {"model", llama.params.model_alias },
{"content", llama.generated_text },
{"tokens_predicted", llama.num_tokens_predicted}, {"tokens_predicted", llama.num_tokens_predicted},
{"generation_settings", format_generation_settings(llama)}, {"generation_settings", format_generation_settings(llama)},
{"prompt", llama.params.prompt}, {"prompt", llama.params.prompt},
{"stopping_word", llama.stopping_word}}; {"stopping_word", llama.stopping_word}};
return res.set_content(data.dump(), "application/json"); return res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json");
} else {
const auto chunked_content_provider = [&](size_t, DataSink &sink) {
while (llama.has_next_token) {
std::string token_text = llama.doCompletion();
json data;
if (llama.has_next_token) {
data = {{"content", token_text}, {"stop", false}};
} else {
// Generation is done, send extra information.
data = {
{"content", token_text},
{"stop", true},
{"model", llama.params.model_alias},
{"tokens_predicted", llama.num_tokens_predicted},
{"generation_settings", format_generation_settings(llama)},
{"prompt", llama.params.prompt},
{"stopping_word", llama.stopping_word},
{"generated_text", llama.generated_text}};
} }
catch (const json::exception &e)
{ std::string str =
// Some tokens have bad UTF-8 strings, the json parser is very sensitive "data: " +
json data = { data.dump(-1, ' ', false, json::error_handler_t::replace) +
{"content", "Bad encoding token"}, "\n\n";
{"tokens_predicted", 0}}; sink.write(str.data(), str.size());
return res.set_content(data.dump(), "application/json");
} }
} });
sink.done();
return true;
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
}
});
svr.Post("/tokenize", [&llama](const Request &req, Response &res) svr.Post("/tokenize", [&llama](const Request &req, Response &res)
{ {
@ -774,72 +787,6 @@ int main(int argc, char **argv)
return res.set_content(data.dump(), "application/json"); return res.set_content(data.dump(), "application/json");
}); });
svr.Get("/next-token", [&llama, &final_text](const Request &req, Response &res)
{
if(llama.params.embedding) {
res.set_content("{}", "application/json");
return;
}
std::string result = "";
if (req.has_param("stop")) {
llama.has_next_token = false;
} else {
result = llama.doCompletion(); // inference next token
}
try {
json data;
if (llama.has_next_token)
{
//fprintf(stdout, "Result: %s\n", result);
final_text += result;
data = {
{"content", result },
{"stop", false }
};
}
else
{
// Generation is done, send extra information.
data = {
{"content", result },
{"stop", true },
{"tokens_predicted", llama.num_tokens_predicted},
{"generation_settings", format_generation_settings(llama)},
{"prompt", llama.params.prompt},
{"stopping_word", llama.stopping_word},
{"generated_text", final_text}
};
}
return res.set_content(data.dump(), "application/json");
} catch (const json::exception &e) {
// Some tokens have bad UTF-8 strings, the json parser is very sensitive
json data;
if (llama.has_next_token)
{
final_text += u8"\uFFFD";
data = {
{"content", result },
{"stop", false }
};
}
else
{
// Generation is done, send extra information.
data = {
{"content", u8"\uFFFD" },
{"stop", true },
{"tokens_predicted", llama.num_tokens_predicted},
{"generation_settings", format_generation_settings(llama)},
{"prompt", llama.params.prompt},
{"stopping_word", llama.stopping_word},
{"generated_text", final_text}
};
}
return res.set_content(data.dump(), "application/json");
}
});
fprintf(stderr, "%s: http server Listening at http://%s:%i\n", __func__, sparams.hostname.c_str(), sparams.port); fprintf(stderr, "%s: http server Listening at http://%s:%i\n", __func__, sparams.hostname.c_str(), sparams.port);
if(params.embedding) { if(params.embedding) {