Main:Update to support chaton mode

Glanced through existing interactive and chatml flow, to incorporate
this flow. Need to look deeper later.

NOTE: Till this point is reapplying of my initial go at chaton, by
simplifying the amount of change done to existing code, a bitmore.
This commit is contained in:
HanishKVC 2024-04-20 18:40:55 +05:30
parent efbcdc1caf
commit 0a8797b28e

View file

@ -1,4 +1,5 @@
#include "common.h"
#include "chaton.hpp"
#include "console.h"
#include "llama.h"
@ -251,11 +252,14 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd_inp;
if (params.interactive_first || params.instruct || params.chatml || !params.prompt.empty() || session_tokens.empty()) {
LOG("tokenize the prompt\n");
if (params.interactive_first || params.instruct || params.chatml || params.chaton || !params.prompt.empty() || session_tokens.empty()) {
LOG("tokenize the prompt: %s\n", params.prompt.c_str());
if (params.chatml) {
params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>";
}
if (params.chaton) {
params.prompt = llama_chat_apply_template_simple(params.chaton_template_id, "system", params.prompt, false);
}
embd_inp = ::llama_tokenize(ctx, params.prompt, true, true);
} else {
LOG("use session tokens\n");
@ -333,7 +337,7 @@ int main(int argc, char ** argv) {
}
// number of tokens to keep when resetting context
if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct || params.chatml) {
if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct || params.chatml || params.chaton) {
params.n_keep = (int)embd_inp.size();
} else {
params.n_keep += add_bos; // always keep the BOS token
@ -363,6 +367,19 @@ int main(int argc, char ** argv) {
params.interactive_first = true;
params.antiprompt.emplace_back("<|im_start|>user\n");
}
// handle chaton mode, it adds on to any reverse prompt specified explicitly by the user
if (params.chaton) {
params.interactive_first = true;
std::vector<std::string> resp_ends = llama_chat_reverse_prompt(params.chaton_template_id);
if (resp_ends.size() == 0) {
LOG_TEELN("ERRR:%s:ChatOn:Unsupported ChatType:%s", __func__, params.chaton_template_id.c_str());
exit(1);
}
for (size_t i = 0; i < resp_ends.size(); i++)
{
params.antiprompt.emplace_back(resp_ends[i]);
}
}
// enable interactive mode if interactive start is specified
if (params.interactive_first) {
@ -817,7 +834,7 @@ int main(int argc, char ** argv) {
if (n_past > 0 && is_interacting) {
LOG("waiting for user input\n");
if (params.instruct || params.chatml) {
if (params.instruct || params.chatml || params.chaton) {
printf("\n> ");
}
@ -876,15 +893,23 @@ int main(int argc, char ** argv) {
process_escapes(buffer);
}
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
const auto line_inp = ::llama_tokenize(ctx, buffer, false, false);
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
std::vector<int> line_inp;
if (params.chaton) {
std::string f_chat = llama_chat_apply_template_simple(params.chaton_template_id, "user", buffer.c_str(), true);
line_inp = ::llama_tokenize(ctx, f_chat, false, true);
LOG("formatted input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
} else {
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
line_inp = ::llama_tokenize(ctx, buffer, false, false);
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end());
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end());
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
}
// instruct mode: insert response suffix
if (params.instruct) {
@ -921,6 +946,7 @@ int main(int argc, char ** argv) {
}
// end of text token
// chaton expected to be used along with interactive argument, so not checking for chaton seperately
if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive || params.chatml)) {
LOG_TEE(" [end of text]\n");
break;