align Command R7B w/ --think / reasoning_content behaviour

This commit is contained in:
Olivier Chafik 2025-02-05 15:47:37 +00:00
parent 3841a163ef
commit e6d9b52480
9 changed files with 176 additions and 87 deletions

View file

@ -12,12 +12,13 @@ std::string common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
case COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK: return "DeepSeek R1 (extract <think>)";
case COMMON_CHAT_FORMAT_DEEPSEEK_R1_THINK: return "DeepSeek R1 (extract reasoning_content)";
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
case COMMON_CHAT_FORMAT_COMMAND_R7B_THINK: return "Command R7B (extract reasoning_content)";
default:
throw std::runtime_error("Unknown chat format");
}
@ -469,22 +470,49 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
"<|END_THINKING|>",
"<|END_ACTION|>",
};
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = COMMON_CHAT_FORMAT_COMMAND_R7B;
auto adjusted_messages = json::array();
for (const auto & msg : inputs.messages) {
auto has_reasoning_content = msg.contains("reasoning_content") && msg["reasoning_content"].is_string();
auto has_tool_calls = msg.contains("tool_calls") && msg["tool_calls"].is_array();
if (has_reasoning_content && has_tool_calls) {
auto adjusted_message = msg;
adjusted_message["tool_plan"] = msg["reasoning_content"];
adjusted_message.erase("reasoning_content");
adjusted_messages.push_back(adjusted_message);
} else {
adjusted_messages.push_back(msg);
}
}
// } else {
// adjusted_messages = inputs.messages;
// }
data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {});
data.format = inputs.think ? COMMON_CHAT_FORMAT_COMMAND_R7B_THINK : COMMON_CHAT_FORMAT_COMMAND_R7B;
return data;
}
static common_chat_msg common_chat_parse_command_r7b(const std::string & input) {
static std::regex response_regex("<\\|START_RESPONSE\\|>([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>");
static std::regex thought_action_regex("<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|><\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool think) {
static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|>)([\\s\\S\\n\\r]*)");
static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>");
static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>");
std::smatch match;
common_chat_msg result;
result.role = "assistant";
if (std::regex_match(input, match, response_regex)) {
result.content = match[1].str();
} else if (std::regex_match(input, match, thought_action_regex)) {
result.tool_plan = match[1].str();
auto actions_str = match[2].str();
std::string rest = input;
if (std::regex_match(rest, match, thought_regex)) {
if (think) {
result.reasoning_content = match[2].str();
} else if (!match[2].str().empty()) {
// Let the unparsed thinking tags through in content only if their insides aren't empty.
result.content = match[1].str();
}
rest = match[3].str();
}
if (std::regex_match(rest, match, action_regex)) {
auto actions_str = match[1].str();
auto actions = json::parse(actions_str);
for (const auto & action : actions) {
result.tool_calls.push_back({
@ -493,9 +521,12 @@ static common_chat_msg common_chat_parse_command_r7b(const std::string & input)
/* .id = */ action["tool_call_id"],
});
}
} else if (std::regex_match(rest, match, response_regex)) {
auto response = match[1].str();
result.content += response;
} else {
LOG_ERR("Failed to parse command_r output");
result.content = input;
result.content += rest;
}
return result;
}
@ -1038,6 +1069,11 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
return common_chat_params_init_deepseek_r1(tmpl, inputs);
}
// Command R7B: : use handler in all cases except json schema (thinking / tools).
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) {
return common_chat_params_init_command_r7b(tmpl, inputs);
}
// Use generic handler when forcing thoughts or JSON schema for final output
// TODO: support thinking mode and/or JSON schema in handlers below this.
if (inputs.think || (!inputs.tools.is_null() && inputs.json_schema.is_object())) {
@ -1081,11 +1117,6 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
return common_chat_params_init_mistral_nemo(tmpl, inputs);
}
// Command R7B (w/ tools)
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos) {
return common_chat_params_init_command_r7b(tmpl, inputs);
}
// Generic fallback
return common_chat_params_init_generic(tmpl, inputs);
}
@ -1123,7 +1154,9 @@ common_chat_msg common_chat_parse(const std::string & input, common_chat_format
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
return common_chat_parse_firefunction_v2(input);
case COMMON_CHAT_FORMAT_COMMAND_R7B:
return common_chat_parse_command_r7b(input);
return common_chat_parse_command_r7b(input, /* think= */ false);
case COMMON_CHAT_FORMAT_COMMAND_R7B_THINK:
return common_chat_parse_command_r7b(input, /* think= */ true);
default:
throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
}