fix tests when llg is enabled

This commit is contained in:
Michal Moskal 2025-01-26 08:02:37 -08:00
parent 8cb12d43d6
commit de269a1833
4 changed files with 9 additions and 7 deletions

View file

@ -990,16 +990,17 @@ public:
}
};
std::string json_schema_to_grammar(const json & schema) {
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
#ifdef LLAMA_USE_LLGUIDANCE
if (!force_gbnf) {
return "%llguidance {}\nstart: %json " + schema.dump();
#else
}
#endif // LLAMA_USE_LLGUIDANCE
return build_grammar([&](const llama_grammar_builder & callbacks) {
auto copy = schema;
callbacks.resolve_refs(copy);
callbacks.add_schema("", copy);
});
#endif // LLAMA_USE_LLGUIDANCE
}
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb) {

View file

@ -5,7 +5,8 @@
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema);
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
bool force_gbnf = false);
struct llama_grammar_builder {
std::function<std::string(const std::string &, const std::string &)> add_rule;

View file

@ -129,7 +129,7 @@ static void test_grammar(const std::string & test_desc, const std::string & gram
test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
}
static void test_schema(const std::string & test_desc, const std::string & schema_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str)), passing_strings, failing_strings);
test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str), true), passing_strings, failing_strings);
}
static void test_simple_grammar() {

View file

@ -1246,7 +1246,7 @@ int main() {
test_all("C++", [](const TestCase & tc) {
try {
tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema)));
tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema), true));
tc.verify_status(SUCCESS);
} catch (const std::runtime_error & ex) {
fprintf(stderr, "Error: %s\n", ex.what());