fix tests when llg is enabled
This commit is contained in:
parent
8cb12d43d6
commit
de269a1833
4 changed files with 9 additions and 7 deletions
|
@ -990,16 +990,17 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string json_schema_to_grammar(const json & schema) {
|
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
|
||||||
#ifdef LLAMA_USE_LLGUIDANCE
|
#ifdef LLAMA_USE_LLGUIDANCE
|
||||||
return "%llguidance {}\nstart: %json " + schema.dump();
|
if (!force_gbnf) {
|
||||||
#else
|
return "%llguidance {}\nstart: %json " + schema.dump();
|
||||||
|
}
|
||||||
|
#endif // LLAMA_USE_LLGUIDANCE
|
||||||
return build_grammar([&](const llama_grammar_builder & callbacks) {
|
return build_grammar([&](const llama_grammar_builder & callbacks) {
|
||||||
auto copy = schema;
|
auto copy = schema;
|
||||||
callbacks.resolve_refs(copy);
|
callbacks.resolve_refs(copy);
|
||||||
callbacks.add_schema("", copy);
|
callbacks.add_schema("", copy);
|
||||||
});
|
});
|
||||||
#endif // LLAMA_USE_LLGUIDANCE
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb) {
|
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb) {
|
||||||
|
|
|
@ -5,7 +5,8 @@
|
||||||
#define JSON_ASSERT GGML_ASSERT
|
#define JSON_ASSERT GGML_ASSERT
|
||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
|
|
||||||
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema);
|
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
|
||||||
|
bool force_gbnf = false);
|
||||||
|
|
||||||
struct llama_grammar_builder {
|
struct llama_grammar_builder {
|
||||||
std::function<std::string(const std::string &, const std::string &)> add_rule;
|
std::function<std::string(const std::string &, const std::string &)> add_rule;
|
||||||
|
|
|
@ -129,7 +129,7 @@ static void test_grammar(const std::string & test_desc, const std::string & gram
|
||||||
test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
|
test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
|
||||||
}
|
}
|
||||||
static void test_schema(const std::string & test_desc, const std::string & schema_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
|
static void test_schema(const std::string & test_desc, const std::string & schema_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
|
||||||
test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str)), passing_strings, failing_strings);
|
test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str), true), passing_strings, failing_strings);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void test_simple_grammar() {
|
static void test_simple_grammar() {
|
||||||
|
|
|
@ -1246,7 +1246,7 @@ int main() {
|
||||||
|
|
||||||
test_all("C++", [](const TestCase & tc) {
|
test_all("C++", [](const TestCase & tc) {
|
||||||
try {
|
try {
|
||||||
tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema)));
|
tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema), true));
|
||||||
tc.verify_status(SUCCESS);
|
tc.verify_status(SUCCESS);
|
||||||
} catch (const std::runtime_error & ex) {
|
} catch (const std::runtime_error & ex) {
|
||||||
fprintf(stderr, "Error: %s\n", ex.what());
|
fprintf(stderr, "Error: %s\n", ex.what());
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue