jinja: only add special tokens if template doesn't seem to handle them

This commit is contained in:
ochafik 2025-01-27 14:28:11 +00:00
parent da606d8d41
commit 15ec01e896
2 changed files with 8 additions and 3 deletions

View file

@ -100,7 +100,7 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
if (!parse_json(it, end, arguments)) {
if (allow_raw_python && name == "python" && std::regex_match("", close_regex)) {
std::string src(it, end);
result.tool_calls.push_back({name, src, /* id= */ ""});
result.tool_calls.push_back({name, src, /* id= */ ""});
break;
}
throw std::runtime_error("Failed to parse json tool call arguments");
@ -373,7 +373,7 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
fprintf(stderr, "[%s]\n", __func__);
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
common_chat_data data;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;

View file

@ -3768,6 +3768,7 @@ int main(int argc, char ** argv) {
try {
common_chat_data chat_data;
bool add_special = false;
if (tmpl && ctx_server.params_base.use_jinja) {
chat_data = common_chat_init(*tmpl, {
/* .messages = */ json_value(data, "messages", json::array()),
@ -3784,7 +3785,11 @@ int main(int argc, char ** argv) {
}
chat_data.grammar = data.at("grammar");
}
// TODO: move inside minja:chat_template?
add_special = tmpl->source().find("eos_token") == std::string::npos &&
tmpl->source().find("bos_token") == std::string::npos;
} else {
add_special = true;
chat_data.prompt = data.at("prompt");
if (data.contains("grammar")) {
chat_data.grammar = data.at("grammar");
@ -3792,7 +3797,7 @@ int main(int argc, char ** argv) {
chat_data.grammar = json_schema_to_grammar(data.at("json_schema"));
}
}
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, true, true);
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, add_special, true);
tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(type);