Add allowed response regex, response bias regex, and response bias value to main example
This commit is contained in:
parent
779d7969c0
commit
30754bbaf9
4 changed files with 49 additions and 1 deletions
5
Makefile
5
Makefile
|
@ -143,6 +143,11 @@ ifdef LLAMA_PERF
|
|||
CFLAGS += -DGGML_PERF
|
||||
CXXFLAGS += -DGGML_PERF
|
||||
endif
|
||||
ifdef LLAMA_USE_BOOST
|
||||
LDFLAGS += -L/usr/lib/x86_64-linux-gnu/ -lboost_regex
|
||||
CFLAGS += -DLLAMA_USE_BOOST
|
||||
CXXFLAGS += -DLLAMA_USE_BOOST
|
||||
endif
|
||||
ifneq ($(filter aarch64%,$(UNAME_M)),)
|
||||
# Apple M1, M2, etc.
|
||||
# Raspberry Pi 3, 4, Zero 2 (64-bit)
|
||||
|
|
|
@ -333,6 +333,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
break;
|
||||
}
|
||||
params.input_suffix = argv[i];
|
||||
} else if (arg == "--allowed-response-regex") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.allowed_regex = argv[i];
|
||||
} else if (arg == "--response-bias-regex") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.bias_regex = argv[i];
|
||||
} else if (arg == "--response-bias-value") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.bias_regex_value = std::stof(argv[i]);
|
||||
} else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
gpt_print_usage(argc, argv, default_params);
|
||||
|
|
|
@ -49,6 +49,9 @@ struct gpt_params {
|
|||
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
|
||||
std::string input_prefix = ""; // string to prefix user inputs with
|
||||
std::string input_suffix = ""; // string to suffix user inputs with
|
||||
std::string allowed_regex = ""; // regex string used to force prediction of matching tokens
|
||||
std::string bias_regex = ""; // matching tokens are biased by bias_regex_value
|
||||
float bias_regex_value = -1; // value to bias tokens matching bias_regex by
|
||||
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
||||
|
||||
std::string lora_adapter = ""; // lora adapter path
|
||||
|
|
|
@ -18,6 +18,10 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#if defined(LLAMA_USE_BOOST)
|
||||
#include <boost/regex.hpp>
|
||||
#endif
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
#include <signal.h>
|
||||
#include <unistd.h>
|
||||
|
@ -51,6 +55,8 @@ void sigint_handler(int signo) {
|
|||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
params.model = "models/llama-7B/ggml-model.bin";
|
||||
// boost::regex regex = boost::regex("(?:(?:\\([a-z A-Z 0-9]*, [a-z A-Z 0-9]*, [a-z A-Z 0-9]*\\))(?:<\\|>\\([a-z A-Z 0-9]*, [a-z A-Z 0-9]*, [a-z A-Z 0-9]*\\))*)|NONE");
|
||||
// boost::regex negative_bias_regex = boost::regex("^NONE");
|
||||
|
||||
if (gpt_params_parse(argc, argv, params) == false) {
|
||||
return 1;
|
||||
|
@ -97,6 +103,12 @@ int main(int argc, char ** argv) {
|
|||
params.prompt = gpt_random_prompt(rng);
|
||||
}
|
||||
|
||||
|
||||
#if defined(LLAMA_USE_BOOST)
|
||||
boost::regex response_allowed_regex = boost::regex(params.allowed_regex);
|
||||
boost::regex response_bias_regex = boost::regex(params.bias_regex);
|
||||
#endif
|
||||
|
||||
// params.prompt = R"(// this function checks if the number n is prime
|
||||
//bool is_prime(int n) {)";
|
||||
|
||||
|
@ -305,7 +317,7 @@ int main(int argc, char ** argv) {
|
|||
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
|
||||
|
||||
std::vector<llama_token> embd;
|
||||
|
||||
std::string partial_completion;
|
||||
while (n_remain != 0 || params.interactive) {
|
||||
// predict
|
||||
if (embd.size() > 0) {
|
||||
|
@ -410,6 +422,15 @@ int main(int argc, char ** argv) {
|
|||
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
||||
logits[it->first] += it->second;
|
||||
}
|
||||
#if defined(LLAMA_USE_BOOST)
|
||||
for (size_t i = 0; i < llama_get_num_logits(ctx); i++) {
|
||||
if (!boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_allowed_regex, boost::match_partial))
|
||||
logits[i] = -INFINITY;
|
||||
else if (boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_bias_regex, boost::match_partial)) {
|
||||
logits[i] += params.bias_regex_value;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
@ -459,6 +480,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
last_n_tokens.erase(last_n_tokens.begin());
|
||||
last_n_tokens.push_back(id);
|
||||
partial_completion += llama_token_to_str(ctx, id);
|
||||
}
|
||||
|
||||
// replace end of text token with newline token when in interactive mode
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue