fix tests when llg is enabled
This commit is contained in:
parent
8cb12d43d6
commit
de269a1833
4 changed files with 9 additions and 7 deletions
|
@ -990,16 +990,17 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
std::string json_schema_to_grammar(const json & schema) {
|
||||
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
|
||||
#ifdef LLAMA_USE_LLGUIDANCE
|
||||
if (!force_gbnf) {
|
||||
return "%llguidance {}\nstart: %json " + schema.dump();
|
||||
#else
|
||||
}
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
return build_grammar([&](const llama_grammar_builder & callbacks) {
|
||||
auto copy = schema;
|
||||
callbacks.resolve_refs(copy);
|
||||
callbacks.add_schema("", copy);
|
||||
});
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
}
|
||||
|
||||
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb) {
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
|
||||
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema);
|
||||
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
|
||||
bool force_gbnf = false);
|
||||
|
||||
struct llama_grammar_builder {
|
||||
std::function<std::string(const std::string &, const std::string &)> add_rule;
|
||||
|
|
|
@ -129,7 +129,7 @@ static void test_grammar(const std::string & test_desc, const std::string & gram
|
|||
test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
|
||||
}
|
||||
static void test_schema(const std::string & test_desc, const std::string & schema_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
|
||||
test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str)), passing_strings, failing_strings);
|
||||
test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str), true), passing_strings, failing_strings);
|
||||
}
|
||||
|
||||
static void test_simple_grammar() {
|
||||
|
|
|
@ -1246,7 +1246,7 @@ int main() {
|
|||
|
||||
test_all("C++", [](const TestCase & tc) {
|
||||
try {
|
||||
tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema)));
|
||||
tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema), true));
|
||||
tc.verify_status(SUCCESS);
|
||||
} catch (const std::runtime_error & ex) {
|
||||
fprintf(stderr, "Error: %s\n", ex.what());
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue