diff --git a/tests/test-grammar-parser.cpp b/tests/test-grammar-parser.cpp index 7422abfc2..ba3d0ccab 100644 --- a/tests/test-grammar-parser.cpp +++ b/tests/test-grammar-parser.cpp @@ -7,10 +7,37 @@ #include +static const char * type_str(llama_gretype type) { + switch (type) { + case LLAMA_GRETYPE_CHAR: return "LLAMA_GRETYPE_CHAR"; + case LLAMA_GRETYPE_CHAR_NOT: return "LLAMA_GRETYPE_CHAR_NOT"; + case LLAMA_GRETYPE_CHAR_ALT: return "LLAMA_GRETYPE_CHAR_ALT"; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: return "LLAMA_GRETYPE_CHAR_RNG_UPPER"; + case LLAMA_GRETYPE_RULE_REF: return "LLAMA_GRETYPE_RULE_REF"; + case LLAMA_GRETYPE_ALT: return "LLAMA_GRETYPE_ALT"; + case LLAMA_GRETYPE_END: return "LLAMA_GRETYPE_END"; + default: return "?"; + } +} static void verify_parsing(const char *grammar_bytes, const std::vector> expected, const std::vector &expected_rules) { uint32_t index = 0; grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_bytes); + + auto print_all = [&]() { + fprintf(stderr, "Code to update expectation:\n"); + fprintf(stderr, " verify_parsing(R\"\"\"(%s)\"\"\", {\n", grammar_bytes); + for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) { + fprintf(stderr, " {\"%s\", %u},\n", it->first.c_str(), it->second); + } + fprintf(stderr, " }, {\n"); + for (auto rule : parsed_grammar.rules) { + for (uint32_t i = 0; i < rule.size(); i++) { + fprintf(stderr, " {%s, %u},\n", type_str(rule[i].type), rule[i].value); + } + } + fprintf(stderr, " });\n"); + }; for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) { std::string key = it->first; @@ -20,9 +47,11 @@ static void verify_parsing(const char *grammar_bytes, const std::vector