From 30754bbaf98733e1a6df29d80843edfecbadef69 Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Wed, 10 May 2023 14:39:40 -0500 Subject: [PATCH] Add allowed response regex, response bias regex, and response bias value to main example --- Makefile | 5 +++++ examples/common.cpp | 18 ++++++++++++++++++ examples/common.h | 3 +++ examples/main/main.cpp | 24 +++++++++++++++++++++++- 4 files changed, 49 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 0ddff9961..e4c338a58 100644 --- a/Makefile +++ b/Makefile @@ -143,6 +143,11 @@ ifdef LLAMA_PERF CFLAGS += -DGGML_PERF CXXFLAGS += -DGGML_PERF endif +ifdef LLAMA_USE_BOOST + LDFLAGS += -L/usr/lib/x86_64-linux-gnu/ -lboost_regex + CFLAGS += -DLLAMA_USE_BOOST + CXXFLAGS += -DLLAMA_USE_BOOST +endif ifneq ($(filter aarch64%,$(UNAME_M)),) # Apple M1, M2, etc. # Raspberry Pi 3, 4, Zero 2 (64-bit) diff --git a/examples/common.cpp b/examples/common.cpp index f3085b08e..84380e460 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -333,6 +333,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.input_suffix = argv[i]; + } else if (arg == "--allowed-response-regex") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.allowed_regex = argv[i]; + } else if (arg == "--response-bias-regex") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.bias_regex = argv[i]; + } else if (arg == "--response-bias-value") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.bias_regex_value = std::stof(argv[i]); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); gpt_print_usage(argc, argv, default_params); diff --git a/examples/common.h b/examples/common.h index 499671b2e..e7293cd8f 100644 --- a/examples/common.h +++ b/examples/common.h @@ -49,6 +49,9 @@ struct gpt_params { std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state std::string input_prefix = ""; // string to prefix user inputs with std::string input_suffix = ""; // string to suffix user inputs with + std::string allowed_regex = ""; // regex string used to force prediction of matching tokens + std::string bias_regex = ""; // matching tokens are biased by bias_regex_value + float bias_regex_value = -1; // value to bias tokens matching bias_regex by std::vector antiprompt; // string upon seeing which more user input is prompted std::string lora_adapter = ""; // lora adapter path diff --git a/examples/main/main.cpp b/examples/main/main.cpp index bd1c4ab55..af31952e4 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -18,6 +18,10 @@ #include #include +#if defined(LLAMA_USE_BOOST) + #include +#endif + #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #include #include @@ -51,6 +55,8 @@ void sigint_handler(int signo) { int main(int argc, char ** argv) { gpt_params params; params.model = "models/llama-7B/ggml-model.bin"; +// boost::regex regex = boost::regex("(?:(?:\\([a-z A-Z 0-9]*, [a-z A-Z 0-9]*, [a-z A-Z 0-9]*\\))(?:<\\|>\\([a-z A-Z 0-9]*, [a-z A-Z 0-9]*, [a-z A-Z 0-9]*\\))*)|NONE"); +// boost::regex negative_bias_regex = boost::regex("^NONE"); if (gpt_params_parse(argc, argv, params) == false) { return 1; @@ -97,6 +103,12 @@ int main(int argc, char ** argv) { params.prompt = gpt_random_prompt(rng); } + +#if defined(LLAMA_USE_BOOST) + boost::regex response_allowed_regex = boost::regex(params.allowed_regex); + boost::regex response_bias_regex = boost::regex(params.bias_regex); +#endif + // params.prompt = R"(// this function checks if the number n is prime //bool is_prime(int n) {)"; @@ -305,7 +317,7 @@ int main(int argc, char ** argv) { console_set_color(con_st, CONSOLE_COLOR_PROMPT); std::vector embd; - + std::string partial_completion; while (n_remain != 0 || params.interactive) { // predict if (embd.size() > 0) { @@ -410,6 +422,15 @@ int main(int argc, char ** argv) { for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { logits[it->first] += it->second; } +#if defined(LLAMA_USE_BOOST) + for (size_t i = 0; i < llama_get_num_logits(ctx); i++) { + if (!boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_allowed_regex, boost::match_partial)) + logits[i] = -INFINITY; + else if (boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_bias_regex, boost::match_partial)) { + logits[i] += params.bias_regex_value; + } + } +#endif std::vector candidates; candidates.reserve(n_vocab); @@ -459,6 +480,7 @@ int main(int argc, char ** argv) { last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); + partial_completion += llama_token_to_str(ctx, id); } // replace end of text token with newline token when in interactive mode