Initial timeout code and expanded json return on completion.
Now passing server params to the help printer so they defaults are ouput. Bad UTF while streaming now returns a replacement character (\uFFFD) Changed some error language very slightly. The JSON now returns extra values, only on `stop` for streaming requests. New JSON Return Values: - tokens_predicted (added to streaming) - seed (just pulls it from params, might return -1) - prompt (Might be useful) - generated_text (Full generated response for streaming requests)
This commit is contained in:
parent
177868e68a
commit
e8efd75492
1 changed files with 75 additions and 24 deletions
|
@ -7,6 +7,8 @@ struct server_params
|
||||||
{
|
{
|
||||||
std::string hostname = "127.0.0.1";
|
std::string hostname = "127.0.0.1";
|
||||||
int32_t port = 8080;
|
int32_t port = 8080;
|
||||||
|
int32_t read_timeout = 600;
|
||||||
|
int32_t write_timeout = 600;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_server_context
|
struct llama_server_context
|
||||||
|
@ -287,7 +289,7 @@ using namespace httplib;
|
||||||
|
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
|
|
||||||
void server_print_usage(int /*argc*/, char **argv, const gpt_params ¶ms)
|
void server_print_usage(int /*argc*/, char **argv, const gpt_params ¶ms, const server_params &sparams)
|
||||||
{
|
{
|
||||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
|
@ -311,14 +313,16 @@ void server_print_usage(int /*argc*/, char **argv, const gpt_params ¶ms)
|
||||||
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
||||||
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
|
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
|
||||||
fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
|
fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
|
||||||
fprintf(stderr, " --host ip address to listen (default 127.0.0.1)\n");
|
fprintf(stderr, " --host ip address to listen (default (default: %d)\n", sparams.hostname);
|
||||||
fprintf(stderr, " --port PORT port to listen (default 8080)\n");
|
fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port);
|
||||||
|
fprintf(stderr, " -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_params ¶ms)
|
bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_params ¶ms)
|
||||||
{
|
{
|
||||||
gpt_params default_params;
|
gpt_params default_params;
|
||||||
|
server_params default_sparams;
|
||||||
std::string arg;
|
std::string arg;
|
||||||
bool invalid_param = false;
|
bool invalid_param = false;
|
||||||
|
|
||||||
|
@ -343,6 +347,15 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
|
||||||
}
|
}
|
||||||
sparams.hostname = argv[i];
|
sparams.hostname = argv[i];
|
||||||
}
|
}
|
||||||
|
else if (arg == "--timeout" || arg == "-to")
|
||||||
|
{
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
sparams.read_timeout = std::stoi(argv[i]);
|
||||||
|
sparams.write_timeout = std::stoi(argv[i]);
|
||||||
|
}
|
||||||
else if (arg == "-m" || arg == "--model")
|
else if (arg == "-m" || arg == "--model")
|
||||||
{
|
{
|
||||||
if (++i >= argc)
|
if (++i >= argc)
|
||||||
|
@ -358,7 +371,7 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
|
||||||
}
|
}
|
||||||
else if (arg == "-h" || arg == "--help")
|
else if (arg == "-h" || arg == "--help")
|
||||||
{
|
{
|
||||||
server_print_usage(argc, argv, default_params);
|
server_print_usage(argc, argv, default_params, default_sparams);
|
||||||
exit(0);
|
exit(0);
|
||||||
}
|
}
|
||||||
else if (arg == "-c" || arg == "--ctx_size")
|
else if (arg == "-c" || arg == "--ctx_size")
|
||||||
|
@ -421,7 +434,7 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||||
server_print_usage(argc, argv, default_params);
|
server_print_usage(argc, argv, default_params, default_sparams);
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -429,7 +442,7 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
|
||||||
if (invalid_param)
|
if (invalid_param)
|
||||||
{
|
{
|
||||||
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
|
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
|
||||||
server_print_usage(argc, argv, default_params);
|
server_print_usage(argc, argv, default_params, default_sparams);
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
@ -538,18 +551,13 @@ int main(int argc, char **argv)
|
||||||
llama_server_context llama;
|
llama_server_context llama;
|
||||||
params.model = "ggml-model.bin";
|
params.model = "ggml-model.bin";
|
||||||
|
|
||||||
|
std::string final_text = "";
|
||||||
|
|
||||||
if (server_params_parse(argc, argv, sparams, params) == false)
|
if (server_params_parse(argc, argv, sparams, params) == false)
|
||||||
{
|
{
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.seed <= 0)
|
|
||||||
{
|
|
||||||
params.seed = time(NULL);
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
|
|
||||||
|
|
||||||
// load the model
|
// load the model
|
||||||
if (!llama.loadModel(params))
|
if (!llama.loadModel(params))
|
||||||
{
|
{
|
||||||
|
@ -561,18 +569,19 @@ int main(int argc, char **argv)
|
||||||
svr.Get("/", [](const Request &, Response &res)
|
svr.Get("/", [](const Request &, Response &res)
|
||||||
{ res.set_content("<h1>llama.cpp server works</h1>", "text/html"); });
|
{ res.set_content("<h1>llama.cpp server works</h1>", "text/html"); });
|
||||||
|
|
||||||
svr.Post("/completion", [&llama](const Request &req, Response &res)
|
svr.Post("/completion", [&llama, &final_text](const Request &req, Response &res)
|
||||||
{
|
{
|
||||||
if(llama.params.embedding) {
|
if(llama.params.embedding) {
|
||||||
json data = {
|
json data = {
|
||||||
{"status", "error"},
|
{"status", "error"},
|
||||||
{"reason", "To use completion function disable embedding mode"}};
|
{"reason", "To use completion function, disable embedding mode"}};
|
||||||
res.set_content(data.dump(), "application/json");
|
res.set_content(data.dump(), "application/json");
|
||||||
res.status = 400;
|
res.status = 400;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama.rewind();
|
llama.rewind();
|
||||||
|
final_text = "";
|
||||||
|
|
||||||
if(parse_options_completion(json::parse(req.body), llama, res) == false){
|
if(parse_options_completion(json::parse(req.body), llama, res) == false){
|
||||||
return;
|
return;
|
||||||
|
@ -582,7 +591,7 @@ int main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
json data = {
|
json data = {
|
||||||
{"status", "error"},
|
{"status", "error"},
|
||||||
{"reason", "Context too long, please be more specific"}};
|
{"reason", "Context too long."}};
|
||||||
res.set_content(data.dump(), "application/json");
|
res.set_content(data.dump(), "application/json");
|
||||||
res.status = 400;
|
res.status = 400;
|
||||||
return;
|
return;
|
||||||
|
@ -603,7 +612,9 @@ int main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
json data = {
|
json data = {
|
||||||
{"content", llama.generated_text },
|
{"content", llama.generated_text },
|
||||||
{"tokens_predicted", llama.num_tokens_predicted}};
|
{"tokens_predicted", llama.num_tokens_predicted},
|
||||||
|
{"seed", llama.params.seed},
|
||||||
|
{"prompt", llama.params.prompt} };
|
||||||
return res.set_content(data.dump(), "application/json");
|
return res.set_content(data.dump(), "application/json");
|
||||||
}
|
}
|
||||||
catch (const json::exception &e)
|
catch (const json::exception &e)
|
||||||
|
@ -641,7 +652,7 @@ int main(int argc, char **argv)
|
||||||
return res.set_content(data.dump(), "application/json");
|
return res.set_content(data.dump(), "application/json");
|
||||||
});
|
});
|
||||||
|
|
||||||
svr.Get("/next-token", [&llama](const Request &req, Response &res)
|
svr.Get("/next-token", [&llama, &final_text](const Request &req, Response &res)
|
||||||
{
|
{
|
||||||
if(llama.params.embedding) {
|
if(llama.params.embedding) {
|
||||||
res.set_content("{}", "application/json");
|
res.set_content("{}", "application/json");
|
||||||
|
@ -654,15 +665,52 @@ int main(int argc, char **argv)
|
||||||
result = llama.doCompletion(); // inference next token
|
result = llama.doCompletion(); // inference next token
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
json data = {
|
json data;
|
||||||
|
if (llama.has_next_token)
|
||||||
|
{
|
||||||
|
final_text += result;
|
||||||
|
data = {
|
||||||
{"content", result },
|
{"content", result },
|
||||||
{"stop", !llama.has_next_token }};
|
{"stop", false }
|
||||||
|
};
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Generation is done, send extra information.
|
||||||
|
data = {
|
||||||
|
{"content", result },
|
||||||
|
{"stop", true },
|
||||||
|
{"tokens_predicted", llama.num_tokens_predicted},
|
||||||
|
{"seed", llama.params.seed},
|
||||||
|
{"prompt", llama.params.prompt},
|
||||||
|
{"generated_text", final_text}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
return res.set_content(data.dump(), "application/json");
|
return res.set_content(data.dump(), "application/json");
|
||||||
} catch (const json::exception &e) {
|
} catch (const json::exception &e) {
|
||||||
// Some tokens have bad UTF-8 strings, the json parser is very sensitive
|
// Some tokens have bad UTF-8 strings, the json parser is very sensitive
|
||||||
json data = {
|
json data;
|
||||||
{"content", "" },
|
if (llama.has_next_token)
|
||||||
{"stop", !llama.has_next_token }};
|
{
|
||||||
|
final_text += u8"\uFFFD";
|
||||||
|
data = {
|
||||||
|
{"content", result },
|
||||||
|
{"stop", false }
|
||||||
|
};
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Generation is done, send extra information.
|
||||||
|
data = {
|
||||||
|
{"content", "\uFFFD" },
|
||||||
|
{"stop", true },
|
||||||
|
{"tokens_predicted", llama.num_tokens_predicted},
|
||||||
|
{"seed", llama.params.seed},
|
||||||
|
{"prompt", llama.params.prompt},
|
||||||
|
{"generated_text", final_text}
|
||||||
|
};
|
||||||
|
}
|
||||||
return res.set_content(data.dump(), "application/json");
|
return res.set_content(data.dump(), "application/json");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -673,6 +721,9 @@ int main(int argc, char **argv)
|
||||||
fprintf(stderr, "NOTE: Mode embedding enabled. Completion function doesn't work in this mode.\n");
|
fprintf(stderr, "NOTE: Mode embedding enabled. Completion function doesn't work in this mode.\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
// change hostname and port
|
// set timeouts and change hostname and port
|
||||||
|
svr.set_read_timeout(sparams.read_timeout);
|
||||||
|
svr.set_write_timeout(sparams.write_timeout);
|
||||||
svr.listen(sparams.hostname, sparams.port);
|
svr.listen(sparams.hostname, sparams.port);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue