Add logit_bias to the OpenAI api (#577)
* Add logit_bias to the OpenAI api * Cleanup and refactor, test in swagger. --------- Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
This commit is contained in:
parent
5006b23099
commit
e733a9e425
4 changed files with 70 additions and 5 deletions
6
expose.h
6
expose.h
|
@ -3,6 +3,7 @@
|
||||||
const int stop_token_max = 16;
|
const int stop_token_max = 16;
|
||||||
const int ban_token_max = 16;
|
const int ban_token_max = 16;
|
||||||
const int tensor_split_max = 16;
|
const int tensor_split_max = 16;
|
||||||
|
const int logit_bias_max = 16;
|
||||||
// match kobold's sampler list and order
|
// match kobold's sampler list and order
|
||||||
enum samplers
|
enum samplers
|
||||||
{
|
{
|
||||||
|
@ -22,6 +23,10 @@ enum stop_reason
|
||||||
EOS_TOKEN=1,
|
EOS_TOKEN=1,
|
||||||
CUSTOM_STOPPER=2,
|
CUSTOM_STOPPER=2,
|
||||||
};
|
};
|
||||||
|
struct logit_bias {
|
||||||
|
int32_t token_id;
|
||||||
|
float bias;
|
||||||
|
};
|
||||||
struct load_model_inputs
|
struct load_model_inputs
|
||||||
{
|
{
|
||||||
const int threads;
|
const int threads;
|
||||||
|
@ -76,6 +81,7 @@ struct generation_inputs
|
||||||
const char * grammar;
|
const char * grammar;
|
||||||
const bool grammar_retain_state;
|
const bool grammar_retain_state;
|
||||||
const bool quiet = false;
|
const bool quiet = false;
|
||||||
|
const logit_bias logit_biases[logit_bias_max];
|
||||||
};
|
};
|
||||||
struct generation_outputs
|
struct generation_outputs
|
||||||
{
|
{
|
||||||
|
|
|
@ -101,6 +101,7 @@ static int stopper_unused_tokens = 0;
|
||||||
static std::mutex concat_output_mtx;
|
static std::mutex concat_output_mtx;
|
||||||
static std::string concat_output = "";
|
static std::string concat_output = "";
|
||||||
static std::string concat_output_reader_copy = "";
|
static std::string concat_output_reader_copy = "";
|
||||||
|
static std::vector<logit_bias> logit_biases;
|
||||||
|
|
||||||
const int extra_context_handle_fragmentation = 80;
|
const int extra_context_handle_fragmentation = 80;
|
||||||
|
|
||||||
|
@ -489,6 +490,12 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers
|
||||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for(int i=0;i<logit_biases.size();++i)
|
||||||
|
{
|
||||||
|
auto & itm = logit_biases[i];
|
||||||
|
candidates[itm.token_id].logit += itm.bias;
|
||||||
|
}
|
||||||
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||||
|
|
||||||
if (grammar != nullptr) {
|
if (grammar != nullptr) {
|
||||||
|
@ -1437,6 +1444,17 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logit_biases.clear();
|
||||||
|
for(int x=0;x<logit_bias_max;++x)
|
||||||
|
{
|
||||||
|
int32_t t_id = inputs.logit_biases[x].token_id;
|
||||||
|
float bias = inputs.logit_biases[x].bias;
|
||||||
|
if(t_id >= 0 && t_id < n_vocab && bias!=0)
|
||||||
|
{
|
||||||
|
logit_biases.push_back(inputs.logit_biases[x]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::string addedmemory = inputs.memory;
|
std::string addedmemory = inputs.memory;
|
||||||
params.prompt = inputs.prompt;
|
params.prompt = inputs.prompt;
|
||||||
params.seed = inputs.seed;
|
params.seed = inputs.seed;
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
|
|
||||||
<!-- schema -->
|
<!-- schema -->
|
||||||
<script>
|
<script>
|
||||||
let spec = {
|
let spec = {
|
||||||
"components": {
|
"components": {
|
||||||
"schemas": {
|
"schemas": {
|
||||||
"BasicError": {
|
"BasicError": {
|
||||||
|
@ -176,7 +176,17 @@
|
||||||
"default": false,
|
"default": false,
|
||||||
"description": "KoboldCpp ONLY. If true, also removes detected stop_sequences from the output and truncates all text after them. Does not work with SSE streaming.",
|
"description": "KoboldCpp ONLY. If true, also removes detected stop_sequences from the output and truncates all text after them. Does not work with SSE streaming.",
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
}
|
},
|
||||||
|
"logit_bias": {
|
||||||
|
"default": {},
|
||||||
|
"description": "KoboldCpp ONLY. An dictionary of key-value pairs, which indicate the token IDs (int) and logit bias (float) to apply for that token. Up to 16 value can be provided.",
|
||||||
|
"type": "object",
|
||||||
|
"example": {
|
||||||
|
"2": -20,
|
||||||
|
"145": -1.4,
|
||||||
|
"3105": 3.2
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"required": [
|
"required": [
|
||||||
"prompt"
|
"prompt"
|
||||||
|
|
37
koboldcpp.py
37
koboldcpp.py
|
@ -18,6 +18,9 @@ sampler_order_max = 7
|
||||||
stop_token_max = 16
|
stop_token_max = 16
|
||||||
ban_token_max = 16
|
ban_token_max = 16
|
||||||
tensor_split_max = 16
|
tensor_split_max = 16
|
||||||
|
logit_bias_max = 16
|
||||||
|
bias_min_value = -100.0
|
||||||
|
bias_max_value = 100.0
|
||||||
|
|
||||||
class load_model_inputs(ctypes.Structure):
|
class load_model_inputs(ctypes.Structure):
|
||||||
_fields_ = [("threads", ctypes.c_int),
|
_fields_ = [("threads", ctypes.c_int),
|
||||||
|
@ -44,6 +47,10 @@ class load_model_inputs(ctypes.Structure):
|
||||||
("banned_tokens", ctypes.c_char_p * ban_token_max),
|
("banned_tokens", ctypes.c_char_p * ban_token_max),
|
||||||
("tensor_split", ctypes.c_float * tensor_split_max)]
|
("tensor_split", ctypes.c_float * tensor_split_max)]
|
||||||
|
|
||||||
|
class logit_bias(ctypes.Structure):
|
||||||
|
_fields_ = [("token_id", ctypes.c_int32),
|
||||||
|
("bias", ctypes.c_float)]
|
||||||
|
|
||||||
class generation_inputs(ctypes.Structure):
|
class generation_inputs(ctypes.Structure):
|
||||||
_fields_ = [("seed", ctypes.c_int),
|
_fields_ = [("seed", ctypes.c_int),
|
||||||
("prompt", ctypes.c_char_p),
|
("prompt", ctypes.c_char_p),
|
||||||
|
@ -70,7 +77,8 @@ class generation_inputs(ctypes.Structure):
|
||||||
("stream_sse", ctypes.c_bool),
|
("stream_sse", ctypes.c_bool),
|
||||||
("grammar", ctypes.c_char_p),
|
("grammar", ctypes.c_char_p),
|
||||||
("grammar_retain_state", ctypes.c_bool),
|
("grammar_retain_state", ctypes.c_bool),
|
||||||
("quiet", ctypes.c_bool)]
|
("quiet", ctypes.c_bool),
|
||||||
|
("logit_biases", logit_bias * logit_bias_max)]
|
||||||
|
|
||||||
class generation_outputs(ctypes.Structure):
|
class generation_outputs(ctypes.Structure):
|
||||||
_fields_ = [("status", ctypes.c_int),
|
_fields_ = [("status", ctypes.c_int),
|
||||||
|
@ -301,7 +309,7 @@ def load_model(model_filename):
|
||||||
ret = handle.load_model(inputs)
|
ret = handle.load_model(inputs)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False):
|
def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, logit_biases={}):
|
||||||
global maxctx, args, currentusergenkey, totalgens
|
global maxctx, args, currentusergenkey, totalgens
|
||||||
inputs = generation_inputs()
|
inputs = generation_inputs()
|
||||||
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
|
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
|
||||||
|
@ -355,6 +363,28 @@ def generate(prompt, memory="", max_length=32, max_context_length=512, temperatu
|
||||||
inputs.stop_sequence[n] = "".encode("UTF-8")
|
inputs.stop_sequence[n] = "".encode("UTF-8")
|
||||||
else:
|
else:
|
||||||
inputs.stop_sequence[n] = stop_sequence[n].encode("UTF-8")
|
inputs.stop_sequence[n] = stop_sequence[n].encode("UTF-8")
|
||||||
|
|
||||||
|
bias_list = []
|
||||||
|
try:
|
||||||
|
if logit_biases and len(logit_biases) > 0:
|
||||||
|
bias_list = [{"key": key, "value": value} for key, value in logit_biases.items()]
|
||||||
|
except Exception as ex:
|
||||||
|
print(f"Logit bias dictionary is invalid: {ex}")
|
||||||
|
|
||||||
|
for n in range(logit_bias_max):
|
||||||
|
if n >= len(bias_list):
|
||||||
|
inputs.logit_biases[n] = logit_bias(-1, 0.0)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
t_id = int(bias_list[n]['key'])
|
||||||
|
bias = float(bias_list[n]['value'])
|
||||||
|
t_id = -1 if t_id < 0 else t_id
|
||||||
|
bias = (bias_max_value if bias > bias_max_value else (bias_min_value if bias < bias_min_value else bias))
|
||||||
|
inputs.logit_biases[n] = logit_bias(t_id, bias)
|
||||||
|
except Exception as ex:
|
||||||
|
inputs.logit_biases[n] = logit_bias(-1, 0.0)
|
||||||
|
print(f"Skipped unparsable logit bias:{ex}")
|
||||||
|
|
||||||
currentusergenkey = genkey
|
currentusergenkey = genkey
|
||||||
totalgens += 1
|
totalgens += 1
|
||||||
ret = handle.generate(inputs,outputs)
|
ret = handle.generate(inputs,outputs)
|
||||||
|
@ -515,7 +545,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
grammar_retain_state = genparams.get('grammar_retain_state', False),
|
grammar_retain_state = genparams.get('grammar_retain_state', False),
|
||||||
genkey=genparams.get('genkey', ''),
|
genkey=genparams.get('genkey', ''),
|
||||||
trimstop=genparams.get('trim_stop', False),
|
trimstop=genparams.get('trim_stop', False),
|
||||||
quiet=is_quiet)
|
quiet=is_quiet,
|
||||||
|
logit_biases=genparams.get('logit_bias', {}))
|
||||||
|
|
||||||
recvtxt = ""
|
recvtxt = ""
|
||||||
if stream_flag:
|
if stream_flag:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue