diff --git a/examples/simple-chat/README.md b/examples/simple-chat/README.md index 6770a7de0..f0099ce3d 100644 --- a/examples/simple-chat/README.md +++ b/examples/simple-chat/README.md @@ -1,7 +1,7 @@ # llama.cpp/example/simple-chat -The purpose of this example is to demonstrate a minimal usage of llama.cpp to create a simple chat program using the built-in chat template in GGUF files. +The purpose of this example is to demonstrate a minimal usage of llama.cpp to create a simple chat program using the chat template from the GGUF file. ```bash -./llama-simple-chat -m ./models/llama-7b-v2/ggml-model-f16.gguf -c 2048 +./llama-simple-chat -m Meta-Llama-3.1-8B-Instruct.gguf -c 2048 ... diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index a537425af..389848a83 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -12,9 +12,7 @@ static void print_usage(int, char ** argv) { } int main(int argc, char ** argv) { - // path to the model gguf file std::string model_path; - // number of layers to offload to the GPU int ngl = 99; int n_ctx = 2048; @@ -91,13 +89,13 @@ int main(int argc, char ** argv) { llama_sampler_chain_add(smpl, llama_sampler_init_temp(0.8f)); llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); - // generation helper + // helper function to evaluate a prompt and generate a response auto generate = [&](const std::string & prompt) { std::string response; // tokenize the prompt - const int n_prompt = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true); - std::vector prompt_tokens(n_prompt); + const int n_prompt_tokens = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true); + std::vector prompt_tokens(n_prompt_tokens); if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) { GGML_ABORT("failed to tokenize the prompt\n"); } @@ -106,7 +104,7 @@ int main(int argc, char ** argv) { llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); llama_token new_token_id; while (true) { - // check if we have enough context space to evaluate this batch + // check if we have enough space in the context to evaluate this batch int n_ctx = llama_n_ctx(ctx); int n_ctx_used = llama_get_kv_cache_used_cells(ctx); if (n_ctx_used + batch.n_tokens > n_ctx) { @@ -116,7 +114,7 @@ int main(int argc, char ** argv) { } if (llama_decode(ctx, batch)) { - GGML_ABORT("failed to eval\n"); + GGML_ABORT("failed to decode\n"); } // sample the next token @@ -127,16 +125,16 @@ int main(int argc, char ** argv) { break; } - // add the token to the response - char buf[128]; + // convert the token to a string, print it and add it to the response + char buf[256]; int n = llama_token_to_piece(model, new_token_id, buf, sizeof(buf), 0, true); if (n < 0) { GGML_ABORT("failed to convert token to piece\n"); } std::string piece(buf, n); - response += piece; printf("%s", piece.c_str()); fflush(stdout); + response += piece; // prepare the next batch with the sampled token batch = llama_batch_get_one(&new_token_id, 1); @@ -146,34 +144,51 @@ int main(int argc, char ** argv) { }; std::vector messages; - std::vector formatted(2048); + std::vector formatted(llama_n_ctx(ctx)); int prev_len = 0; while (true) { + // get user input + printf("\033[32m> \033[0m"); std::string user; std::getline(std::cin, user); - messages.push_back({"user", strdup(user.c_str())}); - // format the messages + if (user.empty()) { + break; + } + + // add the user input to the message list and format it + messages.push_back({"user", strdup(user.c_str())}); int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); if (new_len > (int)formatted.size()) { formatted.resize(new_len); new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); } + if (new_len < 0) { + fprintf(stderr, "failed to apply the chat template\n"); + return 1; + } - // remove previous messages and obtain a prompt + // remove previous messages to obtain the prompt to generate the response std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len); // generate a response - printf("\033[31m"); + printf("\033[33m"); std::string response = generate(prompt); printf("\n\033[0m"); // add the response to the messages messages.push_back({"assistant", strdup(response.c_str())}); prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, formatted.data(), formatted.size()); + if (prev_len < 0) { + fprintf(stderr, "failed to apply the chat template\n"); + return 1; + } } - + // free resources + for (auto & msg : messages) { + free(const_cast(msg.content)); + } llama_sampler_free(smpl); llama_free(ctx); llama_free_model(model);