sync: minja (https://github.com/google/minja/pull/52)
This commit is contained in:
parent
95cddfd8fb
commit
e598e7aa10
1 changed files with 21 additions and 7 deletions
|
@ -249,16 +249,30 @@ class chat_template {
|
||||||
inputs.add_generation_prompt = false;
|
inputs.add_generation_prompt = false;
|
||||||
full = apply(inputs);
|
full = apply(inputs);
|
||||||
}
|
}
|
||||||
|
auto eos_pos_last = full.rfind(eos_token_);
|
||||||
if (full.find(prefix) != 0) {
|
if (eos_pos_last == prefix.size() - eos_token_.size() ||
|
||||||
if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) {
|
(full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
|
||||||
prefix = prefix.substr(0, prefix.size() - eos_token_.size());
|
full = full.substr(0, eos_pos_last);
|
||||||
}
|
}
|
||||||
|
size_t common_prefix_length = 0;
|
||||||
|
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
|
||||||
|
if (prefix[i] != full[i]) {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
if (full.find(prefix) != 0) {
|
if (prefix[i] == '<') {
|
||||||
|
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
|
||||||
|
// but it removes thinking tags for past messages.
|
||||||
|
// The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
common_prefix_length = i + 1;
|
||||||
|
}
|
||||||
|
auto example = full.substr(common_prefix_length);
|
||||||
|
if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
|
||||||
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
|
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
|
||||||
|
} else {
|
||||||
|
tool_call_example_ = example;
|
||||||
}
|
}
|
||||||
tool_call_example_ = full.substr(prefix.size());
|
|
||||||
}
|
}
|
||||||
} catch (const std::exception & e) {
|
} catch (const std::exception & e) {
|
||||||
fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
|
fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
|
||||||
|
@ -363,7 +377,7 @@ class chat_template {
|
||||||
if (polyfill_tools) {
|
if (polyfill_tools) {
|
||||||
adjusted_messages = add_system(inputs.messages,
|
adjusted_messages = add_system(inputs.messages,
|
||||||
"You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
|
"You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
|
||||||
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_));
|
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"));
|
||||||
} else {
|
} else {
|
||||||
adjusted_messages = inputs.messages;
|
adjusted_messages = inputs.messages;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue