wip: chat cli

This commit is contained in:
Xuan Son Nguyen 2025-01-12 13:16:37 +01:00
parent c05e8c9934
commit 95e0afb977
3 changed files with 243 additions and 22 deletions

View file

@ -1,5 +1,5 @@
set(TARGET llama-cli)
add_executable(${TARGET} main.cpp)
add_executable(${TARGET} main.cpp chat.hpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

222
examples/main/chat.hpp Normal file
View file

@ -0,0 +1,222 @@
#include "arg.h"
#include "common.h"
#include "console.h"
#include "log.h"
#include "sampling.h"
#include "llama.h"
#include <fstream>
struct llama_cli_chat {
struct llama_context * ctx;
const struct llama_model * model;
struct common_sampler * smpl;
struct common_params params;
bool interacting = false;
std::vector<common_chat_msg> chat_msgs;
std::ostringstream pending_input;
struct llama_batch batch;
llama_tokens cache_tokens;
int n_past = 0;
llama_cli_chat(
struct common_params & params,
struct llama_context * ctx,
struct common_sampler * smpl) : ctx(ctx), smpl(smpl), params(params) {
model = llama_get_model(ctx);
batch = llama_batch_init(params.n_batch, 0, 1);
}
void decode(llama_tokens & eval_tokens, bool is_generating) {
if (is_generating) {
GGML_ASSERT(eval_tokens.size() == 1);
} else {
n_past = common_lcp(cache_tokens, eval_tokens);
// in case we do a re-generation, we need to prevent eval_tokens from being empty
if ((int) eval_tokens.size() == n_past) {
n_past--;
}
if (n_past > 0) {
eval_tokens.erase(eval_tokens.begin(), eval_tokens.begin() + n_past);
cache_tokens.erase(cache_tokens.begin() + n_past, cache_tokens.end());
LOG_DBG("remove from cache [%d, inf)\n", n_past);
LOG_DBG("in cache: %s\n", common_detokenize(ctx, cache_tokens, true).c_str());
LOG_DBG("to decode %d tokens\n", (int) eval_tokens.size());
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
}
}
// decode
for (size_t i = 0; i < eval_tokens.size(); i += params.n_batch) {
if (interacting) {
break;
}
common_batch_clear(batch);
for (int j = 0; j < params.n_batch && i + j < eval_tokens.size(); ++j) {
n_past++;
bool is_last_token = i + j == eval_tokens.size() - 1;
common_batch_add(batch, eval_tokens[i + j], n_past, {0}, is_last_token);
}
if (llama_decode(ctx, batch)) {
GGML_ABORT("failed to decode\n");
}
}
// update cache tokens
if (is_generating) {
cache_tokens.push_back(eval_tokens[0]);
} else {
cache_tokens.insert(cache_tokens.end(), eval_tokens.begin(), eval_tokens.end());
}
}
[[noreturn]] void run() {
while (true) {
interacting = true;
LOG("\n> ");
// color user input only
console::set_display(console::user_input);
std::string line;
bool another_line = true;
bool continue_input = false;
do {
another_line = console::readline(line, params.multiline_input);
if (handle_command(line, continue_input)) {
continue; // do not add this line to pending_input
}
pending_input << line;
} while (another_line);
if (continue_input) {
continue;
}
if (pending_input.tellp() == 0) {
LOG_DBG("empty line, passing control back\n");
continue;
}
// done taking input, reset color
console::set_display(console::reset);
interacting = false;
// add message and format chat
if (!chat_msgs.empty() && chat_msgs.back().role == "user") {
chat_msgs.pop_back();
}
chat_msgs.push_back({"user", string_strip(pending_input.str())});
pending_input.str(""); // clear
auto formatted = common_chat_apply_template(model, params.chat_template, chat_msgs, true);
// tokenize the new chat history and decode
llama_tokens prompt_tokens = common_tokenize(ctx, formatted, true, true);
decode(prompt_tokens, false);
// generate response
llama_token new_token_id = LLAMA_TOKEN_NULL;
llama_tokens generated_tokens;
common_sampler_reset(smpl);
while (true) {
if (interacting) {
break;
}
// sample the next token
new_token_id = common_sampler_sample(smpl, ctx, -1);
// is it an end of generation?
if (llama_token_is_eog(model, new_token_id)) {
break;
}
// print the token, then decode it
printf("%s", common_token_to_piece(ctx, new_token_id, params.special).c_str());
fflush(stdout);
generated_tokens.push_back(new_token_id);
llama_tokens new_tok = {new_token_id};
decode(new_tok, true);
}
// add the generated tokens to the chat history
std::string response = common_detokenize(ctx, generated_tokens, true);
chat_msgs.push_back({"assistant", response});
// print a new line if needed
if (!response.empty() && response.back() != '\n') {
printf("\n");
}
}
}
void interrupt() {
if (interacting) {
// exit
printf("\n");
console::cleanup();
common_perf_print(ctx, smpl);
common_log_pause(common_log_main());
exit(0);
}
interacting = true;
}
bool handle_command(std::string & inp, bool & continue_input) {
if (inp.empty() || inp[0] != '/') {
return false; // not a command
}
auto parts = string_split<std::string>(string_strip(inp), ' ');
std::string & cmd = parts[0];
if (cmd == "/help") {
LOG("TODO\n");
continue_input = true;
} else if (cmd == "/history") {
display_history();
continue_input = true;
} else if (cmd == "/regen") {
if (chat_msgs.empty()) {
LOG_ERR("no chat history to regenerate\n");
continue_input = true;
return true;
}
if (chat_msgs.back().role == "assistant") {
chat_msgs.pop_back();
}
if (chat_msgs.back().role == "user") {
pending_input.str(""); // clear
pending_input << chat_msgs.back().content;
chat_msgs.pop_back();
}
continue_input = false;
} else if (cmd == "/readfile") {
const std::string filename = parts[1];
LOG_DBG("reading file: '%s'\n", filename.c_str());
std::ifstream text_file(filename);
if (!text_file) {
LOG("failed to open file '%s'\n", filename.c_str());
} else {
pending_input << text_file.rdbuf() << "\n\n";
LOG("read %zu characters from file\n", (size_t) text_file.tellg());
}
continue_input = true;
} else {
LOG_ERR("unknown command: %s\n", cmd.c_str());
continue_input = true;
}
return true;
}
void display_history() {
for (const auto & msg : chat_msgs) {
LOG("%s: %s\n\n", msg.role.c_str(), msg.content.c_str());
}
}
~llama_cli_chat() {
llama_batch_free(batch);
}
};

View file

@ -4,6 +4,7 @@
#include "log.h"
#include "sampling.h"
#include "llama.h"
#include "chat.hpp"
#include <cassert>
#include <cstdio>
@ -35,6 +36,7 @@ static llama_context ** g_ctx;
static llama_model ** g_model;
static common_sampler ** g_smpl;
static common_params * g_params;
static llama_cli_chat * g_chat;
static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss;
static std::vector<llama_token> * g_output_tokens;
@ -65,7 +67,9 @@ static bool file_is_empty(const std::string & path) {
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
static void sigint_handler(int signo) {
if (signo == SIGINT) {
if (!is_interacting && g_params->interactive) {
if (g_chat) {
g_chat->interrupt();
} else if (!is_interacting && g_params->interactive) {
is_interacting = true;
need_insert_eot = true;
} else {
@ -83,14 +87,6 @@ static void sigint_handler(int signo) {
}
#endif
static std::string chat_add_and_format(struct llama_model * model, std::vector<common_chat_msg> & chat_msgs, const std::string & role, const std::string & content) {
common_chat_msg new_msg{role, content};
auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
chat_msgs.push_back({role, content});
LOG_DBG("formatted: '%s'\n", formatted.c_str());
return formatted;
}
int main(int argc, char ** argv) {
common_params params;
g_params = &params;
@ -203,6 +199,12 @@ int main(int argc, char ** argv) {
LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx);
}
// switch on conversation mode if chat template is present
if (!params.chat_template.empty() || !common_get_builtin_chat_template(model).empty()) {
LOG("%s: using chat mode\n", __func__);
params.conversation = true;
}
// print chat template example in conversation mode
if (params.conversation) {
if (params.enable_chat_template) {
@ -251,18 +253,15 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd_inp;
{
auto prompt = (params.conversation && params.enable_chat_template && !params.prompt.empty())
? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
: params.prompt;
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
LOG_DBG("tokenize the prompt\n");
embd_inp = common_tokenize(ctx, prompt, true, true);
embd_inp = common_tokenize(ctx, params.prompt, true, true);
} else {
LOG_DBG("use session tokens\n");
embd_inp = session_tokens;
}
LOG_DBG("prompt: \"%s\"\n", prompt.c_str());
LOG_DBG("prompt: \"%s\"\n", params.prompt.c_str());
LOG_DBG("tokens: %s\n", string_from(ctx, embd_inp).c_str());
}
@ -420,6 +419,12 @@ int main(int argc, char ** argv) {
LOG_INF("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
if (params.conversation) {
llama_cli_chat chat(params, ctx, smpl);
g_chat = &chat;
chat.run();
}
// group-attention state
// number of grouped KV tokens so far (used only if params.grp_attn_n > 1)
int ga_i = 0;
@ -752,10 +757,6 @@ int main(int argc, char ** argv) {
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
is_antiprompt = true;
}
if (params.enable_chat_template) {
chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str());
}
is_interacting = true;
LOG("\n");
}
@ -818,9 +819,7 @@ int main(int argc, char ** argv) {
}
bool format_chat = params.conversation && params.enable_chat_template;
std::string user_inp = format_chat
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
: std::move(buffer);
std::string user_inp = std::move(buffer);
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true);
const auto line_inp = common_tokenize(ctx, user_inp, false, format_chat);