Copy sampler parameters from chat template

This commit is contained in:
Mason M 2025-02-05 12:07:22 -04:00
parent a30111bef3
commit a726adaef7
3 changed files with 65 additions and 10 deletions

View file

@ -1830,12 +1830,51 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
return res >= 0;
}
static void copy_chat_params(const common_chat_params & src, common_chat_sampling_updater * update_sparams)
{
GGML_ASSERT(update_sparams && update_sparams->sparams && update_sparams->vocab);
auto & dst = *update_sparams->sparams;
auto vocab = update_sparams->vocab;
dst.grammar = src.grammar;
dst.grammar_lazy = src.grammar_lazy;
for (const auto & trigger : src.grammar_triggers) {
auto ids = common_tokenize(vocab, trigger.word, false, true);
if (ids.size() == 1) {
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
dst.grammar_trigger_tokens.push_back(ids[0]);
dst.preserved_tokens.insert(ids[0]);
continue;
}
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
dst.grammar_trigger_words.push_back(trigger);
}
for (const auto & preserved : src.preserved_tokens) {
auto ids = common_tokenize(vocab, preserved, false, true);
if (ids.size() == 1) {
LOG_DBG("Preserved token: %d\n", ids[0]);
dst.preserved_tokens.insert(ids[0]);
} else {
// This may happen when using a tool call style meant for a model
// with special tokens to preserve on a model without said tokens.
LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n",
preserved.c_str());
}
}
}
std::string common_chat_apply_template(
const common_chat_templates & tmpl,
const std::vector<common_chat_msg> & msgs,
bool add_ass,
bool use_jinja,
const common_params_tools & tools)
const common_params_tools & tools,
common_chat_sampling_updater * update_sparams)
{
const auto & tmpl_selected =
tools.tools() && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default;
@ -1865,7 +1904,11 @@ std::string common_chat_apply_template(
inputs.messages = messages;
inputs.add_generation_prompt = add_ass;
return common_chat_params_init(tmpl_selected, inputs).prompt;
auto chat_params = common_chat_params_init(tmpl_selected, inputs);
if (update_sparams) {
copy_chat_params(chat_params, update_sparams);
}
return chat_params.prompt;
}
int alloc_size = 0;
@ -1903,11 +1946,12 @@ std::string common_chat_format_single(
const common_chat_msg & new_msg,
bool add_ass,
bool use_jinja,
const common_params_tools & tools)
const common_params_tools & tools,
common_chat_sampling_updater * update_sparams)
{
std::ostringstream ss;
auto fmt_past_msg = past_msg.empty() ? ""
: common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools);
: common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools, update_sparams);
std::vector<common_chat_msg> chat_new(past_msg);
// if the past_msg ends with a newline, we must preserve it in the formatted version
@ -1916,7 +1960,7 @@ std::string common_chat_format_single(
};
// format chat with new_msg
chat_new.push_back(new_msg);
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools);
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools, update_sparams);
// get the diff part
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str();

View file

@ -671,6 +671,11 @@ struct common_chat_templates {
std::unique_ptr<common_chat_template> template_tool_use;
};
struct common_chat_sampling_updater {
common_params_sampling * sparams;
const llama_vocab * vocab;
};
// CPP wrapper for llama_chat_apply_template
// If the built-in template is not supported, we default to chatml
// If the custom "tmpl" is not supported, we throw an error
@ -679,7 +684,8 @@ std::string common_chat_apply_template(
const std::vector<common_chat_msg> & chat,
bool add_ass,
bool use_jinja,
const common_params_tools & tools = common_params_tools());
const common_params_tools & tools = common_params_tools(),
common_chat_sampling_updater * update_sparams = nullptr);
// Format single message, while taking into account the position of that message in chat history
std::string common_chat_format_single(
@ -688,7 +694,8 @@ std::string common_chat_format_single(
const common_chat_msg & new_msg,
bool add_ass,
bool use_jinja,
const common_params_tools & tools = common_params_tools());
const common_params_tools & tools = common_params_tools(),
common_chat_sampling_updater * update_sparams = nullptr);
// Returns an example of formatted chat
std::string common_chat_format_example(

View file

@ -263,14 +263,18 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd_inp;
auto chat_add_and_format = [&chat_msgs, &chat_templates](
auto chat_add_and_format = [&chat_msgs, &chat_templates, &sparams, vocab](
const std::string & role, const std::string & content,
const common_params_tools & tools = common_params_tools())
{
bool add_ass = (role == "user");
common_chat_msg new_msg{role, content, {}};
auto formatted = common_chat_format_single(chat_templates, chat_msgs,
new_msg, role == "user", g_params->use_jinja, tools);
common_chat_sampling_updater updater{&sparams, vocab};
auto formatted =
common_chat_format_single(chat_templates, chat_msgs, new_msg, add_ass, g_params->use_jinja,
tools, &updater);
chat_msgs.push_back({role, content, {}});
LOG_DBG("formatted: '%s'\n", formatted.c_str());