From 733b566bacbee9758d32706c8ac0ebc2f2e104fa Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Fri, 19 May 2023 15:11:14 -0600 Subject: [PATCH] some corrections and added as cmake option --- CMakeLists.txt | 1 + examples/CMakeLists.txt | 4 +- examples/server/README.md | 14 ++- examples/server/server.cpp | 222 +++++++++++++++++++++---------------- 4 files changed, 141 insertions(+), 100 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 48e3238df..7cfff27e5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -71,6 +71,7 @@ option(LLAMA_CLBLAST "llama: use CLBlast" option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) +option(LLAMA_BUILD_SERVER "llama: build server example" OFF) # # Build info header diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e607e6ad6..e4ce5aca7 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -36,6 +36,8 @@ else() add_subdirectory(embedding) add_subdirectory(save-load-state) add_subdirectory(benchmark) - add_subdirectory(server) add_subdirectory(baby-llama) + if(LLAMA_BUILD_SERVER) + add_subdirectory(server) + endif() endif() diff --git a/examples/server/README.md b/examples/server/README.md index c4deea8cc..089e8908c 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -73,7 +73,8 @@ You can interact with this API Endpoints. This implementations just support chat - **POST** `hostname:port/completion`: Setting up the Llama Context to begin the completions tasks. -Options: +*Options:* + `batch_size`: Set the batch size for prompt processing (default: 512). `temperature`: Adjust the randomness of the generated text (default: 0.8). @@ -100,6 +101,8 @@ Options: - **POST** `hostname:port/embedding`: Generate embedding of a given text +*Options:* + `content`: Set the text to get generate the embedding. `threads`: Set the number of threads to use during computation. @@ -108,10 +111,16 @@ To use this endpoint, you need to start the server with the `--embedding` option - **POST** `hostname:port/tokenize`: Tokenize a given text +*Options:* + `content`: Set the text to tokenize. - **GET** `hostname:port/next-token`: Receive the next token predicted, execute this request in a loop. Make sure set `as_loop` as `true` in the completion request. +*Options:* + +`stop`: Set `hostname:port/next-token?stop=true` to stop the token generation. + ## More examples ### Interactive mode @@ -155,6 +164,7 @@ async function ChatCompletion(answer) { let message = ""; while (true) { + // you can stop the inference adding '?stop=true' like this http://127.0.0.1:8080/next-token?stop=true result = await axios.get("http://127.0.0.1:8080/next-token"); process.stdout.write(result.data.content); message += result.data.content; @@ -226,7 +236,7 @@ async function DoInstruction(instruction) { } // This function should be called every time a instruction to the model is needed. -DoInstruction("Destroy the world"); +DoInstruction("Destroy the world"); // as joke ``` ### Embeddings diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 754131c1b..7209a2b52 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -11,11 +11,11 @@ struct server_params struct llama_server_context { - bool context_config = false; + bool as_loop = false; bool has_next_token = false; - bool is_interacting = false; + std::string generated_text = ""; - int32_t tokens_completion = 0; + int32_t num_tokens_predicted = 0; int32_t n_past = 0; int32_t n_consumed = 0; int32_t n_session_consumed = 0; @@ -27,10 +27,19 @@ struct llama_server_context std::vector llama_token_newline; std::vector embd_inp; std::vector> no_show_words; + std::vector tokens_predicted; llama_context *ctx; gpt_params params; + void rewind() { + as_loop = false; + params.antiprompt.clear(); + no_show_words.clear(); + num_tokens_predicted = 0; + generated_text = ""; + } + bool loadModel(gpt_params params_) { params = params_; @@ -123,7 +132,7 @@ struct llama_server_context } } embd.clear(); - if ((int)embd_inp.size() <= n_consumed && !is_interacting) + if ((int)embd_inp.size() <= n_consumed && has_next_token) { // out of user input, sample next token const float temp = params.temp; @@ -206,6 +215,7 @@ struct llama_server_context last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); processed_tokens.push_back(id); + num_tokens_predicted++; } // replace end of text token with newline token when in interactive mode @@ -225,7 +235,6 @@ struct llama_server_context for (auto id : embd) { result = id; - tokens_completion++; } // decrement remaining sampling budget --n_remain; @@ -262,7 +271,6 @@ struct llama_server_context { if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { - is_interacting = true; has_next_token = false; return result; } @@ -270,7 +278,7 @@ struct llama_server_context } if (n_past > 0) { - is_interacting = false; + has_next_token = true; } } @@ -281,35 +289,35 @@ struct llama_server_context if (params.interactive && n_remain <= 0 && params.n_predict != -1) { n_remain = params.n_predict; - is_interacting = true; } has_next_token = n_remain != 0; return result; } - std::string inference() + std::string doCompletion() { llama_token token = nextToken(); if (token == -1) { return ""; } - std::vector tokens_completion; - tokens_completion.push_back(token); + tokens_predicted.clear(); + tokens_predicted.push_back(token); + // Avoid add the no show words to the response for (std::vector word_tokens : no_show_words) { int match_token = 1; - if (tokens_completion[0] == word_tokens[0]) + if (tokens_predicted.front() == word_tokens.front()) { bool execute_matching = true; - if (tokens_completion.size() > 1) { // if previus tokens had been tested + if (tokens_predicted.size() > 1) { // if previus tokens had been tested for (int i = 1; i < word_tokens.size(); i++) { - if (i >= tokens_completion.size()) { + if (i >= tokens_predicted.size()) { match_token = i; break; } - if (tokens_completion[i] == word_tokens[i]) + if (tokens_predicted[i] == word_tokens[i]) { continue; } @@ -325,24 +333,26 @@ struct llama_server_context return ""; } token = nextToken(); - tokens_completion.push_back(token); + tokens_predicted.push_back(token); if (token == word_tokens[match_token]) { // the token follow the sequence match_token++; } else if (match_token < word_tokens.size()) - { // no complete all user tag + { // no complete all word sequence break; } } } } - std::string result = ""; - for (llama_token tkn : tokens_completion) - { - result += llama_token_to_str(ctx, tkn); + if(as_loop) { + generated_text = ""; } - return result; + for (llama_token tkn : tokens_predicted) + { + generated_text += llama_token_to_str(ctx, tkn); + } + return generated_text; } std::vector embedding(std::string content, int threads) { @@ -491,6 +501,76 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para return true; } +bool parse_options_completion(json body, llama_server_context& llama, Response &res) { + if (!body["threads"].is_null()) + { + llama.params.n_threads = body["threads"].get(); + } + if (!body["n_predict"].is_null()) + { + llama.params.n_predict = body["n_predict"].get(); + } + if (!body["top_k"].is_null()) + { + llama.params.top_k = body["top_k"].get(); + } + if (!body["top_p"].is_null()) + { + llama.params.top_p = body["top_p"].get(); + } + if (!body["temperature"].is_null()) + { + llama.params.temp = body["temperature"].get(); + } + if (!body["batch_size"].is_null()) + { + llama.params.n_batch = body["batch_size"].get(); + } + if (!body["n_keep"].is_null()) + { + llama.params.n_keep = body["n_keep"].get(); + } + if (!body["as_loop"].is_null()) + { + llama.as_loop = body["as_loop"].get(); + } + if (!body["interactive"].is_null()) + { + llama.params.interactive = body["interactive"].get(); + } + if (!body["prompt"].is_null()) + { + llama.params.prompt = body["prompt"].get(); + } + else + { + json data = { + {"status", "error"}, + {"reason", "You need to pass the prompt"}}; + res.set_content(data.dump(), "application/json"); + res.status = 400; + return false; + } + if (!body["stop"].is_null()) + { + std::vector stop_words = body["stop"].get>(); + for (std::string stop_word : stop_words) + { + llama.params.antiprompt.push_back(stop_word); + llama.no_show_words.push_back(::llama_tokenize(llama.ctx, stop_word, false)); + } + } + if (!body["exclude"].is_null()) + { + std::vector no_show_words = body["exclude"].get>(); + for (std::string no_show : no_show_words) + { + llama.no_show_words.push_back(::llama_tokenize(llama.ctx, no_show, false)); + } + } + return true; +} + int main(int argc, char **argv) { // own arguments required by this example @@ -535,73 +615,12 @@ int main(int argc, char **argv) return; } - json body = json::parse(req.body); - llama.params.antiprompt.clear(); - llama.no_show_words.clear(); - bool as_loop = false; + llama.rewind(); - if (!body["threads"].is_null()) - { - llama.params.n_threads = body["threads"].get(); - } - if (!body["n_predict"].is_null()) - { - llama.params.n_predict = body["n_predict"].get(); - } - if (!body["top_k"].is_null()) - { - llama.params.top_k = body["top_k"].get(); - } - if (!body["top_p"].is_null()) - { - llama.params.top_p = body["top_p"].get(); - } - if (!body["temperature"].is_null()) - { - llama.params.temp = body["temperature"].get(); - } - if (!body["batch_size"].is_null()) - { - llama.params.n_batch = body["batch_size"].get(); - } - if (!body["n_keep"].is_null()) - { - llama.params.n_keep = body["n_keep"].get(); - } - if (!body["as_loop"].is_null()) - { - as_loop = body["as_loop"].get(); - } - if (!body["interactive"].is_null()) - { - llama.params.interactive = body["interactive"].get(); - } - if (!body["prompt"].is_null()) - { - llama.params.prompt = body["prompt"].get(); - } - else - { - json data = { - {"status", "error"}, - {"reason", "You need to pass the prompt"}}; - res.set_content(data.dump(), "application/json"); - res.status = 400; + if(parse_options_completion(json::parse(req.body), llama, res) == false){ return; } - if (!body["stop"].is_null()) { - std::vector stop_words = body["stop"].get>(); - for (std::string stop_word : stop_words) { - llama.params.antiprompt.push_back(stop_word); - llama.no_show_words.push_back(::llama_tokenize(llama.ctx, stop_word, false)); - } - } - if (!body["exclude"].is_null()) { - std::vector no_show_words = body["exclude"].get>(); - for (std::string no_show : no_show_words) { - llama.no_show_words.push_back(::llama_tokenize(llama.ctx, no_show, false)); - } - } + if (!llama.loadPrompt()) { json data = { @@ -611,23 +630,33 @@ int main(int argc, char **argv) res.status = 400; return; } + llama.beginCompletion(); - llama.tokens_completion = 0; - if(as_loop) { + if(llama.as_loop) { json data = { {"status", "done" } }; return res.set_content(data.dump(), "application/json"); } else { - // Send all completion when finish - std::string completion = ""; + // loop inference until finish completion while (llama.has_next_token) { - completion += llama.inference(); + llama.doCompletion(); + } + try + { + json data = { + {"content", llama.generated_text }, + {"tokens_predicted", llama.num_tokens_predicted}}; + return res.set_content(data.dump(), "application/json"); + } + catch (json::exception e) + { + // Some tokens have bad UTF-8 strings, the json parser is very sensitive + json data = { + {"content", "Bad encoding token"}, + {"tokens_predicted", 0}}; + return res.set_content(data.dump(), "application/json"); } - json data = { - {"content", completion.c_str()}, - {"total_tokens", llama.tokens_completion}}; - return res.set_content(data.dump(), "application/json"); } }); svr.Post("/tokenize", [&llama](const Request &req, Response &res) @@ -664,9 +693,8 @@ int main(int argc, char **argv) std::string result = ""; if (req.has_param("stop")) { llama.has_next_token = false; - llama.is_interacting = true; } else { - result = llama.inference(); + result = llama.doCompletion(); // inference next token } try { json data = {