Add logit_bias to the OpenAI api (#577)
* Add logit_bias to the OpenAI api * Cleanup and refactor, test in swagger. --------- Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
This commit is contained in:
parent
5006b23099
commit
e733a9e425
4 changed files with 70 additions and 5 deletions
6
expose.h
6
expose.h
|
@ -3,6 +3,7 @@
|
|||
const int stop_token_max = 16;
|
||||
const int ban_token_max = 16;
|
||||
const int tensor_split_max = 16;
|
||||
const int logit_bias_max = 16;
|
||||
// match kobold's sampler list and order
|
||||
enum samplers
|
||||
{
|
||||
|
@ -22,6 +23,10 @@ enum stop_reason
|
|||
EOS_TOKEN=1,
|
||||
CUSTOM_STOPPER=2,
|
||||
};
|
||||
struct logit_bias {
|
||||
int32_t token_id;
|
||||
float bias;
|
||||
};
|
||||
struct load_model_inputs
|
||||
{
|
||||
const int threads;
|
||||
|
@ -76,6 +81,7 @@ struct generation_inputs
|
|||
const char * grammar;
|
||||
const bool grammar_retain_state;
|
||||
const bool quiet = false;
|
||||
const logit_bias logit_biases[logit_bias_max];
|
||||
};
|
||||
struct generation_outputs
|
||||
{
|
||||
|
|
|
@ -101,6 +101,7 @@ static int stopper_unused_tokens = 0;
|
|||
static std::mutex concat_output_mtx;
|
||||
static std::string concat_output = "";
|
||||
static std::string concat_output_reader_copy = "";
|
||||
static std::vector<logit_bias> logit_biases;
|
||||
|
||||
const int extra_context_handle_fragmentation = 80;
|
||||
|
||||
|
@ -489,6 +490,12 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers
|
|||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
|
||||
for(int i=0;i<logit_biases.size();++i)
|
||||
{
|
||||
auto & itm = logit_biases[i];
|
||||
candidates[itm.token_id].logit += itm.bias;
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
if (grammar != nullptr) {
|
||||
|
@ -1437,6 +1444,17 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
}
|
||||
}
|
||||
|
||||
logit_biases.clear();
|
||||
for(int x=0;x<logit_bias_max;++x)
|
||||
{
|
||||
int32_t t_id = inputs.logit_biases[x].token_id;
|
||||
float bias = inputs.logit_biases[x].bias;
|
||||
if(t_id >= 0 && t_id < n_vocab && bias!=0)
|
||||
{
|
||||
logit_biases.push_back(inputs.logit_biases[x]);
|
||||
}
|
||||
}
|
||||
|
||||
std::string addedmemory = inputs.memory;
|
||||
params.prompt = inputs.prompt;
|
||||
params.seed = inputs.seed;
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
|
||||
<!-- schema -->
|
||||
<script>
|
||||
let spec = {
|
||||
let spec = {
|
||||
"components": {
|
||||
"schemas": {
|
||||
"BasicError": {
|
||||
|
@ -176,7 +176,17 @@
|
|||
"default": false,
|
||||
"description": "KoboldCpp ONLY. If true, also removes detected stop_sequences from the output and truncates all text after them. Does not work with SSE streaming.",
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"logit_bias": {
|
||||
"default": {},
|
||||
"description": "KoboldCpp ONLY. An dictionary of key-value pairs, which indicate the token IDs (int) and logit bias (float) to apply for that token. Up to 16 value can be provided.",
|
||||
"type": "object",
|
||||
"example": {
|
||||
"2": -20,
|
||||
"145": -1.4,
|
||||
"3105": 3.2
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"prompt"
|
||||
|
|
37
koboldcpp.py
37
koboldcpp.py
|
@ -18,6 +18,9 @@ sampler_order_max = 7
|
|||
stop_token_max = 16
|
||||
ban_token_max = 16
|
||||
tensor_split_max = 16
|
||||
logit_bias_max = 16
|
||||
bias_min_value = -100.0
|
||||
bias_max_value = 100.0
|
||||
|
||||
class load_model_inputs(ctypes.Structure):
|
||||
_fields_ = [("threads", ctypes.c_int),
|
||||
|
@ -44,6 +47,10 @@ class load_model_inputs(ctypes.Structure):
|
|||
("banned_tokens", ctypes.c_char_p * ban_token_max),
|
||||
("tensor_split", ctypes.c_float * tensor_split_max)]
|
||||
|
||||
class logit_bias(ctypes.Structure):
|
||||
_fields_ = [("token_id", ctypes.c_int32),
|
||||
("bias", ctypes.c_float)]
|
||||
|
||||
class generation_inputs(ctypes.Structure):
|
||||
_fields_ = [("seed", ctypes.c_int),
|
||||
("prompt", ctypes.c_char_p),
|
||||
|
@ -70,7 +77,8 @@ class generation_inputs(ctypes.Structure):
|
|||
("stream_sse", ctypes.c_bool),
|
||||
("grammar", ctypes.c_char_p),
|
||||
("grammar_retain_state", ctypes.c_bool),
|
||||
("quiet", ctypes.c_bool)]
|
||||
("quiet", ctypes.c_bool),
|
||||
("logit_biases", logit_bias * logit_bias_max)]
|
||||
|
||||
class generation_outputs(ctypes.Structure):
|
||||
_fields_ = [("status", ctypes.c_int),
|
||||
|
@ -301,7 +309,7 @@ def load_model(model_filename):
|
|||
ret = handle.load_model(inputs)
|
||||
return ret
|
||||
|
||||
def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False):
|
||||
def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, logit_biases={}):
|
||||
global maxctx, args, currentusergenkey, totalgens
|
||||
inputs = generation_inputs()
|
||||
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
|
||||
|
@ -355,6 +363,28 @@ def generate(prompt, memory="", max_length=32, max_context_length=512, temperatu
|
|||
inputs.stop_sequence[n] = "".encode("UTF-8")
|
||||
else:
|
||||
inputs.stop_sequence[n] = stop_sequence[n].encode("UTF-8")
|
||||
|
||||
bias_list = []
|
||||
try:
|
||||
if logit_biases and len(logit_biases) > 0:
|
||||
bias_list = [{"key": key, "value": value} for key, value in logit_biases.items()]
|
||||
except Exception as ex:
|
||||
print(f"Logit bias dictionary is invalid: {ex}")
|
||||
|
||||
for n in range(logit_bias_max):
|
||||
if n >= len(bias_list):
|
||||
inputs.logit_biases[n] = logit_bias(-1, 0.0)
|
||||
else:
|
||||
try:
|
||||
t_id = int(bias_list[n]['key'])
|
||||
bias = float(bias_list[n]['value'])
|
||||
t_id = -1 if t_id < 0 else t_id
|
||||
bias = (bias_max_value if bias > bias_max_value else (bias_min_value if bias < bias_min_value else bias))
|
||||
inputs.logit_biases[n] = logit_bias(t_id, bias)
|
||||
except Exception as ex:
|
||||
inputs.logit_biases[n] = logit_bias(-1, 0.0)
|
||||
print(f"Skipped unparsable logit bias:{ex}")
|
||||
|
||||
currentusergenkey = genkey
|
||||
totalgens += 1
|
||||
ret = handle.generate(inputs,outputs)
|
||||
|
@ -515,7 +545,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
grammar_retain_state = genparams.get('grammar_retain_state', False),
|
||||
genkey=genparams.get('genkey', ''),
|
||||
trimstop=genparams.get('trim_stop', False),
|
||||
quiet=is_quiet)
|
||||
quiet=is_quiet,
|
||||
logit_biases=genparams.get('logit_bias', {}))
|
||||
|
||||
recvtxt = ""
|
||||
if stream_flag:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue