From a4149aa0c84d6489bd34983f5c578625ef373793 Mon Sep 17 00:00:00 2001 From: jwj7140 Date: Thu, 29 Jun 2023 02:04:37 +0900 Subject: [PATCH] change token count method --- examples/server/api_like_OAI.py | 23 ++++++++++++++++------- examples/server/server.cpp | 14 +++++++++----- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/examples/server/api_like_OAI.py b/examples/server/api_like_OAI.py index 5adcb1c5d..02a846320 100755 --- a/examples/server/api_like_OAI.py +++ b/examples/server/api_like_OAI.py @@ -81,13 +81,14 @@ def make_resData(data, chat=False, promptToken=[]): "object": "chat.completion" if (chat) else "text_completion", "created": int(time.time()), "model": "LLaMA_CPP", - "promptToken": promptToken, "usage": { - "prompt_tokens": len(promptToken), + "prompt_tokens": data["tokens_evaluated"], "completion_tokens": data["tokens_predicted"], - "total_tokens": len(promptToken) + data["tokens_predicted"] + "total_tokens": data["tokens_evaluated"] + data["tokens_predicted"] } } + if (len(promptToken) != 0): + resData["promptToken"] = promptToken if (chat): #only one choice is supported resData["choices"] = [{ @@ -146,11 +147,15 @@ def chat_completions(): return Response(status=403) body = request.get_json() stream = False + tokenize = False if(is_present(body, "stream")): stream = body["stream"] + if(is_present(body, "tokenize")): tokenize = body["tokenize"] postData = make_postData(body, chat=True, stream=stream) - tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json() - promptToken = tokenData["tokens"] + promptToken = [] + if (tokenize): + tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json() + promptToken = tokenData["tokens"] if (not stream): data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData)) @@ -176,11 +181,15 @@ def completion(): return Response(status=403) body = request.get_json() stream = False + tokenize = False if(is_present(body, "stream")): stream = body["stream"] + if(is_present(body, "tokenize")): tokenize = body["tokenize"] postData = make_postData(body, chat=False, stream=stream) - tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json() - promptToken = tokenData["tokens"] + promptToken = [] + if (tokenize): + tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json() + promptToken = tokenData["tokens"] if (not stream): data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData)) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 79df5e847..6ae2e4320 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -108,6 +108,7 @@ struct llama_server_context { bool has_next_token = false; std::string generated_text; + size_t num_prompt_tokens = 0; size_t num_tokens_predicted = 0; size_t n_past = 0; size_t n_remain = 0; @@ -139,6 +140,7 @@ struct llama_server_context { void rewind() { params.antiprompt.clear(); + num_prompt_tokens = 0; num_tokens_predicted = 0; generated_text = ""; generated_text.reserve(params.n_ctx); @@ -169,17 +171,18 @@ struct llama_server_context { void loadPrompt() { params.prompt.insert(0, 1, ' '); // always add a first space std::vector prompt_tokens = ::llama_tokenize(ctx, params.prompt, true); + num_prompt_tokens = prompt_tokens.size(); if (params.n_keep < 0) { - params.n_keep = (int)prompt_tokens.size(); + params.n_keep = (int)num_prompt_tokens; } params.n_keep = std::min(params.n_ctx - 4, params.n_keep); // if input prompt is too big, truncate like normal - if (prompt_tokens.size() >= (size_t)params.n_ctx) { + if (num_prompt_tokens>= (size_t)params.n_ctx) { const int n_left = (params.n_ctx - params.n_keep) / 2; std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); - const int erased_blocks = (prompt_tokens.size() - params.n_keep - n_left - 1) / n_left; + const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin()); @@ -193,7 +196,7 @@ struct llama_server_context { truncated = true; prompt_tokens = new_tokens; } else { - const size_t ps = prompt_tokens.size(); + const size_t ps = num_prompt_tokens; std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); } @@ -201,7 +204,7 @@ struct llama_server_context { // compare the evaluated prompt with the new prompt n_past = common_part(embd, prompt_tokens); embd = prompt_tokens; - if (n_past == prompt_tokens.size()) { + if (n_past == num_prompt_tokens) { // we have to evaluate at least 1 token to generate logits. n_past--; } @@ -684,6 +687,7 @@ static json format_final_response(llama_server_context & llama, const std::strin { "stop", true }, { "model", llama.params.model_alias }, { "tokens_predicted", llama.num_tokens_predicted }, + { "tokens_evaluated", llama.num_prompt_tokens }, { "generation_settings", format_generation_settings(llama) }, { "prompt", llama.params.prompt }, { "truncated", llama.truncated },