jinja: only add special tokens if template doesn't seem to handle them

This commit is contained in:
ochafik 2025-01-27 14:28:11 +00:00
parent da606d8d41
commit 15ec01e896
2 changed files with 8 additions and 3 deletions

View file

@ -3768,6 +3768,7 @@ int main(int argc, char ** argv) {
try {
common_chat_data chat_data;
bool add_special = false;
if (tmpl && ctx_server.params_base.use_jinja) {
chat_data = common_chat_init(*tmpl, {
/* .messages = */ json_value(data, "messages", json::array()),
@ -3784,7 +3785,11 @@ int main(int argc, char ** argv) {
}
chat_data.grammar = data.at("grammar");
}
// TODO: move inside minja:chat_template?
add_special = tmpl->source().find("eos_token") == std::string::npos &&
tmpl->source().find("bos_token") == std::string::npos;
} else {
add_special = true;
chat_data.prompt = data.at("prompt");
if (data.contains("grammar")) {
chat_data.grammar = data.at("grammar");
@ -3792,7 +3797,7 @@ int main(int argc, char ** argv) {
chat_data.grammar = json_schema_to_grammar(data.at("json_schema"));
}
}
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, true, true);
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, add_special, true);
tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(type);