From e08100c8513241faefd9d2deb0d3a42540ffa667 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 10 Aug 2024 14:55:39 +0300 Subject: [PATCH] cont : simplify logit_bias + add ignore_eos flag ggml-ci --- common/common.cpp | 22 +++++++++------------- common/common.h | 1 - common/sampling.cpp | 8 ++++++-- common/sampling.h | 4 ++-- examples/server/server.cpp | 15 ++++++--------- include/llama.h | 11 +++++++++++ src/llama.cpp | 4 ++++ 7 files changed, 38 insertions(+), 27 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index c56246b8f..eb1f31e3d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1013,7 +1013,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--ignore-eos") { - params.ignore_eos = true; + sparams.ignore_eos = true; return true; } if (arg == "--penalize-nl") { @@ -1028,7 +1028,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa std::string value_str; try { if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { - sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); + const float bias = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); + sparams.logit_bias.push_back({key, bias}); } else { throw std::exception(); @@ -2133,8 +2134,9 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { llama_lora_adapters_apply(lctx, iparams.lora_adapters); } - if (params.ignore_eos) { - params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; + if (params.sparams.ignore_eos && llama_token_eos(model) == -1) { + fprintf(stderr, "%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__); + params.sparams.ignore_eos = false; } if (params.warmup) { @@ -3179,10 +3181,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); - - const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx))); - const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY; - fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false"); + fprintf(stream, "ignore_eos: %s # default: false\n", sparams.ignore_eos ? "true" : "false"); yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str()); fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false"); @@ -3193,11 +3192,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str()); fprintf(stream, "logit_bias:\n"); - for (std::pair lb : sparams.logit_bias) { - if (ignore_eos && lb.first == logit_bias_eos->first) { - continue; - } - fprintf(stream, " %d: %f", lb.first, lb.second); + for (const auto & logit_bias : sparams.logit_bias) { + fprintf(stream, " %d: %f", logit_bias.token, logit_bias.bias); } fprintf(stream, "lora:\n"); diff --git a/common/common.h b/common/common.h index acb4c95ed..d5ffecb4b 100644 --- a/common/common.h +++ b/common/common.h @@ -172,7 +172,6 @@ struct gpt_params { bool flash_attn = false; // flash attention bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix - bool ignore_eos = false; // ignore generated EOS tokens bool logits_all = false; // return logits for all tokens in the batch bool use_mmap = true; // use mmap for faster loads bool use_mlock = false; // use mlock to keep model in memory diff --git a/common/sampling.cpp b/common/sampling.cpp index 4d57392c8..0b07ad01b 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -332,8 +332,12 @@ static llama_token_data_array llama_sampling_prepare_impl( } // apply params.logit_bias map - for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { - logits[it->first] += it->second; + for (const auto & logit_bias : params.logit_bias) { + logits[logit_bias.token] += logit_bias.bias; + } + + if (params.ignore_eos) { + logits[llama_token_eos(llama_get_model(ctx_main))] = -INFINITY; } if (ctx_cfg) { diff --git a/common/sampling.h b/common/sampling.h index 59158ae85..f3ffa090f 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -3,7 +3,6 @@ #include "llama.h" #include -#include #include // sampler types @@ -37,6 +36,7 @@ typedef struct gpt_sampling_params { float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate bool penalize_nl = false; // consider newlines as a repeatable token + bool ignore_eos = false; uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context std::vector samplers_sequence = { @@ -55,7 +55,7 @@ typedef struct gpt_sampling_params { std::string cfg_negative_prompt; // string to help guidance float cfg_scale = 1.f; // how strong is guidance - std::unordered_map logit_bias; // logit bias for specific tokens + std::vector logit_bias; // logit biases to apply std::vector penalty_prompt_tokens; bool use_penalty_prompt_tokens = false; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 7899ccb75..18e66e6c8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1035,7 +1035,7 @@ struct server_context { slot.sparams.logit_bias.clear(); if (json_value(data, "ignore_eos", false) && has_eos_token) { - slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; + slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY}); } const auto & logit_bias = data.find("logit_bias"); @@ -1056,12 +1056,12 @@ struct server_context { if (el[0].is_number_integer()) { llama_token tok = el[0].get(); if (tok >= 0 && tok < n_vocab) { - slot.sparams.logit_bias[tok] = bias; + slot.sparams.logit_bias.push_back({tok, bias}); } } else if (el[0].is_string()) { auto toks = llama_tokenize(model, el[0].get(), false); for (auto tok : toks) { - slot.sparams.logit_bias[tok] = bias; + slot.sparams.logit_bias.push_back({tok, bias}); } } } @@ -1323,9 +1323,6 @@ struct server_context { } json get_formated_generation(const server_slot & slot) const { - const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); - const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); - std::vector samplers_sequence; samplers_sequence.reserve(slot.sparams.samplers_sequence.size()); for (const auto & sampler_type : slot.sparams.samplers_sequence) { @@ -1359,13 +1356,13 @@ struct server_context { {"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict {"n_keep", slot.params.n_keep}, {"n_discard", slot.params.n_discard}, - {"ignore_eos", ignore_eos}, + {"ignore_eos", slot.sparams.ignore_eos}, {"stream", slot.params.stream}, - {"logit_bias", slot.sparams.logit_bias}, + //{"logit_bias", slot.sparams.logit_bias}, {"n_probs", slot.sparams.n_probs}, {"min_keep", slot.sparams.min_keep}, {"grammar", slot.sparams.grammar}, - {"samplers", samplers_sequence} + {"samplers", samplers_sequence}, }; } diff --git a/include/llama.h b/include/llama.h index c80405f7f..2df6c6ca7 100644 --- a/include/llama.h +++ b/include/llama.h @@ -356,6 +356,11 @@ extern "C" { void * kv_overrides; // pointer to vector containing overrides } llama_model_quantize_params; + typedef struct llama_logit_bias { + llama_token token; + float bias; + } llama_logit_bias; + // parameters for sampling the logits typedef struct llama_sampling_params { uint32_t seed; // the seed used to initialize llama_sampling_context @@ -378,6 +383,12 @@ extern "C" { float mirostat_tau; // target entropy float mirostat_eta; // learning rate bool penalize_nl; // consider newlines as a repeatable token + bool ignore_eos; // ignore the end-of-sequence token + + const char * grammar; + + int32_t n_logit_bias; + const llama_logit_bias * logit_bias; } llama_sampling_params; // performance timing information diff --git a/src/llama.cpp b/src/llama.cpp index 0eb3aea71..e2e265228 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -16512,6 +16512,10 @@ struct llama_sampling_params llama_sampling_default_params() { /*.mirostat_tau =*/ 5.00f, /*.mirostat_eta =*/ 0.10f, /*.penalize_nl =*/ false, + /*.ignore_eos =*/ false, + /*.grammar =*/ nullptr, + /*.n_logit_bias =*/ 0, + /*.logit_bias =*/ nullptr, }; return result;