diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2ab532763..01bbd9236 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -7,6 +7,8 @@ struct server_params { std::string hostname = "127.0.0.1"; int32_t port = 8080; + int32_t read_timeout = 600; + int32_t write_timeout = 600; }; struct llama_server_context @@ -287,7 +289,7 @@ using namespace httplib; using json = nlohmann::json; -void server_print_usage(int /*argc*/, char **argv, const gpt_params ¶ms) +void server_print_usage(int /*argc*/, char **argv, const gpt_params ¶ms, const server_params &sparams) { fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "\n"); @@ -311,14 +313,16 @@ void server_print_usage(int /*argc*/, char **argv, const gpt_params ¶ms) fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); - fprintf(stderr, " --host ip address to listen (default 127.0.0.1)\n"); - fprintf(stderr, " --port PORT port to listen (default 8080)\n"); + fprintf(stderr, " --host ip address to listen (default (default: %d)\n", sparams.hostname); + fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port); + fprintf(stderr, " -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); fprintf(stderr, "\n"); } bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_params ¶ms) { gpt_params default_params; + server_params default_sparams; std::string arg; bool invalid_param = false; @@ -343,6 +347,15 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para } sparams.hostname = argv[i]; } + else if (arg == "--timeout" || arg == "-to") + { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.read_timeout = std::stoi(argv[i]); + sparams.write_timeout = std::stoi(argv[i]); + } else if (arg == "-m" || arg == "--model") { if (++i >= argc) @@ -358,7 +371,7 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para } else if (arg == "-h" || arg == "--help") { - server_print_usage(argc, argv, default_params); + server_print_usage(argc, argv, default_params, default_sparams); exit(0); } else if (arg == "-c" || arg == "--ctx_size") @@ -421,7 +434,7 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); - server_print_usage(argc, argv, default_params); + server_print_usage(argc, argv, default_params, default_sparams); exit(1); } } @@ -429,7 +442,7 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para if (invalid_param) { fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); - server_print_usage(argc, argv, default_params); + server_print_usage(argc, argv, default_params, default_sparams); exit(1); } return true; @@ -538,18 +551,13 @@ int main(int argc, char **argv) llama_server_context llama; params.model = "ggml-model.bin"; + std::string final_text = ""; + if (server_params_parse(argc, argv, sparams, params) == false) { return 1; } - if (params.seed <= 0) - { - params.seed = time(NULL); - } - - fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); - // load the model if (!llama.loadModel(params)) { @@ -561,18 +569,19 @@ int main(int argc, char **argv) svr.Get("/", [](const Request &, Response &res) { res.set_content("

llama.cpp server works

", "text/html"); }); - svr.Post("/completion", [&llama](const Request &req, Response &res) + svr.Post("/completion", [&llama, &final_text](const Request &req, Response &res) { if(llama.params.embedding) { json data = { {"status", "error"}, - {"reason", "To use completion function disable embedding mode"}}; + {"reason", "To use completion function, disable embedding mode"}}; res.set_content(data.dump(), "application/json"); res.status = 400; return; } llama.rewind(); + final_text = ""; if(parse_options_completion(json::parse(req.body), llama, res) == false){ return; @@ -582,7 +591,7 @@ int main(int argc, char **argv) { json data = { {"status", "error"}, - {"reason", "Context too long, please be more specific"}}; + {"reason", "Context too long."}}; res.set_content(data.dump(), "application/json"); res.status = 400; return; @@ -603,7 +612,9 @@ int main(int argc, char **argv) { json data = { {"content", llama.generated_text }, - {"tokens_predicted", llama.num_tokens_predicted}}; + {"tokens_predicted", llama.num_tokens_predicted}, + {"seed", llama.params.seed}, + {"prompt", llama.params.prompt} }; return res.set_content(data.dump(), "application/json"); } catch (const json::exception &e) @@ -641,7 +652,7 @@ int main(int argc, char **argv) return res.set_content(data.dump(), "application/json"); }); - svr.Get("/next-token", [&llama](const Request &req, Response &res) + svr.Get("/next-token", [&llama, &final_text](const Request &req, Response &res) { if(llama.params.embedding) { res.set_content("{}", "application/json"); @@ -654,15 +665,52 @@ int main(int argc, char **argv) result = llama.doCompletion(); // inference next token } try { - json data = { + json data; + if (llama.has_next_token) + { + final_text += result; + data = { {"content", result }, - {"stop", !llama.has_next_token }}; + {"stop", false } + }; + } + else + { + // Generation is done, send extra information. + data = { + {"content", result }, + {"stop", true }, + {"tokens_predicted", llama.num_tokens_predicted}, + {"seed", llama.params.seed}, + {"prompt", llama.params.prompt}, + {"generated_text", final_text} + }; + } + return res.set_content(data.dump(), "application/json"); } catch (const json::exception &e) { // Some tokens have bad UTF-8 strings, the json parser is very sensitive - json data = { - {"content", "" }, - {"stop", !llama.has_next_token }}; + json data; + if (llama.has_next_token) + { + final_text += u8"\uFFFD"; + data = { + {"content", result }, + {"stop", false } + }; + } + else + { + // Generation is done, send extra information. + data = { + {"content", "\uFFFD" }, + {"stop", true }, + {"tokens_predicted", llama.num_tokens_predicted}, + {"seed", llama.params.seed}, + {"prompt", llama.params.prompt}, + {"generated_text", final_text} + }; + } return res.set_content(data.dump(), "application/json"); } }); @@ -673,6 +721,9 @@ int main(int argc, char **argv) fprintf(stderr, "NOTE: Mode embedding enabled. Completion function doesn't work in this mode.\n"); } - // change hostname and port + // set timeouts and change hostname and port + svr.set_read_timeout(sparams.read_timeout); + svr.set_write_timeout(sparams.write_timeout); svr.listen(sparams.hostname, sparams.port); + }