cont : simplify logit_bias + add ignore_eos flag

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-08-10 14:55:39 +03:00
parent 6b7103cccd
commit e08100c851
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
7 changed files with 38 additions and 27 deletions

View file

@ -1013,7 +1013,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
return true;
}
if (arg == "--ignore-eos") {
params.ignore_eos = true;
sparams.ignore_eos = true;
return true;
}
if (arg == "--penalize-nl") {
@ -1028,7 +1028,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
std::string value_str;
try {
if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
const float bias = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
sparams.logit_bias.push_back({key, bias});
}
else {
throw std::exception();
@ -2133,8 +2134,9 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
llama_lora_adapters_apply(lctx, iparams.lora_adapters);
}
if (params.ignore_eos) {
params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
if (params.sparams.ignore_eos && llama_token_eos(model) == -1) {
fprintf(stderr, "%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sparams.ignore_eos = false;
}
if (params.warmup) {
@ -3179,10 +3181,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");
fprintf(stream, "ignore_eos: %s # default: false\n", sparams.ignore_eos ? "true" : "false");
yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str());
fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false");
@ -3193,11 +3192,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str());
fprintf(stream, "logit_bias:\n");
for (std::pair<llama_token, float> lb : sparams.logit_bias) {
if (ignore_eos && lb.first == logit_bias_eos->first) {
continue;
}
fprintf(stream, " %d: %f", lb.first, lb.second);
for (const auto & logit_bias : sparams.logit_bias) {
fprintf(stream, " %d: %f", logit_bias.token, logit_bias.bias);
}
fprintf(stream, "lora:\n");

View file

@ -172,7 +172,6 @@ struct gpt_params {
bool flash_attn = false; // flash attention
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool ignore_eos = false; // ignore generated EOS tokens
bool logits_all = false; // return logits for all tokens in the batch
bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory

View file

@ -332,8 +332,12 @@ static llama_token_data_array llama_sampling_prepare_impl(
}
// apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
for (const auto & logit_bias : params.logit_bias) {
logits[logit_bias.token] += logit_bias.bias;
}
if (params.ignore_eos) {
logits[llama_token_eos(llama_get_model(ctx_main))] = -INFINITY;
}
if (ctx_cfg) {

View file

@ -3,7 +3,6 @@
#include "llama.h"
#include <string>
#include <unordered_map>
#include <vector>
// sampler types
@ -37,6 +36,7 @@ typedef struct gpt_sampling_params {
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = false; // consider newlines as a repeatable token
bool ignore_eos = false;
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
std::vector<llama_sampler_type> samplers_sequence = {
@ -55,7 +55,7 @@ typedef struct gpt_sampling_params {
std::string cfg_negative_prompt; // string to help guidance
float cfg_scale = 1.f; // how strong is guidance
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_token> penalty_prompt_tokens;
bool use_penalty_prompt_tokens = false;

View file

@ -1035,7 +1035,7 @@ struct server_context {
slot.sparams.logit_bias.clear();
if (json_value(data, "ignore_eos", false) && has_eos_token) {
slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
}
const auto & logit_bias = data.find("logit_bias");
@ -1056,12 +1056,12 @@ struct server_context {
if (el[0].is_number_integer()) {
llama_token tok = el[0].get<llama_token>();
if (tok >= 0 && tok < n_vocab) {
slot.sparams.logit_bias[tok] = bias;
slot.sparams.logit_bias.push_back({tok, bias});
}
} else if (el[0].is_string()) {
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
for (auto tok : toks) {
slot.sparams.logit_bias[tok] = bias;
slot.sparams.logit_bias.push_back({tok, bias});
}
}
}
@ -1323,9 +1323,6 @@ struct server_context {
}
json get_formated_generation(const server_slot & slot) const {
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
std::vector<std::string> samplers_sequence;
samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
for (const auto & sampler_type : slot.sparams.samplers_sequence) {
@ -1359,13 +1356,13 @@ struct server_context {
{"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict
{"n_keep", slot.params.n_keep},
{"n_discard", slot.params.n_discard},
{"ignore_eos", ignore_eos},
{"ignore_eos", slot.sparams.ignore_eos},
{"stream", slot.params.stream},
{"logit_bias", slot.sparams.logit_bias},
//{"logit_bias", slot.sparams.logit_bias},
{"n_probs", slot.sparams.n_probs},
{"min_keep", slot.sparams.min_keep},
{"grammar", slot.sparams.grammar},
{"samplers", samplers_sequence}
{"samplers", samplers_sequence},
};
}

View file

@ -356,6 +356,11 @@ extern "C" {
void * kv_overrides; // pointer to vector containing overrides
} llama_model_quantize_params;
typedef struct llama_logit_bias {
llama_token token;
float bias;
} llama_logit_bias;
// parameters for sampling the logits
typedef struct llama_sampling_params {
uint32_t seed; // the seed used to initialize llama_sampling_context
@ -378,6 +383,12 @@ extern "C" {
float mirostat_tau; // target entropy
float mirostat_eta; // learning rate
bool penalize_nl; // consider newlines as a repeatable token
bool ignore_eos; // ignore the end-of-sequence token
const char * grammar;
int32_t n_logit_bias;
const llama_logit_bias * logit_bias;
} llama_sampling_params;
// performance timing information

View file

@ -16512,6 +16512,10 @@ struct llama_sampling_params llama_sampling_default_params() {
/*.mirostat_tau =*/ 5.00f,
/*.mirostat_eta =*/ 0.10f,
/*.penalize_nl =*/ false,
/*.ignore_eos =*/ false,
/*.grammar =*/ nullptr,
/*.n_logit_bias =*/ 0,
/*.logit_bias =*/ nullptr,
};
return result;