update to minja's new api
This commit is contained in:
parent
11c1f0c7d4
commit
30ea3591c9
4 changed files with 170 additions and 12 deletions
|
@ -270,6 +270,28 @@ class chat_template {
|
||||||
const std::string & eos_token() const { return eos_token_; }
|
const std::string & eos_token() const { return eos_token_; }
|
||||||
const chat_template_caps & original_caps() const { return caps_; }
|
const chat_template_caps & original_caps() const { return caps_; }
|
||||||
|
|
||||||
|
// Deprecated, please use the form with chat_template_inputs and chat_template_options
|
||||||
|
std::string apply(
|
||||||
|
const nlohmann::ordered_json & messages,
|
||||||
|
const nlohmann::ordered_json & tools,
|
||||||
|
bool add_generation_prompt,
|
||||||
|
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(),
|
||||||
|
bool apply_polyfills = true)
|
||||||
|
{
|
||||||
|
fprintf(stderr, "[%s] Deprecated!\n", __func__);
|
||||||
|
chat_template_inputs inputs;
|
||||||
|
inputs.messages = messages;
|
||||||
|
inputs.tools = tools;
|
||||||
|
inputs.add_generation_prompt = add_generation_prompt;
|
||||||
|
inputs.extra_context = extra_context;
|
||||||
|
inputs.now = std::chrono::system_clock::now();
|
||||||
|
|
||||||
|
chat_template_options opts;
|
||||||
|
opts.apply_polyfills = apply_polyfills;
|
||||||
|
|
||||||
|
return apply(inputs, opts);
|
||||||
|
}
|
||||||
|
|
||||||
std::string apply(
|
std::string apply(
|
||||||
const chat_template_inputs & inputs,
|
const chat_template_inputs & inputs,
|
||||||
const chat_template_options & opts = chat_template_options()) const
|
const chat_template_options & opts = chat_template_options()) const
|
||||||
|
|
|
@ -175,6 +175,28 @@ static void foreach_function(const json & tools, const std::function<void(const
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::string apply(
|
||||||
|
const common_chat_template & tmpl,
|
||||||
|
const nlohmann::ordered_json & messages,
|
||||||
|
const nlohmann::ordered_json & tools,
|
||||||
|
bool add_generation_prompt,
|
||||||
|
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json())
|
||||||
|
{
|
||||||
|
minja::chat_template_inputs tmpl_inputs;
|
||||||
|
tmpl_inputs.messages = messages;
|
||||||
|
tmpl_inputs.tools = tools;
|
||||||
|
tmpl_inputs.add_generation_prompt = add_generation_prompt;
|
||||||
|
tmpl_inputs.extra_context = extra_context;
|
||||||
|
// TODO: add flag to control date/time, if only for testing purposes.
|
||||||
|
// tmpl_inputs.now = std::chrono::system_clock::now();
|
||||||
|
|
||||||
|
minja::chat_template_options tmpl_opts;
|
||||||
|
tmpl_opts.use_bos_token = false;
|
||||||
|
tmpl_opts.use_eos_token = false;
|
||||||
|
|
||||||
|
return tmpl.apply(tmpl_inputs, tmpl_opts);
|
||||||
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
|
|
||||||
|
@ -256,7 +278,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
|
||||||
inputs.messages,
|
inputs.messages,
|
||||||
"Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request");
|
"Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request");
|
||||||
|
|
||||||
data.prompt = tmpl.apply(tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
data.prompt = apply(tmpl, tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
data.format = COMMON_CHAT_FORMAT_GENERIC;
|
data.format = COMMON_CHAT_FORMAT_GENERIC;
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
@ -322,7 +344,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
|
||||||
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
|
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
|
data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true});
|
||||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
|
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
@ -372,7 +394,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
|
||||||
"<|END_THINKING|>",
|
"<|END_THINKING|>",
|
||||||
"<|END_ACTION|>",
|
"<|END_ACTION|>",
|
||||||
};
|
};
|
||||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
|
data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
@ -489,7 +511,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
||||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
data.additional_stops.push_back("<|eom_id|>");
|
data.additional_stops.push_back("<|eom_id|>");
|
||||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
|
||||||
{"tools_in_user_message", false},
|
{"tools_in_user_message", false},
|
||||||
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
||||||
});
|
});
|
||||||
|
@ -568,7 +590,7 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
|
||||||
"<|tool▁call▁end|>",
|
"<|tool▁call▁end|>",
|
||||||
};
|
};
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
auto prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
|
|
||||||
// Hacks to fix the official (broken) prompt.
|
// Hacks to fix the official (broken) prompt.
|
||||||
// It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead,
|
// It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead,
|
||||||
|
@ -614,10 +636,10 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input)
|
||||||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
fprintf(stderr, "%s\n", __func__);
|
fprintf(stderr, "%s\n", __func__);
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
data.prompt = tmpl.apply(inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
||||||
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
||||||
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||||
}, /* adjust_inputs= */ false);
|
});
|
||||||
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
||||||
data.grammar_lazy = inputs.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != "required";
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
|
@ -661,7 +683,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
||||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
|
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
|
||||||
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
if (!inputs.tools.is_null() && !inputs.tools.empty()) {
|
||||||
data.grammar_lazy = inputs.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != "required";
|
||||||
|
@ -788,7 +810,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
|
||||||
data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
|
data.grammar_triggers.push_back({"<function=", /* .at_start = */ false});
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
|
|
||||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
// TODO: if (has_raw_python)
|
// TODO: if (has_raw_python)
|
||||||
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
|
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
|
||||||
return data;
|
return data;
|
||||||
|
@ -843,7 +865,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
|
||||||
data.preserved_tokens = { "</tool_call>" };
|
data.preserved_tokens = { "</tool_call>" };
|
||||||
}, grammar_options);
|
}, grammar_options);
|
||||||
|
|
||||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
|
data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO;
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
@ -904,7 +926,7 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||||
data.grammar_lazy = false;
|
data.grammar_lazy = false;
|
||||||
if (!inputs.json_schema.is_null()) {
|
if (!inputs.json_schema.is_null()) {
|
||||||
|
|
|
@ -848,7 +848,15 @@ static int apply_chat_template(const common_chat_template & tmpl, LlamaData & ll
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
auto result = tmpl.apply(messages, /* tools= */ json(), append);
|
minja::chat_template_inputs tmpl_inputs;
|
||||||
|
tmpl_inputs.messages = messages;
|
||||||
|
tmpl_inputs.add_generation_prompt = append;
|
||||||
|
|
||||||
|
minja::chat_template_options tmpl_opts;
|
||||||
|
tmpl_opts.use_bos_token = false;
|
||||||
|
tmpl_opts.use_eos_token = false;
|
||||||
|
|
||||||
|
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
|
||||||
llama_data.fmtted.resize(result.size() + 1);
|
llama_data.fmtted.resize(result.size() + 1);
|
||||||
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
|
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
|
||||||
return result.size();
|
return result.size();
|
||||||
|
|
|
@ -340,6 +340,112 @@ def test_weather_tool_call(hf_repo: str, template_override: str | Tuple[str, str
|
||||||
assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
|
assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.parametrize("hf_repo,template_override", [
|
||||||
|
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||||
|
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||||
|
|
||||||
|
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||||
|
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||||
|
|
||||||
|
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||||
|
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
||||||
|
|
||||||
|
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
|
||||||
|
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||||
|
|
||||||
|
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
|
||||||
|
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
|
||||||
|
|
||||||
|
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||||
|
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||||
|
|
||||||
|
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||||
|
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||||
|
|
||||||
|
("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
|
||||||
|
("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
|
||||||
|
|
||||||
|
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||||
|
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||||
|
|
||||||
|
# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
|
||||||
|
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
||||||
|
|
||||||
|
# ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||||
|
])
|
||||||
|
def test_calc_result(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||||
|
global server
|
||||||
|
n_predict = 512
|
||||||
|
server.n_slots = 1
|
||||||
|
server.jinja = True
|
||||||
|
server.n_ctx = 8192
|
||||||
|
server.n_predict = n_predict
|
||||||
|
server.model_hf_repo = hf_repo
|
||||||
|
server.model_hf_file = None
|
||||||
|
if isinstance(template_override, tuple):
|
||||||
|
(template_hf_repo, template_variant) = template_override
|
||||||
|
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
|
||||||
|
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
|
||||||
|
elif isinstance(template_override, str):
|
||||||
|
server.chat_template = template_override
|
||||||
|
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||||
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
|
"max_tokens": n_predict,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
|
||||||
|
{"role": "user", "content": "What's the y coordinate of a point on the unit sphere at angle 30 degrees?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "calculate",
|
||||||
|
"arguments": "{\"expression\":\"sin(30 * pi / 180)\"}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"name": "calculate",
|
||||||
|
"content": "0.5"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type":"function",
|
||||||
|
"function":{
|
||||||
|
"name":"calculate",
|
||||||
|
"description":"A calculator function that computes values of arithmetic expressions in the Python syntax",
|
||||||
|
"parameters":{
|
||||||
|
"type":"object",
|
||||||
|
"properties":{
|
||||||
|
"expression":{
|
||||||
|
"type":"string",
|
||||||
|
"description":"An arithmetic expression to compute the value of (Python syntad, assuming all floats)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required":["expression"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||||
|
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||||
|
choice = res.body["choices"][0]
|
||||||
|
tool_calls = choice["message"].get("tool_calls")
|
||||||
|
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||||
|
tool_call = tool_calls[0]
|
||||||
|
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"]
|
||||||
|
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||||
|
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
|
||||||
|
location = actual_arguments["location"]
|
||||||
|
assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}"
|
||||||
|
assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [
|
@pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [
|
||||||
(None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
(None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue