Copy sampler parameters from chat template
This commit is contained in:
parent
a30111bef3
commit
a726adaef7
3 changed files with 65 additions and 10 deletions
|
@ -1830,12 +1830,51 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
|||
return res >= 0;
|
||||
}
|
||||
|
||||
static void copy_chat_params(const common_chat_params & src, common_chat_sampling_updater * update_sparams)
|
||||
{
|
||||
GGML_ASSERT(update_sparams && update_sparams->sparams && update_sparams->vocab);
|
||||
|
||||
auto & dst = *update_sparams->sparams;
|
||||
auto vocab = update_sparams->vocab;
|
||||
|
||||
dst.grammar = src.grammar;
|
||||
dst.grammar_lazy = src.grammar_lazy;
|
||||
|
||||
for (const auto & trigger : src.grammar_triggers) {
|
||||
auto ids = common_tokenize(vocab, trigger.word, false, true);
|
||||
|
||||
if (ids.size() == 1) {
|
||||
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
|
||||
dst.grammar_trigger_tokens.push_back(ids[0]);
|
||||
dst.preserved_tokens.insert(ids[0]);
|
||||
continue;
|
||||
}
|
||||
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
|
||||
dst.grammar_trigger_words.push_back(trigger);
|
||||
}
|
||||
|
||||
for (const auto & preserved : src.preserved_tokens) {
|
||||
auto ids = common_tokenize(vocab, preserved, false, true);
|
||||
if (ids.size() == 1) {
|
||||
LOG_DBG("Preserved token: %d\n", ids[0]);
|
||||
dst.preserved_tokens.insert(ids[0]);
|
||||
|
||||
} else {
|
||||
// This may happen when using a tool call style meant for a model
|
||||
// with special tokens to preserve on a model without said tokens.
|
||||
LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n",
|
||||
preserved.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string common_chat_apply_template(
|
||||
const common_chat_templates & tmpl,
|
||||
const std::vector<common_chat_msg> & msgs,
|
||||
bool add_ass,
|
||||
bool use_jinja,
|
||||
const common_params_tools & tools)
|
||||
const common_params_tools & tools,
|
||||
common_chat_sampling_updater * update_sparams)
|
||||
{
|
||||
const auto & tmpl_selected =
|
||||
tools.tools() && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default;
|
||||
|
@ -1865,7 +1904,11 @@ std::string common_chat_apply_template(
|
|||
|
||||
inputs.messages = messages;
|
||||
inputs.add_generation_prompt = add_ass;
|
||||
return common_chat_params_init(tmpl_selected, inputs).prompt;
|
||||
auto chat_params = common_chat_params_init(tmpl_selected, inputs);
|
||||
if (update_sparams) {
|
||||
copy_chat_params(chat_params, update_sparams);
|
||||
}
|
||||
return chat_params.prompt;
|
||||
}
|
||||
|
||||
int alloc_size = 0;
|
||||
|
@ -1903,11 +1946,12 @@ std::string common_chat_format_single(
|
|||
const common_chat_msg & new_msg,
|
||||
bool add_ass,
|
||||
bool use_jinja,
|
||||
const common_params_tools & tools)
|
||||
const common_params_tools & tools,
|
||||
common_chat_sampling_updater * update_sparams)
|
||||
{
|
||||
std::ostringstream ss;
|
||||
auto fmt_past_msg = past_msg.empty() ? ""
|
||||
: common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools);
|
||||
: common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools, update_sparams);
|
||||
|
||||
std::vector<common_chat_msg> chat_new(past_msg);
|
||||
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
||||
|
@ -1916,7 +1960,7 @@ std::string common_chat_format_single(
|
|||
};
|
||||
// format chat with new_msg
|
||||
chat_new.push_back(new_msg);
|
||||
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools);
|
||||
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools, update_sparams);
|
||||
// get the diff part
|
||||
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
||||
return ss.str();
|
||||
|
|
|
@ -671,6 +671,11 @@ struct common_chat_templates {
|
|||
std::unique_ptr<common_chat_template> template_tool_use;
|
||||
};
|
||||
|
||||
struct common_chat_sampling_updater {
|
||||
common_params_sampling * sparams;
|
||||
const llama_vocab * vocab;
|
||||
};
|
||||
|
||||
// CPP wrapper for llama_chat_apply_template
|
||||
// If the built-in template is not supported, we default to chatml
|
||||
// If the custom "tmpl" is not supported, we throw an error
|
||||
|
@ -679,7 +684,8 @@ std::string common_chat_apply_template(
|
|||
const std::vector<common_chat_msg> & chat,
|
||||
bool add_ass,
|
||||
bool use_jinja,
|
||||
const common_params_tools & tools = common_params_tools());
|
||||
const common_params_tools & tools = common_params_tools(),
|
||||
common_chat_sampling_updater * update_sparams = nullptr);
|
||||
|
||||
// Format single message, while taking into account the position of that message in chat history
|
||||
std::string common_chat_format_single(
|
||||
|
@ -688,7 +694,8 @@ std::string common_chat_format_single(
|
|||
const common_chat_msg & new_msg,
|
||||
bool add_ass,
|
||||
bool use_jinja,
|
||||
const common_params_tools & tools = common_params_tools());
|
||||
const common_params_tools & tools = common_params_tools(),
|
||||
common_chat_sampling_updater * update_sparams = nullptr);
|
||||
|
||||
// Returns an example of formatted chat
|
||||
std::string common_chat_format_example(
|
||||
|
|
|
@ -263,14 +263,18 @@ int main(int argc, char ** argv) {
|
|||
|
||||
std::vector<llama_token> embd_inp;
|
||||
|
||||
auto chat_add_and_format = [&chat_msgs, &chat_templates](
|
||||
auto chat_add_and_format = [&chat_msgs, &chat_templates, &sparams, vocab](
|
||||
const std::string & role, const std::string & content,
|
||||
const common_params_tools & tools = common_params_tools())
|
||||
{
|
||||
bool add_ass = (role == "user");
|
||||
|
||||
common_chat_msg new_msg{role, content, {}};
|
||||
|
||||
auto formatted = common_chat_format_single(chat_templates, chat_msgs,
|
||||
new_msg, role == "user", g_params->use_jinja, tools);
|
||||
common_chat_sampling_updater updater{&sparams, vocab};
|
||||
auto formatted =
|
||||
common_chat_format_single(chat_templates, chat_msgs, new_msg, add_ass, g_params->use_jinja,
|
||||
tools, &updater);
|
||||
|
||||
chat_msgs.push_back({role, content, {}});
|
||||
LOG_DBG("formatted: '%s'\n", formatted.c_str());
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue