jinja: only add special tokens if template doesn't seem to handle them

This commit is contained in:
ochafik 2025-01-27 14:28:11 +00:00
parent da606d8d41
commit 15ec01e896
2 changed files with 8 additions and 3 deletions

View file

@ -100,7 +100,7 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
if (!parse_json(it, end, arguments)) { if (!parse_json(it, end, arguments)) {
if (allow_raw_python && name == "python" && std::regex_match("", close_regex)) { if (allow_raw_python && name == "python" && std::regex_match("", close_regex)) {
std::string src(it, end); std::string src(it, end);
result.tool_calls.push_back({name, src, /* id= */ ""}); result.tool_calls.push_back({name, src, /* id= */ ""});
break; break;
} }
throw std::runtime_error("Failed to parse json tool call arguments"); throw std::runtime_error("Failed to parse json tool call arguments");
@ -373,7 +373,7 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
fprintf(stderr, "[%s]\n", __func__); fprintf(stderr, "[%s]\n", __func__);
auto builtin_tools = json {"wolfram_alpha", "brave_search"}; auto builtin_tools = json {"wolfram_alpha", "brave_search"};
common_chat_data data; common_chat_data data;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;

View file

@ -3768,6 +3768,7 @@ int main(int argc, char ** argv) {
try { try {
common_chat_data chat_data; common_chat_data chat_data;
bool add_special = false;
if (tmpl && ctx_server.params_base.use_jinja) { if (tmpl && ctx_server.params_base.use_jinja) {
chat_data = common_chat_init(*tmpl, { chat_data = common_chat_init(*tmpl, {
/* .messages = */ json_value(data, "messages", json::array()), /* .messages = */ json_value(data, "messages", json::array()),
@ -3784,7 +3785,11 @@ int main(int argc, char ** argv) {
} }
chat_data.grammar = data.at("grammar"); chat_data.grammar = data.at("grammar");
} }
// TODO: move inside minja:chat_template?
add_special = tmpl->source().find("eos_token") == std::string::npos &&
tmpl->source().find("bos_token") == std::string::npos;
} else { } else {
add_special = true;
chat_data.prompt = data.at("prompt"); chat_data.prompt = data.at("prompt");
if (data.contains("grammar")) { if (data.contains("grammar")) {
chat_data.grammar = data.at("grammar"); chat_data.grammar = data.at("grammar");
@ -3792,7 +3797,7 @@ int main(int argc, char ** argv) {
chat_data.grammar = json_schema_to_grammar(data.at("json_schema")); chat_data.grammar = json_schema_to_grammar(data.at("json_schema"));
} }
} }
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, true, true); std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, add_special, true);
tasks.reserve(tokenized_prompts.size()); tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) { for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(type); server_task task = server_task(type);