Add allowed response regex, response bias regex, and response bias value to main example
This commit is contained in:
parent
779d7969c0
commit
30754bbaf9
4 changed files with 49 additions and 1 deletions
5
Makefile
5
Makefile
|
@ -143,6 +143,11 @@ ifdef LLAMA_PERF
|
||||||
CFLAGS += -DGGML_PERF
|
CFLAGS += -DGGML_PERF
|
||||||
CXXFLAGS += -DGGML_PERF
|
CXXFLAGS += -DGGML_PERF
|
||||||
endif
|
endif
|
||||||
|
ifdef LLAMA_USE_BOOST
|
||||||
|
LDFLAGS += -L/usr/lib/x86_64-linux-gnu/ -lboost_regex
|
||||||
|
CFLAGS += -DLLAMA_USE_BOOST
|
||||||
|
CXXFLAGS += -DLLAMA_USE_BOOST
|
||||||
|
endif
|
||||||
ifneq ($(filter aarch64%,$(UNAME_M)),)
|
ifneq ($(filter aarch64%,$(UNAME_M)),)
|
||||||
# Apple M1, M2, etc.
|
# Apple M1, M2, etc.
|
||||||
# Raspberry Pi 3, 4, Zero 2 (64-bit)
|
# Raspberry Pi 3, 4, Zero 2 (64-bit)
|
||||||
|
|
|
@ -333,6 +333,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.input_suffix = argv[i];
|
params.input_suffix = argv[i];
|
||||||
|
} else if (arg == "--allowed-response-regex") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.allowed_regex = argv[i];
|
||||||
|
} else if (arg == "--response-bias-regex") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.bias_regex = argv[i];
|
||||||
|
} else if (arg == "--response-bias-value") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.bias_regex_value = std::stof(argv[i]);
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||||
gpt_print_usage(argc, argv, default_params);
|
gpt_print_usage(argc, argv, default_params);
|
||||||
|
|
|
@ -49,6 +49,9 @@ struct gpt_params {
|
||||||
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
|
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
|
||||||
std::string input_prefix = ""; // string to prefix user inputs with
|
std::string input_prefix = ""; // string to prefix user inputs with
|
||||||
std::string input_suffix = ""; // string to suffix user inputs with
|
std::string input_suffix = ""; // string to suffix user inputs with
|
||||||
|
std::string allowed_regex = ""; // regex string used to force prediction of matching tokens
|
||||||
|
std::string bias_regex = ""; // matching tokens are biased by bias_regex_value
|
||||||
|
float bias_regex_value = -1; // value to bias tokens matching bias_regex by
|
||||||
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
||||||
|
|
||||||
std::string lora_adapter = ""; // lora adapter path
|
std::string lora_adapter = ""; // lora adapter path
|
||||||
|
|
|
@ -18,6 +18,10 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#if defined(LLAMA_USE_BOOST)
|
||||||
|
#include <boost/regex.hpp>
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||||
#include <signal.h>
|
#include <signal.h>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
|
@ -51,6 +55,8 @@ void sigint_handler(int signo) {
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
params.model = "models/llama-7B/ggml-model.bin";
|
params.model = "models/llama-7B/ggml-model.bin";
|
||||||
|
// boost::regex regex = boost::regex("(?:(?:\\([a-z A-Z 0-9]*, [a-z A-Z 0-9]*, [a-z A-Z 0-9]*\\))(?:<\\|>\\([a-z A-Z 0-9]*, [a-z A-Z 0-9]*, [a-z A-Z 0-9]*\\))*)|NONE");
|
||||||
|
// boost::regex negative_bias_regex = boost::regex("^NONE");
|
||||||
|
|
||||||
if (gpt_params_parse(argc, argv, params) == false) {
|
if (gpt_params_parse(argc, argv, params) == false) {
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -97,6 +103,12 @@ int main(int argc, char ** argv) {
|
||||||
params.prompt = gpt_random_prompt(rng);
|
params.prompt = gpt_random_prompt(rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#if defined(LLAMA_USE_BOOST)
|
||||||
|
boost::regex response_allowed_regex = boost::regex(params.allowed_regex);
|
||||||
|
boost::regex response_bias_regex = boost::regex(params.bias_regex);
|
||||||
|
#endif
|
||||||
|
|
||||||
// params.prompt = R"(// this function checks if the number n is prime
|
// params.prompt = R"(// this function checks if the number n is prime
|
||||||
//bool is_prime(int n) {)";
|
//bool is_prime(int n) {)";
|
||||||
|
|
||||||
|
@ -305,7 +317,7 @@ int main(int argc, char ** argv) {
|
||||||
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
|
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
|
||||||
|
|
||||||
std::vector<llama_token> embd;
|
std::vector<llama_token> embd;
|
||||||
|
std::string partial_completion;
|
||||||
while (n_remain != 0 || params.interactive) {
|
while (n_remain != 0 || params.interactive) {
|
||||||
// predict
|
// predict
|
||||||
if (embd.size() > 0) {
|
if (embd.size() > 0) {
|
||||||
|
@ -410,6 +422,15 @@ int main(int argc, char ** argv) {
|
||||||
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
||||||
logits[it->first] += it->second;
|
logits[it->first] += it->second;
|
||||||
}
|
}
|
||||||
|
#if defined(LLAMA_USE_BOOST)
|
||||||
|
for (size_t i = 0; i < llama_get_num_logits(ctx); i++) {
|
||||||
|
if (!boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_allowed_regex, boost::match_partial))
|
||||||
|
logits[i] = -INFINITY;
|
||||||
|
else if (boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_bias_regex, boost::match_partial)) {
|
||||||
|
logits[i] += params.bias_regex_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
candidates.reserve(n_vocab);
|
candidates.reserve(n_vocab);
|
||||||
|
@ -459,6 +480,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
last_n_tokens.erase(last_n_tokens.begin());
|
last_n_tokens.erase(last_n_tokens.begin());
|
||||||
last_n_tokens.push_back(id);
|
last_n_tokens.push_back(id);
|
||||||
|
partial_completion += llama_token_to_str(ctx, id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// replace end of text token with newline token when in interactive mode
|
// replace end of text token with newline token when in interactive mode
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue