Copy sampler parameters from chat template
This commit is contained in:
parent
a30111bef3
commit
a726adaef7
3 changed files with 65 additions and 10 deletions
|
@ -1830,12 +1830,51 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
||||||
return res >= 0;
|
return res >= 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void copy_chat_params(const common_chat_params & src, common_chat_sampling_updater * update_sparams)
|
||||||
|
{
|
||||||
|
GGML_ASSERT(update_sparams && update_sparams->sparams && update_sparams->vocab);
|
||||||
|
|
||||||
|
auto & dst = *update_sparams->sparams;
|
||||||
|
auto vocab = update_sparams->vocab;
|
||||||
|
|
||||||
|
dst.grammar = src.grammar;
|
||||||
|
dst.grammar_lazy = src.grammar_lazy;
|
||||||
|
|
||||||
|
for (const auto & trigger : src.grammar_triggers) {
|
||||||
|
auto ids = common_tokenize(vocab, trigger.word, false, true);
|
||||||
|
|
||||||
|
if (ids.size() == 1) {
|
||||||
|
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
|
||||||
|
dst.grammar_trigger_tokens.push_back(ids[0]);
|
||||||
|
dst.preserved_tokens.insert(ids[0]);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
|
||||||
|
dst.grammar_trigger_words.push_back(trigger);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto & preserved : src.preserved_tokens) {
|
||||||
|
auto ids = common_tokenize(vocab, preserved, false, true);
|
||||||
|
if (ids.size() == 1) {
|
||||||
|
LOG_DBG("Preserved token: %d\n", ids[0]);
|
||||||
|
dst.preserved_tokens.insert(ids[0]);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// This may happen when using a tool call style meant for a model
|
||||||
|
// with special tokens to preserve on a model without said tokens.
|
||||||
|
LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n",
|
||||||
|
preserved.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::string common_chat_apply_template(
|
std::string common_chat_apply_template(
|
||||||
const common_chat_templates & tmpl,
|
const common_chat_templates & tmpl,
|
||||||
const std::vector<common_chat_msg> & msgs,
|
const std::vector<common_chat_msg> & msgs,
|
||||||
bool add_ass,
|
bool add_ass,
|
||||||
bool use_jinja,
|
bool use_jinja,
|
||||||
const common_params_tools & tools)
|
const common_params_tools & tools,
|
||||||
|
common_chat_sampling_updater * update_sparams)
|
||||||
{
|
{
|
||||||
const auto & tmpl_selected =
|
const auto & tmpl_selected =
|
||||||
tools.tools() && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default;
|
tools.tools() && tmpl.template_tool_use ? *tmpl.template_tool_use : *tmpl.template_default;
|
||||||
|
@ -1865,7 +1904,11 @@ std::string common_chat_apply_template(
|
||||||
|
|
||||||
inputs.messages = messages;
|
inputs.messages = messages;
|
||||||
inputs.add_generation_prompt = add_ass;
|
inputs.add_generation_prompt = add_ass;
|
||||||
return common_chat_params_init(tmpl_selected, inputs).prompt;
|
auto chat_params = common_chat_params_init(tmpl_selected, inputs);
|
||||||
|
if (update_sparams) {
|
||||||
|
copy_chat_params(chat_params, update_sparams);
|
||||||
|
}
|
||||||
|
return chat_params.prompt;
|
||||||
}
|
}
|
||||||
|
|
||||||
int alloc_size = 0;
|
int alloc_size = 0;
|
||||||
|
@ -1903,11 +1946,12 @@ std::string common_chat_format_single(
|
||||||
const common_chat_msg & new_msg,
|
const common_chat_msg & new_msg,
|
||||||
bool add_ass,
|
bool add_ass,
|
||||||
bool use_jinja,
|
bool use_jinja,
|
||||||
const common_params_tools & tools)
|
const common_params_tools & tools,
|
||||||
|
common_chat_sampling_updater * update_sparams)
|
||||||
{
|
{
|
||||||
std::ostringstream ss;
|
std::ostringstream ss;
|
||||||
auto fmt_past_msg = past_msg.empty() ? ""
|
auto fmt_past_msg = past_msg.empty() ? ""
|
||||||
: common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools);
|
: common_chat_apply_template(tmpl, past_msg, false, use_jinja, tools, update_sparams);
|
||||||
|
|
||||||
std::vector<common_chat_msg> chat_new(past_msg);
|
std::vector<common_chat_msg> chat_new(past_msg);
|
||||||
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
||||||
|
@ -1916,7 +1960,7 @@ std::string common_chat_format_single(
|
||||||
};
|
};
|
||||||
// format chat with new_msg
|
// format chat with new_msg
|
||||||
chat_new.push_back(new_msg);
|
chat_new.push_back(new_msg);
|
||||||
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools);
|
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja, tools, update_sparams);
|
||||||
// get the diff part
|
// get the diff part
|
||||||
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
||||||
return ss.str();
|
return ss.str();
|
||||||
|
|
|
@ -671,6 +671,11 @@ struct common_chat_templates {
|
||||||
std::unique_ptr<common_chat_template> template_tool_use;
|
std::unique_ptr<common_chat_template> template_tool_use;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct common_chat_sampling_updater {
|
||||||
|
common_params_sampling * sparams;
|
||||||
|
const llama_vocab * vocab;
|
||||||
|
};
|
||||||
|
|
||||||
// CPP wrapper for llama_chat_apply_template
|
// CPP wrapper for llama_chat_apply_template
|
||||||
// If the built-in template is not supported, we default to chatml
|
// If the built-in template is not supported, we default to chatml
|
||||||
// If the custom "tmpl" is not supported, we throw an error
|
// If the custom "tmpl" is not supported, we throw an error
|
||||||
|
@ -679,7 +684,8 @@ std::string common_chat_apply_template(
|
||||||
const std::vector<common_chat_msg> & chat,
|
const std::vector<common_chat_msg> & chat,
|
||||||
bool add_ass,
|
bool add_ass,
|
||||||
bool use_jinja,
|
bool use_jinja,
|
||||||
const common_params_tools & tools = common_params_tools());
|
const common_params_tools & tools = common_params_tools(),
|
||||||
|
common_chat_sampling_updater * update_sparams = nullptr);
|
||||||
|
|
||||||
// Format single message, while taking into account the position of that message in chat history
|
// Format single message, while taking into account the position of that message in chat history
|
||||||
std::string common_chat_format_single(
|
std::string common_chat_format_single(
|
||||||
|
@ -688,7 +694,8 @@ std::string common_chat_format_single(
|
||||||
const common_chat_msg & new_msg,
|
const common_chat_msg & new_msg,
|
||||||
bool add_ass,
|
bool add_ass,
|
||||||
bool use_jinja,
|
bool use_jinja,
|
||||||
const common_params_tools & tools = common_params_tools());
|
const common_params_tools & tools = common_params_tools(),
|
||||||
|
common_chat_sampling_updater * update_sparams = nullptr);
|
||||||
|
|
||||||
// Returns an example of formatted chat
|
// Returns an example of formatted chat
|
||||||
std::string common_chat_format_example(
|
std::string common_chat_format_example(
|
||||||
|
|
|
@ -263,14 +263,18 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
std::vector<llama_token> embd_inp;
|
std::vector<llama_token> embd_inp;
|
||||||
|
|
||||||
auto chat_add_and_format = [&chat_msgs, &chat_templates](
|
auto chat_add_and_format = [&chat_msgs, &chat_templates, &sparams, vocab](
|
||||||
const std::string & role, const std::string & content,
|
const std::string & role, const std::string & content,
|
||||||
const common_params_tools & tools = common_params_tools())
|
const common_params_tools & tools = common_params_tools())
|
||||||
{
|
{
|
||||||
|
bool add_ass = (role == "user");
|
||||||
|
|
||||||
common_chat_msg new_msg{role, content, {}};
|
common_chat_msg new_msg{role, content, {}};
|
||||||
|
|
||||||
auto formatted = common_chat_format_single(chat_templates, chat_msgs,
|
common_chat_sampling_updater updater{&sparams, vocab};
|
||||||
new_msg, role == "user", g_params->use_jinja, tools);
|
auto formatted =
|
||||||
|
common_chat_format_single(chat_templates, chat_msgs, new_msg, add_ass, g_params->use_jinja,
|
||||||
|
tools, &updater);
|
||||||
|
|
||||||
chat_msgs.push_back({role, content, {}});
|
chat_msgs.push_back({role, content, {}});
|
||||||
LOG_DBG("formatted: '%s'\n", formatted.c_str());
|
LOG_DBG("formatted: '%s'\n", formatted.c_str());
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue