Add check for whether regex should be used
This commit is contained in:
parent
58d848dadd
commit
827ac3a457
2 changed files with 8 additions and 6 deletions
|
@ -51,7 +51,7 @@ struct gpt_params {
|
||||||
std::string input_suffix = ""; // string to suffix user inputs with
|
std::string input_suffix = ""; // string to suffix user inputs with
|
||||||
std::string allowed_regex = ""; // regex string used to force prediction of matching tokens
|
std::string allowed_regex = ""; // regex string used to force prediction of matching tokens
|
||||||
std::string bias_regex = ""; // matching tokens are biased by bias_regex_value
|
std::string bias_regex = ""; // matching tokens are biased by bias_regex_value
|
||||||
float bias_regex_value = -1; // value to bias tokens matching bias_regex by
|
float bias_regex_value = 0; // value to bias tokens matching bias_regex by
|
||||||
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
||||||
|
|
||||||
std::string lora_adapter = ""; // lora adapter path
|
std::string lora_adapter = ""; // lora adapter path
|
||||||
|
|
|
@ -423,13 +423,15 @@ int main(int argc, char ** argv) {
|
||||||
logits[it->first] += it->second;
|
logits[it->first] += it->second;
|
||||||
}
|
}
|
||||||
#if defined(LLAMA_USE_BOOST)
|
#if defined(LLAMA_USE_BOOST)
|
||||||
for (size_t i = 0; i < llama_get_num_logits(ctx); i++) {
|
if (params.allowed_regex != "" || params.bias_regex != "") {
|
||||||
|
for (size_t i = 0; i < llama_n_vocab(ctx); i++) {
|
||||||
if (!boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_allowed_regex, boost::match_partial))
|
if (!boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_allowed_regex, boost::match_partial))
|
||||||
logits[i] = -INFINITY;
|
logits[i] = -INFINITY;
|
||||||
else if (boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_bias_regex, boost::match_partial)) {
|
else if (boost::regex_match(partial_completion + llama_token_to_str(ctx, i), response_bias_regex, boost::match_partial)) {
|
||||||
logits[i] += params.bias_regex_value;
|
logits[i] += params.bias_regex_value;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue