minor
This commit is contained in:
parent
587c114b09
commit
d113cf2a13
2 changed files with 33 additions and 18 deletions
|
@ -1,7 +1,7 @@
|
|||
# llama.cpp/example/simple-chat
|
||||
|
||||
The purpose of this example is to demonstrate a minimal usage of llama.cpp to create a simple chat program using the built-in chat template in GGUF files.
|
||||
The purpose of this example is to demonstrate a minimal usage of llama.cpp to create a simple chat program using the chat template from the GGUF file.
|
||||
|
||||
```bash
|
||||
./llama-simple-chat -m ./models/llama-7b-v2/ggml-model-f16.gguf -c 2048
|
||||
./llama-simple-chat -m Meta-Llama-3.1-8B-Instruct.gguf -c 2048
|
||||
...
|
||||
|
|
|
@ -12,9 +12,7 @@ static void print_usage(int, char ** argv) {
|
|||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
// path to the model gguf file
|
||||
std::string model_path;
|
||||
// number of layers to offload to the GPU
|
||||
int ngl = 99;
|
||||
int n_ctx = 2048;
|
||||
|
||||
|
@ -91,13 +89,13 @@ int main(int argc, char ** argv) {
|
|||
llama_sampler_chain_add(smpl, llama_sampler_init_temp(0.8f));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
|
||||
|
||||
// generation helper
|
||||
// helper function to evaluate a prompt and generate a response
|
||||
auto generate = [&](const std::string & prompt) {
|
||||
std::string response;
|
||||
|
||||
// tokenize the prompt
|
||||
const int n_prompt = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true);
|
||||
std::vector<llama_token> prompt_tokens(n_prompt);
|
||||
const int n_prompt_tokens = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true);
|
||||
std::vector<llama_token> prompt_tokens(n_prompt_tokens);
|
||||
if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) {
|
||||
GGML_ABORT("failed to tokenize the prompt\n");
|
||||
}
|
||||
|
@ -106,7 +104,7 @@ int main(int argc, char ** argv) {
|
|||
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
|
||||
llama_token new_token_id;
|
||||
while (true) {
|
||||
// check if we have enough context space to evaluate this batch
|
||||
// check if we have enough space in the context to evaluate this batch
|
||||
int n_ctx = llama_n_ctx(ctx);
|
||||
int n_ctx_used = llama_get_kv_cache_used_cells(ctx);
|
||||
if (n_ctx_used + batch.n_tokens > n_ctx) {
|
||||
|
@ -116,7 +114,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
GGML_ABORT("failed to eval\n");
|
||||
GGML_ABORT("failed to decode\n");
|
||||
}
|
||||
|
||||
// sample the next token
|
||||
|
@ -127,16 +125,16 @@ int main(int argc, char ** argv) {
|
|||
break;
|
||||
}
|
||||
|
||||
// add the token to the response
|
||||
char buf[128];
|
||||
// convert the token to a string, print it and add it to the response
|
||||
char buf[256];
|
||||
int n = llama_token_to_piece(model, new_token_id, buf, sizeof(buf), 0, true);
|
||||
if (n < 0) {
|
||||
GGML_ABORT("failed to convert token to piece\n");
|
||||
}
|
||||
std::string piece(buf, n);
|
||||
response += piece;
|
||||
printf("%s", piece.c_str());
|
||||
fflush(stdout);
|
||||
response += piece;
|
||||
|
||||
// prepare the next batch with the sampled token
|
||||
batch = llama_batch_get_one(&new_token_id, 1);
|
||||
|
@ -146,34 +144,51 @@ int main(int argc, char ** argv) {
|
|||
};
|
||||
|
||||
std::vector<llama_chat_message> messages;
|
||||
std::vector<char> formatted(2048);
|
||||
std::vector<char> formatted(llama_n_ctx(ctx));
|
||||
int prev_len = 0;
|
||||
while (true) {
|
||||
// get user input
|
||||
printf("\033[32m> \033[0m");
|
||||
std::string user;
|
||||
std::getline(std::cin, user);
|
||||
messages.push_back({"user", strdup(user.c_str())});
|
||||
|
||||
// format the messages
|
||||
if (user.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
// add the user input to the message list and format it
|
||||
messages.push_back({"user", strdup(user.c_str())});
|
||||
int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
|
||||
if (new_len > (int)formatted.size()) {
|
||||
formatted.resize(new_len);
|
||||
new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
|
||||
}
|
||||
if (new_len < 0) {
|
||||
fprintf(stderr, "failed to apply the chat template\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// remove previous messages and obtain a prompt
|
||||
// remove previous messages to obtain the prompt to generate the response
|
||||
std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len);
|
||||
|
||||
// generate a response
|
||||
printf("\033[31m");
|
||||
printf("\033[33m");
|
||||
std::string response = generate(prompt);
|
||||
printf("\n\033[0m");
|
||||
|
||||
// add the response to the messages
|
||||
messages.push_back({"assistant", strdup(response.c_str())});
|
||||
prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, formatted.data(), formatted.size());
|
||||
if (prev_len < 0) {
|
||||
fprintf(stderr, "failed to apply the chat template\n");
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// free resources
|
||||
for (auto & msg : messages) {
|
||||
free(const_cast<char *>(msg.content));
|
||||
}
|
||||
llama_sampler_free(smpl);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue