diff --git a/common/chaton.hpp b/common/chaton.hpp index a9af7a565..c62ab146e 100644 --- a/common/chaton.hpp +++ b/common/chaton.hpp @@ -50,3 +50,20 @@ inline void chaton_meta_dump() { } LOG_TEELN("\n\nINFO:%s:ChatOn Meta\n%s", __func__, conMeta.dump(4).c_str()); } + +inline std::string chaton_tmpl_apply(const std::string &tmpl, const std::string &role, const std::string &content) { + std::stringstream ss; + ss << conMeta[tmpl]["global"]["prefix"]; + ss << conMeta[tmpl][role]["prefix"] << content << conMeta[tmpl][role]["suffix"]; + ss << conMeta[tmpl]["global"]["suffix"]; + std::string taggedStr = ss.str(); + return taggedStr; +} + +inline std::string chaton_tmpl_role_part(const std::string &tmpl, const std::string &role, const std::string &part) { + return conMeta[tmpl][role][part]; +} + +inline std::string chaton_tmpl_part(const std::string &tmpl, const std::string &part) { + return conMeta[tmpl][part]; +} diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3432cb9f1..f58888e99 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -256,11 +256,14 @@ int main(int argc, char ** argv) { std::vector embd_inp; - if (params.interactive_first || params.instruct || params.chatml || !params.prompt.empty() || session_tokens.empty()) { + if (params.interactive_first || params.instruct || params.chatml || params.chaton || !params.prompt.empty() || session_tokens.empty()) { LOG("tokenize the prompt\n"); if (params.chatml) { params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>"; } + if (params.chaton) { + params.prompt = chaton_tmpl_apply(params.chaton_template_id, "system", params.prompt); + } embd_inp = ::llama_tokenize(ctx, params.prompt, true, true); } else { LOG("use session tokens\n"); @@ -338,7 +341,7 @@ int main(int argc, char ** argv) { } // number of tokens to keep when resetting context - if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct || params.chatml) { + if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct || params.chatml || params.chaton) { params.n_keep = (int)embd_inp.size(); } else { params.n_keep += add_bos; // always keep the BOS token @@ -369,6 +372,16 @@ int main(int argc, char ** argv) { params.antiprompt.emplace_back("<|im_start|>user\n"); } + // chaton mode + const auto chaton_assitant_prefix = ::llama_tokenize(ctx, chaton_tmpl_role_part(params.chaton_template_id, "assistant", "prefix"), false, true); + if (params.chaton) { + params.interactive = true; // may remove later, by requiring user to explicitly request interactive mode + params.interactive_first = true; + params.input_prefix = chaton_tmpl_role_part(params.chaton_template_id, "user", "prefix"); + params.input_suffix = chaton_tmpl_role_part(params.chaton_template_id, "user", "suffix"); + params.antiprompt.emplace_back(chaton_tmpl_part(params.chaton_template_id, "reverse-prompt")); + } + // enable interactive mode if interactive start is specified if (params.interactive_first) { params.interactive = true; @@ -822,7 +835,7 @@ int main(int argc, char ** argv) { if (n_past > 0 && is_interacting) { LOG("waiting for user input\n"); - if (params.instruct || params.chatml) { + if (params.instruct || params.chatml || params.chaton) { printf("\n> "); } @@ -901,6 +914,11 @@ int main(int argc, char ** argv) { LOG("inserting chatml suffix\n"); embd_inp.insert(embd_inp.end(), cml_sfx.begin(), cml_sfx.end()); } + // chaton mode: insert assistant prefix + if (params.chaton) { + LOG("inserting chaton assistant prefix\n"); + embd_inp.insert(embd_inp.end(), chaton_assitant_prefix.begin(), chaton_assitant_prefix.end()); + } for (size_t i = original_size; i < embd_inp.size(); ++i) { const llama_token token = embd_inp[i];