wip: chat cli
This commit is contained in:
parent
c05e8c9934
commit
95e0afb977
3 changed files with 243 additions and 22 deletions
|
@ -1,5 +1,5 @@
|
||||||
set(TARGET llama-cli)
|
set(TARGET llama-cli)
|
||||||
add_executable(${TARGET} main.cpp)
|
add_executable(${TARGET} main.cpp chat.hpp)
|
||||||
install(TARGETS ${TARGET} RUNTIME)
|
install(TARGETS ${TARGET} RUNTIME)
|
||||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||||
|
|
222
examples/main/chat.hpp
Normal file
222
examples/main/chat.hpp
Normal file
|
@ -0,0 +1,222 @@
|
||||||
|
#include "arg.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "console.h"
|
||||||
|
#include "log.h"
|
||||||
|
#include "sampling.h"
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
struct llama_cli_chat {
|
||||||
|
struct llama_context * ctx;
|
||||||
|
const struct llama_model * model;
|
||||||
|
struct common_sampler * smpl;
|
||||||
|
struct common_params params;
|
||||||
|
|
||||||
|
bool interacting = false;
|
||||||
|
std::vector<common_chat_msg> chat_msgs;
|
||||||
|
std::ostringstream pending_input;
|
||||||
|
|
||||||
|
struct llama_batch batch;
|
||||||
|
llama_tokens cache_tokens;
|
||||||
|
int n_past = 0;
|
||||||
|
|
||||||
|
llama_cli_chat(
|
||||||
|
struct common_params & params,
|
||||||
|
struct llama_context * ctx,
|
||||||
|
struct common_sampler * smpl) : ctx(ctx), smpl(smpl), params(params) {
|
||||||
|
model = llama_get_model(ctx);
|
||||||
|
batch = llama_batch_init(params.n_batch, 0, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void decode(llama_tokens & eval_tokens, bool is_generating) {
|
||||||
|
if (is_generating) {
|
||||||
|
GGML_ASSERT(eval_tokens.size() == 1);
|
||||||
|
} else {
|
||||||
|
n_past = common_lcp(cache_tokens, eval_tokens);
|
||||||
|
// in case we do a re-generation, we need to prevent eval_tokens from being empty
|
||||||
|
if ((int) eval_tokens.size() == n_past) {
|
||||||
|
n_past--;
|
||||||
|
}
|
||||||
|
if (n_past > 0) {
|
||||||
|
eval_tokens.erase(eval_tokens.begin(), eval_tokens.begin() + n_past);
|
||||||
|
cache_tokens.erase(cache_tokens.begin() + n_past, cache_tokens.end());
|
||||||
|
LOG_DBG("remove from cache [%d, inf)\n", n_past);
|
||||||
|
LOG_DBG("in cache: %s\n", common_detokenize(ctx, cache_tokens, true).c_str());
|
||||||
|
LOG_DBG("to decode %d tokens\n", (int) eval_tokens.size());
|
||||||
|
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// decode
|
||||||
|
for (size_t i = 0; i < eval_tokens.size(); i += params.n_batch) {
|
||||||
|
if (interacting) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_batch_clear(batch);
|
||||||
|
for (int j = 0; j < params.n_batch && i + j < eval_tokens.size(); ++j) {
|
||||||
|
n_past++;
|
||||||
|
bool is_last_token = i + j == eval_tokens.size() - 1;
|
||||||
|
common_batch_add(batch, eval_tokens[i + j], n_past, {0}, is_last_token);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llama_decode(ctx, batch)) {
|
||||||
|
GGML_ABORT("failed to decode\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// update cache tokens
|
||||||
|
if (is_generating) {
|
||||||
|
cache_tokens.push_back(eval_tokens[0]);
|
||||||
|
} else {
|
||||||
|
cache_tokens.insert(cache_tokens.end(), eval_tokens.begin(), eval_tokens.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[[noreturn]] void run() {
|
||||||
|
while (true) {
|
||||||
|
interacting = true;
|
||||||
|
LOG("\n> ");
|
||||||
|
|
||||||
|
// color user input only
|
||||||
|
console::set_display(console::user_input);
|
||||||
|
std::string line;
|
||||||
|
bool another_line = true;
|
||||||
|
bool continue_input = false;
|
||||||
|
do {
|
||||||
|
another_line = console::readline(line, params.multiline_input);
|
||||||
|
if (handle_command(line, continue_input)) {
|
||||||
|
continue; // do not add this line to pending_input
|
||||||
|
}
|
||||||
|
pending_input << line;
|
||||||
|
} while (another_line);
|
||||||
|
|
||||||
|
if (continue_input) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pending_input.tellp() == 0) {
|
||||||
|
LOG_DBG("empty line, passing control back\n");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// done taking input, reset color
|
||||||
|
console::set_display(console::reset);
|
||||||
|
interacting = false;
|
||||||
|
|
||||||
|
// add message and format chat
|
||||||
|
if (!chat_msgs.empty() && chat_msgs.back().role == "user") {
|
||||||
|
chat_msgs.pop_back();
|
||||||
|
}
|
||||||
|
chat_msgs.push_back({"user", string_strip(pending_input.str())});
|
||||||
|
pending_input.str(""); // clear
|
||||||
|
auto formatted = common_chat_apply_template(model, params.chat_template, chat_msgs, true);
|
||||||
|
|
||||||
|
// tokenize the new chat history and decode
|
||||||
|
llama_tokens prompt_tokens = common_tokenize(ctx, formatted, true, true);
|
||||||
|
decode(prompt_tokens, false);
|
||||||
|
|
||||||
|
// generate response
|
||||||
|
llama_token new_token_id = LLAMA_TOKEN_NULL;
|
||||||
|
llama_tokens generated_tokens;
|
||||||
|
common_sampler_reset(smpl);
|
||||||
|
while (true) {
|
||||||
|
if (interacting) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// sample the next token
|
||||||
|
new_token_id = common_sampler_sample(smpl, ctx, -1);
|
||||||
|
|
||||||
|
// is it an end of generation?
|
||||||
|
if (llama_token_is_eog(model, new_token_id)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// print the token, then decode it
|
||||||
|
printf("%s", common_token_to_piece(ctx, new_token_id, params.special).c_str());
|
||||||
|
fflush(stdout);
|
||||||
|
generated_tokens.push_back(new_token_id);
|
||||||
|
llama_tokens new_tok = {new_token_id};
|
||||||
|
decode(new_tok, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
// add the generated tokens to the chat history
|
||||||
|
std::string response = common_detokenize(ctx, generated_tokens, true);
|
||||||
|
chat_msgs.push_back({"assistant", response});
|
||||||
|
|
||||||
|
// print a new line if needed
|
||||||
|
if (!response.empty() && response.back() != '\n') {
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void interrupt() {
|
||||||
|
if (interacting) {
|
||||||
|
// exit
|
||||||
|
printf("\n");
|
||||||
|
console::cleanup();
|
||||||
|
common_perf_print(ctx, smpl);
|
||||||
|
common_log_pause(common_log_main());
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
interacting = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool handle_command(std::string & inp, bool & continue_input) {
|
||||||
|
if (inp.empty() || inp[0] != '/') {
|
||||||
|
return false; // not a command
|
||||||
|
}
|
||||||
|
auto parts = string_split<std::string>(string_strip(inp), ' ');
|
||||||
|
std::string & cmd = parts[0];
|
||||||
|
if (cmd == "/help") {
|
||||||
|
LOG("TODO\n");
|
||||||
|
continue_input = true;
|
||||||
|
} else if (cmd == "/history") {
|
||||||
|
display_history();
|
||||||
|
continue_input = true;
|
||||||
|
} else if (cmd == "/regen") {
|
||||||
|
if (chat_msgs.empty()) {
|
||||||
|
LOG_ERR("no chat history to regenerate\n");
|
||||||
|
continue_input = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (chat_msgs.back().role == "assistant") {
|
||||||
|
chat_msgs.pop_back();
|
||||||
|
}
|
||||||
|
if (chat_msgs.back().role == "user") {
|
||||||
|
pending_input.str(""); // clear
|
||||||
|
pending_input << chat_msgs.back().content;
|
||||||
|
chat_msgs.pop_back();
|
||||||
|
}
|
||||||
|
continue_input = false;
|
||||||
|
} else if (cmd == "/readfile") {
|
||||||
|
const std::string filename = parts[1];
|
||||||
|
LOG_DBG("reading file: '%s'\n", filename.c_str());
|
||||||
|
std::ifstream text_file(filename);
|
||||||
|
if (!text_file) {
|
||||||
|
LOG("failed to open file '%s'\n", filename.c_str());
|
||||||
|
} else {
|
||||||
|
pending_input << text_file.rdbuf() << "\n\n";
|
||||||
|
LOG("read %zu characters from file\n", (size_t) text_file.tellg());
|
||||||
|
}
|
||||||
|
continue_input = true;
|
||||||
|
} else {
|
||||||
|
LOG_ERR("unknown command: %s\n", cmd.c_str());
|
||||||
|
continue_input = true;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void display_history() {
|
||||||
|
for (const auto & msg : chat_msgs) {
|
||||||
|
LOG("%s: %s\n\n", msg.role.c_str(), msg.content.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
~llama_cli_chat() {
|
||||||
|
llama_batch_free(batch);
|
||||||
|
}
|
||||||
|
};
|
|
@ -4,6 +4,7 @@
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
#include "chat.hpp"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
@ -35,6 +36,7 @@ static llama_context ** g_ctx;
|
||||||
static llama_model ** g_model;
|
static llama_model ** g_model;
|
||||||
static common_sampler ** g_smpl;
|
static common_sampler ** g_smpl;
|
||||||
static common_params * g_params;
|
static common_params * g_params;
|
||||||
|
static llama_cli_chat * g_chat;
|
||||||
static std::vector<llama_token> * g_input_tokens;
|
static std::vector<llama_token> * g_input_tokens;
|
||||||
static std::ostringstream * g_output_ss;
|
static std::ostringstream * g_output_ss;
|
||||||
static std::vector<llama_token> * g_output_tokens;
|
static std::vector<llama_token> * g_output_tokens;
|
||||||
|
@ -65,7 +67,9 @@ static bool file_is_empty(const std::string & path) {
|
||||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||||
static void sigint_handler(int signo) {
|
static void sigint_handler(int signo) {
|
||||||
if (signo == SIGINT) {
|
if (signo == SIGINT) {
|
||||||
if (!is_interacting && g_params->interactive) {
|
if (g_chat) {
|
||||||
|
g_chat->interrupt();
|
||||||
|
} else if (!is_interacting && g_params->interactive) {
|
||||||
is_interacting = true;
|
is_interacting = true;
|
||||||
need_insert_eot = true;
|
need_insert_eot = true;
|
||||||
} else {
|
} else {
|
||||||
|
@ -83,14 +87,6 @@ static void sigint_handler(int signo) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static std::string chat_add_and_format(struct llama_model * model, std::vector<common_chat_msg> & chat_msgs, const std::string & role, const std::string & content) {
|
|
||||||
common_chat_msg new_msg{role, content};
|
|
||||||
auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
|
|
||||||
chat_msgs.push_back({role, content});
|
|
||||||
LOG_DBG("formatted: '%s'\n", formatted.c_str());
|
|
||||||
return formatted;
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
common_params params;
|
common_params params;
|
||||||
g_params = ¶ms;
|
g_params = ¶ms;
|
||||||
|
@ -203,6 +199,12 @@ int main(int argc, char ** argv) {
|
||||||
LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx);
|
LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// switch on conversation mode if chat template is present
|
||||||
|
if (!params.chat_template.empty() || !common_get_builtin_chat_template(model).empty()) {
|
||||||
|
LOG("%s: using chat mode\n", __func__);
|
||||||
|
params.conversation = true;
|
||||||
|
}
|
||||||
|
|
||||||
// print chat template example in conversation mode
|
// print chat template example in conversation mode
|
||||||
if (params.conversation) {
|
if (params.conversation) {
|
||||||
if (params.enable_chat_template) {
|
if (params.enable_chat_template) {
|
||||||
|
@ -251,18 +253,15 @@ int main(int argc, char ** argv) {
|
||||||
std::vector<llama_token> embd_inp;
|
std::vector<llama_token> embd_inp;
|
||||||
|
|
||||||
{
|
{
|
||||||
auto prompt = (params.conversation && params.enable_chat_template && !params.prompt.empty())
|
|
||||||
? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
|
|
||||||
: params.prompt;
|
|
||||||
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
|
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
|
||||||
LOG_DBG("tokenize the prompt\n");
|
LOG_DBG("tokenize the prompt\n");
|
||||||
embd_inp = common_tokenize(ctx, prompt, true, true);
|
embd_inp = common_tokenize(ctx, params.prompt, true, true);
|
||||||
} else {
|
} else {
|
||||||
LOG_DBG("use session tokens\n");
|
LOG_DBG("use session tokens\n");
|
||||||
embd_inp = session_tokens;
|
embd_inp = session_tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_DBG("prompt: \"%s\"\n", prompt.c_str());
|
LOG_DBG("prompt: \"%s\"\n", params.prompt.c_str());
|
||||||
LOG_DBG("tokens: %s\n", string_from(ctx, embd_inp).c_str());
|
LOG_DBG("tokens: %s\n", string_from(ctx, embd_inp).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -420,6 +419,12 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
LOG_INF("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
LOG_INF("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||||
|
|
||||||
|
if (params.conversation) {
|
||||||
|
llama_cli_chat chat(params, ctx, smpl);
|
||||||
|
g_chat = &chat;
|
||||||
|
chat.run();
|
||||||
|
}
|
||||||
|
|
||||||
// group-attention state
|
// group-attention state
|
||||||
// number of grouped KV tokens so far (used only if params.grp_attn_n > 1)
|
// number of grouped KV tokens so far (used only if params.grp_attn_n > 1)
|
||||||
int ga_i = 0;
|
int ga_i = 0;
|
||||||
|
@ -752,10 +757,6 @@ int main(int argc, char ** argv) {
|
||||||
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
|
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
|
||||||
is_antiprompt = true;
|
is_antiprompt = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.enable_chat_template) {
|
|
||||||
chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str());
|
|
||||||
}
|
|
||||||
is_interacting = true;
|
is_interacting = true;
|
||||||
LOG("\n");
|
LOG("\n");
|
||||||
}
|
}
|
||||||
|
@ -818,9 +819,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool format_chat = params.conversation && params.enable_chat_template;
|
bool format_chat = params.conversation && params.enable_chat_template;
|
||||||
std::string user_inp = format_chat
|
std::string user_inp = std::move(buffer);
|
||||||
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
|
|
||||||
: std::move(buffer);
|
|
||||||
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
|
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
|
||||||
const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true);
|
const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true);
|
||||||
const auto line_inp = common_tokenize(ctx, user_inp, false, format_chat);
|
const auto line_inp = common_tokenize(ctx, user_inp, false, format_chat);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue