cont : simplify logit_bias + add ignore_eos flag
ggml-ci
This commit is contained in:
parent
6b7103cccd
commit
e08100c851
7 changed files with 38 additions and 27 deletions
|
@ -1013,7 +1013,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||
return true;
|
||||
}
|
||||
if (arg == "--ignore-eos") {
|
||||
params.ignore_eos = true;
|
||||
sparams.ignore_eos = true;
|
||||
return true;
|
||||
}
|
||||
if (arg == "--penalize-nl") {
|
||||
|
@ -1028,7 +1028,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||
std::string value_str;
|
||||
try {
|
||||
if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
|
||||
sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
|
||||
const float bias = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
|
||||
sparams.logit_bias.push_back({key, bias});
|
||||
}
|
||||
else {
|
||||
throw std::exception();
|
||||
|
@ -2133,8 +2134,9 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
|||
llama_lora_adapters_apply(lctx, iparams.lora_adapters);
|
||||
}
|
||||
|
||||
if (params.ignore_eos) {
|
||||
params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
|
||||
if (params.sparams.ignore_eos && llama_token_eos(model) == -1) {
|
||||
fprintf(stderr, "%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
||||
params.sparams.ignore_eos = false;
|
||||
}
|
||||
|
||||
if (params.warmup) {
|
||||
|
@ -3179,10 +3181,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
|||
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
|
||||
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
|
||||
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
|
||||
|
||||
const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
|
||||
const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
|
||||
fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");
|
||||
fprintf(stream, "ignore_eos: %s # default: false\n", sparams.ignore_eos ? "true" : "false");
|
||||
|
||||
yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str());
|
||||
fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false");
|
||||
|
@ -3193,11 +3192,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
|||
fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str());
|
||||
|
||||
fprintf(stream, "logit_bias:\n");
|
||||
for (std::pair<llama_token, float> lb : sparams.logit_bias) {
|
||||
if (ignore_eos && lb.first == logit_bias_eos->first) {
|
||||
continue;
|
||||
}
|
||||
fprintf(stream, " %d: %f", lb.first, lb.second);
|
||||
for (const auto & logit_bias : sparams.logit_bias) {
|
||||
fprintf(stream, " %d: %f", logit_bias.token, logit_bias.bias);
|
||||
}
|
||||
|
||||
fprintf(stream, "lora:\n");
|
||||
|
|
|
@ -172,7 +172,6 @@ struct gpt_params {
|
|||
bool flash_attn = false; // flash attention
|
||||
|
||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||
bool ignore_eos = false; // ignore generated EOS tokens
|
||||
bool logits_all = false; // return logits for all tokens in the batch
|
||||
bool use_mmap = true; // use mmap for faster loads
|
||||
bool use_mlock = false; // use mlock to keep model in memory
|
||||
|
|
|
@ -332,8 +332,12 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||
}
|
||||
|
||||
// apply params.logit_bias map
|
||||
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
||||
logits[it->first] += it->second;
|
||||
for (const auto & logit_bias : params.logit_bias) {
|
||||
logits[logit_bias.token] += logit_bias.bias;
|
||||
}
|
||||
|
||||
if (params.ignore_eos) {
|
||||
logits[llama_token_eos(llama_get_model(ctx_main))] = -INFINITY;
|
||||
}
|
||||
|
||||
if (ctx_cfg) {
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
#include "llama.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
// sampler types
|
||||
|
@ -37,6 +36,7 @@ typedef struct gpt_sampling_params {
|
|||
float mirostat_tau = 5.00f; // target entropy
|
||||
float mirostat_eta = 0.10f; // learning rate
|
||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||
bool ignore_eos = false;
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
|
||||
|
||||
std::vector<llama_sampler_type> samplers_sequence = {
|
||||
|
@ -55,7 +55,7 @@ typedef struct gpt_sampling_params {
|
|||
std::string cfg_negative_prompt; // string to help guidance
|
||||
float cfg_scale = 1.f; // how strong is guidance
|
||||
|
||||
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
|
||||
std::vector<llama_token> penalty_prompt_tokens;
|
||||
bool use_penalty_prompt_tokens = false;
|
||||
|
|
|
@ -1035,7 +1035,7 @@ struct server_context {
|
|||
slot.sparams.logit_bias.clear();
|
||||
|
||||
if (json_value(data, "ignore_eos", false) && has_eos_token) {
|
||||
slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
|
||||
slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
|
||||
}
|
||||
|
||||
const auto & logit_bias = data.find("logit_bias");
|
||||
|
@ -1056,12 +1056,12 @@ struct server_context {
|
|||
if (el[0].is_number_integer()) {
|
||||
llama_token tok = el[0].get<llama_token>();
|
||||
if (tok >= 0 && tok < n_vocab) {
|
||||
slot.sparams.logit_bias[tok] = bias;
|
||||
slot.sparams.logit_bias.push_back({tok, bias});
|
||||
}
|
||||
} else if (el[0].is_string()) {
|
||||
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
|
||||
for (auto tok : toks) {
|
||||
slot.sparams.logit_bias[tok] = bias;
|
||||
slot.sparams.logit_bias.push_back({tok, bias});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1323,9 +1323,6 @@ struct server_context {
|
|||
}
|
||||
|
||||
json get_formated_generation(const server_slot & slot) const {
|
||||
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
|
||||
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
|
||||
|
||||
std::vector<std::string> samplers_sequence;
|
||||
samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
|
||||
for (const auto & sampler_type : slot.sparams.samplers_sequence) {
|
||||
|
@ -1359,13 +1356,13 @@ struct server_context {
|
|||
{"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict
|
||||
{"n_keep", slot.params.n_keep},
|
||||
{"n_discard", slot.params.n_discard},
|
||||
{"ignore_eos", ignore_eos},
|
||||
{"ignore_eos", slot.sparams.ignore_eos},
|
||||
{"stream", slot.params.stream},
|
||||
{"logit_bias", slot.sparams.logit_bias},
|
||||
//{"logit_bias", slot.sparams.logit_bias},
|
||||
{"n_probs", slot.sparams.n_probs},
|
||||
{"min_keep", slot.sparams.min_keep},
|
||||
{"grammar", slot.sparams.grammar},
|
||||
{"samplers", samplers_sequence}
|
||||
{"samplers", samplers_sequence},
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -356,6 +356,11 @@ extern "C" {
|
|||
void * kv_overrides; // pointer to vector containing overrides
|
||||
} llama_model_quantize_params;
|
||||
|
||||
typedef struct llama_logit_bias {
|
||||
llama_token token;
|
||||
float bias;
|
||||
} llama_logit_bias;
|
||||
|
||||
// parameters for sampling the logits
|
||||
typedef struct llama_sampling_params {
|
||||
uint32_t seed; // the seed used to initialize llama_sampling_context
|
||||
|
@ -378,6 +383,12 @@ extern "C" {
|
|||
float mirostat_tau; // target entropy
|
||||
float mirostat_eta; // learning rate
|
||||
bool penalize_nl; // consider newlines as a repeatable token
|
||||
bool ignore_eos; // ignore the end-of-sequence token
|
||||
|
||||
const char * grammar;
|
||||
|
||||
int32_t n_logit_bias;
|
||||
const llama_logit_bias * logit_bias;
|
||||
} llama_sampling_params;
|
||||
|
||||
// performance timing information
|
||||
|
|
|
@ -16512,6 +16512,10 @@ struct llama_sampling_params llama_sampling_default_params() {
|
|||
/*.mirostat_tau =*/ 5.00f,
|
||||
/*.mirostat_eta =*/ 0.10f,
|
||||
/*.penalize_nl =*/ false,
|
||||
/*.ignore_eos =*/ false,
|
||||
/*.grammar =*/ nullptr,
|
||||
/*.n_logit_bias =*/ 0,
|
||||
/*.logit_bias =*/ nullptr,
|
||||
};
|
||||
|
||||
return result;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue