add eos_id_list to llama.cpp
This commit is contained in:
parent
4b65b648ce
commit
3a4d5790bf
13 changed files with 122 additions and 55 deletions
|
@ -2417,14 +2417,21 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
|||
}
|
||||
}
|
||||
|
||||
const int n_eos = llama_n_eos(llama_get_model(lctx));
|
||||
std::vector<int32_t> eos_tokens(n_eos, 0);
|
||||
int32_t* eos_ptr = eos_tokens.data();
|
||||
llama_token_eos(llama_get_model(lctx), eos_ptr);
|
||||
if (params.ignore_eos) {
|
||||
params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
|
||||
for (int32_t i = 0; i < n_eos; ++i) {
|
||||
params.sparams.logit_bias[eos_ptr[i]] = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
if (params.warmup) {
|
||||
LOG("warming up the model with an empty run\n");
|
||||
|
||||
std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
|
||||
std::vector<llama_token> tmp = { llama_token_bos(model) };
|
||||
tmp.insert(tmp.end(), eos_tokens.begin(), eos_tokens.end());
|
||||
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
|
||||
llama_kv_cache_clear(lctx);
|
||||
llama_synchronize(lctx);
|
||||
|
@ -3357,8 +3364,17 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
|||
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
|
||||
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
|
||||
|
||||
const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
|
||||
const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
|
||||
const int n_eos = llama_n_eos(llama_get_model(lctx));
|
||||
std::vector<int32_t> eos_tokens(n_eos, 0);
|
||||
int32_t* eos_ptr = eos_tokens.data();
|
||||
llama_token_eos(llama_get_model(lctx), eos_ptr);
|
||||
bool ignore_eos = false;
|
||||
for (auto eos: eos_tokens) {
|
||||
const auto logit_bias_eos = sparams.logit_bias.find(eos);
|
||||
if (logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY) {
|
||||
ignore_eos = true;
|
||||
}
|
||||
}
|
||||
fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");
|
||||
|
||||
yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str());
|
||||
|
@ -3371,7 +3387,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
|||
|
||||
fprintf(stream, "logit_bias:\n");
|
||||
for (std::pair<llama_token, float> lb : sparams.logit_bias) {
|
||||
if (ignore_eos && lb.first == logit_bias_eos->first) {
|
||||
if (ignore_eos && std::count(eos_tokens.begin(), eos_tokens.end(), lb.first)) {
|
||||
continue;
|
||||
}
|
||||
fprintf(stream, " %d: %f", lb.first, lb.second);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue