From 827ac3a457b5588d4695a2d1dd54da4dcf31e6e5 Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Wed, 10 May 2023 17:48:37 -0500 Subject: [PATCH] Add check for whether regex should be used --- examples/common.h | 2 +- examples/main/main.cpp | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/common.h b/examples/common.h index e7293cd8f..0c89e586b 100644 --- a/examples/common.h +++ b/examples/common.h @@ -51,7 +51,7 @@ struct gpt_params { std::string input_suffix = ""; // string to suffix user inputs with std::string allowed_regex = ""; // regex string used to force prediction of matching tokens std::string bias_regex = ""; // matching tokens are biased by bias_regex_value - float bias_regex_value = -1; // value to bias tokens matching bias_regex by + float bias_regex_value = 0; // value to bias tokens matching bias_regex by std::vector antiprompt; // string upon seeing which more user input is prompted std::string lora_adapter = ""; // lora adapter path diff --git a/examples/main/main.cpp b/examples/main/main.cpp index e79540d02..2e98504fe 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -423,11 +423,13 @@ int main(int argc, char ** argv) { logits[it->first] += it->second; } #if defined(LLAMA_USE_BOOST) - for (size_t i = 0; i < llama_get_num_logits(ctx); i++) { - if (!boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_allowed_regex, boost::match_partial)) - logits[i] = -INFINITY; - else if (boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_bias_regex, boost::match_partial)) { - logits[i] += params.bias_regex_value; + if (params.allowed_regex != "" || params.bias_regex != "") { + for (size_t i = 0; i < llama_n_vocab(ctx); i++) { + if (!boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_allowed_regex, boost::match_partial)) + logits[i] = -INFINITY; + else if (boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_bias_regex, boost::match_partial)) { + logits[i] += params.bias_regex_value; + } } } #endif