This commit is contained in:
slaren 2024-11-01 23:01:33 +01:00
parent 587c114b09
commit d113cf2a13
2 changed files with 33 additions and 18 deletions

View file

@ -1,7 +1,7 @@
# llama.cpp/example/simple-chat
The purpose of this example is to demonstrate a minimal usage of llama.cpp to create a simple chat program using the built-in chat template in GGUF files.
The purpose of this example is to demonstrate a minimal usage of llama.cpp to create a simple chat program using the chat template from the GGUF file.
```bash
./llama-simple-chat -m ./models/llama-7b-v2/ggml-model-f16.gguf -c 2048
./llama-simple-chat -m Meta-Llama-3.1-8B-Instruct.gguf -c 2048
...

View file

@ -12,9 +12,7 @@ static void print_usage(int, char ** argv) {
}
int main(int argc, char ** argv) {
// path to the model gguf file
std::string model_path;
// number of layers to offload to the GPU
int ngl = 99;
int n_ctx = 2048;
@ -91,13 +89,13 @@ int main(int argc, char ** argv) {
llama_sampler_chain_add(smpl, llama_sampler_init_temp(0.8f));
llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
// generation helper
// helper function to evaluate a prompt and generate a response
auto generate = [&](const std::string & prompt) {
std::string response;
// tokenize the prompt
const int n_prompt = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true);
std::vector<llama_token> prompt_tokens(n_prompt);
const int n_prompt_tokens = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true);
std::vector<llama_token> prompt_tokens(n_prompt_tokens);
if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) {
GGML_ABORT("failed to tokenize the prompt\n");
}
@ -106,7 +104,7 @@ int main(int argc, char ** argv) {
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
llama_token new_token_id;
while (true) {
// check if we have enough context space to evaluate this batch
// check if we have enough space in the context to evaluate this batch
int n_ctx = llama_n_ctx(ctx);
int n_ctx_used = llama_get_kv_cache_used_cells(ctx);
if (n_ctx_used + batch.n_tokens > n_ctx) {
@ -116,7 +114,7 @@ int main(int argc, char ** argv) {
}
if (llama_decode(ctx, batch)) {
GGML_ABORT("failed to eval\n");
GGML_ABORT("failed to decode\n");
}
// sample the next token
@ -127,16 +125,16 @@ int main(int argc, char ** argv) {
break;
}
// add the token to the response
char buf[128];
// convert the token to a string, print it and add it to the response
char buf[256];
int n = llama_token_to_piece(model, new_token_id, buf, sizeof(buf), 0, true);
if (n < 0) {
GGML_ABORT("failed to convert token to piece\n");
}
std::string piece(buf, n);
response += piece;
printf("%s", piece.c_str());
fflush(stdout);
response += piece;
// prepare the next batch with the sampled token
batch = llama_batch_get_one(&new_token_id, 1);
@ -146,34 +144,51 @@ int main(int argc, char ** argv) {
};
std::vector<llama_chat_message> messages;
std::vector<char> formatted(2048);
std::vector<char> formatted(llama_n_ctx(ctx));
int prev_len = 0;
while (true) {
// get user input
printf("\033[32m> \033[0m");
std::string user;
std::getline(std::cin, user);
messages.push_back({"user", strdup(user.c_str())});
// format the messages
if (user.empty()) {
break;
}
// add the user input to the message list and format it
messages.push_back({"user", strdup(user.c_str())});
int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
if (new_len > (int)formatted.size()) {
formatted.resize(new_len);
new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
}
if (new_len < 0) {
fprintf(stderr, "failed to apply the chat template\n");
return 1;
}
// remove previous messages and obtain a prompt
// remove previous messages to obtain the prompt to generate the response
std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len);
// generate a response
printf("\033[31m");
printf("\033[33m");
std::string response = generate(prompt);
printf("\n\033[0m");
// add the response to the messages
messages.push_back({"assistant", strdup(response.c_str())});
prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, formatted.data(), formatted.size());
if (prev_len < 0) {
fprintf(stderr, "failed to apply the chat template\n");
return 1;
}
}
// free resources
for (auto & msg : messages) {
free(const_cast<char *>(msg.content));
}
llama_sampler_free(smpl);
llama_free(ctx);
llama_free_model(model);