Add check for whether regex should be used

This commit is contained in:
Branden Butler 2023-05-10 17:48:37 -05:00
parent 58d848dadd
commit 827ac3a457
2 changed files with 8 additions and 6 deletions

View file

@ -51,7 +51,7 @@ struct gpt_params {
std::string input_suffix = ""; // string to suffix user inputs with
std::string allowed_regex = ""; // regex string used to force prediction of matching tokens
std::string bias_regex = ""; // matching tokens are biased by bias_regex_value
float bias_regex_value = -1; // value to bias tokens matching bias_regex by
float bias_regex_value = 0; // value to bias tokens matching bias_regex by
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string lora_adapter = ""; // lora adapter path

View file

@ -423,11 +423,13 @@ int main(int argc, char ** argv) {
logits[it->first] += it->second;
}
#if defined(LLAMA_USE_BOOST)
for (size_t i = 0; i < llama_get_num_logits(ctx); i++) {
if (!boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_allowed_regex, boost::match_partial))
logits[i] = -INFINITY;
else if (boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_bias_regex, boost::match_partial)) {
logits[i] += params.bias_regex_value;
if (params.allowed_regex != "" || params.bias_regex != "") {
for (size_t i = 0; i < llama_n_vocab(ctx); i++) {
if (!boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_allowed_regex, boost::match_partial))
logits[i] = -INFINITY;
else if (boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_bias_regex, boost::match_partial)) {
logits[i] += params.bias_regex_value;
}
}
}
#endif