Add allowed response regex, response bias regex, and response bias value to main example

This commit is contained in:
Branden Butler 2023-05-10 14:39:40 -05:00
parent 779d7969c0
commit 30754bbaf9
4 changed files with 49 additions and 1 deletions

View file

@ -143,6 +143,11 @@ ifdef LLAMA_PERF
CFLAGS += -DGGML_PERF
CXXFLAGS += -DGGML_PERF
endif
ifdef LLAMA_USE_BOOST
LDFLAGS += -L/usr/lib/x86_64-linux-gnu/ -lboost_regex
CFLAGS += -DLLAMA_USE_BOOST
CXXFLAGS += -DLLAMA_USE_BOOST
endif
ifneq ($(filter aarch64%,$(UNAME_M)),)
# Apple M1, M2, etc.
# Raspberry Pi 3, 4, Zero 2 (64-bit)

View file

@ -333,6 +333,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.input_suffix = argv[i];
} else if (arg == "--allowed-response-regex") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.allowed_regex = argv[i];
} else if (arg == "--response-bias-regex") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.bias_regex = argv[i];
} else if (arg == "--response-bias-value") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.bias_regex_value = std::stof(argv[i]);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
gpt_print_usage(argc, argv, default_params);

View file

@ -49,6 +49,9 @@ struct gpt_params {
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
std::string input_prefix = ""; // string to prefix user inputs with
std::string input_suffix = ""; // string to suffix user inputs with
std::string allowed_regex = ""; // regex string used to force prediction of matching tokens
std::string bias_regex = ""; // matching tokens are biased by bias_regex_value
float bias_regex_value = -1; // value to bias tokens matching bias_regex by
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string lora_adapter = ""; // lora adapter path

View file

@ -18,6 +18,10 @@
#include <string>
#include <vector>
#if defined(LLAMA_USE_BOOST)
#include <boost/regex.hpp>
#endif
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h>
#include <unistd.h>
@ -51,6 +55,8 @@ void sigint_handler(int signo) {
int main(int argc, char ** argv) {
gpt_params params;
params.model = "models/llama-7B/ggml-model.bin";
// boost::regex regex = boost::regex("(?:(?:\\([a-z A-Z 0-9]*, [a-z A-Z 0-9]*, [a-z A-Z 0-9]*\\))(?:<\\|>\\([a-z A-Z 0-9]*, [a-z A-Z 0-9]*, [a-z A-Z 0-9]*\\))*)|NONE");
// boost::regex negative_bias_regex = boost::regex("^NONE");
if (gpt_params_parse(argc, argv, params) == false) {
return 1;
@ -97,6 +103,12 @@ int main(int argc, char ** argv) {
params.prompt = gpt_random_prompt(rng);
}
#if defined(LLAMA_USE_BOOST)
boost::regex response_allowed_regex = boost::regex(params.allowed_regex);
boost::regex response_bias_regex = boost::regex(params.bias_regex);
#endif
// params.prompt = R"(// this function checks if the number n is prime
//bool is_prime(int n) {)";
@ -305,7 +317,7 @@ int main(int argc, char ** argv) {
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
std::vector<llama_token> embd;
std::string partial_completion;
while (n_remain != 0 || params.interactive) {
// predict
if (embd.size() > 0) {
@ -410,6 +422,15 @@ int main(int argc, char ** argv) {
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}
#if defined(LLAMA_USE_BOOST)
for (size_t i = 0; i < llama_get_num_logits(ctx); i++) {
if (!boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_allowed_regex, boost::match_partial))
logits[i] = -INFINITY;
else if (boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_bias_regex, boost::match_partial)) {
logits[i] += params.bias_regex_value;
}
}
#endif
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
@ -459,6 +480,7 @@ int main(int argc, char ** argv) {
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
partial_completion += llama_token_to_str(ctx, id);
}
// replace end of text token with newline token when in interactive mode