From e733a9e425d7824d8adad98e9e3501b6a0eff07b Mon Sep 17 00:00:00 2001 From: DebuggingLife46 <141757134+DebuggingLife46@users.noreply.github.com> Date: Tue, 26 Dec 2023 21:56:19 +0530 Subject: [PATCH] Add logit_bias to the OpenAI api (#577) * Add logit_bias to the OpenAI api * Cleanup and refactor, test in swagger. --------- Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com> --- expose.h | 6 ++++++ gpttype_adapter.cpp | 18 ++++++++++++++++++ kcpp_docs.embd | 14 ++++++++++++-- koboldcpp.py | 37 ++++++++++++++++++++++++++++++++++--- 4 files changed, 70 insertions(+), 5 deletions(-) diff --git a/expose.h b/expose.h index 40a4aff94..0c9df8bb6 100644 --- a/expose.h +++ b/expose.h @@ -3,6 +3,7 @@ const int stop_token_max = 16; const int ban_token_max = 16; const int tensor_split_max = 16; +const int logit_bias_max = 16; // match kobold's sampler list and order enum samplers { @@ -22,6 +23,10 @@ enum stop_reason EOS_TOKEN=1, CUSTOM_STOPPER=2, }; +struct logit_bias { + int32_t token_id; + float bias; +}; struct load_model_inputs { const int threads; @@ -76,6 +81,7 @@ struct generation_inputs const char * grammar; const bool grammar_retain_state; const bool quiet = false; + const logit_bias logit_biases[logit_bias_max]; }; struct generation_outputs { diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index d14376521..705f7958d 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -101,6 +101,7 @@ static int stopper_unused_tokens = 0; static std::mutex concat_output_mtx; static std::string concat_output = ""; static std::string concat_output_reader_copy = ""; +static std::vector logit_biases; const int extra_context_handle_fragmentation = 80; @@ -489,6 +490,12 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector= 0 && t_id < n_vocab && bias!=0) + { + logit_biases.push_back(inputs.logit_biases[x]); + } + } + std::string addedmemory = inputs.memory; params.prompt = inputs.prompt; params.seed = inputs.seed; diff --git a/kcpp_docs.embd b/kcpp_docs.embd index 5d1930b4c..1e5ca85d2 100644 --- a/kcpp_docs.embd +++ b/kcpp_docs.embd @@ -8,7 +8,7 @@