Add logit_bias to the OpenAI api (#577)

* Add logit_bias to the OpenAI api

* Cleanup and refactor, test in swagger.

---------

Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
This commit is contained in:
DebuggingLife46 2023-12-26 21:56:19 +05:30 committed by GitHub
parent 5006b23099
commit e733a9e425
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 70 additions and 5 deletions

View file

@ -3,6 +3,7 @@
const int stop_token_max = 16; const int stop_token_max = 16;
const int ban_token_max = 16; const int ban_token_max = 16;
const int tensor_split_max = 16; const int tensor_split_max = 16;
const int logit_bias_max = 16;
// match kobold's sampler list and order // match kobold's sampler list and order
enum samplers enum samplers
{ {
@ -22,6 +23,10 @@ enum stop_reason
EOS_TOKEN=1, EOS_TOKEN=1,
CUSTOM_STOPPER=2, CUSTOM_STOPPER=2,
}; };
struct logit_bias {
int32_t token_id;
float bias;
};
struct load_model_inputs struct load_model_inputs
{ {
const int threads; const int threads;
@ -76,6 +81,7 @@ struct generation_inputs
const char * grammar; const char * grammar;
const bool grammar_retain_state; const bool grammar_retain_state;
const bool quiet = false; const bool quiet = false;
const logit_bias logit_biases[logit_bias_max];
}; };
struct generation_outputs struct generation_outputs
{ {

View file

@ -101,6 +101,7 @@ static int stopper_unused_tokens = 0;
static std::mutex concat_output_mtx; static std::mutex concat_output_mtx;
static std::string concat_output = ""; static std::string concat_output = "";
static std::string concat_output_reader_copy = ""; static std::string concat_output_reader_copy = "";
static std::vector<logit_bias> logit_biases;
const int extra_context_handle_fragmentation = 80; const int extra_context_handle_fragmentation = 80;
@ -489,6 +490,12 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
} }
for(int i=0;i<logit_biases.size();++i)
{
auto & itm = logit_biases[i];
candidates[itm.token_id].logit += itm.bias;
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
if (grammar != nullptr) { if (grammar != nullptr) {
@ -1437,6 +1444,17 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
} }
} }
logit_biases.clear();
for(int x=0;x<logit_bias_max;++x)
{
int32_t t_id = inputs.logit_biases[x].token_id;
float bias = inputs.logit_biases[x].bias;
if(t_id >= 0 && t_id < n_vocab && bias!=0)
{
logit_biases.push_back(inputs.logit_biases[x]);
}
}
std::string addedmemory = inputs.memory; std::string addedmemory = inputs.memory;
params.prompt = inputs.prompt; params.prompt = inputs.prompt;
params.seed = inputs.seed; params.seed = inputs.seed;

View file

@ -176,7 +176,17 @@
"default": false, "default": false,
"description": "KoboldCpp ONLY. If true, also removes detected stop_sequences from the output and truncates all text after them. Does not work with SSE streaming.", "description": "KoboldCpp ONLY. If true, also removes detected stop_sequences from the output and truncates all text after them. Does not work with SSE streaming.",
"type": "boolean" "type": "boolean"
} },
"logit_bias": {
"default": {},
"description": "KoboldCpp ONLY. An dictionary of key-value pairs, which indicate the token IDs (int) and logit bias (float) to apply for that token. Up to 16 value can be provided.",
"type": "object",
"example": {
"2": -20,
"145": -1.4,
"3105": 3.2
},
},
}, },
"required": [ "required": [
"prompt" "prompt"

View file

@ -18,6 +18,9 @@ sampler_order_max = 7
stop_token_max = 16 stop_token_max = 16
ban_token_max = 16 ban_token_max = 16
tensor_split_max = 16 tensor_split_max = 16
logit_bias_max = 16
bias_min_value = -100.0
bias_max_value = 100.0
class load_model_inputs(ctypes.Structure): class load_model_inputs(ctypes.Structure):
_fields_ = [("threads", ctypes.c_int), _fields_ = [("threads", ctypes.c_int),
@ -44,6 +47,10 @@ class load_model_inputs(ctypes.Structure):
("banned_tokens", ctypes.c_char_p * ban_token_max), ("banned_tokens", ctypes.c_char_p * ban_token_max),
("tensor_split", ctypes.c_float * tensor_split_max)] ("tensor_split", ctypes.c_float * tensor_split_max)]
class logit_bias(ctypes.Structure):
_fields_ = [("token_id", ctypes.c_int32),
("bias", ctypes.c_float)]
class generation_inputs(ctypes.Structure): class generation_inputs(ctypes.Structure):
_fields_ = [("seed", ctypes.c_int), _fields_ = [("seed", ctypes.c_int),
("prompt", ctypes.c_char_p), ("prompt", ctypes.c_char_p),
@ -70,7 +77,8 @@ class generation_inputs(ctypes.Structure):
("stream_sse", ctypes.c_bool), ("stream_sse", ctypes.c_bool),
("grammar", ctypes.c_char_p), ("grammar", ctypes.c_char_p),
("grammar_retain_state", ctypes.c_bool), ("grammar_retain_state", ctypes.c_bool),
("quiet", ctypes.c_bool)] ("quiet", ctypes.c_bool),
("logit_biases", logit_bias * logit_bias_max)]
class generation_outputs(ctypes.Structure): class generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int), _fields_ = [("status", ctypes.c_int),
@ -301,7 +309,7 @@ def load_model(model_filename):
ret = handle.load_model(inputs) ret = handle.load_model(inputs)
return ret return ret
def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False): def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, logit_biases={}):
global maxctx, args, currentusergenkey, totalgens global maxctx, args, currentusergenkey, totalgens
inputs = generation_inputs() inputs = generation_inputs()
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs)) outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
@ -355,6 +363,28 @@ def generate(prompt, memory="", max_length=32, max_context_length=512, temperatu
inputs.stop_sequence[n] = "".encode("UTF-8") inputs.stop_sequence[n] = "".encode("UTF-8")
else: else:
inputs.stop_sequence[n] = stop_sequence[n].encode("UTF-8") inputs.stop_sequence[n] = stop_sequence[n].encode("UTF-8")
bias_list = []
try:
if logit_biases and len(logit_biases) > 0:
bias_list = [{"key": key, "value": value} for key, value in logit_biases.items()]
except Exception as ex:
print(f"Logit bias dictionary is invalid: {ex}")
for n in range(logit_bias_max):
if n >= len(bias_list):
inputs.logit_biases[n] = logit_bias(-1, 0.0)
else:
try:
t_id = int(bias_list[n]['key'])
bias = float(bias_list[n]['value'])
t_id = -1 if t_id < 0 else t_id
bias = (bias_max_value if bias > bias_max_value else (bias_min_value if bias < bias_min_value else bias))
inputs.logit_biases[n] = logit_bias(t_id, bias)
except Exception as ex:
inputs.logit_biases[n] = logit_bias(-1, 0.0)
print(f"Skipped unparsable logit bias:{ex}")
currentusergenkey = genkey currentusergenkey = genkey
totalgens += 1 totalgens += 1
ret = handle.generate(inputs,outputs) ret = handle.generate(inputs,outputs)
@ -515,7 +545,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
grammar_retain_state = genparams.get('grammar_retain_state', False), grammar_retain_state = genparams.get('grammar_retain_state', False),
genkey=genparams.get('genkey', ''), genkey=genparams.get('genkey', ''),
trimstop=genparams.get('trim_stop', False), trimstop=genparams.get('trim_stop', False),
quiet=is_quiet) quiet=is_quiet,
logit_biases=genparams.get('logit_bias', {}))
recvtxt = "" recvtxt = ""
if stream_flag: if stream_flag: