diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 7f287e1c7..fd11798a2 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -533,6 +533,7 @@ json format_generation_settings(llama_server_context &llama) { const auto eos_bias = llama.params.logit_bias.find(llama_token_eos()); const bool ignore_eos = eos_bias != llama.params.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); + return json { { "seed", llama.params.seed }, { "temp", llama.params.temp }, @@ -653,10 +654,15 @@ bool parse_options_completion(json body, llama_server_context& llama, Response & if (body["logit_bias"].is_array()) { int n_vocab = llama_n_vocab(llama.ctx); for (const auto &el : body["logit_bias"]) { - if (el.is_array() && el.size() == 2 && el[0].is_number_integer() && el[1].is_number_float()) { + if (el.is_array() && el.size() == 2 && el[0].is_number_integer()) { llama_token tok = el[0].get(); - if (tok < 0 || tok >= n_vocab) continue; - llama.params.logit_bias[tok] = el[1].get(); + if (tok >= 0 && tok < n_vocab) { + if (el[1].is_number_float()) { + llama.params.logit_bias[tok] = el[1].get(); + } else if (el[1].is_boolean() && !el[1].get()) { + llama.params.logit_bias[tok] = -INFINITY; + } + } } } }