tool-call: support Command R7B (+ return tool_plan "thoughts" in API) (#11585)

* `tool-call`: support Command R7B (w/ tool_plan return)

* `tool-call`: cleaner preservation of tokens + warn when likely bad chat template override

* `tool-call`: test cleanup / handle lazy grammar triggers
This commit is contained in:
Olivier Chafik 2025-02-02 09:25:38 +00:00 committed by GitHub
parent 69804487e0
commit bfcce4d693
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 420 additions and 56 deletions

View file

@ -1128,6 +1128,7 @@ curl http://localhost:8080/v1/chat/completions \
- Hermes 2/3, Qwen 2.5
- Mistral Nemo
- Firefunction v2
- Command R7B
- DeepSeek R1 (WIP / seems reluctant to call any tools?)
<details>
@ -1202,21 +1203,28 @@ curl http://localhost:8080/v1/chat/completions \
```shell
# Native support:
llama-server --jinja -fa -hf bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/Llama-3.2-3B-Instruct-GGUF:Q6_K
llama-server --jinja -fa -hf bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q6_K_L
llama-server --jinja -fa -hf bartowski/functionary-small-v3.2-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B )
llama-server --jinja -fa -hf bartowski/Llama-3.3-70B-Instruct-GGUF:Q4_K_M
# Native support requires the right template for these GGUFs:
llama-server --jinja -fa -hf bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M \
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use )
llama-server --jinja -fa -hf bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M \
--chat-template-file <( python scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use )
llama-server --jinja -fa -hf bartowski/firefunction-v2-GGUF -hff firefunction-v2-IQ1_M.gguf \
--chat-template-file <( python scripts/get_chat_template.py fireworks-ai/firellama-3-firefunction-v2 )
--chat-template-file <( python scripts/get_chat_template.py fireworks-ai/llama-3-firefunction-v2 tool_use )
llama-server --jinja -fa -hf bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L \
--chat-template-file <( python scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use )
# Generic format support
llama-server --jinja -fa -hf bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q4_K_M
llama-server --jinja -fa -hf bartowski/phi-4-GGUF:Q4_0
llama-server --jinja -fa -hf bartowski/gemma-2-2b-it-GGUF:Q8_0
llama-server --jinja -fa -hf bartowski/c4ai-command-r-v01-GGUF:Q2_K
```
- Test in CLI:

View file

@ -131,6 +131,11 @@ struct slot_params {
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
}
std::vector<std::string> grammar_trigger_words;
for (const auto & trigger : sampling.grammar_trigger_words) {
grammar_trigger_words.push_back(trigger.word);
}
return json {
{"n_predict", n_predict}, // Server configured n_predict
{"seed", sampling.seed},
@ -165,8 +170,9 @@ struct slot_params {
{"n_probs", sampling.n_probs},
{"min_keep", sampling.min_keep},
{"grammar", sampling.grammar},
// {"grammar_trigger_words", sampling.grammar_trigger_words},
{"grammar_trigger_words", grammar_trigger_words},
{"grammar_trigger_tokens", sampling.grammar_trigger_tokens},
{"preserved_tokens", sampling.preserved_tokens},
{"samplers", samplers},
{"speculative.n_max", speculative.n_max},
{"speculative.n_min", speculative.n_min},
@ -363,12 +369,26 @@ struct server_task {
if (ids.size() == 1) {
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
continue;
}
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
params.sampling.grammar_trigger_words.push_back(trigger);
}
}
const auto preserved_tokens = data.find("preserved_tokens");
if (preserved_tokens != data.end()) {
for (const auto & t : *preserved_tokens) {
auto ids = common_tokenize(vocab, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
LOG_DBG("Preserved token: %d\n", ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
} else {
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
}
}
}
if (params.sampling.grammar_lazy) {
GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0);
}
@ -695,19 +715,19 @@ struct server_task_result_cmpl_final : server_task_result {
json to_json_oaicompat_chat() {
std::string finish_reason = "length";
common_chat_msg message;
common_chat_msg msg;
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
LOG_DBG("Parsing chat message: %s\n", content.c_str());
message = common_chat_parse(content, oaicompat_chat_format);
finish_reason = message.tool_calls.empty() ? "stop" : "tool_calls";
msg = common_chat_parse(content, oaicompat_chat_format);
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
} else {
message.content = content;
msg.content = content;
}
json tool_calls;
if (!message.tool_calls.empty()) {
if (!msg.tool_calls.empty()) {
tool_calls = json::array();
for (const auto & tc : message.tool_calls) {
for (const auto & tc : msg.tool_calls) {
tool_calls.push_back({
{"type", "function"},
{"function", {
@ -719,14 +739,19 @@ struct server_task_result_cmpl_final : server_task_result {
}
}
json message {
{"content", msg.content},
{"tool_calls", tool_calls},
{"role", "assistant"},
};
if (!msg.tool_plan.empty()) {
message["tool_plan"] = msg.tool_plan;
}
json choice {
{"finish_reason", finish_reason},
{"index", 0},
{"message", json {
{"content", message.content},
{"tool_calls", tool_calls},
{"role", "assistant"},
}},
{"message", message},
};
if (!stream && probs_output.size() > 0) {
@ -2833,8 +2858,7 @@ struct server_context {
server_slot * slot_batched = nullptr;
auto accept_special_token = [&](server_slot & slot, llama_token token) {
const auto & trigger_tokens = slot.params.sampling.grammar_trigger_tokens;
return params_base.special || std::find(trigger_tokens.begin(), trigger_tokens.end(), token) != trigger_tokens.end();
return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end();
};
// frist, add sampled tokens from any ongoing sequences

View file

@ -662,6 +662,7 @@ static json oaicompat_completion_params_parse(
});
}
llama_params["grammar_triggers"] = grammar_triggers;
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
for (const auto & stop : chat_params.additional_stops) {
llama_params["stop"].push_back(stop);
}