From dda915cac4ad2116d87ea4296bdc4757a155dd4a Mon Sep 17 00:00:00 2001 From: digiwombat Date: Sun, 28 May 2023 08:43:38 -0400 Subject: [PATCH] Added capturing the stopping word and sending it along with the final JSON. Fixed an fprintf warning Fixed a bug that broke streaming Properly removed thread changing in json (only grabbed batch_size before) --- examples/server/server.cpp | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0286fcc5b..54be938fc 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -33,13 +33,14 @@ struct llama_server_context llama_context *ctx; gpt_params params; - bool reload_ctx = false; + std::string stopping_word = ""; void rewind() { as_loop = false; params.antiprompt.clear(); num_tokens_predicted = 0; generated_text = ""; + stopping_word = ""; //processed_tokens.clear(); embd_inp.clear(); @@ -233,6 +234,7 @@ struct llama_server_context } if (!embd.empty() && embd.back() == llama_token_eos()) { + stopping_word = llama_token_to_str(ctx, embd.back()); has_next_token = false; } @@ -258,6 +260,7 @@ struct llama_server_context size_t i = generated_text.find(word, generated_text.size() - (word.size() + token_text.size())); if (i != std::string::npos) { generated_text.erase(generated_text.begin() + i, generated_text.begin() + i + word.size()); + stopping_word = word; has_next_token = false; break; } @@ -313,7 +316,7 @@ void server_print_usage(int /*argc*/, char **argv, const gpt_params ¶ms, con fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); - fprintf(stderr, " --host ip address to listen (default (default: %d)\n", sparams.hostname); + fprintf(stderr, " --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port); fprintf(stderr, " -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); fprintf(stderr, "\n"); @@ -449,9 +452,9 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para } bool parse_options_completion(json body, llama_server_context& llama, Response &res) { - if (!body["threads"].is_null()) + if (!body["as_loop"].is_null()) { - llama.params.n_threads = body["threads"].get(); + llama.as_loop = body["as_loop"].get(); } if (!body["n_predict"].is_null()) { @@ -475,7 +478,7 @@ bool parse_options_completion(json body, llama_server_context& llama, Response & } if (!body["repeat_last_n"].is_null()) { - llama.params.repeat_last_n = body["repeat_last_n"].get(); + llama.params.repeat_last_n = body["repeat_last_n"].get(); } if (!body["temperature"].is_null()) { @@ -630,7 +633,8 @@ int main(int argc, char **argv) "penalize_nl", llama.params.penalize_nl } }, - {"prompt", llama.params.prompt} }; + {"prompt", llama.params.prompt}, + {"stopping_word", llama.stopping_word} }; return res.set_content(data.dump(), "application/json"); } catch (const json::exception &e) @@ -684,6 +688,7 @@ int main(int argc, char **argv) json data; if (llama.has_next_token) { + //fprintf(stdout, "Result: %s\n", result); final_text += result; data = { {"content", result }, @@ -715,6 +720,7 @@ int main(int argc, char **argv) } }, {"prompt", llama.params.prompt}, + {"stopping_word", llama.stopping_word}, {"generated_text", final_text} }; } @@ -735,7 +741,7 @@ int main(int argc, char **argv) { // Generation is done, send extra information. data = { - {"content", "\uFFFD" }, + {"content", u8"\uFFFD" }, {"stop", true }, {"tokens_predicted", llama.num_tokens_predicted}, {"seed", llama.params.seed}, @@ -756,6 +762,7 @@ int main(int argc, char **argv) } }, {"prompt", llama.params.prompt}, + {"stopping_word", llama.stopping_word}, {"generated_text", final_text} }; }