change token count method

This commit is contained in:
jwj7140 2023-06-29 02:04:37 +09:00
parent e1abf636a4
commit a4149aa0c8
2 changed files with 25 additions and 12 deletions

View file

@ -81,13 +81,14 @@ def make_resData(data, chat=False, promptToken=[]):
"object": "chat.completion" if (chat) else "text_completion", "object": "chat.completion" if (chat) else "text_completion",
"created": int(time.time()), "created": int(time.time()),
"model": "LLaMA_CPP", "model": "LLaMA_CPP",
"promptToken": promptToken,
"usage": { "usage": {
"prompt_tokens": len(promptToken), "prompt_tokens": data["tokens_evaluated"],
"completion_tokens": data["tokens_predicted"], "completion_tokens": data["tokens_predicted"],
"total_tokens": len(promptToken) + data["tokens_predicted"] "total_tokens": data["tokens_evaluated"] + data["tokens_predicted"]
} }
} }
if (len(promptToken) != 0):
resData["promptToken"] = promptToken
if (chat): if (chat):
#only one choice is supported #only one choice is supported
resData["choices"] = [{ resData["choices"] = [{
@ -146,11 +147,15 @@ def chat_completions():
return Response(status=403) return Response(status=403)
body = request.get_json() body = request.get_json()
stream = False stream = False
tokenize = False
if(is_present(body, "stream")): stream = body["stream"] if(is_present(body, "stream")): stream = body["stream"]
if(is_present(body, "tokenize")): tokenize = body["tokenize"]
postData = make_postData(body, chat=True, stream=stream) postData = make_postData(body, chat=True, stream=stream)
tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json() promptToken = []
promptToken = tokenData["tokens"] if (tokenize):
tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json()
promptToken = tokenData["tokens"]
if (not stream): if (not stream):
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData)) data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData))
@ -176,11 +181,15 @@ def completion():
return Response(status=403) return Response(status=403)
body = request.get_json() body = request.get_json()
stream = False stream = False
tokenize = False
if(is_present(body, "stream")): stream = body["stream"] if(is_present(body, "stream")): stream = body["stream"]
if(is_present(body, "tokenize")): tokenize = body["tokenize"]
postData = make_postData(body, chat=False, stream=stream) postData = make_postData(body, chat=False, stream=stream)
tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json() promptToken = []
promptToken = tokenData["tokens"] if (tokenize):
tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json()
promptToken = tokenData["tokens"]
if (not stream): if (not stream):
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData)) data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData))

View file

@ -108,6 +108,7 @@ struct llama_server_context {
bool has_next_token = false; bool has_next_token = false;
std::string generated_text; std::string generated_text;
size_t num_prompt_tokens = 0;
size_t num_tokens_predicted = 0; size_t num_tokens_predicted = 0;
size_t n_past = 0; size_t n_past = 0;
size_t n_remain = 0; size_t n_remain = 0;
@ -139,6 +140,7 @@ struct llama_server_context {
void rewind() { void rewind() {
params.antiprompt.clear(); params.antiprompt.clear();
num_prompt_tokens = 0;
num_tokens_predicted = 0; num_tokens_predicted = 0;
generated_text = ""; generated_text = "";
generated_text.reserve(params.n_ctx); generated_text.reserve(params.n_ctx);
@ -169,17 +171,18 @@ struct llama_server_context {
void loadPrompt() { void loadPrompt() {
params.prompt.insert(0, 1, ' '); // always add a first space params.prompt.insert(0, 1, ' '); // always add a first space
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true); std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
num_prompt_tokens = prompt_tokens.size();
if (params.n_keep < 0) { if (params.n_keep < 0) {
params.n_keep = (int)prompt_tokens.size(); params.n_keep = (int)num_prompt_tokens;
} }
params.n_keep = std::min(params.n_ctx - 4, params.n_keep); params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
// if input prompt is too big, truncate like normal // if input prompt is too big, truncate like normal
if (prompt_tokens.size() >= (size_t)params.n_ctx) { if (num_prompt_tokens>= (size_t)params.n_ctx) {
const int n_left = (params.n_ctx - params.n_keep) / 2; const int n_left = (params.n_ctx - params.n_keep) / 2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (prompt_tokens.size() - params.n_keep - n_left - 1) / n_left; const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin()); std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
@ -193,7 +196,7 @@ struct llama_server_context {
truncated = true; truncated = true;
prompt_tokens = new_tokens; prompt_tokens = new_tokens;
} else { } else {
const size_t ps = prompt_tokens.size(); const size_t ps = num_prompt_tokens;
std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
} }
@ -201,7 +204,7 @@ struct llama_server_context {
// compare the evaluated prompt with the new prompt // compare the evaluated prompt with the new prompt
n_past = common_part(embd, prompt_tokens); n_past = common_part(embd, prompt_tokens);
embd = prompt_tokens; embd = prompt_tokens;
if (n_past == prompt_tokens.size()) { if (n_past == num_prompt_tokens) {
// we have to evaluate at least 1 token to generate logits. // we have to evaluate at least 1 token to generate logits.
n_past--; n_past--;
} }
@ -684,6 +687,7 @@ static json format_final_response(llama_server_context & llama, const std::strin
{ "stop", true }, { "stop", true },
{ "model", llama.params.model_alias }, { "model", llama.params.model_alias },
{ "tokens_predicted", llama.num_tokens_predicted }, { "tokens_predicted", llama.num_tokens_predicted },
{ "tokens_evaluated", llama.num_prompt_tokens },
{ "generation_settings", format_generation_settings(llama) }, { "generation_settings", format_generation_settings(llama) },
{ "prompt", llama.params.prompt }, { "prompt", llama.params.prompt },
{ "truncated", llama.truncated }, { "truncated", llama.truncated },