diff --git a/common/common.cpp b/common/common.cpp index 467fb014e..a065478a7 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -915,6 +915,28 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.chatml = true; return true; } + if (arg == "--chaton-json") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.chaton_json = argv[i]; + return true; + } + if (arg == "--chaton-template-id") { + if (++i >= argc) { + invalid_param = true; + return true; + } + std::string got = argv[i]; + std::regex whitespaces(R"(\s+)"); + std::string trimmed = std::regex_replace(got, whitespaces, ""); + if (!trimmed.empty()) { + params.chaton_template_id = trimmed; + params.chaton = true; + } + return true; + } if (arg == "--infill") { params.infill = true; return true; @@ -1419,6 +1441,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --interactive-first run in interactive mode and wait for input right away\n"); printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n"); printf(" -cml, --chatml run in chatml mode (use with ChatML-compatible models)\n"); + printf(" --chaton-json specify the json file containing chat-handshake-template-standard(s)"); + printf(" --chaton-template-id specify the specific template standard to use from loaded json file"); printf(" --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); printf(" -r PROMPT, --reverse-prompt PROMPT\n"); printf(" halt generation at PROMPT, return control in interactive mode\n"); diff --git a/common/common.h b/common/common.h index 9252a4b63..f0709d38a 100644 --- a/common/common.h +++ b/common/common.h @@ -141,6 +141,9 @@ struct gpt_params { bool use_color = false; // use color to distinguish generations and inputs bool interactive = false; // interactive mode bool chatml = false; // chatml mode (used for models trained on chatml syntax) + bool chaton = false; // whether chaton is enabled or disabled + std::string chaton_json = ""; // name of the json file containing the chaton templates + std::string chaton_template_id = ""; // the specific chat-handshake-template-standard to use bool prompt_cache_all = false; // save user input and generations to prompt cache bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it