add eos_id_list to llama.cpp

This commit is contained in:
toyer 2024-06-24 12:27:02 +00:00
parent 4b65b648ce
commit 3a4d5790bf
13 changed files with 122 additions and 55 deletions

View file

@ -1021,7 +1021,13 @@ struct server_context {
slot.sparams.logit_bias.clear();
if (json_value(data, "ignore_eos", false)) {
slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
const int n_eos = llama_n_eos(model);
std::vector<int32_t> eos_tokens(n_eos, 0);
int32_t* eos_ptr = eos_tokens.data();
llama_token_eos(model, eos_ptr);
for (int32_t i = 0; i < n_eos; ++i) {
slot.sparams.logit_bias[eos_ptr[i]] = -INFINITY;
}
}
const auto & logit_bias = data.find("logit_bias");
@ -1308,9 +1314,17 @@ struct server_context {
}
json get_formated_generation(const server_slot & slot) const {
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
const int n_eos = llama_n_eos(model);
std::vector<int32_t> eos_tokens(n_eos, 0);
int32_t* eos_ptr = eos_tokens.data();
llama_token_eos(model, eos_ptr);
bool ignore_eos = false;
for (auto eos: eos_tokens) {
const auto logit_bias_eos = slot.sparams.logit_bias.find(eos);
if (logit_bias_eos != slot.sparams.logit_bias.end() && eos < 0.0f && std::isinf(logit_bias_eos->second)) {
ignore_eos = true;
}
}
std::vector<std::string> samplers_sequence;
samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
for (const auto & sampler_type : slot.sparams.samplers_sequence) {