From de269a18332de2621d0b8ee86d1a207b96a96b59 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sun, 26 Jan 2025 08:02:37 -0800 Subject: [PATCH] fix tests when llg is enabled --- common/json-schema-to-grammar.cpp | 9 +++++---- common/json-schema-to-grammar.h | 3 ++- tests/test-grammar-integration.cpp | 2 +- tests/test-json-schema-to-grammar.cpp | 2 +- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 1e23748ae..021d1aba0 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -990,16 +990,17 @@ public: } }; -std::string json_schema_to_grammar(const json & schema) { +std::string json_schema_to_grammar(const json & schema, bool force_gbnf) { #ifdef LLAMA_USE_LLGUIDANCE - return "%llguidance {}\nstart: %json " + schema.dump(); -#else + if (!force_gbnf) { + return "%llguidance {}\nstart: %json " + schema.dump(); + } +#endif // LLAMA_USE_LLGUIDANCE return build_grammar([&](const llama_grammar_builder & callbacks) { auto copy = schema; callbacks.resolve_refs(copy); callbacks.add_schema("", copy); }); -#endif // LLAMA_USE_LLGUIDANCE } std::string build_grammar(const std::function & cb) { diff --git a/common/json-schema-to-grammar.h b/common/json-schema-to-grammar.h index 4f43ab3a5..12fb47828 100644 --- a/common/json-schema-to-grammar.h +++ b/common/json-schema-to-grammar.h @@ -5,7 +5,8 @@ #define JSON_ASSERT GGML_ASSERT #include "json.hpp" -std::string json_schema_to_grammar(const nlohmann::ordered_json & schema); +std::string json_schema_to_grammar(const nlohmann::ordered_json & schema, + bool force_gbnf = false); struct llama_grammar_builder { std::function add_rule; diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index e1bdbb925..aa3685956 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -129,7 +129,7 @@ static void test_grammar(const std::string & test_desc, const std::string & gram test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings); } static void test_schema(const std::string & test_desc, const std::string & schema_str, const std::vector & passing_strings, const std::vector & failing_strings) { - test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str)), passing_strings, failing_strings); + test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str), true), passing_strings, failing_strings); } static void test_simple_grammar() { diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index 9d2db91f5..f38994c92 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -1246,7 +1246,7 @@ int main() { test_all("C++", [](const TestCase & tc) { try { - tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema))); + tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema), true)); tc.verify_status(SUCCESS); } catch (const std::runtime_error & ex) { fprintf(stderr, "Error: %s\n", ex.what());