Add left recursion check: quit early instead of going into an infinite loop
This commit is contained in:
parent
7bd4ffb780
commit
65176e78ec
2 changed files with 109 additions and 0 deletions
72
llama.cpp
72
llama.cpp
|
@ -13178,6 +13178,25 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
|
|||
// grammar - external
|
||||
//
|
||||
|
||||
enum detect_left_recursion_status {
|
||||
// haven't searched this nonterminal
|
||||
LLAMA_LEFT_REC_NOT_SEARCHED = 0,
|
||||
|
||||
// searching this nonterminal in progress
|
||||
LLAMA_LEFT_REC_IN_PROGRESS = 1,
|
||||
|
||||
// finished searching this nonterminal
|
||||
LLAMA_LEFT_REC_FINISHED_SEARCH = 2,
|
||||
|
||||
// detected a cycle
|
||||
LLAMA_LEFT_REC_FOUND_CYCLE = 3,
|
||||
};
|
||||
|
||||
static void detect_left_recursion(
|
||||
const std::vector<std::vector<llama_grammar_element>> & rules,
|
||||
size_t rule_index,
|
||||
std::vector<enum detect_left_recursion_status> * rules_visited);
|
||||
|
||||
struct llama_grammar * llama_grammar_init(
|
||||
const llama_grammar_element ** rules,
|
||||
size_t n_rules,
|
||||
|
@ -13193,6 +13212,17 @@ struct llama_grammar * llama_grammar_init(
|
|||
vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
|
||||
}
|
||||
|
||||
// Check for left recursion
|
||||
std::vector<enum detect_left_recursion_status> rules_visited(n_rules);
|
||||
for (size_t i = 0; i < n_rules; i++) {
|
||||
detect_left_recursion(vec_rules, i, &rules_visited);
|
||||
}
|
||||
|
||||
auto iter = std::find(rules_visited.begin(), rules_visited.end(), LLAMA_LEFT_REC_FOUND_CYCLE);
|
||||
if (iter != rules_visited.end()) {
|
||||
throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %d", (int)(iter - rules_visited.begin())));
|
||||
}
|
||||
|
||||
// loop over alternates of start rule to build initial stacks
|
||||
std::vector<std::vector<const llama_grammar_element *>> stacks;
|
||||
pos = vec_rules[start_rule_index].data();
|
||||
|
@ -13215,9 +13245,51 @@ struct llama_grammar * llama_grammar_init(
|
|||
}
|
||||
} while (true);
|
||||
|
||||
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
||||
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
||||
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
|
||||
}
|
||||
|
||||
static void detect_left_recursion(
|
||||
const std::vector<std::vector<llama_grammar_element>> & rules,
|
||||
size_t rule_index,
|
||||
std::vector<enum detect_left_recursion_status> * rules_visited) {
|
||||
|
||||
int visit_status = (*rules_visited)[rule_index];
|
||||
if (visit_status == LLAMA_LEFT_REC_IN_PROGRESS) {
|
||||
// in progress -- we're in a cycle
|
||||
(*rules_visited)[rule_index] = LLAMA_LEFT_REC_FOUND_CYCLE;
|
||||
return;
|
||||
} else if (visit_status == LLAMA_LEFT_REC_NOT_SEARCHED) {
|
||||
// haven't visited yet. mark in progress, recurse, then mark complete.
|
||||
// mark in progress
|
||||
(*rules_visited)[rule_index] = LLAMA_LEFT_REC_IN_PROGRESS;
|
||||
|
||||
// recurse
|
||||
const std::vector<llama_grammar_element> & rule = rules[rule_index];
|
||||
size_t i = 0;
|
||||
do {
|
||||
if (rule[i].type == LLAMA_GRETYPE_RULE_REF) {
|
||||
detect_left_recursion(rules, (size_t)rule[i].value, rules_visited);
|
||||
}
|
||||
while (!llama_grammar_is_end_of_sequence(&rule[i])) {
|
||||
i++;
|
||||
}
|
||||
i++;
|
||||
} while (i < rule.size());
|
||||
|
||||
// mark complete, but only if the recursive call didn't mark a cycle.
|
||||
// that doesn't mean there's definitely no cycle *for this rule* -- the recursive call
|
||||
// might have found a different cycle and stopped early.
|
||||
if ((*rules_visited)[rule_index] == LLAMA_LEFT_REC_IN_PROGRESS) {
|
||||
(*rules_visited)[rule_index] = LLAMA_LEFT_REC_FINISHED_SEARCH;
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
void llama_grammar_free(struct llama_grammar * grammar) {
|
||||
delete grammar;
|
||||
}
|
||||
|
|
|
@ -28,6 +28,18 @@ static llama_grammar* build_grammar(const std::string & grammar_str) {
|
|||
return grammar;
|
||||
}
|
||||
|
||||
static bool test_build_grammar_fails(const std::string & grammar_str) {
|
||||
bool grammar_fails = false;
|
||||
try {
|
||||
build_grammar(grammar_str);
|
||||
fprintf(stderr, "❌ Expected build failure, but succeeded: %s\n", grammar_str.c_str());
|
||||
} catch (const std::exception & err) {
|
||||
grammar_fails = true;
|
||||
fprintf(stdout, "✅︎\n");
|
||||
}
|
||||
return grammar_fails;
|
||||
}
|
||||
|
||||
static bool match_string(const std::string & input, llama_grammar* grammar) {
|
||||
auto decoded = decode_utf8(input, {});
|
||||
|
||||
|
@ -320,6 +332,30 @@ number ::= [0-9]+)""";
|
|||
fprintf(stderr, " ✅︎ Passed\n");
|
||||
}
|
||||
|
||||
static void test_failure_left_recursion() {
|
||||
fprintf(stderr, "⚫ Testing left recursion detection:\n");
|
||||
|
||||
// Test simple left recursion detection
|
||||
const std::string simple_str = R"""(root ::= "a" | root "a")""";
|
||||
assert(test_build_grammar_fails(simple_str));
|
||||
|
||||
// Test more complicated left recursion detection
|
||||
const std::string medium_str = R"""(
|
||||
root ::= asdf
|
||||
asdf ::= "a" | asdf "a"
|
||||
)""";
|
||||
assert(test_build_grammar_fails(medium_str));
|
||||
|
||||
// Test even more complicated left recursion detection
|
||||
const std::string hard_str = R"""(
|
||||
root ::= asdf
|
||||
asdf ::= "a" | foo "b"
|
||||
foo ::= "c" | asdf "d" | "e")""";
|
||||
assert(test_build_grammar_fails(hard_str));
|
||||
|
||||
fprintf(stderr, " ✅︎ Passed\n");
|
||||
}
|
||||
|
||||
int main() {
|
||||
fprintf(stdout, "Running grammar integration tests...\n");
|
||||
test_simple_grammar();
|
||||
|
@ -327,6 +363,7 @@ int main() {
|
|||
test_quantifiers();
|
||||
test_failure_missing_root();
|
||||
test_failure_missing_reference();
|
||||
test_failure_left_recursion();
|
||||
fprintf(stdout, "All tests passed.\n");
|
||||
return 0;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue