From 13cf6929b7c979b83a755f529fb14434734ce58a Mon Sep 17 00:00:00 2001 From: Henri Vasserman Date: Mon, 12 Jun 2023 17:29:25 +0300 Subject: [PATCH] more json changes and stop info --- examples/server/server.cpp | 132 ++++++++++++++++++++++--------------- 1 file changed, 79 insertions(+), 53 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 24fffbc14..8c02dd977 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -105,6 +105,10 @@ struct llama_server_context { llama_context * ctx = nullptr; gpt_params params; + bool truncated = false; + bool stopped_eos = false; + bool stopped_word = false; + bool stopped_limit = false; std::string stopping_word; int json_indent = -1; @@ -122,6 +126,10 @@ struct llama_server_context { num_tokens_predicted = 0; generated_text = ""; generated_text.reserve(params.n_ctx); + truncated = false; + stopped_eos = false; + stopped_word = false; + stopped_limit = false; stopping_word = ""; multibyte_pending = 0; @@ -166,6 +174,7 @@ struct llama_server_context { { "new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()) }, }); + truncated = true; prompt_tokens = new_tokens; } else { const size_t ps = prompt_tokens.size(); @@ -207,14 +216,13 @@ struct llama_server_context { new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end()); embd = new_tokens; n_past = params.n_keep; - if (server_verbose) { - LOG_VERBOSE("input truncated", { - { "n_ctx", params.n_ctx }, - { "n_keep", params.n_keep }, - { "n_left", n_left }, - { "new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()) }, - }); - } + truncated = true; + LOG_VERBOSE("input truncated", { + { "n_ctx", params.n_ctx }, + { "n_keep", params.n_keep }, + { "n_left", n_left }, + { "new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend()) }, + }); } while (n_past < embd.size()) { @@ -314,8 +322,9 @@ struct llama_server_context { --n_remain; if (!embd.empty() && embd.back() == llama_token_eos()) { - stopping_word = llama_token_to_str(ctx, embd.back()); + //stopping_word = llama_token_to_str(ctx, embd.back()); has_next_token = false; + stopped_eos = true; LOG_VERBOSE("eos token found", {}); return result; } @@ -341,6 +350,7 @@ struct llama_server_context { (stop_pos == std::string::npos || pos < stop_pos)) { if (type == STOP_FULL) { stopping_word = word; + stopped_word = true; has_next_token = false; } stop_pos = pos; @@ -378,17 +388,22 @@ struct llama_server_context { n_remain++; } - if (server_verbose) { - LOG_VERBOSE("next token", { - { "token", token }, - { "token_text", llama_token_to_str(ctx, token) }, - { "has_next_token", has_next_token }, - { "n_remain", n_remain }, - { "num_tokens_predicted", num_tokens_predicted }, - { "stopping_word", stopping_word }, - }); + if (!has_next_token && n_remain == 0) { + stopped_limit = true; } + LOG_VERBOSE("next token", { + { "token", token }, + { "token_text", llama_token_to_str(ctx, token) }, + { "has_next_token", has_next_token }, + { "n_remain", n_remain }, + { "num_tokens_predicted", num_tokens_predicted }, + { "stopped_eos", stopped_eos }, + { "stopped_word", stopped_word }, + { "stopped_limit", stopped_limit }, + { "stopping_word", stopping_word }, + }); + return token_text; } }; @@ -578,7 +593,7 @@ void server_params_parse(int argc, char ** argv, server_params & sparams, } } -json format_generation_settings(llama_server_context & llama) { +static json format_generation_settings(llama_server_context & llama) { const auto eos_bias = llama.params.logit_bias.find(llama_token_eos()); const bool ignore_eos = eos_bias != llama.params.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); @@ -607,6 +622,35 @@ json format_generation_settings(llama_server_context & llama) { }; } +static json format_final_response(llama_server_context & llama, const std::string & content) { + return json { + { "content", content }, + { "stop", true }, + { "model", llama.params.model_alias }, + { "tokens_predicted", llama.num_tokens_predicted }, + { "generation_settings", format_generation_settings(llama) }, + { "prompt", llama.params.prompt }, + { "truncated", llama.truncated }, + { "stopped_eos", llama.stopped_eos }, + { "stopped_word", llama.stopped_word }, + { "stopped_limit", llama.stopped_limit }, + { "stopping_word", llama.stopping_word }, + }; +} + +static json format_partial_response(const std::string & content) { + return json { + { "content", content }, + { "stop", false }, + }; +} + +static json format_tokenizer_response(const std::vector & tokens) { + return json { + { "tokens", tokens } + }; +} + bool parse_options_completion(json body, llama_server_context & llama) { gpt_params default_params; @@ -663,6 +707,17 @@ bool parse_options_completion(json body, llama_server_context & llama) { return true; } +static void log_server_request(const Request & req, const Response & res) { + LOG_INFO("request", { + { "remote_addr", req.remote_addr }, + { "remote_port", req.remote_port }, + { "status", res.status }, + { "path", req.path }, + { "request", req.body }, + { "response", res.body }, + }); +} + int main(int argc, char ** argv) { // own arguments required by this example gpt_params params; @@ -739,15 +794,7 @@ int main(int argc, char ** argv) { llama.generated_text.end()); } - json data { - { "content", llama.generated_text }, - { "stop", true }, - { "model", llama.params.model_alias }, - { "tokens_predicted", llama.num_tokens_predicted }, - { "generation_settings", format_generation_settings(llama) }, - { "prompt", llama.params.prompt }, - { "stopping_word", llama.stopping_word }, - }; + json data = format_final_response(llama, llama.generated_text); llama_print_timings(llama.ctx); @@ -785,22 +832,10 @@ int main(int argc, char ** argv) { json data; if (llama.has_next_token) { - data = { - { "content", to_send }, - { "stop", false }, - }; + data = format_partial_response(to_send); } else { // Generation is done, send extra information. - data = { - { "content", to_send }, - { "stop", true }, - { "model", llama.params.model_alias }, - { "tokens_predicted", llama.num_tokens_predicted }, - { "generation_settings", format_generation_settings(llama) }, - { "prompt", llama.params.prompt }, - { "stopping_word", llama.stopping_word }, - { "generated_text", llama.generated_text }, - }; + data = format_final_response(llama, to_send); } std::string str = @@ -836,20 +871,11 @@ int main(int argc, char ** argv) { json body = json::parse(req.body); std::string content = body["content"].get(); std::vector tokens = ::llama_tokenize(llama.ctx, content, false); - json data {{ "tokens", tokens }}; + json data = format_tokenizer_response(tokens); return res.set_content(data.dump(llama.json_indent), "application/json"); }); - svr.set_logger([](const Request & req, const Response & res) { - LOG_INFO("request", { - { "remote_addr", req.remote_addr }, - { "remote_port", req.remote_port }, - { "status", res.status }, - { "path", req.path }, - { "request", req.body }, - { "response", res.body }, - }); - }); + svr.set_logger(log_server_request); svr.set_exception_handler([](const Request &, Response & res, std::exception_ptr ep) { const auto * fmt = "500 Internal Server Error\n%s";