Initial timeout code and expanded json return on completion.

Now passing server params to the help printer so they defaults are ouput.
Bad UTF while streaming now returns a replacement character (\uFFFD)
Changed some error language very slightly.
The JSON now returns extra values, only on `stop` for streaming requests.
New JSON Return Values:
  - tokens_predicted (added to streaming)
  - seed (just pulls it from params, might return -1)
  - prompt (Might be useful)
  - generated_text (Full generated response for streaming requests)
This commit is contained in:
digiwombat 2023-05-28 07:44:31 -04:00
parent 177868e68a
commit e8efd75492

View file

@ -7,6 +7,8 @@ struct server_params
{ {
std::string hostname = "127.0.0.1"; std::string hostname = "127.0.0.1";
int32_t port = 8080; int32_t port = 8080;
int32_t read_timeout = 600;
int32_t write_timeout = 600;
}; };
struct llama_server_context struct llama_server_context
@ -287,7 +289,7 @@ using namespace httplib;
using json = nlohmann::json; using json = nlohmann::json;
void server_print_usage(int /*argc*/, char **argv, const gpt_params &params) void server_print_usage(int /*argc*/, char **argv, const gpt_params &params, const server_params &sparams)
{ {
fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
@ -311,14 +313,16 @@ void server_print_usage(int /*argc*/, char **argv, const gpt_params &params)
fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
fprintf(stderr, " --host ip address to listen (default 127.0.0.1)\n"); fprintf(stderr, " --host ip address to listen (default (default: %d)\n", sparams.hostname);
fprintf(stderr, " --port PORT port to listen (default 8080)\n"); fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port);
fprintf(stderr, " -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_params &params) bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_params &params)
{ {
gpt_params default_params; gpt_params default_params;
server_params default_sparams;
std::string arg; std::string arg;
bool invalid_param = false; bool invalid_param = false;
@ -343,6 +347,15 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
} }
sparams.hostname = argv[i]; sparams.hostname = argv[i];
} }
else if (arg == "--timeout" || arg == "-to")
{
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.read_timeout = std::stoi(argv[i]);
sparams.write_timeout = std::stoi(argv[i]);
}
else if (arg == "-m" || arg == "--model") else if (arg == "-m" || arg == "--model")
{ {
if (++i >= argc) if (++i >= argc)
@ -358,7 +371,7 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
} }
else if (arg == "-h" || arg == "--help") else if (arg == "-h" || arg == "--help")
{ {
server_print_usage(argc, argv, default_params); server_print_usage(argc, argv, default_params, default_sparams);
exit(0); exit(0);
} }
else if (arg == "-c" || arg == "--ctx_size") else if (arg == "-c" || arg == "--ctx_size")
@ -421,7 +434,7 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
else else
{ {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
server_print_usage(argc, argv, default_params); server_print_usage(argc, argv, default_params, default_sparams);
exit(1); exit(1);
} }
} }
@ -429,7 +442,7 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
if (invalid_param) if (invalid_param)
{ {
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
server_print_usage(argc, argv, default_params); server_print_usage(argc, argv, default_params, default_sparams);
exit(1); exit(1);
} }
return true; return true;
@ -538,18 +551,13 @@ int main(int argc, char **argv)
llama_server_context llama; llama_server_context llama;
params.model = "ggml-model.bin"; params.model = "ggml-model.bin";
std::string final_text = "";
if (server_params_parse(argc, argv, sparams, params) == false) if (server_params_parse(argc, argv, sparams, params) == false)
{ {
return 1; return 1;
} }
if (params.seed <= 0)
{
params.seed = time(NULL);
}
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
// load the model // load the model
if (!llama.loadModel(params)) if (!llama.loadModel(params))
{ {
@ -561,18 +569,19 @@ int main(int argc, char **argv)
svr.Get("/", [](const Request &, Response &res) svr.Get("/", [](const Request &, Response &res)
{ res.set_content("<h1>llama.cpp server works</h1>", "text/html"); }); { res.set_content("<h1>llama.cpp server works</h1>", "text/html"); });
svr.Post("/completion", [&llama](const Request &req, Response &res) svr.Post("/completion", [&llama, &final_text](const Request &req, Response &res)
{ {
if(llama.params.embedding) { if(llama.params.embedding) {
json data = { json data = {
{"status", "error"}, {"status", "error"},
{"reason", "To use completion function disable embedding mode"}}; {"reason", "To use completion function, disable embedding mode"}};
res.set_content(data.dump(), "application/json"); res.set_content(data.dump(), "application/json");
res.status = 400; res.status = 400;
return; return;
} }
llama.rewind(); llama.rewind();
final_text = "";
if(parse_options_completion(json::parse(req.body), llama, res) == false){ if(parse_options_completion(json::parse(req.body), llama, res) == false){
return; return;
@ -582,7 +591,7 @@ int main(int argc, char **argv)
{ {
json data = { json data = {
{"status", "error"}, {"status", "error"},
{"reason", "Context too long, please be more specific"}}; {"reason", "Context too long."}};
res.set_content(data.dump(), "application/json"); res.set_content(data.dump(), "application/json");
res.status = 400; res.status = 400;
return; return;
@ -603,7 +612,9 @@ int main(int argc, char **argv)
{ {
json data = { json data = {
{"content", llama.generated_text }, {"content", llama.generated_text },
{"tokens_predicted", llama.num_tokens_predicted}}; {"tokens_predicted", llama.num_tokens_predicted},
{"seed", llama.params.seed},
{"prompt", llama.params.prompt} };
return res.set_content(data.dump(), "application/json"); return res.set_content(data.dump(), "application/json");
} }
catch (const json::exception &e) catch (const json::exception &e)
@ -641,7 +652,7 @@ int main(int argc, char **argv)
return res.set_content(data.dump(), "application/json"); return res.set_content(data.dump(), "application/json");
}); });
svr.Get("/next-token", [&llama](const Request &req, Response &res) svr.Get("/next-token", [&llama, &final_text](const Request &req, Response &res)
{ {
if(llama.params.embedding) { if(llama.params.embedding) {
res.set_content("{}", "application/json"); res.set_content("{}", "application/json");
@ -654,15 +665,52 @@ int main(int argc, char **argv)
result = llama.doCompletion(); // inference next token result = llama.doCompletion(); // inference next token
} }
try { try {
json data = { json data;
if (llama.has_next_token)
{
final_text += result;
data = {
{"content", result }, {"content", result },
{"stop", !llama.has_next_token }}; {"stop", false }
};
}
else
{
// Generation is done, send extra information.
data = {
{"content", result },
{"stop", true },
{"tokens_predicted", llama.num_tokens_predicted},
{"seed", llama.params.seed},
{"prompt", llama.params.prompt},
{"generated_text", final_text}
};
}
return res.set_content(data.dump(), "application/json"); return res.set_content(data.dump(), "application/json");
} catch (const json::exception &e) { } catch (const json::exception &e) {
// Some tokens have bad UTF-8 strings, the json parser is very sensitive // Some tokens have bad UTF-8 strings, the json parser is very sensitive
json data = { json data;
{"content", "" }, if (llama.has_next_token)
{"stop", !llama.has_next_token }}; {
final_text += u8"\uFFFD";
data = {
{"content", result },
{"stop", false }
};
}
else
{
// Generation is done, send extra information.
data = {
{"content", "\uFFFD" },
{"stop", true },
{"tokens_predicted", llama.num_tokens_predicted},
{"seed", llama.params.seed},
{"prompt", llama.params.prompt},
{"generated_text", final_text}
};
}
return res.set_content(data.dump(), "application/json"); return res.set_content(data.dump(), "application/json");
} }
}); });
@ -673,6 +721,9 @@ int main(int argc, char **argv)
fprintf(stderr, "NOTE: Mode embedding enabled. Completion function doesn't work in this mode.\n"); fprintf(stderr, "NOTE: Mode embedding enabled. Completion function doesn't work in this mode.\n");
} }
// change hostname and port // set timeouts and change hostname and port
svr.set_read_timeout(sparams.read_timeout);
svr.set_write_timeout(sparams.write_timeout);
svr.listen(sparams.hostname, sparams.port); svr.listen(sparams.hostname, sparams.port);
} }