Changed to single API endpoint for streaming and non.
next-token endpoint removed. "as_loop" setting changed to "streaming"
This commit is contained in:
parent
03ea8f013a
commit
3292f057dc
1 changed files with 45 additions and 75 deletions
|
@ -13,7 +13,7 @@ struct server_params
|
|||
|
||||
struct llama_server_context
|
||||
{
|
||||
bool as_loop = false;
|
||||
bool streaming = false;
|
||||
bool has_next_token = false;
|
||||
std::string generated_text = "";
|
||||
|
||||
|
@ -35,7 +35,7 @@ struct llama_server_context
|
|||
std::string stopping_word;
|
||||
|
||||
void rewind() {
|
||||
as_loop = false;
|
||||
streaming = false;
|
||||
params.antiprompt.clear();
|
||||
num_tokens_predicted = 0;
|
||||
generated_text = "";
|
||||
|
@ -253,7 +253,7 @@ struct llama_server_context
|
|||
if (token == -1) {
|
||||
return "";
|
||||
}
|
||||
if(as_loop) {
|
||||
if(streaming) {
|
||||
generated_text = "";
|
||||
}
|
||||
|
||||
|
@ -478,9 +478,9 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
|
|||
|
||||
bool parse_options_completion(json body, llama_server_context& llama, Response &res) {
|
||||
gpt_params default_params;
|
||||
if (!body["as_loop"].is_null())
|
||||
if (!body["streaming"].is_null())
|
||||
{
|
||||
llama.as_loop = body["as_loop"].get<bool>();
|
||||
llama.streaming = body["streaming"].get<bool>();
|
||||
}
|
||||
if (!body["n_predict"].is_null())
|
||||
{
|
||||
|
@ -718,11 +718,46 @@ int main(int argc, char **argv)
|
|||
}
|
||||
|
||||
llama.beginCompletion();
|
||||
if(llama.as_loop) {
|
||||
json data = {
|
||||
{"status", "done" } };
|
||||
return res.set_content(data.dump(), "application/json");
|
||||
} else {
|
||||
if(llama.streaming)
|
||||
{
|
||||
fprintf(stdout, "In streaming\n");
|
||||
res.set_chunked_content_provider("text/event-stream", [&](size_t /*offset*/,
|
||||
DataSink& sink) {
|
||||
std::string final_text = "";
|
||||
// loop inference until finish completion
|
||||
while (llama.has_next_token) {
|
||||
std::string result = llama.doCompletion();
|
||||
json data;
|
||||
final_text += result;
|
||||
fprintf(stdout, "Result: %s\n", result);
|
||||
if (llama.has_next_token)
|
||||
{
|
||||
data = { {"content", result}, {"stop", false} };
|
||||
}
|
||||
else
|
||||
{
|
||||
// Generation is done, send extra information.
|
||||
data = { {"content", result},
|
||||
{"stop", true},
|
||||
{"tokens_predicted", llama.num_tokens_predicted},
|
||||
{"generation_settings", format_generation_settings(llama)},
|
||||
{"prompt", llama.params.prompt},
|
||||
{"stopping_word", llama.stopping_word},
|
||||
{"generated_text", final_text} };
|
||||
}
|
||||
|
||||
std::string str =
|
||||
"data: " + data.dump(4, ' ', false, json::error_handler_t::replace) +
|
||||
"\n\n";
|
||||
sink.write(str.data(), str.size());
|
||||
}
|
||||
|
||||
sink.done();
|
||||
return true;
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// loop inference until finish completion
|
||||
while (llama.has_next_token)
|
||||
{
|
||||
|
@ -774,71 +809,6 @@ int main(int argc, char **argv)
|
|||
return res.set_content(data.dump(), "application/json");
|
||||
});
|
||||
|
||||
svr.Get("/next-token", [&llama, &final_text](const Request &req, Response &res)
|
||||
{
|
||||
if(llama.params.embedding) {
|
||||
res.set_content("{}", "application/json");
|
||||
return;
|
||||
}
|
||||
std::string result = "";
|
||||
if (req.has_param("stop")) {
|
||||
llama.has_next_token = false;
|
||||
} else {
|
||||
result = llama.doCompletion(); // inference next token
|
||||
}
|
||||
try {
|
||||
json data;
|
||||
if (llama.has_next_token)
|
||||
{
|
||||
//fprintf(stdout, "Result: %s\n", result);
|
||||
final_text += result;
|
||||
data = {
|
||||
{"content", result },
|
||||
{"stop", false }
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
// Generation is done, send extra information.
|
||||
data = {
|
||||
{"content", result },
|
||||
{"stop", true },
|
||||
{"tokens_predicted", llama.num_tokens_predicted},
|
||||
{"generation_settings", format_generation_settings(llama)},
|
||||
{"prompt", llama.params.prompt},
|
||||
{"stopping_word", llama.stopping_word},
|
||||
{"generated_text", final_text}
|
||||
};
|
||||
}
|
||||
|
||||
return res.set_content(data.dump(), "application/json");
|
||||
} catch (const json::exception &e) {
|
||||
// Some tokens have bad UTF-8 strings, the json parser is very sensitive
|
||||
json data;
|
||||
if (llama.has_next_token)
|
||||
{
|
||||
final_text += u8"\uFFFD";
|
||||
data = {
|
||||
{"content", result },
|
||||
{"stop", false }
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
// Generation is done, send extra information.
|
||||
data = {
|
||||
{"content", u8"\uFFFD" },
|
||||
{"stop", true },
|
||||
{"tokens_predicted", llama.num_tokens_predicted},
|
||||
{"generation_settings", format_generation_settings(llama)},
|
||||
{"prompt", llama.params.prompt},
|
||||
{"stopping_word", llama.stopping_word},
|
||||
{"generated_text", final_text}
|
||||
};
|
||||
}
|
||||
return res.set_content(data.dump(), "application/json");
|
||||
}
|
||||
});
|
||||
|
||||
fprintf(stderr, "%s: http server Listening at http://%s:%i\n", __func__, sparams.hostname.c_str(), sparams.port);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue