jinja: only add special tokens if template doesn't seem to handle them

This commit is contained in:
ochafik 2025-01-27 14:28:11 +00:00
parent da606d8d41
commit 15ec01e896
2 changed files with 8 additions and 3 deletions

View file

@ -3768,6 +3768,7 @@ int main(int argc, char ** argv) {
try { try {
common_chat_data chat_data; common_chat_data chat_data;
bool add_special = false;
if (tmpl && ctx_server.params_base.use_jinja) { if (tmpl && ctx_server.params_base.use_jinja) {
chat_data = common_chat_init(*tmpl, { chat_data = common_chat_init(*tmpl, {
/* .messages = */ json_value(data, "messages", json::array()), /* .messages = */ json_value(data, "messages", json::array()),
@ -3784,7 +3785,11 @@ int main(int argc, char ** argv) {
} }
chat_data.grammar = data.at("grammar"); chat_data.grammar = data.at("grammar");
} }
// TODO: move inside minja:chat_template?
add_special = tmpl->source().find("eos_token") == std::string::npos &&
tmpl->source().find("bos_token") == std::string::npos;
} else { } else {
add_special = true;
chat_data.prompt = data.at("prompt"); chat_data.prompt = data.at("prompt");
if (data.contains("grammar")) { if (data.contains("grammar")) {
chat_data.grammar = data.at("grammar"); chat_data.grammar = data.at("grammar");
@ -3792,7 +3797,7 @@ int main(int argc, char ** argv) {
chat_data.grammar = json_schema_to_grammar(data.at("json_schema")); chat_data.grammar = json_schema_to_grammar(data.at("json_schema"));
} }
} }
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, true, true); std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, add_special, true);
tasks.reserve(tokenized_prompts.size()); tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) { for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(type); server_task task = server_task(type);