add llama2 chat example

This commit is contained in:
slaren 2023-08-22 19:23:39 +02:00
parent b532a69b2f
commit 56ddd88d05
5 changed files with 164 additions and 1 deletions

1
.gitignore vendored
View file

@ -45,6 +45,7 @@ models-mnt
/baby-llama /baby-llama
/beam-search /beam-search
/save-load-state /save-load-state
/llama2-chat
build-info.h build-info.h
arm_neon.h arm_neon.h
compile_commands.json compile_commands.json

View file

@ -1,5 +1,5 @@
# Define the default target now so that it is always the first target # Define the default target now so that it is always the first target
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple save-load-state server embd-input-test gguf llama-bench baby-llama beam-search tests/test-c.o BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple save-load-state server embd-input-test gguf llama-bench baby-llama beam-search llama2-chat tests/test-c.o
# Binaries only useful for tests # Binaries only useful for tests
TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1 TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1
@ -449,6 +449,9 @@ baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o common.o $(OBJS)
beam-search: examples/beam-search/beam-search.cpp build-info.h ggml.o llama.o common.o $(OBJS) beam-search: examples/beam-search/beam-search.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
llama2-chat: examples/llama2-chat/llama2-chat.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
ifneq '' '$(or $(filter clean,$(MAKECMDGOALS)),$(LLAMA_METAL))' ifneq '' '$(or $(filter clean,$(MAKECMDGOALS)),$(LLAMA_METAL))'
BUILD_TARGETS += metal BUILD_TARGETS += metal
endif endif

View file

@ -26,6 +26,7 @@ else()
add_subdirectory(embd-input) add_subdirectory(embd-input)
add_subdirectory(llama-bench) add_subdirectory(llama-bench)
add_subdirectory(beam-search) add_subdirectory(beam-search)
add_subdirectory(llama2-chat)
if (LLAMA_METAL) if (LLAMA_METAL)
add_subdirectory(metal) add_subdirectory(metal)
endif() endif()

View file

@ -0,0 +1,8 @@
set(TARGET llama2-chat)
add_executable(${TARGET} llama2-chat.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO)
endif()

View file

@ -0,0 +1,150 @@
#include "build-info.h"
#include "common.h"
#include "llama.h"
#include <cmath>
#include <cstdio>
#include <string>
#include <vector>
#include <iostream>
static const std::string B_INST = "[INST]";
static const std::string E_INST = "[/INST]";
static const std::string B_SYS = "<<SYS>>\n";
static const std::string E_SYS = "\n<<SYS>>\n\n";
struct chat {
llama_context_params lparams;
llama_model * model;
llama_context * ctx;
std::string system;
int n_threads = 8;
chat(const std::string & model_file, const std::string & system) : system(system) {
lparams = llama_context_default_params();
lparams.n_ctx = 4096;
lparams.n_gpu_layers = 99;
model = llama_load_model_from_file(model_file.c_str(), lparams);
if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
exit(1);
}
ctx = llama_new_context_with_model(model, lparams);
if (ctx == NULL) {
fprintf(stderr , "%s: error: unable to create context\n" , __func__);
exit(1);
}
}
std::vector<llama_token> tokenize_dialog(const std::string & user, const std::string & assistant = "") {
std::string content;
// B_SYS + dialog[0]["content"] + E_SYS + dialog[1]["content"],
if (!system.empty()) {
content = B_SYS + system + E_SYS + user;
system.clear();
} else {
content = user;
}
// f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}"
std::string prompt;
prompt = B_INST + " " + content + " " + E_INST;
// f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
if (!assistant.empty()) {
prompt += " " + assistant;
}
// printf("prompt: %s\n", prompt.c_str());
auto tokens = ::llama_tokenize(ctx, prompt, true);
if (!assistant.empty()) {
tokens.push_back(llama_token_eos(ctx));
}
return tokens;
}
void eval_prompt(std::vector<llama_token> prompt) {
while (!prompt.empty()) {
int n_tokens = std::min(lparams.n_batch, (int)prompt.size());
llama_eval(ctx, prompt.data(), n_tokens, llama_get_kv_cache_token_count(ctx), n_threads);
prompt.erase(prompt.begin(), prompt.begin() + n_tokens);
}
}
llama_token sample_token() {
auto logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
llama_sample_temperature(ctx, &candidates_p, 0.7f);
llama_token token = llama_sample_token(ctx , &candidates_p);
return token;
}
void eval_answer() {
std::string answer;
do {
llama_token id = sample_token();
llama_eval(ctx, &id, 1, llama_get_kv_cache_token_count(ctx), n_threads);
//printf("[%d]%s", id, llama_token_to_str(ctx, id).c_str());
if (id == llama_token_eos(ctx)) {
break;
}
printf("%s", llama_token_to_str(ctx, id).c_str());
} while (true);
}
void chat_loop() {
std::string user_prompt;
while (true) {
printf("\nUser: ");
std::getline(std::cin, user_prompt);
if (user_prompt.empty())
break;
auto prompt = tokenize_dialog(user_prompt);
eval_prompt(prompt);
eval_answer();
}
}
};
int main(int argc, char ** argv) {
if (argc == 1 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH [\"system prompt\"]", argv[0]);
return 1 ;
}
std::string model_file = argv[1];
std::string system;
if (argc > 2) {
system = argv[2];
}
llama_backend_init(false);
chat c(model_file, system);
c.chat_loop();
}