Revising GBNF validator program to be much simpler.
This commit is contained in:
parent
924ce1dce7
commit
f28bfa3876
5 changed files with 161 additions and 20 deletions
4
Makefile
4
Makefile
|
@ -818,6 +818,10 @@ passkey: examples/passkey/passkey.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
|
gbnf-validator: examples/gbnf-validator/gbnf-validator.cpp ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
|
||||||
|
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||||
|
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
ifeq ($(UNAME_S),Darwin)
|
ifeq ($(UNAME_S),Darwin)
|
||||||
swift: examples/batched.swift
|
swift: examples/batched.swift
|
||||||
(cd examples/batched.swift; make build)
|
(cd examples/batched.swift; make build)
|
||||||
|
|
5
examples/gbnf-validator/CMakeLists.txt
Normal file
5
examples/gbnf-validator/CMakeLists.txt
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
set(TARGET gbnf-validator)
|
||||||
|
add_executable(${TARGET} gbnf-validator.cpp)
|
||||||
|
install(TARGETS ${TARGET} RUNTIME)
|
||||||
|
target_link_libraries(${TARGET} PRIVATE common grammar-parser llama ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
122
examples/gbnf-validator/gbnf-validator.cpp
Normal file
122
examples/gbnf-validator/gbnf-validator.cpp
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
#define LLAMA_API_INTERNAL
|
||||||
|
|
||||||
|
#include "grammar-parser.h"
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "llama.h"
|
||||||
|
#include "unicode.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
static bool llama_sample_grammar_string(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
|
||||||
|
auto decoded = decode_utf8(input_str, {});
|
||||||
|
const auto & code_points = decoded.first;
|
||||||
|
|
||||||
|
size_t pos = 0;
|
||||||
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||||
|
auto prev_stacks = grammar->stacks;
|
||||||
|
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
|
||||||
|
if (grammar->stacks.empty()) {
|
||||||
|
error_pos = pos;
|
||||||
|
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
|
||||||
|
grammar->stacks = prev_stacks;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
++pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto & stack : grammar->stacks) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
error_pos = pos;
|
||||||
|
error_msg = "Unexpected end of input";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void print_error_message(const std::string & input_str, size_t error_pos, const std::string & error_msg) {
|
||||||
|
std::cout << "Input string is invalid according to the grammar." << std::endl;
|
||||||
|
std::cout << "Error: " << error_msg << " at position " << std::to_string(error_pos) << std::endl;
|
||||||
|
std::cout << std::endl;
|
||||||
|
std::cout << "Input string:" << std::endl;
|
||||||
|
std::cout << input_str.substr(0, error_pos);
|
||||||
|
if (error_pos < input_str.size()) {
|
||||||
|
std::cout << "\033[1;31m" << input_str[error_pos];
|
||||||
|
if (error_pos+1 < input_str.size()) {
|
||||||
|
std::cout << "\033[0;31m" << input_str.substr(error_pos+1);
|
||||||
|
}
|
||||||
|
std::cout << "\033[0m" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
if (argc != 3) {
|
||||||
|
std::cerr << "Usage: " << argv[0] << " <grammar_file> <input_file>" << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string grammar_file = argv[1];
|
||||||
|
const std::string input_file = argv[2];
|
||||||
|
|
||||||
|
// Read the GBNF grammar file
|
||||||
|
std::ifstream grammar_stream(grammar_file);
|
||||||
|
if (!grammar_stream.is_open()) {
|
||||||
|
std::cerr << "Failed to open grammar file: " << grammar_file << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string grammar_str((std::istreambuf_iterator<char>(grammar_stream)), std::istreambuf_iterator<char>());
|
||||||
|
grammar_stream.close();
|
||||||
|
|
||||||
|
// Parse the GBNF grammar
|
||||||
|
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
||||||
|
|
||||||
|
// will be empty (default) if there are parse errors
|
||||||
|
if (parsed_grammar.rules.empty()) {
|
||||||
|
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that there is a "root" node.
|
||||||
|
if (parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()) {
|
||||||
|
fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
||||||
|
|
||||||
|
// Create the LLAMA grammar
|
||||||
|
auto grammar = llama_grammar_init(
|
||||||
|
grammar_rules.data(),
|
||||||
|
grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||||
|
|
||||||
|
// Read the input file
|
||||||
|
std::ifstream input_stream(input_file);
|
||||||
|
if (!input_stream.is_open()) {
|
||||||
|
std::cerr << "Failed to open input file: " << input_file << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string input_str((std::istreambuf_iterator<char>(input_stream)), std::istreambuf_iterator<char>());
|
||||||
|
input_stream.close();
|
||||||
|
|
||||||
|
// Validate the input string against the grammar
|
||||||
|
size_t error_pos;
|
||||||
|
std::string error_msg;
|
||||||
|
bool is_valid = llama_sample_grammar_string(grammar, input_str, error_pos, error_msg);
|
||||||
|
|
||||||
|
if (is_valid) {
|
||||||
|
std::cout << "Input string is valid according to the grammar." << std::endl;
|
||||||
|
} else {
|
||||||
|
print_error_message(input_str, error_pos, error_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
llama_grammar_free(grammar);
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
22
llama.cpp
22
llama.cpp
|
@ -10508,28 +10508,10 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
||||||
// grammar - internal
|
// grammar - internal
|
||||||
//
|
//
|
||||||
|
|
||||||
struct llama_partial_utf8 {
|
|
||||||
uint32_t value; // bit value so far (unshifted)
|
|
||||||
int n_remain; // num bytes remaining; -1 indicates invalid sequence
|
|
||||||
};
|
|
||||||
|
|
||||||
struct llama_grammar {
|
|
||||||
const std::vector<std::vector<llama_grammar_element>> rules;
|
|
||||||
std::vector<std::vector<const llama_grammar_element *>> stacks;
|
|
||||||
|
|
||||||
// buffer for partially generated UTF-8 sequence from accepted tokens
|
|
||||||
llama_partial_utf8 partial_utf8;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct llama_grammar_candidate {
|
|
||||||
size_t index;
|
|
||||||
const uint32_t * code_points;
|
|
||||||
llama_partial_utf8 partial_utf8;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
|
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
|
||||||
// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
|
// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
|
||||||
static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
||||||
const std::string & src,
|
const std::string & src,
|
||||||
llama_partial_utf8 partial_start) {
|
llama_partial_utf8 partial_start) {
|
||||||
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
||||||
|
@ -10731,7 +10713,7 @@ static void llama_grammar_advance_stack(
|
||||||
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
||||||
// produces the N possible stacks if the given char is accepted at those
|
// produces the N possible stacks if the given char is accepted at those
|
||||||
// positions
|
// positions
|
||||||
static std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
|
std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
|
||||||
const std::vector<std::vector<llama_grammar_element>> & rules,
|
const std::vector<std::vector<llama_grammar_element>> & rules,
|
||||||
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
|
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
|
||||||
const uint32_t chr) {
|
const uint32_t chr) {
|
||||||
|
|
28
llama.h
28
llama.h
|
@ -987,10 +987,38 @@ extern "C" {
|
||||||
|
|
||||||
struct ggml_tensor;
|
struct ggml_tensor;
|
||||||
|
|
||||||
|
struct llama_partial_utf8 {
|
||||||
|
uint32_t value; // bit value so far (unshifted)
|
||||||
|
int n_remain; // num bytes remaining; -1 indicates invalid sequence
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_grammar {
|
||||||
|
const std::vector<std::vector<llama_grammar_element>> rules;
|
||||||
|
std::vector<std::vector<const llama_grammar_element *>> stacks;
|
||||||
|
|
||||||
|
// buffer for partially generated UTF-8 sequence from accepted tokens
|
||||||
|
llama_partial_utf8 partial_utf8;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_grammar_candidate {
|
||||||
|
size_t index;
|
||||||
|
const uint32_t * code_points;
|
||||||
|
llama_partial_utf8 partial_utf8;
|
||||||
|
};
|
||||||
|
|
||||||
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
|
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
|
||||||
struct llama_context * ctx
|
struct llama_context * ctx
|
||||||
);
|
);
|
||||||
|
|
||||||
|
std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
|
||||||
|
const std::vector<std::vector<llama_grammar_element>> & rules,
|
||||||
|
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
|
||||||
|
const uint32_t chr);
|
||||||
|
|
||||||
|
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
||||||
|
const std::string & src,
|
||||||
|
llama_partial_utf8 partial_start);
|
||||||
|
|
||||||
#endif // LLAMA_API_INTERNAL
|
#endif // LLAMA_API_INTERNAL
|
||||||
|
|
||||||
#endif // LLAMA_H
|
#endif // LLAMA_H
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue