From 88cc7bb6f7f581ef00fd0c753881cb3227caabc0 Mon Sep 17 00:00:00 2001 From: Henri Vasserman Date: Fri, 2 Jun 2023 13:29:57 +0300 Subject: [PATCH] Stuff with logits --- examples/server/server.cpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index afa52a286..2021130d5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -538,7 +538,10 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para } json format_generation_settings(llama_server_context &llama) { - const bool ignore_eos = -INFINITY == llama.params.logit_bias[llama_token_eos()]; + const auto eos_bias = llama.params.logit_bias.find(llama_token_eos()); + const bool ignore_eos = eos_bias != llama.params.logit_bias.end() && + eos_bias->second < 0.0f && std::isinf(eos_bias->second); + return json { { "seed", llama.params.seed }, { "temp", llama.params.temp }, @@ -659,10 +662,15 @@ bool parse_options_completion(json body, llama_server_context& llama, Response & if (body["logit_bias"].is_array()) { int n_vocab = llama_n_vocab(llama.ctx); for (const auto &el : body["logit_bias"]) { - if (el.is_array() && el.size() == 2 && el[0].is_number_integer() && el[1].is_number_float()) { + if (el.is_array() && el.size() == 2 && el[0].is_number_integer()) { llama_token tok = el[0].get(); - if (tok < 0 || tok >= n_vocab) continue; - llama.params.logit_bias[tok] = el[1].get(); + if (tok >= 0 && tok < n_vocab) { + if (el[1].is_number_float()) { + llama.params.logit_bias[tok] = el[1].get(); + } else if (el[1].is_boolean() && !el[1].get()) { + llama.params.logit_bias[tok] = -INFINITY; + } + } } } }