add eos_id_list to llama.cpp

This commit is contained in:
toyer 2024-06-24 12:27:02 +00:00
parent 4b65b648ce
commit 3a4d5790bf
13 changed files with 122 additions and 55 deletions

View file

@ -2417,14 +2417,21 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
}
}
const int n_eos = llama_n_eos(llama_get_model(lctx));
std::vector<int32_t> eos_tokens(n_eos, 0);
int32_t* eos_ptr = eos_tokens.data();
llama_token_eos(llama_get_model(lctx), eos_ptr);
if (params.ignore_eos) {
params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
for (int32_t i = 0; i < n_eos; ++i) {
params.sparams.logit_bias[eos_ptr[i]] = -INFINITY;
}
}
if (params.warmup) {
LOG("warming up the model with an empty run\n");
std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
std::vector<llama_token> tmp = { llama_token_bos(model) };
tmp.insert(tmp.end(), eos_tokens.begin(), eos_tokens.end());
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_kv_cache_clear(lctx);
llama_synchronize(lctx);
@ -3357,8 +3364,17 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
const int n_eos = llama_n_eos(llama_get_model(lctx));
std::vector<int32_t> eos_tokens(n_eos, 0);
int32_t* eos_ptr = eos_tokens.data();
llama_token_eos(llama_get_model(lctx), eos_ptr);
bool ignore_eos = false;
for (auto eos: eos_tokens) {
const auto logit_bias_eos = sparams.logit_bias.find(eos);
if (logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY) {
ignore_eos = true;
}
}
fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");
yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str());
@ -3371,7 +3387,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "logit_bias:\n");
for (std::pair<llama_token, float> lb : sparams.logit_bias) {
if (ignore_eos && lb.first == logit_bias_eos->first) {
if (ignore_eos && std::count(eos_tokens.begin(), eos_tokens.end(), lb.first)) {
continue;
}
fprintf(stream, " %d: %f", lb.first, lb.second);