diff --git a/common/common.cpp b/common/common.cpp index 6a00d25be..4b65f1668 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1014,16 +1014,19 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } if (arg == "--in-prefix-bos") { params.input_prefix_bos = true; + params.enable_chat_template = false; return true; } if (arg == "--in-prefix") { CHECK_ARG params.input_prefix = argv[i]; + params.enable_chat_template = false; return true; } if (arg == "--in-suffix") { CHECK_ARG params.input_suffix = argv[i]; + params.enable_chat_template = false; return true; } if (arg == "--spm-infill") { @@ -1170,13 +1173,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } if (arg == "--chat-template") { CHECK_ARG - if (!llama_chat_verify_template(argv[i])) { - fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]); + std::string value(argv[i]); + if (!llama_chat_verify_template(value)) { + fprintf(stderr, "error: the supplied chat template is not supported: %s\n", value.c_str()); fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n"); invalid_param = true; return true; } - params.chat_template = argv[i]; + params.chat_template = value; return true; } if (arg == "--slot-prompt-similarity" || arg == "-sps") { @@ -1406,7 +1410,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param "halt generation at PROMPT, return control in interactive mode\n" "can be specified more than once for multiple prompts" }); options.push_back({ "main", "-sp, --special", "special tokens output enabled (default: %s)", params.special ? "true" : "false" }); - options.push_back({ "main", "-cnv, --conversation", "run in conversation mode (does not print special tokens and suffix/prefix) (default: %s)", params.conversation ? "true" : "false" }); + options.push_back({ "main", "-cnv, --conversation", "run in conversation mode (does not print special tokens and suffix/prefix, use default chat template) (default: %s)", params.conversation ? "true" : "false" }); options.push_back({ "main infill", "-i, --interactive", "run in interactive mode (default: %s)", params.interactive ? "true" : "false" }); options.push_back({ "main infill", "-if, --interactive-first", "run in interactive mode and wait for input right away (default: %s)", params.interactive_first ? "true" : "false" }); options.push_back({ "main infill", "-mli, --multiline-input", "allows you to write or paste multiple lines without ending each in '\\'" }); @@ -2668,12 +2672,19 @@ std::string llama_chat_format_single(const struct llama_model * model, const std::vector & past_msg, const llama_chat_msg & new_msg, bool add_ass) { + std::ostringstream ss; auto fmt_past_msg = llama_chat_apply_template(model, tmpl, past_msg, false); std::vector chat_new(past_msg); + // if the past_msg ends with a newline, we must preserve it in the formatted version + if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { + ss << "\n"; + }; + // format chat with new_msg chat_new.push_back(new_msg); auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass); - auto formatted = fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); - return formatted; + // get the diff part + ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); + return ss.str(); } std::string llama_chat_format_example(const struct llama_model * model, diff --git a/common/common.h b/common/common.h index d6cb814b9..627b7ed85 100644 --- a/common/common.h +++ b/common/common.h @@ -200,6 +200,7 @@ struct gpt_params { std::string public_path = ""; std::string chat_template = ""; std::string system_prompt = ""; + bool enable_chat_template = true; std::vector api_keys; diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index b154038b2..03f536910 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -142,9 +142,9 @@ int main(void) { std::cout << "fmt_single(" << tmpl << ")\n" << output << "\n-------------------------\n"; return output; }; - assert(fmt_single("chatml") == "<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n"); + assert(fmt_single("chatml") == "\n<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n"); assert(fmt_single("llama2") == "[INST] How are you [/INST]"); - assert(fmt_single("gemma") == "user\nHow are you\nmodel\n"); + assert(fmt_single("gemma") == "\nuser\nHow are you\nmodel\n"); assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); return 0;