add streaming via server-sent events
Removes /next-token endpoint and adds a "stream" parameter to the /completion one.
This commit is contained in:
parent
03ea8f013a
commit
d6fff56e22
1 changed files with 80 additions and 133 deletions
|
@ -13,7 +13,7 @@ struct server_params
|
||||||
|
|
||||||
struct llama_server_context
|
struct llama_server_context
|
||||||
{
|
{
|
||||||
bool as_loop = false;
|
bool stream = false;
|
||||||
bool has_next_token = false;
|
bool has_next_token = false;
|
||||||
std::string generated_text = "";
|
std::string generated_text = "";
|
||||||
|
|
||||||
|
@ -35,7 +35,6 @@ struct llama_server_context
|
||||||
std::string stopping_word;
|
std::string stopping_word;
|
||||||
|
|
||||||
void rewind() {
|
void rewind() {
|
||||||
as_loop = false;
|
|
||||||
params.antiprompt.clear();
|
params.antiprompt.clear();
|
||||||
num_tokens_predicted = 0;
|
num_tokens_predicted = 0;
|
||||||
generated_text = "";
|
generated_text = "";
|
||||||
|
@ -253,9 +252,6 @@ struct llama_server_context
|
||||||
if (token == -1) {
|
if (token == -1) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
if(as_loop) {
|
|
||||||
generated_text = "";
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string token_text = llama_token_to_str(ctx, token);
|
std::string token_text = llama_token_to_str(ctx, token);
|
||||||
generated_text += token_text;
|
generated_text += token_text;
|
||||||
|
@ -270,7 +266,7 @@ struct llama_server_context
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return generated_text;
|
return token_text;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> embedding(std::string content, int threads) {
|
std::vector<float> embedding(std::string content, int threads) {
|
||||||
|
@ -478,9 +474,13 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
|
||||||
|
|
||||||
bool parse_options_completion(json body, llama_server_context& llama, Response &res) {
|
bool parse_options_completion(json body, llama_server_context& llama, Response &res) {
|
||||||
gpt_params default_params;
|
gpt_params default_params;
|
||||||
if (!body["as_loop"].is_null())
|
if (!body["stream"].is_null())
|
||||||
{
|
{
|
||||||
llama.as_loop = body["as_loop"].get<bool>();
|
llama.stream = body["stream"].get<bool>();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
llama.stream = false;
|
||||||
}
|
}
|
||||||
if (!body["n_predict"].is_null())
|
if (!body["n_predict"].is_null())
|
||||||
{
|
{
|
||||||
|
@ -671,8 +671,6 @@ int main(int argc, char **argv)
|
||||||
llama_server_context llama;
|
llama_server_context llama;
|
||||||
params.model = "ggml-model.bin";
|
params.model = "ggml-model.bin";
|
||||||
|
|
||||||
std::string final_text;
|
|
||||||
|
|
||||||
if (server_params_parse(argc, argv, sparams, params) == false)
|
if (server_params_parse(argc, argv, sparams, params) == false)
|
||||||
{
|
{
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -689,65 +687,80 @@ int main(int argc, char **argv)
|
||||||
svr.Get("/", [](const Request &, Response &res)
|
svr.Get("/", [](const Request &, Response &res)
|
||||||
{ res.set_content("<h1>llama.cpp server works</h1>", "text/html"); });
|
{ res.set_content("<h1>llama.cpp server works</h1>", "text/html"); });
|
||||||
|
|
||||||
svr.Post("/completion", [&llama, &final_text](const Request &req, Response &res)
|
svr.Post("/completion", [&llama](const Request &req, Response &res) {
|
||||||
{
|
if (llama.params.embedding) {
|
||||||
if(llama.params.embedding) {
|
|
||||||
json data = {
|
json data = {
|
||||||
{"status", "error"},
|
{"status", "error"},
|
||||||
{"reason", "To use completion function, disable embedding mode"}};
|
{"reason", "To use completion function, disable embedding mode"}};
|
||||||
res.set_content(data.dump(), "application/json");
|
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
|
||||||
|
"application/json");
|
||||||
res.status = 400;
|
res.status = 400;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama.rewind();
|
llama.rewind();
|
||||||
final_text = "";
|
|
||||||
|
|
||||||
if(parse_options_completion(json::parse(req.body), llama, res) == false){
|
if (parse_options_completion(json::parse(req.body), llama, res) == false) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!llama.loadPrompt())
|
if (!llama.loadPrompt()) {
|
||||||
{
|
json data = {{"status", "error"}, {"reason", "Context too long."}};
|
||||||
json data = {
|
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace),
|
||||||
{"status", "error"},
|
"application/json");
|
||||||
{"reason", "Context too long."}};
|
|
||||||
res.set_content(data.dump(), "application/json");
|
|
||||||
res.status = 400;
|
res.status = 400;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama.beginCompletion();
|
llama.beginCompletion();
|
||||||
if(llama.as_loop) {
|
|
||||||
json data = {
|
if (!llama.stream) {
|
||||||
{"status", "done" } };
|
while (llama.has_next_token) {
|
||||||
return res.set_content(data.dump(), "application/json");
|
|
||||||
} else {
|
|
||||||
// loop inference until finish completion
|
|
||||||
while (llama.has_next_token)
|
|
||||||
{
|
|
||||||
llama.doCompletion();
|
llama.doCompletion();
|
||||||
}
|
}
|
||||||
try
|
|
||||||
{
|
json data = {{"content", llama.generated_text},
|
||||||
json data = {
|
{"stop", true},
|
||||||
{"model", llama.params.model_alias },
|
{"model", llama.params.model_alias },
|
||||||
{"content", llama.generated_text },
|
|
||||||
{"tokens_predicted", llama.num_tokens_predicted},
|
{"tokens_predicted", llama.num_tokens_predicted},
|
||||||
{"generation_settings", format_generation_settings(llama)},
|
{"generation_settings", format_generation_settings(llama)},
|
||||||
{"prompt", llama.params.prompt},
|
{"prompt", llama.params.prompt},
|
||||||
{"stopping_word", llama.stopping_word} };
|
{"stopping_word", llama.stopping_word}};
|
||||||
return res.set_content(data.dump(), "application/json");
|
return res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json");
|
||||||
|
} else {
|
||||||
|
const auto chunked_content_provider = [&](size_t, DataSink &sink) {
|
||||||
|
while (llama.has_next_token) {
|
||||||
|
std::string token_text = llama.doCompletion();
|
||||||
|
|
||||||
|
json data;
|
||||||
|
if (llama.has_next_token) {
|
||||||
|
data = {{"content", token_text}, {"stop", false}};
|
||||||
|
} else {
|
||||||
|
// Generation is done, send extra information.
|
||||||
|
data = {
|
||||||
|
{"content", token_text},
|
||||||
|
{"stop", true},
|
||||||
|
{"model", llama.params.model_alias},
|
||||||
|
{"tokens_predicted", llama.num_tokens_predicted},
|
||||||
|
{"generation_settings", format_generation_settings(llama)},
|
||||||
|
{"prompt", llama.params.prompt},
|
||||||
|
{"stopping_word", llama.stopping_word},
|
||||||
|
{"generated_text", llama.generated_text}};
|
||||||
}
|
}
|
||||||
catch (const json::exception &e)
|
|
||||||
{
|
std::string str =
|
||||||
// Some tokens have bad UTF-8 strings, the json parser is very sensitive
|
"data: " +
|
||||||
json data = {
|
data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||||
{"content", "Bad encoding token"},
|
"\n\n";
|
||||||
{"tokens_predicted", 0}};
|
sink.write(str.data(), str.size());
|
||||||
return res.set_content(data.dump(), "application/json");
|
|
||||||
}
|
}
|
||||||
} });
|
|
||||||
|
sink.done();
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
svr.Post("/tokenize", [&llama](const Request &req, Response &res)
|
svr.Post("/tokenize", [&llama](const Request &req, Response &res)
|
||||||
{
|
{
|
||||||
|
@ -774,72 +787,6 @@ int main(int argc, char **argv)
|
||||||
return res.set_content(data.dump(), "application/json");
|
return res.set_content(data.dump(), "application/json");
|
||||||
});
|
});
|
||||||
|
|
||||||
svr.Get("/next-token", [&llama, &final_text](const Request &req, Response &res)
|
|
||||||
{
|
|
||||||
if(llama.params.embedding) {
|
|
||||||
res.set_content("{}", "application/json");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
std::string result = "";
|
|
||||||
if (req.has_param("stop")) {
|
|
||||||
llama.has_next_token = false;
|
|
||||||
} else {
|
|
||||||
result = llama.doCompletion(); // inference next token
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
json data;
|
|
||||||
if (llama.has_next_token)
|
|
||||||
{
|
|
||||||
//fprintf(stdout, "Result: %s\n", result);
|
|
||||||
final_text += result;
|
|
||||||
data = {
|
|
||||||
{"content", result },
|
|
||||||
{"stop", false }
|
|
||||||
};
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
// Generation is done, send extra information.
|
|
||||||
data = {
|
|
||||||
{"content", result },
|
|
||||||
{"stop", true },
|
|
||||||
{"tokens_predicted", llama.num_tokens_predicted},
|
|
||||||
{"generation_settings", format_generation_settings(llama)},
|
|
||||||
{"prompt", llama.params.prompt},
|
|
||||||
{"stopping_word", llama.stopping_word},
|
|
||||||
{"generated_text", final_text}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return res.set_content(data.dump(), "application/json");
|
|
||||||
} catch (const json::exception &e) {
|
|
||||||
// Some tokens have bad UTF-8 strings, the json parser is very sensitive
|
|
||||||
json data;
|
|
||||||
if (llama.has_next_token)
|
|
||||||
{
|
|
||||||
final_text += u8"\uFFFD";
|
|
||||||
data = {
|
|
||||||
{"content", result },
|
|
||||||
{"stop", false }
|
|
||||||
};
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
// Generation is done, send extra information.
|
|
||||||
data = {
|
|
||||||
{"content", u8"\uFFFD" },
|
|
||||||
{"stop", true },
|
|
||||||
{"tokens_predicted", llama.num_tokens_predicted},
|
|
||||||
{"generation_settings", format_generation_settings(llama)},
|
|
||||||
{"prompt", llama.params.prompt},
|
|
||||||
{"stopping_word", llama.stopping_word},
|
|
||||||
{"generated_text", final_text}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
return res.set_content(data.dump(), "application/json");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
fprintf(stderr, "%s: http server Listening at http://%s:%i\n", __func__, sparams.hostname.c_str(), sparams.port);
|
fprintf(stderr, "%s: http server Listening at http://%s:%i\n", __func__, sparams.hostname.c_str(), sparams.port);
|
||||||
|
|
||||||
if(params.embedding) {
|
if(params.embedding) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue