jinja: only add special tokens if template doesn't seem to handle them
This commit is contained in:
parent
da606d8d41
commit
15ec01e896
2 changed files with 8 additions and 3 deletions
|
@ -100,7 +100,7 @@ static common_chat_msg parse_json_tool_calls(const json & tools, const std::stri
|
||||||
if (!parse_json(it, end, arguments)) {
|
if (!parse_json(it, end, arguments)) {
|
||||||
if (allow_raw_python && name == "python" && std::regex_match("", close_regex)) {
|
if (allow_raw_python && name == "python" && std::regex_match("", close_regex)) {
|
||||||
std::string src(it, end);
|
std::string src(it, end);
|
||||||
result.tool_calls.push_back({name, src, /* id= */ ""});
|
result.tool_calls.push_back({name, src, /* id= */ ""});
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
throw std::runtime_error("Failed to parse json tool call arguments");
|
throw std::runtime_error("Failed to parse json tool call arguments");
|
||||||
|
@ -373,7 +373,7 @@ static common_chat_data common_chat_init_llama_3_1_python_tag_tool_calls(const c
|
||||||
fprintf(stderr, "[%s]\n", __func__);
|
fprintf(stderr, "[%s]\n", __func__);
|
||||||
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
||||||
common_chat_data data;
|
common_chat_data data;
|
||||||
|
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> tool_rules;
|
std::vector<std::string> tool_rules;
|
||||||
|
|
||||||
|
|
|
@ -3768,6 +3768,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
try {
|
try {
|
||||||
common_chat_data chat_data;
|
common_chat_data chat_data;
|
||||||
|
bool add_special = false;
|
||||||
if (tmpl && ctx_server.params_base.use_jinja) {
|
if (tmpl && ctx_server.params_base.use_jinja) {
|
||||||
chat_data = common_chat_init(*tmpl, {
|
chat_data = common_chat_init(*tmpl, {
|
||||||
/* .messages = */ json_value(data, "messages", json::array()),
|
/* .messages = */ json_value(data, "messages", json::array()),
|
||||||
|
@ -3784,7 +3785,11 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
chat_data.grammar = data.at("grammar");
|
chat_data.grammar = data.at("grammar");
|
||||||
}
|
}
|
||||||
|
// TODO: move inside minja:chat_template?
|
||||||
|
add_special = tmpl->source().find("eos_token") == std::string::npos &&
|
||||||
|
tmpl->source().find("bos_token") == std::string::npos;
|
||||||
} else {
|
} else {
|
||||||
|
add_special = true;
|
||||||
chat_data.prompt = data.at("prompt");
|
chat_data.prompt = data.at("prompt");
|
||||||
if (data.contains("grammar")) {
|
if (data.contains("grammar")) {
|
||||||
chat_data.grammar = data.at("grammar");
|
chat_data.grammar = data.at("grammar");
|
||||||
|
@ -3792,7 +3797,7 @@ int main(int argc, char ** argv) {
|
||||||
chat_data.grammar = json_schema_to_grammar(data.at("json_schema"));
|
chat_data.grammar = json_schema_to_grammar(data.at("json_schema"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, true, true);
|
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, chat_data.prompt, add_special, true);
|
||||||
tasks.reserve(tokenized_prompts.size());
|
tasks.reserve(tokenized_prompts.size());
|
||||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||||
server_task task = server_task(type);
|
server_task task = server_task(type);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue