diff --git a/common/chat-handler.cpp b/common/chat-handler.cpp index 74805f222..2e44b69b3 100644 --- a/common/chat-handler.cpp +++ b/common/chat-handler.cpp @@ -100,7 +100,7 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri if (!parse_json(it, end, arguments)) { if (allow_raw_python && name == "python" && std::regex_match("", close_regex)) { std::string src(it, end); - result.tool_calls.push_back({name, src, /* id= */ ""}); + result.tool_calls.push_back({name, src, /* id= */ ""}); break; } throw std::runtime_error("Failed to parse json tool call arguments"); @@ -373,7 +373,7 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c fprintf(stderr, "[%s]\n", __func__); auto builtin_tools = json {"wolfram_alpha", "brave_search"}; common_chat_data data; - + data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a75a7b01f..a96552dff 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3768,6 +3768,7 @@ int main(int argc, char ** argv) { try { common_chat_data chat_data; + bool add_special = false; if (tmpl && ctx_server.params_base.use_jinja) { chat_data = common_chat_init(*tmpl, { /* .messages = */ json_value(data, "messages", json::array()), @@ -3784,7 +3785,11 @@ int main(int argc, char ** argv) { } chat_data.grammar = data.at("grammar"); } + // TODO: move inside minja:chat_template? + add_special = tmpl->source().find("eos_token") == std::string::npos && + tmpl->source().find("bos_token") == std::string::npos; } else { + add_special = true; chat_data.prompt = data.at("prompt"); if (data.contains("grammar")) { chat_data.grammar = data.at("grammar"); @@ -3792,7 +3797,7 @@ int main(int argc, char ** argv) { chat_data.grammar = json_schema_to_grammar(data.at("json_schema")); } } - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, true, true); + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, add_special, true); tasks.reserve(tokenized_prompts.size()); for (size_t i = 0; i < tokenized_prompts.size(); i++) { server_task task = server_task(type);