Add allowed response regex, response bias regex, and response bias value to main example

This commit is contained in:
Branden Butler 2023-05-10 14:39:40 -05:00
parent 779d7969c0
commit 30754bbaf9
4 changed files with 49 additions and 1 deletions

View file

@ -143,6 +143,11 @@ ifdef LLAMA_PERF
CFLAGS += -DGGML_PERF CFLAGS += -DGGML_PERF
CXXFLAGS += -DGGML_PERF CXXFLAGS += -DGGML_PERF
endif endif
ifdef LLAMA_USE_BOOST
LDFLAGS += -L/usr/lib/x86_64-linux-gnu/ -lboost_regex
CFLAGS += -DLLAMA_USE_BOOST
CXXFLAGS += -DLLAMA_USE_BOOST
endif
ifneq ($(filter aarch64%,$(UNAME_M)),) ifneq ($(filter aarch64%,$(UNAME_M)),)
# Apple M1, M2, etc. # Apple M1, M2, etc.
# Raspberry Pi 3, 4, Zero 2 (64-bit) # Raspberry Pi 3, 4, Zero 2 (64-bit)

View file

@ -333,6 +333,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.input_suffix = argv[i]; params.input_suffix = argv[i];
} else if (arg == "--allowed-response-regex") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.allowed_regex = argv[i];
} else if (arg == "--response-bias-regex") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.bias_regex = argv[i];
} else if (arg == "--response-bias-value") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.bias_regex_value = std::stof(argv[i]);
} else { } else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
gpt_print_usage(argc, argv, default_params); gpt_print_usage(argc, argv, default_params);

View file

@ -49,6 +49,9 @@ struct gpt_params {
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
std::string input_prefix = ""; // string to prefix user inputs with std::string input_prefix = ""; // string to prefix user inputs with
std::string input_suffix = ""; // string to suffix user inputs with std::string input_suffix = ""; // string to suffix user inputs with
std::string allowed_regex = ""; // regex string used to force prediction of matching tokens
std::string bias_regex = ""; // matching tokens are biased by bias_regex_value
float bias_regex_value = -1; // value to bias tokens matching bias_regex by
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string lora_adapter = ""; // lora adapter path std::string lora_adapter = ""; // lora adapter path

View file

@ -18,6 +18,10 @@
#include <string> #include <string>
#include <vector> #include <vector>
#if defined(LLAMA_USE_BOOST)
#include <boost/regex.hpp>
#endif
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h> #include <signal.h>
#include <unistd.h> #include <unistd.h>
@ -51,6 +55,8 @@ void sigint_handler(int signo) {
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
params.model = "models/llama-7B/ggml-model.bin"; params.model = "models/llama-7B/ggml-model.bin";
// boost::regex regex = boost::regex("(?:(?:\\([a-z A-Z 0-9]*, [a-z A-Z 0-9]*, [a-z A-Z 0-9]*\\))(?:<\\|>\\([a-z A-Z 0-9]*, [a-z A-Z 0-9]*, [a-z A-Z 0-9]*\\))*)|NONE");
// boost::regex negative_bias_regex = boost::regex("^NONE");
if (gpt_params_parse(argc, argv, params) == false) { if (gpt_params_parse(argc, argv, params) == false) {
return 1; return 1;
@ -97,6 +103,12 @@ int main(int argc, char ** argv) {
params.prompt = gpt_random_prompt(rng); params.prompt = gpt_random_prompt(rng);
} }
#if defined(LLAMA_USE_BOOST)
boost::regex response_allowed_regex = boost::regex(params.allowed_regex);
boost::regex response_bias_regex = boost::regex(params.bias_regex);
#endif
// params.prompt = R"(// this function checks if the number n is prime // params.prompt = R"(// this function checks if the number n is prime
//bool is_prime(int n) {)"; //bool is_prime(int n) {)";
@ -305,7 +317,7 @@ int main(int argc, char ** argv) {
console_set_color(con_st, CONSOLE_COLOR_PROMPT); console_set_color(con_st, CONSOLE_COLOR_PROMPT);
std::vector<llama_token> embd; std::vector<llama_token> embd;
std::string partial_completion;
while (n_remain != 0 || params.interactive) { while (n_remain != 0 || params.interactive) {
// predict // predict
if (embd.size() > 0) { if (embd.size() > 0) {
@ -410,6 +422,15 @@ int main(int argc, char ** argv) {
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second; logits[it->first] += it->second;
} }
#if defined(LLAMA_USE_BOOST)
for (size_t i = 0; i < llama_get_num_logits(ctx); i++) {
if (!boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_allowed_regex, boost::match_partial))
logits[i] = -INFINITY;
else if (boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_bias_regex, boost::match_partial)) {
logits[i] += params.bias_regex_value;
}
}
#endif
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
@ -459,6 +480,7 @@ int main(int argc, char ** argv) {
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id); last_n_tokens.push_back(id);
partial_completion += llama_token_to_str(ctx, id);
} }
// replace end of text token with newline token when in interactive mode // replace end of text token with newline token when in interactive mode