align Command R7B w/ --think / reasoning_content behaviour
This commit is contained in:
parent
3841a163ef
commit
e6d9b52480
9 changed files with 176 additions and 87 deletions
|
@ -12,12 +12,13 @@ std::string common_chat_format_name(common_chat_format format) {
|
|||
case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK: return "DeepSeek R1 (extract <think>)";
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK: return "DeepSeek R1 (extract reasoning_content)";
|
||||
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
|
||||
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B_THINK: return "Command R7B (extract reasoning_content)";
|
||||
default:
|
||||
throw std::runtime_error("Unknown chat format");
|
||||
}
|
||||
|
@ -469,22 +470,49 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
|
|||
"<|END_THINKING|>",
|
||||
"<|END_ACTION|>",
|
||||
};
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
|
||||
auto adjusted_messages = json::array();
|
||||
for (const auto & msg : inputs.messages) {
|
||||
auto has_reasoning_content = msg.contains("reasoning_content") && msg["reasoning_content"].is_string();
|
||||
auto has_tool_calls = msg.contains("tool_calls") && msg["tool_calls"].is_array();
|
||||
if (has_reasoning_content && has_tool_calls) {
|
||||
auto adjusted_message = msg;
|
||||
adjusted_message["tool_plan"] = msg["reasoning_content"];
|
||||
adjusted_message.erase("reasoning_content");
|
||||
adjusted_messages.push_back(adjusted_message);
|
||||
} else {
|
||||
adjusted_messages.push_back(msg);
|
||||
}
|
||||
}
|
||||
// } else {
|
||||
// adjusted_messages = inputs.messages;
|
||||
// }
|
||||
data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {});
|
||||
data.format = inputs.think ? COMMON_CHAT_FORMAT_COMMAND_R7B_THINK : COMMON_CHAT_FORMAT_COMMAND_R7B;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_command_r7b(const std::string & input) {
|
||||
static std::regex response_regex("<\\|START_RESPONSE\\|>([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>");
|
||||
static std::regex thought_action_regex("<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|><\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
|
||||
static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool think) {
|
||||
static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|>)([\\s\\S\\n\\r]*)");
|
||||
static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
|
||||
static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>");
|
||||
|
||||
std::smatch match;
|
||||
|
||||
common_chat_msg result;
|
||||
result.role = "assistant";
|
||||
if (std::regex_match(input, match, response_regex)) {
|
||||
result.content = match[1].str();
|
||||
} else if (std::regex_match(input, match, thought_action_regex)) {
|
||||
result.tool_plan = match[1].str();
|
||||
auto actions_str = match[2].str();
|
||||
|
||||
std::string rest = input;
|
||||
|
||||
if (std::regex_match(rest, match, thought_regex)) {
|
||||
if (think) {
|
||||
result.reasoning_content = match[2].str();
|
||||
} else if (!match[2].str().empty()) {
|
||||
// Let the unparsed thinking tags through in content only if their insides aren't empty.
|
||||
result.content = match[1].str();
|
||||
}
|
||||
rest = match[3].str();
|
||||
}
|
||||
if (std::regex_match(rest, match, action_regex)) {
|
||||
auto actions_str = match[1].str();
|
||||
auto actions = json::parse(actions_str);
|
||||
for (const auto & action : actions) {
|
||||
result.tool_calls.push_back({
|
||||
|
@ -493,9 +521,12 @@ static common_chat_msg common_chat_parse_command_r7b(const std::string & input)
|
|||
/* .id = */ action["tool_call_id"],
|
||||
});
|
||||
}
|
||||
} else if (std::regex_match(rest, match, response_regex)) {
|
||||
auto response = match[1].str();
|
||||
result.content += response;
|
||||
} else {
|
||||
LOG_ERR("Failed to parse command_r output");
|
||||
result.content = input;
|
||||
result.content += rest;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
@ -1038,6 +1069,11 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
|
|||
return common_chat_params_init_deepseek_r1(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Command R7B: : use handler in all cases except json schema (thinking / tools).
|
||||
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) {
|
||||
return common_chat_params_init_command_r7b(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Use generic handler when forcing thoughts or JSON schema for final output
|
||||
// TODO: support thinking mode and/or JSON schema in handlers below this.
|
||||
if (inputs.think || (!inputs.tools.is_null() && inputs.json_schema.is_object())) {
|
||||
|
@ -1081,11 +1117,6 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
|
|||
return common_chat_params_init_mistral_nemo(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Command R7B (w/ tools)
|
||||
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos) {
|
||||
return common_chat_params_init_command_r7b(tmpl, inputs);
|
||||
}
|
||||
|
||||
// Generic fallback
|
||||
return common_chat_params_init_generic(tmpl, inputs);
|
||||
}
|
||||
|
@ -1123,7 +1154,9 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
|
|||
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
|
||||
return common_chat_parse_firefunction_v2(input);
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B:
|
||||
return common_chat_parse_command_r7b(input);
|
||||
return common_chat_parse_command_r7b(input, /* think= */ false);
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B_THINK:
|
||||
return common_chat_parse_command_r7b(input, /* think= */ true);
|
||||
default:
|
||||
throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue