add llama2 chat example
This commit is contained in:
parent
b532a69b2f
commit
56ddd88d05
5 changed files with 164 additions and 1 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -45,6 +45,7 @@ models-mnt
|
||||||
/baby-llama
|
/baby-llama
|
||||||
/beam-search
|
/beam-search
|
||||||
/save-load-state
|
/save-load-state
|
||||||
|
/llama2-chat
|
||||||
build-info.h
|
build-info.h
|
||||||
arm_neon.h
|
arm_neon.h
|
||||||
compile_commands.json
|
compile_commands.json
|
||||||
|
|
5
Makefile
5
Makefile
|
@ -1,5 +1,5 @@
|
||||||
# Define the default target now so that it is always the first target
|
# Define the default target now so that it is always the first target
|
||||||
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple save-load-state server embd-input-test gguf llama-bench baby-llama beam-search tests/test-c.o
|
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple save-load-state server embd-input-test gguf llama-bench baby-llama beam-search llama2-chat tests/test-c.o
|
||||||
|
|
||||||
# Binaries only useful for tests
|
# Binaries only useful for tests
|
||||||
TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1
|
TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama tests/test-tokenizer-0-falcon tests/test-tokenizer-1
|
||||||
|
@ -449,6 +449,9 @@ baby-llama: examples/baby-llama/baby-llama.cpp ggml.o llama.o common.o $(OBJS)
|
||||||
beam-search: examples/beam-search/beam-search.cpp build-info.h ggml.o llama.o common.o $(OBJS)
|
beam-search: examples/beam-search/beam-search.cpp build-info.h ggml.o llama.o common.o $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
|
llama2-chat: examples/llama2-chat/llama2-chat.cpp build-info.h ggml.o llama.o common.o $(OBJS)
|
||||||
|
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
ifneq '' '$(or $(filter clean,$(MAKECMDGOALS)),$(LLAMA_METAL))'
|
ifneq '' '$(or $(filter clean,$(MAKECMDGOALS)),$(LLAMA_METAL))'
|
||||||
BUILD_TARGETS += metal
|
BUILD_TARGETS += metal
|
||||||
endif
|
endif
|
||||||
|
|
|
@ -26,6 +26,7 @@ else()
|
||||||
add_subdirectory(embd-input)
|
add_subdirectory(embd-input)
|
||||||
add_subdirectory(llama-bench)
|
add_subdirectory(llama-bench)
|
||||||
add_subdirectory(beam-search)
|
add_subdirectory(beam-search)
|
||||||
|
add_subdirectory(llama2-chat)
|
||||||
if (LLAMA_METAL)
|
if (LLAMA_METAL)
|
||||||
add_subdirectory(metal)
|
add_subdirectory(metal)
|
||||||
endif()
|
endif()
|
||||||
|
|
8
examples/llama2-chat/CMakeLists.txt
Normal file
8
examples/llama2-chat/CMakeLists.txt
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
set(TARGET llama2-chat)
|
||||||
|
add_executable(${TARGET} llama2-chat.cpp)
|
||||||
|
install(TARGETS ${TARGET} RUNTIME)
|
||||||
|
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||||
|
if(TARGET BUILD_INFO)
|
||||||
|
add_dependencies(${TARGET} BUILD_INFO)
|
||||||
|
endif()
|
150
examples/llama2-chat/llama2-chat.cpp
Normal file
150
examples/llama2-chat/llama2-chat.cpp
Normal file
|
@ -0,0 +1,150 @@
|
||||||
|
#include "build-info.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
static const std::string B_INST = "[INST]";
|
||||||
|
static const std::string E_INST = "[/INST]";
|
||||||
|
static const std::string B_SYS = "<<SYS>>\n";
|
||||||
|
static const std::string E_SYS = "\n<<SYS>>\n\n";
|
||||||
|
|
||||||
|
struct chat {
|
||||||
|
llama_context_params lparams;
|
||||||
|
llama_model * model;
|
||||||
|
llama_context * ctx;
|
||||||
|
|
||||||
|
std::string system;
|
||||||
|
int n_threads = 8;
|
||||||
|
|
||||||
|
chat(const std::string & model_file, const std::string & system) : system(system) {
|
||||||
|
lparams = llama_context_default_params();
|
||||||
|
lparams.n_ctx = 4096;
|
||||||
|
lparams.n_gpu_layers = 99;
|
||||||
|
|
||||||
|
model = llama_load_model_from_file(model_file.c_str(), lparams);
|
||||||
|
if (model == NULL) {
|
||||||
|
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = llama_new_context_with_model(model, lparams);
|
||||||
|
if (ctx == NULL) {
|
||||||
|
fprintf(stderr , "%s: error: unable to create context\n" , __func__);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<llama_token> tokenize_dialog(const std::string & user, const std::string & assistant = "") {
|
||||||
|
std::string content;
|
||||||
|
// B_SYS + dialog[0]["content"] + E_SYS + dialog[1]["content"],
|
||||||
|
if (!system.empty()) {
|
||||||
|
content = B_SYS + system + E_SYS + user;
|
||||||
|
system.clear();
|
||||||
|
} else {
|
||||||
|
content = user;
|
||||||
|
}
|
||||||
|
// f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}"
|
||||||
|
std::string prompt;
|
||||||
|
prompt = B_INST + " " + content + " " + E_INST;
|
||||||
|
|
||||||
|
// f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
|
||||||
|
if (!assistant.empty()) {
|
||||||
|
prompt += " " + assistant;
|
||||||
|
}
|
||||||
|
|
||||||
|
// printf("prompt: %s\n", prompt.c_str());
|
||||||
|
|
||||||
|
auto tokens = ::llama_tokenize(ctx, prompt, true);
|
||||||
|
|
||||||
|
if (!assistant.empty()) {
|
||||||
|
tokens.push_back(llama_token_eos(ctx));
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
void eval_prompt(std::vector<llama_token> prompt) {
|
||||||
|
while (!prompt.empty()) {
|
||||||
|
int n_tokens = std::min(lparams.n_batch, (int)prompt.size());
|
||||||
|
llama_eval(ctx, prompt.data(), n_tokens, llama_get_kv_cache_token_count(ctx), n_threads);
|
||||||
|
prompt.erase(prompt.begin(), prompt.begin() + n_tokens);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token sample_token() {
|
||||||
|
auto logits = llama_get_logits(ctx);
|
||||||
|
auto n_vocab = llama_n_vocab(ctx);
|
||||||
|
|
||||||
|
std::vector<llama_token_data> candidates;
|
||||||
|
candidates.reserve(n_vocab);
|
||||||
|
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||||
|
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||||
|
|
||||||
|
llama_sample_temperature(ctx, &candidates_p, 0.7f);
|
||||||
|
|
||||||
|
llama_token token = llama_sample_token(ctx , &candidates_p);
|
||||||
|
|
||||||
|
return token;
|
||||||
|
}
|
||||||
|
|
||||||
|
void eval_answer() {
|
||||||
|
std::string answer;
|
||||||
|
do {
|
||||||
|
llama_token id = sample_token();
|
||||||
|
llama_eval(ctx, &id, 1, llama_get_kv_cache_token_count(ctx), n_threads);
|
||||||
|
|
||||||
|
//printf("[%d]%s", id, llama_token_to_str(ctx, id).c_str());
|
||||||
|
|
||||||
|
if (id == llama_token_eos(ctx)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("%s", llama_token_to_str(ctx, id).c_str());
|
||||||
|
|
||||||
|
} while (true);
|
||||||
|
}
|
||||||
|
|
||||||
|
void chat_loop() {
|
||||||
|
std::string user_prompt;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
printf("\nUser: ");
|
||||||
|
std::getline(std::cin, user_prompt);
|
||||||
|
if (user_prompt.empty())
|
||||||
|
break;
|
||||||
|
|
||||||
|
auto prompt = tokenize_dialog(user_prompt);
|
||||||
|
eval_prompt(prompt);
|
||||||
|
|
||||||
|
eval_answer();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
int main(int argc, char ** argv) {
|
||||||
|
if (argc == 1 || argv[1][0] == '-') {
|
||||||
|
printf("usage: %s MODEL_PATH [\"system prompt\"]", argv[0]);
|
||||||
|
return 1 ;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string model_file = argv[1];
|
||||||
|
std::string system;
|
||||||
|
|
||||||
|
if (argc > 2) {
|
||||||
|
system = argv[2];
|
||||||
|
}
|
||||||
|
llama_backend_init(false);
|
||||||
|
|
||||||
|
chat c(model_file, system);
|
||||||
|
|
||||||
|
c.chat_loop();
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue