Merge c2b26000c3
into 39509fb082
This commit is contained in:
commit
ec9ede2d5f
3 changed files with 248 additions and 45 deletions
|
@ -1,5 +1,5 @@
|
|||
set(TARGET llama-cli)
|
||||
add_executable(${TARGET} main.cpp)
|
||||
add_executable(${TARGET} main.cpp chat.hpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
|
224
examples/main/chat.hpp
Normal file
224
examples/main/chat.hpp
Normal file
|
@ -0,0 +1,224 @@
|
|||
#include "arg.h"
|
||||
#include "common.h"
|
||||
#include "console.h"
|
||||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
struct llama_cli_chat {
|
||||
struct llama_context * ctx;
|
||||
const struct llama_model * model;
|
||||
const struct llama_vocab * vocab;
|
||||
struct common_sampler * smpl;
|
||||
struct common_params params;
|
||||
|
||||
bool interacting = false;
|
||||
std::vector<common_chat_msg> chat_msgs;
|
||||
std::ostringstream pending_input;
|
||||
|
||||
struct llama_batch batch;
|
||||
llama_tokens cache_tokens;
|
||||
int n_past = 0;
|
||||
|
||||
llama_cli_chat(
|
||||
struct common_params & params,
|
||||
struct llama_context * ctx,
|
||||
struct common_sampler * smpl) : ctx(ctx), smpl(smpl), params(params) {
|
||||
model = llama_get_model(ctx);
|
||||
vocab = llama_model_get_vocab(model);
|
||||
batch = llama_batch_init(params.n_batch, 0, 1);
|
||||
}
|
||||
|
||||
void decode(llama_tokens & eval_tokens, bool is_generating) {
|
||||
if (is_generating) {
|
||||
GGML_ASSERT(eval_tokens.size() == 1);
|
||||
} else {
|
||||
n_past = common_lcp(cache_tokens, eval_tokens);
|
||||
// in case we do a re-generation, we need to prevent eval_tokens from being empty
|
||||
if ((int) eval_tokens.size() == n_past) {
|
||||
n_past--;
|
||||
}
|
||||
if (n_past > 0) {
|
||||
eval_tokens.erase(eval_tokens.begin(), eval_tokens.begin() + n_past);
|
||||
cache_tokens.erase(cache_tokens.begin() + n_past, cache_tokens.end());
|
||||
LOG_DBG("remove from cache [%d, inf)\n", n_past);
|
||||
LOG_DBG("in cache: %s\n", common_detokenize(ctx, cache_tokens, true).c_str());
|
||||
LOG_DBG("to decode %d tokens\n", (int) eval_tokens.size());
|
||||
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
|
||||
}
|
||||
}
|
||||
|
||||
// decode
|
||||
for (size_t i = 0; i < eval_tokens.size(); i += params.n_batch) {
|
||||
if (interacting) {
|
||||
break;
|
||||
}
|
||||
|
||||
common_batch_clear(batch);
|
||||
for (int j = 0; j < params.n_batch && i + j < eval_tokens.size(); ++j) {
|
||||
n_past++;
|
||||
bool is_last_token = i + j == eval_tokens.size() - 1;
|
||||
common_batch_add(batch, eval_tokens[i + j], n_past, {0}, is_last_token);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
GGML_ABORT("failed to decode\n");
|
||||
}
|
||||
}
|
||||
|
||||
// update cache tokens
|
||||
if (is_generating) {
|
||||
cache_tokens.push_back(eval_tokens[0]);
|
||||
} else {
|
||||
cache_tokens.insert(cache_tokens.end(), eval_tokens.begin(), eval_tokens.end());
|
||||
}
|
||||
}
|
||||
|
||||
[[noreturn]] void run() {
|
||||
while (true) {
|
||||
interacting = true;
|
||||
LOG("\n> ");
|
||||
|
||||
// color user input only
|
||||
console::set_display(console::user_input);
|
||||
std::string line;
|
||||
bool another_line = true;
|
||||
bool continue_input = false;
|
||||
do {
|
||||
another_line = console::readline(line, params.multiline_input);
|
||||
if (handle_command(line, continue_input)) {
|
||||
continue; // do not add this line to pending_input
|
||||
}
|
||||
pending_input << line;
|
||||
} while (another_line);
|
||||
|
||||
if (continue_input) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (pending_input.tellp() == 0) {
|
||||
LOG_DBG("empty line, passing control back\n");
|
||||
continue;
|
||||
}
|
||||
|
||||
// done taking input, reset color
|
||||
console::set_display(console::reset);
|
||||
interacting = false;
|
||||
|
||||
// add message and format chat
|
||||
if (!chat_msgs.empty() && chat_msgs.back().role == "user") {
|
||||
chat_msgs.pop_back();
|
||||
}
|
||||
chat_msgs.push_back({"user", string_strip(pending_input.str())});
|
||||
pending_input.str(""); // clear
|
||||
auto formatted = common_chat_apply_template(model, params.chat_template, chat_msgs, true);
|
||||
|
||||
// tokenize the new chat history and decode
|
||||
llama_tokens prompt_tokens = common_tokenize(ctx, formatted, true, true);
|
||||
decode(prompt_tokens, false);
|
||||
|
||||
// generate response
|
||||
llama_token new_token_id = LLAMA_TOKEN_NULL;
|
||||
llama_tokens generated_tokens;
|
||||
common_sampler_reset(smpl);
|
||||
while (true) {
|
||||
if (interacting) {
|
||||
break;
|
||||
}
|
||||
|
||||
// sample the next token
|
||||
new_token_id = common_sampler_sample(smpl, ctx, -1);
|
||||
|
||||
// is it an end of generation?
|
||||
if (llama_vocab_is_eog(vocab, new_token_id)) {
|
||||
break;
|
||||
}
|
||||
|
||||
// print the token, then decode it
|
||||
printf("%s", common_token_to_piece(ctx, new_token_id, params.special).c_str());
|
||||
fflush(stdout);
|
||||
generated_tokens.push_back(new_token_id);
|
||||
llama_tokens new_tok = {new_token_id};
|
||||
decode(new_tok, true);
|
||||
}
|
||||
|
||||
// add the generated tokens to the chat history
|
||||
std::string response = common_detokenize(ctx, generated_tokens, true);
|
||||
chat_msgs.push_back({"assistant", response});
|
||||
|
||||
// print a new line if needed
|
||||
if (!response.empty() && response.back() != '\n') {
|
||||
printf("\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void interrupt() {
|
||||
if (interacting) {
|
||||
// exit
|
||||
printf("\n");
|
||||
console::cleanup();
|
||||
common_perf_print(ctx, smpl);
|
||||
common_log_pause(common_log_main());
|
||||
exit(0);
|
||||
}
|
||||
interacting = true;
|
||||
}
|
||||
|
||||
bool handle_command(std::string & inp, bool & continue_input) {
|
||||
if (inp.empty() || inp[0] != '/') {
|
||||
return false; // not a command
|
||||
}
|
||||
auto parts = string_split<std::string>(string_strip(inp), ' ');
|
||||
std::string & cmd = parts[0];
|
||||
if (cmd == "/help") {
|
||||
LOG("TODO\n");
|
||||
continue_input = true;
|
||||
} else if (cmd == "/history") {
|
||||
display_history();
|
||||
continue_input = true;
|
||||
} else if (cmd == "/regen") {
|
||||
if (chat_msgs.empty()) {
|
||||
LOG_ERR("no chat history to regenerate\n");
|
||||
continue_input = true;
|
||||
return true;
|
||||
}
|
||||
if (chat_msgs.back().role == "assistant") {
|
||||
chat_msgs.pop_back();
|
||||
}
|
||||
if (chat_msgs.back().role == "user") {
|
||||
pending_input.str(""); // clear
|
||||
pending_input << chat_msgs.back().content;
|
||||
chat_msgs.pop_back();
|
||||
}
|
||||
continue_input = false;
|
||||
} else if (cmd == "/readfile") {
|
||||
const std::string filename = parts[1];
|
||||
LOG_DBG("reading file: '%s'\n", filename.c_str());
|
||||
std::ifstream text_file(filename);
|
||||
if (!text_file) {
|
||||
LOG("failed to open file '%s'\n", filename.c_str());
|
||||
} else {
|
||||
pending_input << text_file.rdbuf() << "\n\n";
|
||||
LOG("read %zu characters from file\n", (size_t) text_file.tellg());
|
||||
}
|
||||
continue_input = true;
|
||||
} else {
|
||||
LOG_ERR("unknown command: %s\n", cmd.c_str());
|
||||
continue_input = true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void display_history() {
|
||||
for (const auto & msg : chat_msgs) {
|
||||
LOG("%s: %s\n\n", msg.role.c_str(), msg.content.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
~llama_cli_chat() {
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
};
|
|
@ -4,6 +4,7 @@
|
|||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
#include "llama.h"
|
||||
#include "chat.hpp"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
|
@ -34,11 +35,11 @@ static llama_context ** g_ctx;
|
|||
static llama_model ** g_model;
|
||||
static common_sampler ** g_smpl;
|
||||
static common_params * g_params;
|
||||
static llama_cli_chat * g_chat;
|
||||
static std::vector<llama_token> * g_input_tokens;
|
||||
static std::ostringstream * g_output_ss;
|
||||
static std::vector<llama_token> * g_output_tokens;
|
||||
static bool is_interacting = false;
|
||||
static bool need_insert_eot = false;
|
||||
|
||||
static void print_usage(int argc, char ** argv) {
|
||||
(void) argc;
|
||||
|
@ -64,9 +65,10 @@ static bool file_is_empty(const std::string & path) {
|
|||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||
static void sigint_handler(int signo) {
|
||||
if (signo == SIGINT) {
|
||||
if (!is_interacting && g_params->interactive) {
|
||||
if (g_chat) {
|
||||
g_chat->interrupt();
|
||||
} else if (!is_interacting && g_params->interactive) {
|
||||
is_interacting = true;
|
||||
need_insert_eot = true;
|
||||
} else {
|
||||
console::cleanup();
|
||||
LOG("\n");
|
||||
|
@ -82,14 +84,6 @@ static void sigint_handler(int signo) {
|
|||
}
|
||||
#endif
|
||||
|
||||
static std::string chat_add_and_format(struct llama_model * model, std::vector<common_chat_msg> & chat_msgs, const std::string & role, const std::string & content) {
|
||||
common_chat_msg new_msg{role, content};
|
||||
auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
|
||||
chat_msgs.push_back({role, content});
|
||||
LOG_DBG("formatted: '%s'\n", formatted.c_str());
|
||||
return formatted;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
g_params = ¶ms;
|
||||
|
@ -204,6 +198,12 @@ int main(int argc, char ** argv) {
|
|||
LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx);
|
||||
}
|
||||
|
||||
// switch on conversation mode if chat template is present
|
||||
if (!params.chat_template.empty() || !common_get_builtin_chat_template(model).empty()) {
|
||||
LOG("%s: using chat mode\n", __func__);
|
||||
params.conversation = true;
|
||||
}
|
||||
|
||||
// print chat template example in conversation mode
|
||||
if (params.conversation) {
|
||||
if (params.enable_chat_template) {
|
||||
|
@ -252,18 +252,15 @@ int main(int argc, char ** argv) {
|
|||
std::vector<llama_token> embd_inp;
|
||||
|
||||
{
|
||||
auto prompt = (params.conversation && params.enable_chat_template && !params.prompt.empty())
|
||||
? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
|
||||
: params.prompt;
|
||||
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
|
||||
LOG_DBG("tokenize the prompt\n");
|
||||
embd_inp = common_tokenize(ctx, prompt, true, true);
|
||||
embd_inp = common_tokenize(ctx, params.prompt, true, true);
|
||||
} else {
|
||||
LOG_DBG("use session tokens\n");
|
||||
embd_inp = session_tokens;
|
||||
}
|
||||
|
||||
LOG_DBG("prompt: \"%s\"\n", prompt.c_str());
|
||||
LOG_DBG("prompt: \"%s\"\n", params.prompt.c_str());
|
||||
LOG_DBG("tokens: %s\n", string_from(ctx, embd_inp).c_str());
|
||||
}
|
||||
|
||||
|
@ -421,6 +418,12 @@ int main(int argc, char ** argv) {
|
|||
|
||||
LOG_INF("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||
|
||||
if (params.conversation) {
|
||||
llama_cli_chat chat(params, ctx, smpl);
|
||||
g_chat = &chat;
|
||||
chat.run();
|
||||
}
|
||||
|
||||
// group-attention state
|
||||
// number of grouped KV tokens so far (used only if params.grp_attn_n > 1)
|
||||
int ga_i = 0;
|
||||
|
@ -753,35 +756,21 @@ int main(int argc, char ** argv) {
|
|||
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
|
||||
is_antiprompt = true;
|
||||
}
|
||||
|
||||
if (params.enable_chat_template) {
|
||||
chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str());
|
||||
}
|
||||
is_interacting = true;
|
||||
LOG("\n");
|
||||
}
|
||||
}
|
||||
|
||||
// if current token is not EOG, we add it to current assistant message
|
||||
if (params.conversation) {
|
||||
const auto id = common_sampler_last(smpl);
|
||||
assistant_ss << common_token_to_piece(ctx, id, false);
|
||||
}
|
||||
|
||||
if (n_past > 0 && is_interacting) {
|
||||
LOG_DBG("waiting for user input\n");
|
||||
|
||||
if (params.conversation) {
|
||||
LOG("\n> ");
|
||||
}
|
||||
|
||||
if (params.input_prefix_bos) {
|
||||
LOG_DBG("adding input prefix BOS token\n");
|
||||
embd_inp.push_back(llama_vocab_bos(vocab));
|
||||
}
|
||||
|
||||
std::string buffer;
|
||||
if (!params.input_prefix.empty() && !params.conversation) {
|
||||
if (!params.input_prefix.empty()) {
|
||||
LOG_DBG("appending input prefix: '%s'\n", params.input_prefix.c_str());
|
||||
LOG("%s", params.input_prefix.c_str());
|
||||
}
|
||||
|
@ -805,7 +794,7 @@ int main(int argc, char ** argv) {
|
|||
// Entering a empty line lets the user pass control back
|
||||
if (buffer.length() > 1) {
|
||||
// append input suffix if any
|
||||
if (!params.input_suffix.empty() && !params.conversation) {
|
||||
if (!params.input_suffix.empty()) {
|
||||
LOG_DBG("appending input suffix: '%s'\n", params.input_suffix.c_str());
|
||||
LOG("%s", params.input_suffix.c_str());
|
||||
}
|
||||
|
@ -818,24 +807,14 @@ int main(int argc, char ** argv) {
|
|||
string_process_escapes(buffer);
|
||||
}
|
||||
|
||||
bool format_chat = params.conversation && params.enable_chat_template;
|
||||
std::string user_inp = format_chat
|
||||
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
|
||||
: std::move(buffer);
|
||||
std::string user_inp = std::move(buffer);
|
||||
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
|
||||
const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true);
|
||||
const auto line_inp = common_tokenize(ctx, user_inp, false, format_chat);
|
||||
const auto line_inp = common_tokenize(ctx, user_inp, false, true);
|
||||
const auto line_sfx = common_tokenize(ctx, params.input_suffix, false, true);
|
||||
|
||||
LOG_DBG("input tokens: %s\n", string_from(ctx, line_inp).c_str());
|
||||
|
||||
// if user stop generation mid-way, we must add EOT to finish model's last response
|
||||
if (need_insert_eot && format_chat) {
|
||||
llama_token eot = llama_vocab_eot(vocab);
|
||||
embd_inp.push_back(eot == LLAMA_TOKEN_NULL ? llama_vocab_eos(vocab) : eot);
|
||||
need_insert_eot = false;
|
||||
}
|
||||
|
||||
embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end());
|
||||
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
|
||||
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue