add eos_id_list to llama.cpp
This commit is contained in:
parent
4b65b648ce
commit
3a4d5790bf
13 changed files with 122 additions and 55 deletions
|
@ -1021,7 +1021,13 @@ struct server_context {
|
|||
slot.sparams.logit_bias.clear();
|
||||
|
||||
if (json_value(data, "ignore_eos", false)) {
|
||||
slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
|
||||
const int n_eos = llama_n_eos(model);
|
||||
std::vector<int32_t> eos_tokens(n_eos, 0);
|
||||
int32_t* eos_ptr = eos_tokens.data();
|
||||
llama_token_eos(model, eos_ptr);
|
||||
for (int32_t i = 0; i < n_eos; ++i) {
|
||||
slot.sparams.logit_bias[eos_ptr[i]] = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
const auto & logit_bias = data.find("logit_bias");
|
||||
|
@ -1308,9 +1314,17 @@ struct server_context {
|
|||
}
|
||||
|
||||
json get_formated_generation(const server_slot & slot) const {
|
||||
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
|
||||
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
|
||||
|
||||
const int n_eos = llama_n_eos(model);
|
||||
std::vector<int32_t> eos_tokens(n_eos, 0);
|
||||
int32_t* eos_ptr = eos_tokens.data();
|
||||
llama_token_eos(model, eos_ptr);
|
||||
bool ignore_eos = false;
|
||||
for (auto eos: eos_tokens) {
|
||||
const auto logit_bias_eos = slot.sparams.logit_bias.find(eos);
|
||||
if (logit_bias_eos != slot.sparams.logit_bias.end() && eos < 0.0f && std::isinf(logit_bias_eos->second)) {
|
||||
ignore_eos = true;
|
||||
}
|
||||
}
|
||||
std::vector<std::string> samplers_sequence;
|
||||
samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
|
||||
for (const auto & sampler_type : slot.sparams.samplers_sequence) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue