From 56ddd88d05a8962fb0eb9f647af77dcef877ed28 Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 22 Aug 2023 19:23:39 +0200 Subject: [PATCH] add llama2 chat example --- .gitignore | 1 + Makefile | 5 +- examples/CMakeLists.txt | 1 + examples/llama2-chat/CMakeLists.txt | 8 ++ examples/llama2-chat/llama2-chat.cpp | 150 +++++++++++++++++++++++++++ 5 files changed, 164 insertions(+), 1 deletion(-) create mode 100644 examples/llama2-chat/CMakeLists.txt create mode 100644 examples/llama2-chat/llama2-chat.cpp diff --git a/.gitignore b/.gitignore index 8b5f45a2d..d6c28a582 100644 --- a/.gitignore +++ b/.gitignore @@ -45,6 +45,7 @@ models-mnt /baby-llama /beam-search /save-load-state +/llama2-chat build-info.h arm_neon.h compile_commands.json diff --git a/Makefile b/Makefile index b750540fe..695fa5cd6 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Define the default target now so that it is always the first target -BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple save-load-state server embd-input-test gguf llama-bench baby-llama beam-search tests/test-c.o +BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple save-load-state server embd-input-test gguf llama-bench baby-llama beam-search llama2-chat tests/test-c.o # Binaries only useful for tests TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1 @@ -449,6 +449,9 @@ baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o common.o $(OBJS) beam-search: examples/beam-search/beam-search.cpp build-info.h ggml.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) +llama2-chat: examples/llama2-chat/llama2-chat.cpp build-info.h ggml.o llama.o common.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) + ifneq '' '$(or $(filter clean,$(MAKECMDGOALS)),$(LLAMA_METAL))' BUILD_TARGETS += metal endif diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 6e65eb087..ca66323e3 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -26,6 +26,7 @@ else() add_subdirectory(embd-input) add_subdirectory(llama-bench) add_subdirectory(beam-search) + add_subdirectory(llama2-chat) if (LLAMA_METAL) add_subdirectory(metal) endif() diff --git a/examples/llama2-chat/CMakeLists.txt b/examples/llama2-chat/CMakeLists.txt new file mode 100644 index 000000000..dc5be7059 --- /dev/null +++ b/examples/llama2-chat/CMakeLists.txt @@ -0,0 +1,8 @@ +set(TARGET llama2-chat) +add_executable(${TARGET} llama2-chat.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) +if(TARGET BUILD_INFO) + add_dependencies(${TARGET} BUILD_INFO) +endif() diff --git a/examples/llama2-chat/llama2-chat.cpp b/examples/llama2-chat/llama2-chat.cpp new file mode 100644 index 000000000..a5937fa91 --- /dev/null +++ b/examples/llama2-chat/llama2-chat.cpp @@ -0,0 +1,150 @@ +#include "build-info.h" +#include "common.h" +#include "llama.h" + +#include +#include +#include +#include +#include + +static const std::string B_INST = "[INST]"; +static const std::string E_INST = "[/INST]"; +static const std::string B_SYS = "<>\n"; +static const std::string E_SYS = "\n<>\n\n"; + +struct chat { + llama_context_params lparams; + llama_model * model; + llama_context * ctx; + + std::string system; + int n_threads = 8; + + chat(const std::string & model_file, const std::string & system) : system(system) { + lparams = llama_context_default_params(); + lparams.n_ctx = 4096; + lparams.n_gpu_layers = 99; + + model = llama_load_model_from_file(model_file.c_str(), lparams); + if (model == NULL) { + fprintf(stderr , "%s: error: unable to load model\n" , __func__); + exit(1); + } + + ctx = llama_new_context_with_model(model, lparams); + if (ctx == NULL) { + fprintf(stderr , "%s: error: unable to create context\n" , __func__); + exit(1); + } + } + + std::vector tokenize_dialog(const std::string & user, const std::string & assistant = "") { + std::string content; + // B_SYS + dialog[0]["content"] + E_SYS + dialog[1]["content"], + if (!system.empty()) { + content = B_SYS + system + E_SYS + user; + system.clear(); + } else { + content = user; + } + // f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}" + std::string prompt; + prompt = B_INST + " " + content + " " + E_INST; + + // f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", + if (!assistant.empty()) { + prompt += " " + assistant; + } + + // printf("prompt: %s\n", prompt.c_str()); + + auto tokens = ::llama_tokenize(ctx, prompt, true); + + if (!assistant.empty()) { + tokens.push_back(llama_token_eos(ctx)); + } + + return tokens; + } + + void eval_prompt(std::vector prompt) { + while (!prompt.empty()) { + int n_tokens = std::min(lparams.n_batch, (int)prompt.size()); + llama_eval(ctx, prompt.data(), n_tokens, llama_get_kv_cache_token_count(ctx), n_threads); + prompt.erase(prompt.begin(), prompt.begin() + n_tokens); + } + } + + llama_token sample_token() { + auto logits = llama_get_logits(ctx); + auto n_vocab = llama_n_vocab(ctx); + + std::vector candidates; + candidates.reserve(n_vocab); + + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + + llama_sample_temperature(ctx, &candidates_p, 0.7f); + + llama_token token = llama_sample_token(ctx , &candidates_p); + + return token; + } + + void eval_answer() { + std::string answer; + do { + llama_token id = sample_token(); + llama_eval(ctx, &id, 1, llama_get_kv_cache_token_count(ctx), n_threads); + + //printf("[%d]%s", id, llama_token_to_str(ctx, id).c_str()); + + if (id == llama_token_eos(ctx)) { + break; + } + + printf("%s", llama_token_to_str(ctx, id).c_str()); + + } while (true); + } + + void chat_loop() { + std::string user_prompt; + + while (true) { + printf("\nUser: "); + std::getline(std::cin, user_prompt); + if (user_prompt.empty()) + break; + + auto prompt = tokenize_dialog(user_prompt); + eval_prompt(prompt); + + eval_answer(); + } + } +}; + +int main(int argc, char ** argv) { + if (argc == 1 || argv[1][0] == '-') { + printf("usage: %s MODEL_PATH [\"system prompt\"]", argv[0]); + return 1 ; + } + + std::string model_file = argv[1]; + std::string system; + + if (argc > 2) { + system = argv[2]; + } + llama_backend_init(false); + + chat c(model_file, system); + + c.chat_loop(); +}