diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d677cdd1a..3a87f5116 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -24,7 +24,6 @@ struct server_params { int32_t port = 8080; int32_t read_timeout = 600; int32_t write_timeout = 600; - bool verbose = false; }; static size_t common_part(const std::vector & a, const std::vector & b) { @@ -38,12 +37,13 @@ enum stop_type { STOP_PARTIAL, }; -bool ends_with(const std::string & str, const std::string & suffix) { +static bool ends_with(const std::string & str, const std::string & suffix) { return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); } -size_t find_partial_stop_string(const std::string & stop, const std::string & text) { +static size_t find_partial_stop_string(const std::string & stop, + const std::string & text) { if (!text.empty() && !stop.empty()) { const char text_last_char = text.back(); for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { @@ -67,9 +67,10 @@ static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { return ret; } -static void server_log(const char * level, const char * function, int line, const char * message, nlohmann::ordered_json extra) { +static void server_log(const char * level, const char * function, int line, + const char * message, const nlohmann::ordered_json & extra) { nlohmann::ordered_json log { - { "timestamp", time(NULL) }, + { "timestamp", time(nullptr) }, { "level", level }, { "function", function }, { "line", line }, @@ -80,19 +81,23 @@ static void server_log(const char * level, const char * function, int line, cons log.merge_patch(extra); } - std::string str = log.dump(-1, ' ', false, json::error_handler_t::replace); + const std::string str = log.dump(-1, ' ', false, json::error_handler_t::replace); fprintf(stdout, "%.*s\n", (int)str.size(), str.data()); fflush(stdout); } static bool server_verbose = false; -#define LOG_VERBOSE(MSG, ...) \ +#if SERVER_VERBOSE != 1 +# define LOG_VERBOSE(MSG, ...) +#else +# define LOG_VERBOSE(MSG, ...) \ do { \ if (server_verbose) { \ server_log("VERBOSE", __func__, __LINE__, MSG, __VA_ARGS__); \ } \ } while(0) +#endif #define LOG_ERROR(MSG, ...) server_log("ERROR", __func__, __LINE__, MSG, __VA_ARGS__) #define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__) @@ -101,7 +106,7 @@ static bool server_verbose = false; struct llama_server_context { bool stream = false; bool has_next_token = false; - std::string generated_text = ""; + std::string generated_text; size_t num_tokens_predicted = 0; size_t n_past = 0; @@ -118,8 +123,6 @@ struct llama_server_context { bool stopped_word = false; bool stopped_limit = false; std::string stopping_word; - - int json_indent = -1; int32_t multibyte_pending = 0; ~llama_server_context() { @@ -148,7 +151,7 @@ struct llama_server_context { bool loadModel(const gpt_params & params_) { params = params_; ctx = llama_init_from_gpt_params(params); - if (ctx == NULL) { + if (ctx == nullptr) { LOG_ERROR("unable to load model", { { "model", params_.model } }); return false; } @@ -265,7 +268,9 @@ struct llama_server_context { const float mirostat_tau = params.mirostat_tau; const float mirostat_eta = params.mirostat_eta; const bool penalize_nl = params.penalize_nl; - llama_token id = 0; { + llama_token id = 0; + + { auto * logits = llama_get_logits(ctx); auto n_vocab = llama_n_vocab(ctx); @@ -342,7 +347,7 @@ struct llama_server_context { } size_t findStoppingStrings(const std::string & text, const size_t last_token_size, - const stop_type type) { + const stop_type type) { size_t stop_pos = std::string::npos; for (const std::string & word : params.antiprompt) { size_t pos; @@ -368,9 +373,9 @@ struct llama_server_context { } std::string doCompletion() { - llama_token token = nextToken(); + const llama_token token = nextToken(); - std::string token_text = token == -1 ? "" : llama_token_to_str(ctx, token); + const std::string token_text = token == -1 ? "" : llama_token_to_str(ctx, token); generated_text += token_text; if (multibyte_pending > 0) { @@ -380,10 +385,10 @@ struct llama_server_context { // 2-byte characters: 110xxxxx 10xxxxxx if ((c & 0xE0) == 0xC0) { multibyte_pending = 1; - // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx + // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx } else if ((c & 0xF0) == 0xE0) { multibyte_pending = 2; - // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx } else if ((c & 0xF8) == 0xF0) { multibyte_pending = 3; } else { @@ -416,12 +421,13 @@ struct llama_server_context { } }; -void server_print_usage(const char * argv0, const gpt_params & params, const server_params & sparams) { +static void server_print_usage(const char * argv0, const gpt_params & params, + const server_params & sparams) { fprintf(stderr, "usage: %s [options]\n", argv0); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); fprintf(stderr, " -h, --help show this help message and exit\n"); - fprintf(stderr, " -v, --verbose verbose output (default: %s)\n", sparams.verbose ? "enabled" : "disabled"); + fprintf(stderr, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled"); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); @@ -453,8 +459,8 @@ void server_print_usage(const char * argv0, const gpt_params & params, const ser fprintf(stderr, "\n"); } -void server_params_parse(int argc, char ** argv, server_params & sparams, - gpt_params & params) { +static void server_params_parse(int argc, char ** argv, server_params & sparams, + gpt_params & params) { gpt_params default_params; server_params default_sparams; std::string arg; @@ -543,12 +549,12 @@ void server_params_parse(int argc, char ** argv, server_params & sparams, std::vector split_arg{ it, {} }; GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES); - for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) { - if (i < split_arg.size()) { - params.tensor_split[i] = std::stof(split_arg[i]); + for (size_t i_device = 0; i_device < LLAMA_MAX_DEVICES; ++i_device) { + if (i_device < split_arg.size()) { + params.tensor_split[i_device] = std::stof(split_arg[i_device]); } else { - params.tensor_split[i] = 0.0f; + params.tensor_split[i_device] = 0.0f; } } #else @@ -579,9 +585,10 @@ void server_params_parse(int argc, char ** argv, server_params & sparams, } params.lora_base = argv[i]; } else if (arg == "-v" || arg == "--verbose") { - sparams.verbose = true; #if SERVER_VERBOSE != 1 LOG_WARNING("server.cpp is not built with verbose logging.", {}); +#else + server_verbose = true; #endif } else if (arg == "--mlock") { params.use_mlock = true; @@ -659,7 +666,7 @@ static json format_tokenizer_response(const std::vector & tokens) { }; } -bool parse_options_completion(json body, llama_server_context & llama) { +static void parse_options_completion(const json & body, llama_server_context & llama) { gpt_params default_params; llama.stream = body.value("stream", false); @@ -687,7 +694,7 @@ bool parse_options_completion(json body, llama_server_context & llama) { } if (body["logit_bias"].is_array()) { - int n_vocab = llama_n_vocab(llama.ctx); + const int n_vocab = llama_n_vocab(llama.ctx); for (const auto & el : body["logit_bias"]) { if (el.is_array() && el.size() == 2 && el[0].is_number_integer()) { llama_token tok = el[0].get(); @@ -711,8 +718,6 @@ bool parse_options_completion(json body, llama_server_context & llama) { } LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama)); - - return true; } static void log_server_request(const Request & req, const Response & res) { @@ -736,10 +741,6 @@ int main(int argc, char ** argv) { server_params_parse(argc, argv, sparams, params); -#if SERVER_VERBOSE == 1 - server_verbose = sparams.verbose; -#endif - if (params.model_alias == "unknown") { params.model_alias = params.model; } @@ -773,13 +774,10 @@ int main(int argc, char ** argv) { }); svr.Post("/completion", [&llama](const Request & req, Response & res) { - llama.rewind(); llama_reset_timings(llama.ctx); - if (!parse_options_completion(json::parse(req.body), llama)) { - return; - } + parse_options_completion(json::parse(req.body), llama); llama.loadPrompt(); llama.beginCompletion(); @@ -802,15 +800,13 @@ int main(int argc, char ** argv) { llama.generated_text.end()); } - json data = format_final_response(llama, llama.generated_text); + const json data = format_final_response(llama, llama.generated_text); llama_print_timings(llama.ctx); - res.set_content( - data.dump(llama.json_indent, ' ', false, json::error_handler_t::replace), - "application/json"); - } - else { + res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), + "application/json"); + } else { const auto chunked_content_provider = [&](size_t, DataSink & sink) { size_t sent_count = 0; @@ -822,7 +818,7 @@ int main(int argc, char ** argv) { size_t pos = std::min(sent_count, llama.generated_text.size()); - const char* str_test = llama.generated_text.c_str() + pos; + const std::string str_test = llama.generated_text.substr(pos); size_t stop_pos = llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL); if (stop_pos != std::string::npos) { @@ -835,21 +831,17 @@ int main(int argc, char ** argv) { STOP_PARTIAL); } - std::string to_send = llama.generated_text.substr(pos, stop_pos); + const std::string to_send = llama.generated_text.substr(pos, stop_pos); sent_count += to_send.size(); - json data; - if (llama.has_next_token) { - data = format_partial_response(to_send); - } else { - // Generation is done, send extra information. - data = format_final_response(llama, to_send); - } + const json data = llama.has_next_token + ? format_partial_response(to_send) + // Generation is done, send extra information. + : format_final_response(llama, to_send); - std::string str = + const std::string str = "data: " + - data.dump(llama.has_next_token ? -1 : llama.json_indent, ' ', false, - json::error_handler_t::replace) + + data.dump(-1, ' ', false, json::error_handler_t::replace) + "\n\n"; LOG_VERBOSE("data stream", { @@ -876,11 +868,11 @@ int main(int argc, char ** argv) { }); svr.Post("/tokenize", [&llama](const Request & req, Response & res) { - json body = json::parse(req.body); - std::string content = body["content"].get(); - std::vector tokens = ::llama_tokenize(llama.ctx, content, false); - json data = format_tokenizer_response(tokens); - return res.set_content(data.dump(llama.json_indent), "application/json"); + const json body = json::parse(req.body); + const std::string content = body["content"].get(); + const std::vector tokens = llama_tokenize(llama.ctx, content, false); + const json data = format_tokenizer_response(tokens); + return res.set_content(data.dump(), "application/json"); }); svr.set_logger(log_server_request); @@ -890,14 +882,14 @@ int main(int argc, char ** argv) { char buf[BUFSIZ]; try { std::rethrow_exception(std::move(ep)); - } catch (std::exception& e) { + } catch (std::exception & e) { snprintf(buf, sizeof(buf), fmt, e.what()); } catch (...) { snprintf(buf, sizeof(buf), fmt, "Unknown Exception"); } res.set_content(buf, "text/plain"); res.status = 500; - }); + }); // set timeouts and change hostname and port svr.set_read_timeout(sparams.read_timeout);