Stuff with logits
This commit is contained in:
parent
0bc047730f
commit
88cc7bb6f7
1 changed files with 12 additions and 4 deletions
|
@ -538,7 +538,10 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
|
|||
}
|
||||
|
||||
json format_generation_settings(llama_server_context &llama) {
|
||||
const bool ignore_eos = -INFINITY == llama.params.logit_bias[llama_token_eos()];
|
||||
const auto eos_bias = llama.params.logit_bias.find(llama_token_eos());
|
||||
const bool ignore_eos = eos_bias != llama.params.logit_bias.end() &&
|
||||
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
|
||||
|
||||
return json {
|
||||
{ "seed", llama.params.seed },
|
||||
{ "temp", llama.params.temp },
|
||||
|
@ -659,10 +662,15 @@ bool parse_options_completion(json body, llama_server_context& llama, Response &
|
|||
if (body["logit_bias"].is_array()) {
|
||||
int n_vocab = llama_n_vocab(llama.ctx);
|
||||
for (const auto &el : body["logit_bias"]) {
|
||||
if (el.is_array() && el.size() == 2 && el[0].is_number_integer() && el[1].is_number_float()) {
|
||||
if (el.is_array() && el.size() == 2 && el[0].is_number_integer()) {
|
||||
llama_token tok = el[0].get<llama_token>();
|
||||
if (tok < 0 || tok >= n_vocab) continue;
|
||||
llama.params.logit_bias[tok] = el[1].get<float>();
|
||||
if (tok >= 0 && tok < n_vocab) {
|
||||
if (el[1].is_number_float()) {
|
||||
llama.params.logit_bias[tok] = el[1].get<float>();
|
||||
} else if (el[1].is_boolean() && !el[1].get<bool>()) {
|
||||
llama.params.logit_bias[tok] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue