Initial timeout code and expanded json return on completion.
Now passing server params to the help printer so they defaults are ouput. Bad UTF while streaming now returns a replacement character (\uFFFD) Changed some error language very slightly. The JSON now returns extra values, only on `stop` for streaming requests. New JSON Return Values: - tokens_predicted (added to streaming) - seed (just pulls it from params, might return -1) - prompt (Might be useful) - generated_text (Full generated response for streaming requests)
This commit is contained in:
parent
177868e68a
commit
e8efd75492
1 changed files with 75 additions and 24 deletions
|
@ -7,6 +7,8 @@ struct server_params
|
|||
{
|
||||
std::string hostname = "127.0.0.1";
|
||||
int32_t port = 8080;
|
||||
int32_t read_timeout = 600;
|
||||
int32_t write_timeout = 600;
|
||||
};
|
||||
|
||||
struct llama_server_context
|
||||
|
@ -287,7 +289,7 @@ using namespace httplib;
|
|||
|
||||
using json = nlohmann::json;
|
||||
|
||||
void server_print_usage(int /*argc*/, char **argv, const gpt_params ¶ms)
|
||||
void server_print_usage(int /*argc*/, char **argv, const gpt_params ¶ms, const server_params &sparams)
|
||||
{
|
||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
|
@ -311,14 +313,16 @@ void server_print_usage(int /*argc*/, char **argv, const gpt_params ¶ms)
|
|||
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
||||
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
|
||||
fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
|
||||
fprintf(stderr, " --host ip address to listen (default 127.0.0.1)\n");
|
||||
fprintf(stderr, " --port PORT port to listen (default 8080)\n");
|
||||
fprintf(stderr, " --host ip address to listen (default (default: %d)\n", sparams.hostname);
|
||||
fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port);
|
||||
fprintf(stderr, " -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_params ¶ms)
|
||||
{
|
||||
gpt_params default_params;
|
||||
server_params default_sparams;
|
||||
std::string arg;
|
||||
bool invalid_param = false;
|
||||
|
||||
|
@ -343,6 +347,15 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
|
|||
}
|
||||
sparams.hostname = argv[i];
|
||||
}
|
||||
else if (arg == "--timeout" || arg == "-to")
|
||||
{
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
sparams.read_timeout = std::stoi(argv[i]);
|
||||
sparams.write_timeout = std::stoi(argv[i]);
|
||||
}
|
||||
else if (arg == "-m" || arg == "--model")
|
||||
{
|
||||
if (++i >= argc)
|
||||
|
@ -358,7 +371,7 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
|
|||
}
|
||||
else if (arg == "-h" || arg == "--help")
|
||||
{
|
||||
server_print_usage(argc, argv, default_params);
|
||||
server_print_usage(argc, argv, default_params, default_sparams);
|
||||
exit(0);
|
||||
}
|
||||
else if (arg == "-c" || arg == "--ctx_size")
|
||||
|
@ -421,7 +434,7 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
|
|||
else
|
||||
{
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
server_print_usage(argc, argv, default_params);
|
||||
server_print_usage(argc, argv, default_params, default_sparams);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
@ -429,7 +442,7 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
|
|||
if (invalid_param)
|
||||
{
|
||||
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
|
||||
server_print_usage(argc, argv, default_params);
|
||||
server_print_usage(argc, argv, default_params, default_sparams);
|
||||
exit(1);
|
||||
}
|
||||
return true;
|
||||
|
@ -538,18 +551,13 @@ int main(int argc, char **argv)
|
|||
llama_server_context llama;
|
||||
params.model = "ggml-model.bin";
|
||||
|
||||
std::string final_text = "";
|
||||
|
||||
if (server_params_parse(argc, argv, sparams, params) == false)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.seed <= 0)
|
||||
{
|
||||
params.seed = time(NULL);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
|
||||
|
||||
// load the model
|
||||
if (!llama.loadModel(params))
|
||||
{
|
||||
|
@ -561,18 +569,19 @@ int main(int argc, char **argv)
|
|||
svr.Get("/", [](const Request &, Response &res)
|
||||
{ res.set_content("<h1>llama.cpp server works</h1>", "text/html"); });
|
||||
|
||||
svr.Post("/completion", [&llama](const Request &req, Response &res)
|
||||
svr.Post("/completion", [&llama, &final_text](const Request &req, Response &res)
|
||||
{
|
||||
if(llama.params.embedding) {
|
||||
json data = {
|
||||
{"status", "error"},
|
||||
{"reason", "To use completion function disable embedding mode"}};
|
||||
{"reason", "To use completion function, disable embedding mode"}};
|
||||
res.set_content(data.dump(), "application/json");
|
||||
res.status = 400;
|
||||
return;
|
||||
}
|
||||
|
||||
llama.rewind();
|
||||
final_text = "";
|
||||
|
||||
if(parse_options_completion(json::parse(req.body), llama, res) == false){
|
||||
return;
|
||||
|
@ -582,7 +591,7 @@ int main(int argc, char **argv)
|
|||
{
|
||||
json data = {
|
||||
{"status", "error"},
|
||||
{"reason", "Context too long, please be more specific"}};
|
||||
{"reason", "Context too long."}};
|
||||
res.set_content(data.dump(), "application/json");
|
||||
res.status = 400;
|
||||
return;
|
||||
|
@ -603,7 +612,9 @@ int main(int argc, char **argv)
|
|||
{
|
||||
json data = {
|
||||
{"content", llama.generated_text },
|
||||
{"tokens_predicted", llama.num_tokens_predicted}};
|
||||
{"tokens_predicted", llama.num_tokens_predicted},
|
||||
{"seed", llama.params.seed},
|
||||
{"prompt", llama.params.prompt} };
|
||||
return res.set_content(data.dump(), "application/json");
|
||||
}
|
||||
catch (const json::exception &e)
|
||||
|
@ -641,7 +652,7 @@ int main(int argc, char **argv)
|
|||
return res.set_content(data.dump(), "application/json");
|
||||
});
|
||||
|
||||
svr.Get("/next-token", [&llama](const Request &req, Response &res)
|
||||
svr.Get("/next-token", [&llama, &final_text](const Request &req, Response &res)
|
||||
{
|
||||
if(llama.params.embedding) {
|
||||
res.set_content("{}", "application/json");
|
||||
|
@ -654,15 +665,52 @@ int main(int argc, char **argv)
|
|||
result = llama.doCompletion(); // inference next token
|
||||
}
|
||||
try {
|
||||
json data = {
|
||||
json data;
|
||||
if (llama.has_next_token)
|
||||
{
|
||||
final_text += result;
|
||||
data = {
|
||||
{"content", result },
|
||||
{"stop", !llama.has_next_token }};
|
||||
{"stop", false }
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
// Generation is done, send extra information.
|
||||
data = {
|
||||
{"content", result },
|
||||
{"stop", true },
|
||||
{"tokens_predicted", llama.num_tokens_predicted},
|
||||
{"seed", llama.params.seed},
|
||||
{"prompt", llama.params.prompt},
|
||||
{"generated_text", final_text}
|
||||
};
|
||||
}
|
||||
|
||||
return res.set_content(data.dump(), "application/json");
|
||||
} catch (const json::exception &e) {
|
||||
// Some tokens have bad UTF-8 strings, the json parser is very sensitive
|
||||
json data = {
|
||||
{"content", "" },
|
||||
{"stop", !llama.has_next_token }};
|
||||
json data;
|
||||
if (llama.has_next_token)
|
||||
{
|
||||
final_text += u8"\uFFFD";
|
||||
data = {
|
||||
{"content", result },
|
||||
{"stop", false }
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
// Generation is done, send extra information.
|
||||
data = {
|
||||
{"content", "\uFFFD" },
|
||||
{"stop", true },
|
||||
{"tokens_predicted", llama.num_tokens_predicted},
|
||||
{"seed", llama.params.seed},
|
||||
{"prompt", llama.params.prompt},
|
||||
{"generated_text", final_text}
|
||||
};
|
||||
}
|
||||
return res.set_content(data.dump(), "application/json");
|
||||
}
|
||||
});
|
||||
|
@ -673,6 +721,9 @@ int main(int argc, char **argv)
|
|||
fprintf(stderr, "NOTE: Mode embedding enabled. Completion function doesn't work in this mode.\n");
|
||||
}
|
||||
|
||||
// change hostname and port
|
||||
// set timeouts and change hostname and port
|
||||
svr.set_read_timeout(sparams.read_timeout);
|
||||
svr.set_write_timeout(sparams.write_timeout);
|
||||
svr.listen(sparams.hostname, sparams.port);
|
||||
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue