diff --git a/common/common.cpp b/common/common.cpp index 15d025846..13e03e88f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1830,12 +1830,51 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { return res >= 0; } +static void copy_chat_params(const common_chat_params & src, common_chat_sampling_updater * update_sparams) +{ + GGML_ASSERT(update_sparams && update_sparams->sparams && update_sparams->vocab); + + auto & dst = *update_sparams->sparams; + auto vocab = update_sparams->vocab; + + dst.grammar = src.grammar; + dst.grammar_lazy = src.grammar_lazy; + + for (const auto & trigger : src.grammar_triggers) { + auto ids = common_tokenize(vocab, trigger.word, false, true); + + if (ids.size() == 1) { + LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str()); + dst.grammar_trigger_tokens.push_back(ids[0]); + dst.preserved_tokens.insert(ids[0]); + continue; + } + LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str()); + dst.grammar_trigger_words.push_back(trigger); + } + + for (const auto & preserved : src.preserved_tokens) { + auto ids = common_tokenize(vocab, preserved, false, true); + if (ids.size() == 1) { + LOG_DBG("Preserved token: %d\n", ids[0]); + dst.preserved_tokens.insert(ids[0]); + + } else { + // This may happen when using a tool call style meant for a model + // with special tokens to preserve on a model without said tokens. + LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", + preserved.c_str()); + } + } +} + std::string common_chat_apply_template( const common_chat_templates & tmpl, const std::vector & msgs, bool add_ass, bool use_jinja, - const common_params_tools & tools) + const common_params_tools & tools, + common_chat_sampling_updater * update_sparams) { const auto & tmpl_selected = tools.tools() && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default; @@ -1865,7 +1904,11 @@ std::string common_chat_apply_template( inputs.messages = messages; inputs.add_generation_prompt = add_ass; - return common_chat_params_init(tmpl_selected, inputs).prompt; + auto chat_params = common_chat_params_init(tmpl_selected, inputs); + if (update_sparams) { + copy_chat_params(chat_params, update_sparams); + } + return chat_params.prompt; } int alloc_size = 0; @@ -1903,11 +1946,12 @@ std::string common_chat_format_single( const common_chat_msg & new_msg, bool add_ass, bool use_jinja, - const common_params_tools & tools) + const common_params_tools & tools, + common_chat_sampling_updater * update_sparams) { std::ostringstream ss; auto fmt_past_msg = past_msg.empty() ? "" - : common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools); + : common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools, update_sparams); std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version @@ -1916,7 +1960,7 @@ std::string common_chat_format_single( }; // format chat with new_msg chat_new.push_back(new_msg); - auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools); + auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools, update_sparams); // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); diff --git a/common/common.h b/common/common.h index c3f4f66ee..f7bdb4e69 100644 --- a/common/common.h +++ b/common/common.h @@ -671,6 +671,11 @@ struct common_chat_templates { std::unique_ptr template_tool_use; }; +struct common_chat_sampling_updater { + common_params_sampling * sparams; + const llama_vocab * vocab; +}; + // CPP wrapper for llama_chat_apply_template // If the built-in template is not supported, we default to chatml // If the custom "tmpl" is not supported, we throw an error @@ -679,7 +684,8 @@ std::string common_chat_apply_template( const std::vector & chat, bool add_ass, bool use_jinja, - const common_params_tools & tools = common_params_tools()); + const common_params_tools & tools = common_params_tools(), + common_chat_sampling_updater * update_sparams = nullptr); // Format single message, while taking into account the position of that message in chat history std::string common_chat_format_single( @@ -688,7 +694,8 @@ std::string common_chat_format_single( const common_chat_msg & new_msg, bool add_ass, bool use_jinja, - const common_params_tools & tools = common_params_tools()); + const common_params_tools & tools = common_params_tools(), + common_chat_sampling_updater * update_sparams = nullptr); // Returns an example of formatted chat std::string common_chat_format_example( diff --git a/examples/main/main.cpp b/examples/main/main.cpp index d562522a4..0ea614bcd 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -263,14 +263,18 @@ int main(int argc, char ** argv) { std::vector embd_inp; - auto chat_add_and_format = [&chat_msgs, &chat_templates]( + auto chat_add_and_format = [&chat_msgs, &chat_templates, &sparams, vocab]( const std::string & role, const std::string & content, const common_params_tools & tools = common_params_tools()) { + bool add_ass = (role == "user"); + common_chat_msg new_msg{role, content, {}}; - auto formatted = common_chat_format_single(chat_templates, chat_msgs, - new_msg, role == "user", g_params->use_jinja, tools); + common_chat_sampling_updater updater{&sparams, vocab}; + auto formatted = + common_chat_format_single(chat_templates, chat_msgs, new_msg, add_ass, g_params->use_jinja, + tools, &updater); chat_msgs.push_back({role, content, {}}); LOG_DBG("formatted: '%s'\n", formatted.c_str());