From 8c6a5fc92bac786f0bb0737a2c98d96096a28ea1 Mon Sep 17 00:00:00 2001 From: Henri Vasserman Date: Thu, 1 Jun 2023 13:18:12 +0300 Subject: [PATCH 1/2] last tokens fixes --- examples/server/server.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 3e27a7bbb..fc24f9c13 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -88,7 +88,6 @@ struct llama_server_context n_remain = 0; n_past = 0; n_consumed = 0; - last_n_tokens.clear(); } bool loadModel(const gpt_params ¶ms_) @@ -120,7 +119,12 @@ struct llama_server_context const int n_left = (params.n_ctx - params.n_keep)/2; std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left, prompt_tokens.end()); + std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin()); prompt_tokens = new_tokens; + } else { + size_t ps = prompt_tokens.size(); + std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); + std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); } // compare the evaluated prompt with the new prompt @@ -251,10 +255,7 @@ struct llama_server_context id = llama_sample_token(ctx, &candidates_p); } } - if (!last_n_tokens.empty()) - { - last_n_tokens.erase(last_n_tokens.begin()); - } + last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); num_tokens_predicted++; } From 9531ae60dbd21c19260258cfd19e71fff18bcf7a Mon Sep 17 00:00:00 2001 From: Henri Vasserman Date: Thu, 1 Jun 2023 13:57:47 +0300 Subject: [PATCH 2/2] Add logit bias support --- examples/server/server.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index fc24f9c13..04a6af47a 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -655,6 +655,16 @@ bool parse_options_completion(json body, llama_server_context& llama, Response & } else { llama.params.logit_bias.erase(llama_token_eos()); } + if (body["logit_bias"].is_array()) { + int n_vocab = llama_n_vocab(llama.ctx); + for (const auto &el : body["logit_bias"]) { + if (el.is_array() && el.size() == 2 && el[0].is_number_integer() && el[1].is_number_float()) { + llama_token tok = el[0].get(); + if (tok < 0 || tok >= n_vocab) continue; + llama.params.logit_bias[tok] = el[1].get(); + } + } + } if (!body["prompt"].is_null()) { llama.params.prompt = body["prompt"].get(); } else {