Merge branch 'master' into sycl_readme_update

This commit is contained in:
Ouadie EL FAROUKI 2024-03-21 13:12:32 +00:00
commit a80267a110
40 changed files with 12162 additions and 6703 deletions

View file

@ -21,6 +21,118 @@ env:
GGML_N_THREADS: 1
jobs:
macOS-latest-cmake-arm64:
runs-on: macos-14
steps:
- name: Clone
id: checkout
uses: actions/checkout@v3
- name: Dependencies
id: depends
continue-on-error: true
run: |
brew update
- name: Build
id: cmake_build
run: |
sysctl -a
mkdir build
cd build
cmake -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_METAL_EMBED_LIBRARY=ON ..
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu)
- name: Test
id: cmake_test
run: |
cd build
ctest -L main --verbose --timeout 900
- name: Determine tag name
id: tag
shell: bash
run: |
BUILD_NUMBER="$(git rev-list --count HEAD)"
SHORT_HASH="$(git rev-parse --short=7 HEAD)"
if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then
echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT
else
SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-')
echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT
fi
- name: Pack artifacts
id: pack_artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
run: |
cp LICENSE ./build/bin/
zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip ./build/bin/*
- name: Upload artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: actions/upload-artifact@v3
with:
path: |
llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip
macOS-latest-cmake-x64:
runs-on: macos-latest
steps:
- name: Clone
id: checkout
uses: actions/checkout@v3
- name: Dependencies
id: depends
continue-on-error: true
run: |
brew update
- name: Build
id: cmake_build
run: |
sysctl -a
mkdir build
cd build
cmake -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_METAL_EMBED_LIBRARY=ON ..
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu)
- name: Test
id: cmake_test
run: |
cd build
ctest -L main --verbose --timeout 900
- name: Determine tag name
id: tag
shell: bash
run: |
BUILD_NUMBER="$(git rev-list --count HEAD)"
SHORT_HASH="$(git rev-parse --short=7 HEAD)"
if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then
echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT
else
SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-')
echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT
fi
- name: Pack artifacts
id: pack_artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
run: |
cp LICENSE ./build/bin/
zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip ./build/bin/*
- name: Upload artifacts
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
uses: actions/upload-artifact@v3
with:
path: |
llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip
ubuntu-focal-make:
runs-on: ubuntu-20.04
@ -748,6 +860,8 @@ jobs:
- macOS-latest-cmake
- windows-latest-cmake
- windows-latest-cmake-cublas
- macOS-latest-cmake-arm64
- macOS-latest-cmake-x64
steps:
- name: Clone

3
.gitignore vendored
View file

@ -11,7 +11,10 @@
*.gcda
*.dot
*.bat
*.tmp
*.metallib
*.etag
*.lastModified
.DS_Store
.build/
.cache/

View file

@ -9,7 +9,8 @@ TEST_TARGETS = \
tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt \
tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama \
tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama tests/test-tokenizer-1-bpe tests/test-rope \
tests/test-backend-ops tests/test-model-load-cancel tests/test-autorelease
tests/test-backend-ops tests/test-model-load-cancel tests/test-autorelease \
tests/test-json-schema-to-grammar
# Code coverage output files
COV_TARGETS = *.gcno tests/*.gcno *.gcda tests/*.gcda *.gcov tests/*.gcov lcov-report gcovr-report
@ -666,6 +667,9 @@ console.o: common/console.cpp common/console.h
grammar-parser.o: common/grammar-parser.cpp common/grammar-parser.h
$(CXX) $(CXXFLAGS) -c $< -o $@
json-schema-to-grammar.o: common/json-schema-to-grammar.cpp common/json-schema-to-grammar.h
$(CXX) $(CXXFLAGS) -c $< -o $@
train.o: common/train.cpp common/train.h
$(CXX) $(CXXFLAGS) -c $< -o $@
@ -745,7 +749,7 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h common/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp json-schema-to-grammar.o common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
@ -865,6 +869,10 @@ tests/test-double-float: tests/test-double-float.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp json-schema-to-grammar.o ggml.o llama.o grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
tests/test-grad0: tests/test-grad0.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

View file

@ -122,6 +122,7 @@ pub fn build(b: *std.build.Builder) !void {
const console = make.obj("console", "common/console.cpp");
const sampling = make.obj("sampling", "common/sampling.cpp");
const grammar_parser = make.obj("grammar-parser", "common/grammar-parser.cpp");
const json_schema_to_grammar = make.obj("json-schema-to-grammar", "common/json-schema-to-grammar.cpp");
const train = make.obj("train", "common/train.cpp");
const clip = make.obj("clip", "examples/llava/clip.cpp");
const llava = make.obj("llava", "examples/llava/llava.cpp");
@ -133,7 +134,7 @@ pub fn build(b: *std.build.Builder) !void {
_ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, train });
_ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, train });
const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, sampling, grammar_parser, clip, llava });
const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, sampling, grammar_parser, json_schema_to_grammar, clip, llava });
if (server.target.isWindows()) {
server.linkSystemLibrary("ws2_32");
}

View file

@ -47,6 +47,8 @@ if (BUILD_SHARED_LIBS)
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
set(TARGET json-schema-to-grammar)
add_library(${TARGET} OBJECT json-schema-to-grammar.cpp json-schema-to-grammar.h)
set(TARGET common)
@ -60,6 +62,7 @@ add_library(${TARGET} STATIC
console.cpp
grammar-parser.h
grammar-parser.cpp
json.hpp
train.h
train.cpp
)

View file

@ -1590,6 +1590,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
if (s == "q4_1") {
return GGML_TYPE_Q4_1;
}
if (s == "iq4_nl") {
return GGML_TYPE_IQ4_NL;
}
if (s == "q5_0") {
return GGML_TYPE_Q5_0;
}

View file

@ -0,0 +1,725 @@
#include "json-schema-to-grammar.h"
#include <algorithm>
#include <fstream>
#include <map>
#include <regex>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
using json = nlohmann::json;
const std::string SPACE_RULE = "\" \"?";
std::unordered_map<std::string, std::string> PRIMITIVE_RULES = {
{"boolean", "(\"true\" | \"false\") space"},
{"number", "(\"-\"? ([0-9] | [1-9] [0-9]*)) (\".\" [0-9]+)? ([eE] [-+]? [0-9]+)? space"},
{"integer", "(\"-\"? ([0-9] | [1-9] [0-9]*)) space"},
{"value", "object | array | string | number | boolean"},
{"object", "\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space"},
{"array", "\"[\" space ( value (\",\" space value)* )? \"]\" space"},
{"uuid", "\"\\\"\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
"\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
"\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
"\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
"\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] \"\\\"\" space"},
{"string", " \"\\\"\" (\n"
" [^\"\\\\] |\n"
" \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])\n"
" )* \"\\\"\" space"},
{"null", "\"null\" space"}
};
std::vector<std::string> OBJECT_RULE_NAMES = {"object", "array", "string", "number", "boolean", "null", "value"};
std::unordered_map<std::string, std::string> DATE_RULES = {
{"date", "[0-9] [0-9] [0-9] [0-9] \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )"},
{"time", "([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9] [0-9] [0-9] )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )"},
{"date-time", "date \"T\" time"},
{"date-string", "\"\\\"\" date \"\\\"\" space"},
{"time-string", "\"\\\"\" time \"\\\"\" space"},
{"date-time-string", "\"\\\"\" date-time \"\\\"\" space"}
};
static bool is_reserved_name(const std::string & name) {
static std::unordered_set<std::string> RESERVED_NAMES;
if (RESERVED_NAMES.empty()) {
RESERVED_NAMES.insert("root");
for (const auto &p : PRIMITIVE_RULES) RESERVED_NAMES.insert(p.first);
for (const auto &p : DATE_RULES) RESERVED_NAMES.insert(p.first);
}
return RESERVED_NAMES.find(name) != RESERVED_NAMES.end();
}
std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]");
std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}
};
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
template <typename Iterator>
std::string join(Iterator begin, Iterator end, const std::string & separator) {
std::ostringstream result;
if (begin != end) {
result << *begin;
for (Iterator it = begin + 1; it != end; ++it) {
result << separator << *it;
}
}
return result.str();
}
static std::vector<std::string> split(const std::string & str, const std::string & delimiter) {
std::vector<std::string> tokens;
size_t start = 0;
size_t end = str.find(delimiter);
while (end != std::string::npos) {
tokens.push_back(str.substr(start, end - start));
start = end + delimiter.length();
end = str.find(delimiter, start);
}
tokens.push_back(str.substr(start));
return tokens;
}
static std::string repeat(const std::string & str, size_t n) {
if (n == 0) {
return "";
}
std::string result;
result.reserve(str.length() * n);
for (size_t i = 0; i < n; ++i) {
result += str;
}
return result;
}
static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
std::smatch match;
std::string result;
std::string::const_iterator searchStart(input.cbegin());
std::string::const_iterator searchEnd(input.cend());
while (std::regex_search(searchStart, searchEnd, match, regex)) {
result.append(searchStart, searchStart + match.position());
result.append(replacement(match));
searchStart = match.suffix().first;
}
result.append(searchStart, searchEnd);
return result;
}
static std::string format_literal(const std::string & literal) {
std::string escaped = replacePattern(json(literal).dump(), GRAMMAR_LITERAL_ESCAPE_RE, [&](const std::smatch & match) {
char c = match.str()[0];
return GRAMMAR_LITERAL_ESCAPES.at(c);
});
return "\"" + escaped + "\"";
}
class SchemaConverter {
private:
std::function<json(const std::string &)> _fetch_json;
bool _dotall;
std::map<std::string, std::string> _rules;
std::unordered_map<std::string, nlohmann::json> _refs;
std::unordered_set<std::string> _refs_being_resolved;
std::vector<std::string> _errors;
std::vector<std::string> _warnings;
std::string _add_rule(const std::string & name, const std::string & rule) {
std::string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-");
if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) {
_rules[esc_name] = rule;
return esc_name;
} else {
int i = 0;
while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) {
i++;
}
std::string key = esc_name + std::to_string(i);
_rules[key] = rule;
return key;
}
}
std::string _generate_union_rule(const std::string & name, const std::vector<json> & alt_schemas) {
std::vector<std::string> rules;
for (size_t i = 0; i < alt_schemas.size(); i++) {
rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
}
return join(rules.begin(), rules.end(), " | ");
}
std::string _visit_pattern(const std::string & pattern, const std::string & name) {
if (!(pattern.front() == '^' && pattern.back() == '$')) {
_errors.push_back("Pattern must start with '^' and end with '$'");
return "";
}
std::string sub_pattern = pattern.substr(1, pattern.length() - 2);
std::unordered_map<std::string, std::string> sub_rule_ids;
size_t i = 0;
size_t length = sub_pattern.length();
using literal_or_rule = std::pair<std::string, bool>;
auto to_rule = [&](const literal_or_rule & ls) {
auto is_literal = ls.second;
auto s = ls.first;
return is_literal ? "\"" + s + "\"" : s;
};
std::function<literal_or_rule()> transform = [&]() -> literal_or_rule {
size_t start = i;
std::vector<literal_or_rule> seq;
auto get_dot = [&]() {
std::string rule;
if (_dotall) {
rule = "[\\U00000000-\\U0010FFFF]";
} else {
rule = "[\\U00000000-\\x09\\x0B\\x0C\\x0E-\\U0010FFFF]";
}
return _add_rule("dot", rule);
};
// Joins the sequence, merging consecutive literals together.
auto join_seq = [&]() {
std::vector<literal_or_rule> ret;
std::string literal;
auto flush_literal = [&]() {
if (literal.empty()) {
return false;
}
ret.push_back(std::make_pair(literal, true));
literal.clear();
return true;
};
for (const auto & item : seq) {
auto is_literal = item.second;
if (is_literal) {
literal += item.first;
} else {
flush_literal();
ret.push_back(item);
}
}
flush_literal();
std::vector<std::string> results;
for (const auto & item : ret) {
results.push_back(to_rule(item));
}
return std::make_pair(join(results.begin(), results.end(), " "), false);
};
while (i < length) {
char c = sub_pattern[i];
if (c == '.') {
seq.push_back(std::make_pair(get_dot(), false));
i++;
} else if (c == '(') {
i++;
if (i < length) {
if (sub_pattern[i] == '?') {
_warnings.push_back("Unsupported pattern syntax");
}
}
seq.push_back(std::make_pair("(" + to_rule(transform()) + ")", false));
} else if (c == ')') {
i++;
if (start > 0 && sub_pattern[start - 1] != '(') {
_errors.push_back("Unbalanced parentheses");
}
return join_seq();
} else if (c == '[') {
std::string square_brackets = std::string(1, c);
i++;
while (i < length && sub_pattern[i] != ']') {
if (sub_pattern[i] == '\\') {
square_brackets += sub_pattern.substr(i, 2);
i += 2;
} else {
square_brackets += sub_pattern[i];
i++;
}
}
if (i >= length) {
_errors.push_back("Unbalanced square brackets");
}
square_brackets += ']';
i++;
seq.push_back(std::make_pair(square_brackets, false));
} else if (c == '|') {
seq.push_back(std::make_pair("|", false));
i++;
} else if (c == '*' || c == '+' || c == '?') {
seq.back() = std::make_pair(to_rule(seq.back()) + c, false);
i++;
} else if (c == '{') {
std::string curly_brackets = std::string(1, c);
i++;
while (i < length && sub_pattern[i] != '}') {
curly_brackets += sub_pattern[i];
i++;
}
if (i >= length) {
_errors.push_back("Unbalanced curly brackets");
}
curly_brackets += '}';
i++;
auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
int min_times = 0;
int max_times = std::numeric_limits<int>::max();
try {
if (nums.size() == 1) {
min_times = max_times = std::stoi(nums[0]);
} else if (nums.size() != 2) {
_errors.push_back("Wrong number of values in curly brackets");
} else {
if (!nums[0].empty()) {
min_times = std::stoi(nums[0]);
}
if (!nums[1].empty()) {
max_times = std::stoi(nums[1]);
}
}
} catch (const std::invalid_argument & e) {
_errors.push_back("Invalid number in curly brackets");
return std::make_pair("", false);
}
auto &last = seq.back();
auto &sub = last.first;
auto sub_is_literal = last.second;
if (min_times == 0 && max_times == std::numeric_limits<int>::max()) {
sub += "*";
} else if (min_times == 0 && max_times == 1) {
sub += "?";
} else if (min_times == 1 && max_times == std::numeric_limits<int>::max()) {
sub += "+";
} else {
if (!sub_is_literal) {
std::string & sub_id = sub_rule_ids[sub];
if (sub_id.empty()) {
sub_id = _add_rule(name + "-" + std::to_string(sub_rule_ids.size()), sub);
}
sub = sub_id;
}
std::string result;
if (sub_is_literal && min_times > 0) {
result = "\"" + repeat(sub.substr(1, sub.length() - 2), min_times) + "\"";
} else {
for (int j = 0; j < min_times; j++) {
if (j > 0) {
result += " ";
}
result += sub;
}
}
if (min_times > 0 && min_times < max_times) {
result += " ";
}
if (max_times == std::numeric_limits<int>::max()) {
result += sub + "*";
} else {
for (int j = min_times; j < max_times; j++) {
if (j > min_times) {
result += " ";
}
result += sub + "?";
}
}
seq.back().first = result;
seq.back().second = false;
}
} else {
std::string literal;
auto is_non_literal = [&](char c) {
return NON_LITERAL_SET.find(c) != NON_LITERAL_SET.end();
};
while (i < length) {
if (sub_pattern[i] == '\\' && i < length - 1) {
char next = sub_pattern[i + 1];
if (ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.find(next) != ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.end()) {
i++;
literal += sub_pattern[i];
i++;
} else {
literal += sub_pattern.substr(i, 2);
i += 2;
}
} else if (sub_pattern[i] == '"') {
literal += "\\\"";
i++;
} else if (!is_non_literal(sub_pattern[i]) &&
(i == length - 1 || literal.empty() || sub_pattern[i + 1] == '.' || !is_non_literal(sub_pattern[i + 1]))) {
literal += sub_pattern[i];
i++;
} else {
break;
}
}
if (!literal.empty()) {
seq.push_back(std::make_pair(literal, true));
}
}
}
return join_seq();
};
return _add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space");
}
std::string _resolve_ref(const std::string & ref) {
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
_refs_being_resolved.insert(ref);
json resolved = _refs[ref];
ref_name = visit(resolved, ref_name);
_refs_being_resolved.erase(ref);
}
return ref_name;
}
std::string _build_object_rule(
const std::vector<std::pair<std::string, json>> & properties,
const std::unordered_set<std::string> & required,
const std::string & name,
const json & additional_properties)
{
std::vector<std::string> required_props;
std::vector<std::string> optional_props;
std::unordered_map<std::string, std::string> prop_kv_rule_names;
for (const auto & kv : properties) {
const auto &prop_name = kv.first;
const auto &prop_schema = kv.second;
std::string prop_rule_name = visit(prop_schema, name + (name.empty() ? "" : "-") + prop_name);
prop_kv_rule_names[prop_name] = _add_rule(
name + (name.empty() ? "" : "-") + prop_name + "-kv",
format_literal(prop_name) + " space \":\" space " + prop_rule_name
);
if (required.find(prop_name) != required.end()) {
required_props.push_back(prop_name);
} else {
optional_props.push_back(prop_name);
}
}
if (additional_properties.is_object() || (additional_properties.is_boolean() && additional_properties.get<bool>())) {
std::string sub_name = name + (name.empty() ? "" : "-") + "additional";
std::string value_rule = visit(additional_properties.is_object() ? additional_properties : json::object(), sub_name + "-value");
std::string kv_rule = _add_rule(sub_name + "-kv", _add_rule("string", PRIMITIVE_RULES.at("string")) + " \":\" space " + value_rule);
prop_kv_rule_names["*"] = kv_rule;
optional_props.push_back("*");
}
std::string rule = "\"{\" space ";
for (size_t i = 0; i < required_props.size(); i++) {
if (i > 0) {
rule += " \",\" space ";
}
rule += prop_kv_rule_names[required_props[i]];
}
if (!optional_props.empty()) {
rule += " (";
if (!required_props.empty()) {
rule += " \",\" space ( ";
}
std::function<std::string(const std::vector<std::string> &, bool)> get_recursive_refs = [&](const std::vector<std::string> & ks, bool first_is_optional) {
std::string res;
if (ks.empty()) {
return res;
}
std::string k = ks[0];
std::string kv_rule_name = prop_kv_rule_names[k];
if (k == "*") {
res = _add_rule(
name + (name.empty() ? "" : "-") + "additional-kvs",
kv_rule_name + " ( \",\" space " + kv_rule_name + " )*"
);
} else if (first_is_optional) {
res = "( \",\" space " + kv_rule_name + " )?";
} else {
res = kv_rule_name;
}
if (ks.size() > 1) {
res += " " + _add_rule(
name + (name.empty() ? "" : "-") + k + "-rest",
get_recursive_refs(std::vector<std::string>(ks.begin() + 1, ks.end()), true)
);
}
return res;
};
for (size_t i = 0; i < optional_props.size(); i++) {
if (i > 0) {
rule += " | ";
}
rule += get_recursive_refs(std::vector<std::string>(optional_props.begin() + i, optional_props.end()), false);
}
if (!required_props.empty()) {
rule += " )";
}
rule += " )?";
}
rule += " \"}\" space";
return rule;
}
public:
SchemaConverter(
const std::function<json(const std::string &)> & fetch_json,
bool dotall)
: _fetch_json(fetch_json), _dotall(dotall)
{
_rules["space"] = SPACE_RULE;
}
void resolve_refs(nlohmann::json & schema, const std::string & url) {
/*
* Resolves all $ref fields in the given schema, fetching any remote schemas,
* replacing each $ref with absolute reference URL and populates _refs with the
* respective referenced (sub)schema dictionaries.
*/
std::function<void(json &)> visit_refs = [&](json & n) {
if (n.is_array()) {
for (auto & x : n) {
visit_refs(x);
}
} else if (n.is_object()) {
if (n.contains("$ref")) {
std::string ref = n["$ref"];
if (_refs.find(ref) == _refs.end()) {
json target;
if (ref.find("https://") == 0) {
std::string base_url = ref.substr(0, ref.find('#'));
auto it = _refs.find(base_url);
if (it != _refs.end()) {
target = it->second;
} else {
// Fetch the referenced schema and resolve its refs
auto referenced = _fetch_json(ref);
resolve_refs(referenced, base_url);
_refs[base_url] = referenced;
}
if (ref.find('#') == std::string::npos || ref.substr(ref.find('#') + 1).empty()) {
return;
}
} else if (ref.find("#/") == 0) {
target = schema;
n["$ref"] = url + ref;
ref = url + ref;
} else {
_errors.push_back("Unsupported ref: " + ref);
return;
}
std::string pointer = ref.substr(ref.find('#') + 1);
std::vector<std::string> tokens = split(pointer, "/");
for (size_t i = 1; i < tokens.size(); ++i) {
std::string sel = tokens[i];
if (target.is_null() || !target.contains(sel)) {
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
return;
}
target = target[sel];
}
_refs[ref] = target;
}
} else {
for (auto & kv : n.items()) {
visit_refs(kv.value());
}
}
}
};
visit_refs(schema);
}
std::string _generate_constant_rule(const json & value) {
if (!value.is_string()) {
_errors.push_back("Only std::string constants are supported, got " + value.dump());
return "";
}
return format_literal(value.get<std::string>());
}
std::string visit(const json & schema, const std::string & name) {
json schema_type = schema.contains("type") ? schema["type"] : json();
std::string schema_format = schema.contains("format") ? schema["format"].get<std::string>() : "";
std::string rule_name = is_reserved_name(name) ? name + "-" : name.empty() ? "root" : name;
if (schema.contains("$ref")) {
return _add_rule(rule_name, _resolve_ref(schema["$ref"]));
} else if (schema.contains("oneOf") || schema.contains("anyOf")) {
std::vector<json> alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get<std::vector<json>>() : schema["anyOf"].get<std::vector<json>>();
return _add_rule(rule_name, _generate_union_rule(name, alt_schemas));
} else if (schema_type.is_array()) {
std::vector<json> schema_types;
for (const auto & t : schema_type) {
schema_types.push_back({{"type", t}});
}
return _add_rule(rule_name, _generate_union_rule(name, schema_types));
} else if (schema.contains("const")) {
return _add_rule(rule_name, _generate_constant_rule(schema["const"]));
} else if (schema.contains("enum")) {
std::vector<std::string> enum_values;
for (const auto & v : schema["enum"]) {
enum_values.push_back(_generate_constant_rule(v));
}
return _add_rule(rule_name, join(enum_values.begin(), enum_values.end(), " | "));
} else if ((schema_type.is_null() || schema_type == "object")
&& (schema.contains("properties") ||
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
std::unordered_set<std::string> required;
if (schema.contains("required") && schema["required"].is_array()) {
for (const auto & item : schema["required"]) {
if (item.is_string()) {
required.insert(item.get<std::string>());
}
}
}
std::vector<std::pair<std::string, json>> properties;
if (schema.contains("properties")) {
for (const auto & prop : schema["properties"].items()) {
properties.emplace_back(prop.key(), prop.value());
}
}
return _add_rule(rule_name,
_build_object_rule(
properties, required, name,
schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
} else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) {
std::unordered_set<std::string> required;
std::vector<std::pair<std::string, json>> properties;
std::string hybrid_name = name;
std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
if (comp_schema.contains("$ref")) {
add_component(_refs[comp_schema["$ref"]], is_required);
} else if (comp_schema.contains("properties")) {
for (const auto & prop : comp_schema["properties"].items()) {
properties.emplace_back(prop.key(), prop.value());
if (is_required) {
required.insert(prop.key());
}
}
} else {
// todo warning
}
};
for (auto & t : schema["allOf"]) {
if (t.contains("anyOf")) {
for (auto & tt : t["anyOf"]) {
add_component(tt, false);
}
} else {
add_component(t, true);
}
}
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
if (items.is_array()) {
std::string rule = "\"[\" space ";
for (size_t i = 0; i < items.size(); i++) {
if (i > 0) {
rule += " \",\" space ";
}
rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" + std::to_string(i));
}
rule += " \"]\" space";
return _add_rule(rule_name, rule);
} else {
std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
std::string list_item_operator = "( \",\" space " + item_rule_name + " )";
std::string successive_items;
int min_items = schema.contains("minItems") ? schema["minItems"].get<int>() : 0;
json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : -1;
if (min_items > 0) {
successive_items += repeat(list_item_operator, min_items - 1);
min_items--;
}
if (max_items >= 0 && max_items > min_items) {
successive_items += repeat(list_item_operator + "?", max_items - min_items - 1);
} else {
successive_items += list_item_operator + "*";
}
std::string rule;
if (min_items == 0) {
rule = "\"[\" space ( " + item_rule_name + " " + successive_items + " )? \"]\" space";
} else {
rule = "\"[\" space " + item_rule_name + " " + successive_items + " \"]\" space";
}
return _add_rule(rule_name, rule);
}
} else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
return _visit_pattern(schema["pattern"], rule_name);
} else if ((schema_type.is_null() || schema_type == "string") && std::regex_match(schema_format, std::regex("^uuid[1-5]?$"))) {
return _add_rule(rule_name == "root" ? "root" : schema_format, PRIMITIVE_RULES.at("uuid"));
} else if ((schema_type.is_null() || schema_type == "string") && DATE_RULES.find(schema_format) != DATE_RULES.end()) {
for (const auto & kv : DATE_RULES) {
_add_rule(kv.first, kv.second);
}
return schema_format + "-string";
} else if (schema.empty() || schema_type == "object") {
for (const auto & n : OBJECT_RULE_NAMES) {
_add_rule(n, PRIMITIVE_RULES.at(n));
}
return _add_rule(rule_name, "object");
} else {
if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get<std::string>()) == PRIMITIVE_RULES.end()) {
_errors.push_back("Unrecognized schema: " + schema.dump());
return "";
}
// TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
return _add_rule(rule_name == "root" ? "root" : schema_type.get<std::string>(), PRIMITIVE_RULES.at(schema_type.get<std::string>()));
}
}
void check_errors() {
if (!_errors.empty()) {
throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n"));
}
if (!_warnings.empty()) {
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str());
}
}
std::string format_grammar() {
std::stringstream ss;
for (const auto & kv : _rules) {
ss << kv.first << " ::= " << kv.second << std::endl;
}
return ss.str();
}
};
std::string json_schema_to_grammar(const json & schema) {
SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
auto copy = schema;
converter.resolve_refs(copy, "input");
converter.visit(copy, "");
converter.check_errors();
return converter.format_grammar();
}

View file

@ -0,0 +1,4 @@
#pragma once
#include "json.hpp"
std::string json_schema_to_grammar(const nlohmann::json& schema);

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,74 @@
# Usage:
#! ./server -m some-model.gguf &
#! pip install pydantic
#! python json-schema-pydantic-example.py
from pydantic import BaseModel, TypeAdapter
from annotated_types import MinLen
from typing import Annotated, List, Optional
import json, requests
if True:
def create_completion(*, response_model=None, endpoint="http://localhost:8080/v1/chat/completions", messages, **kwargs):
'''
Creates a chat completion using an OpenAI-compatible endpoint w/ JSON schema support
(llama.cpp server, llama-cpp-python, Anyscale / Together...)
The response_model param takes a type (+ supports Pydantic) and behaves just as w/ Instructor (see below)
'''
if response_model:
type_adapter = TypeAdapter(response_model)
schema = type_adapter.json_schema()
messages = [{
"role": "system",
"content": f"You respond in JSON format with the following schema: {json.dumps(schema, indent=2)}"
}] + messages
response_format={"type": "json_object", "schema": schema}
data = requests.post(endpoint, headers={"Content-Type": "application/json"},
json=dict(messages=messages, response_format=response_format, **kwargs)).json()
if 'error' in data:
raise Exception(data['error']['message'])
content = data["choices"][0]["message"]["content"]
return type_adapter.validate_json(content) if type_adapter else content
else:
# This alternative branch uses Instructor + OpenAI client lib.
# Instructor support streamed iterable responses, retry & more.
# (see https://python.useinstructor.com/)
#! pip install instructor openai
import instructor, openai
client = instructor.patch(
openai.OpenAI(api_key="123", base_url="http://localhost:8080"),
mode=instructor.Mode.JSON_SCHEMA)
create_completion = client.chat.completions.create
if __name__ == '__main__':
class QAPair(BaseModel):
question: str
concise_answer: str
justification: str
class PyramidalSummary(BaseModel):
title: str
summary: str
question_answers: Annotated[List[QAPair], MinLen(2)]
sub_sections: Optional[Annotated[List['PyramidalSummary'], MinLen(2)]]
print("# Summary\n", create_completion(
model="...",
response_model=PyramidalSummary,
messages=[{
"role": "user",
"content": f"""
You are a highly efficient corporate document summarizer.
Create a pyramidal summary of an imaginary internal document about our company processes
(starting high-level, going down to each sub sections).
Keep questions short, and answers even shorter (trivia / quizz style).
"""
}]))

View file

@ -1,8 +1,10 @@
#!/usr/bin/env python3
import argparse
import itertools
import json
import re
import sys
from typing import Any, Dict, List, Set, Tuple, Union
# whitespace is constrained to a single space char to prevent model "running away" in
# whitespace. Also maybe improves generation quality?
@ -12,22 +14,50 @@ PRIMITIVE_RULES = {
'boolean': '("true" | "false") space',
'number': '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space',
'integer': '("-"? ([0-9] | [1-9] [0-9]*)) space',
'value' : 'object | array | string | number | boolean',
'object' : '"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space',
'array' : '"[" space ( value ("," space value)* )? "]" space',
'uuid' : '"\\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + ' "\\"" space',
'string': r''' "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space''',
'null': '"null" space',
}
OBJECT_RULE_NAMES = ['object', 'array', 'string', 'number', 'boolean', 'null', 'value']
# TODO: support "uri", "email" string formats
DATE_RULES = {
'date' : '[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )',
'time' : '([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )',
'date-time': 'date "T" time',
'date-string': '"\\"" date "\\"" space',
'time-string': '"\\"" time "\\"" space',
'date-time-string': '"\\"" date-time "\\"" space',
}
RESERVED_NAMES = set(["root", *PRIMITIVE_RULES.keys(), *DATE_RULES.keys()])
INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+')
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"'}
GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]')
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'}
NON_LITERAL_SET = set('|.()[]{}*+?')
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?')
DATE_PATTERN = '[0-9]{4}-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])'
TIME_PATTERN = '([01][0-9]|2[0-3])(:[0-5][0-9]){2}(\\.[0-9]{1,3})?(Z|[+-](([01][0-9]|2[0-3]):[0-5][0-9]))' # Cap millisecond precision w/ 3 digits
class SchemaConverter:
def __init__(self, prop_order):
def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
self._prop_order = prop_order
self._allow_fetch = allow_fetch
self._dotall = dotall
self._raw_pattern = raw_pattern
self._rules = {'space': SPACE_RULE}
self._refs = {}
self._refs_being_resolved = set()
def _format_literal(self, literal):
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
@ -41,78 +71,421 @@ class SchemaConverter:
key = esc_name
else:
i = 0
while f'{esc_name}{i}' in self._rules:
while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule:
i += 1
key = f'{esc_name}{i}'
self._rules[key] = rule
return key
def resolve_refs(self, schema: dict, url: str):
'''
Resolves all $ref fields in the given schema, fetching any remote schemas,
replacing $ref with absolute reference URL and populating self._refs with the
respective referenced (sub)schema dictionaries.
'''
def visit(n: dict):
if isinstance(n, list):
return [visit(x) for x in n]
elif isinstance(n, dict):
ref = n.get('$ref')
if ref is not None and ref not in self._refs:
if ref.startswith('https://'):
assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
import requests
frag_split = ref.split('#')
base_url = frag_split[0]
target = self._refs.get(base_url)
if target is None:
target = self.resolve_refs(requests.get(ref).json(), base_url)
self._refs[base_url] = target
if len(frag_split) == 1 or frag_split[-1] == '':
return target
elif ref.startswith('#/'):
target = schema
ref = f'{url}{ref}'
n['$ref'] = ref
else:
raise ValueError(f'Unsupported ref {ref}')
for sel in ref.split('#')[-1].split('/')[1:]:
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
target = target[sel]
self._refs[ref] = target
else:
for v in n.values():
visit(v)
return n
return visit(schema)
def _generate_union_rule(self, name, alt_schemas):
return ' | '.join((
self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
for i, alt_schema in enumerate(alt_schemas)
))
def _visit_pattern(self, pattern, name):
'''
Transforms a regular expression pattern into a GBNF rule.
Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
we define sub-rules to keep the output lean.
'''
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
pattern = pattern[1:-1]
sub_rule_ids = {}
i = 0
length = len(pattern)
def to_rule(s: Tuple[str, bool]) -> str:
(txt, is_literal) = s
return "\"" + txt + "\"" if is_literal else txt
def transform() -> Tuple[str, bool]:
'''
Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
'''
nonlocal i
nonlocal pattern
nonlocal sub_rule_ids
start = i
# For each component of this sequence, store its string representation and whether it's a literal.
# We only need a flat structure here to apply repetition operators to the last item, and
# to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
# (GBNF's syntax is luckily very close to regular expressions!)
seq: list[Tuple[str, bool]] = []
def get_dot():
if self._dotall:
rule = '[\\U00000000-\\U0010FFFF]'
else:
# Accept any character... except \n and \r line break chars (\x0A and \xOD)
rule = '[\\U00000000-\\x09\\x0B\\x0C\\x0E-\\U0010FFFF]'
return self._add_rule(f'dot', rule)
def join_seq():
nonlocal seq
ret = []
for is_literal, g in itertools.groupby(seq, lambda x: x[1]):
if is_literal:
ret.append((''.join(x[0] for x in g), True))
else:
ret.extend(g)
if len(ret) == 1:
return ret[0]
return (' '.join(to_rule(x) for x in seq), False)
while i < length:
c = pattern[i]
if c == '.':
seq.append((get_dot(), False))
i += 1
elif c == '(':
i += 1
if i < length:
assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
seq.append((f'({to_rule(transform())})', False))
elif c == ')':
i += 1
assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}'
return join_seq()
elif c == '[':
square_brackets = c
i += 1
while i < length and pattern[i] != ']':
if pattern[i] == '\\':
square_brackets += pattern[i:i+2]
i += 2
else:
square_brackets += pattern[i]
i += 1
assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}'
square_brackets += ']'
i += 1
seq.append((square_brackets, False))
elif c == '|':
seq.append(('|', False))
i += 1
elif c in ('*', '+', '?'):
seq[-1] = (to_rule(seq[-1]) + c, False)
i += 1
elif c == '{':
curly_brackets = c
i += 1
while i < length and pattern[i] != '}':
curly_brackets += pattern[i]
i += 1
assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}'
curly_brackets += '}'
i += 1
nums = [s.strip() for s in curly_brackets[1:-1].split(',')]
min_times = 0
max_times = None
try:
if len(nums) == 1:
min_times = int(nums[0])
max_times = min_times
else:
assert len(nums) == 2
min_times = int(nums[0]) if nums[0] else 0
max_times = int(nums[1]) if nums[1] else None
except ValueError:
raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/')
(sub, sub_is_literal) = seq[-1]
if min_times == 0 and max_times is None:
seq[-1] = (f'{sub}*', False)
elif min_times == 0 and max_times == 1:
seq[-1] = (f'{sub}?', False)
elif min_times == 1 and max_times is None:
seq[-1] = (f'{sub}+', False)
else:
if not sub_is_literal:
id = sub_rule_ids.get(sub)
if id is None:
id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
sub_rule_ids[sub] = id
sub = id
seq[-1] = (
' '.join(
([f'"{sub[1:-1] * min_times}"'] if sub_is_literal else [sub] * min_times) +
([f'{sub}?'] * (max_times - min_times) if max_times is not None else [f'{sub}*'])),
False
)
else:
literal = ''
while i < length:
if pattern[i] == '\\' and i < length - 1:
next = pattern[i + 1]
if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
i += 1
literal += pattern[i]
i += 1
else:
literal += pattern[i:i+2]
i += 2
elif pattern[i] == '"' and not self._raw_pattern:
literal += '\\"'
i += 1
elif pattern[i] not in NON_LITERAL_SET and \
(i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET):
literal += pattern[i]
i += 1
else:
break
if literal:
seq.append((literal, True))
return join_seq()
return self._add_rule(
name,
to_rule(transform()) if self._raw_pattern \
else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space")
def _resolve_ref(self, ref):
ref_name = ref.split('/')[-1]
if ref_name not in self._rules and ref not in self._refs_being_resolved:
self._refs_being_resolved.add(ref)
resolved = self._refs[ref]
ref_name = self.visit(resolved, ref_name)
self._refs_being_resolved.remove(ref)
return ref_name
def _generate_constant_rule(self, value):
assert isinstance(value, str), f'Only string constants are supported, got {value}'
return self._format_literal(value)
def visit(self, schema, name):
schema_type = schema.get('type')
rule_name = name or 'root'
schema_format = schema.get('format')
rule_name = name + '-' if name in RESERVED_NAMES else name or 'root'
if 'oneOf' in schema or 'anyOf' in schema:
rule = ' | '.join((
self.visit(alt_schema, f'{name}{"-" if name else ""}{i}')
for i, alt_schema in enumerate(schema.get('oneOf') or schema['anyOf'])
))
return self._add_rule(rule_name, rule)
if (ref := schema.get('$ref')) is not None:
return self._add_rule(rule_name, self._resolve_ref(ref))
elif 'oneOf' in schema or 'anyOf' in schema:
return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf']))
elif isinstance(schema_type, list):
return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type]))
elif 'const' in schema:
return self._add_rule(rule_name, self._format_literal(schema['const']))
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']))
elif 'enum' in schema:
rule = ' | '.join((self._format_literal(v) for v in schema['enum']))
rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum']))
return self._add_rule(rule_name, rule)
elif schema_type == 'object' and 'properties' in schema:
# TODO: `required` keyword
prop_order = self._prop_order
prop_pairs = sorted(
schema['properties'].items(),
# sort by position in prop_order (if specified) then by key
key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]),
)
elif schema_type in (None, 'object') and \
('properties' in schema or \
('additionalProperties' in schema and schema['additionalProperties'] is not True)):
required = set(schema.get('required', []))
properties = list(schema.get('properties', {}).items())
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
rule = '"{" space'
for i, (prop_name, prop_schema) in enumerate(prop_pairs):
prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
if i > 0:
rule += ' "," space'
rule += fr' {self._format_literal(prop_name)} space ":" space {prop_rule_name}'
rule += ' "}" space'
elif schema_type in (None, 'object') and 'allOf' in schema:
required = set()
properties = []
hybrid_name = name
def add_component(comp_schema, is_required):
if (ref := comp_schema.get('$ref')) is not None:
comp_schema = self._refs[ref]
return self._add_rule(rule_name, rule)
if 'properties' in comp_schema:
for prop_name, prop_schema in comp_schema['properties'].items():
properties.append((prop_name, prop_schema))
if is_required:
required.add(prop_name)
elif schema_type == 'array' and 'items' in schema:
# TODO `prefixItems` keyword
item_rule_name = self.visit(schema['items'], f'{name}{"-" if name else ""}item')
for t in schema['allOf']:
if 'anyOf' in t:
for tt in t['anyOf']:
add_component(tt, is_required=False)
else:
add_component(t, is_required=True)
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=[]))
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
items = schema.get('items') or schema['prefixItems']
if isinstance(items, list):
return self._add_rule(
rule_name,
'"[" space ' +
' "," space '.join(
self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
for i, item in enumerate(items)) +
' "]" space')
else:
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
list_item_operator = f'( "," space {item_rule_name} )'
successive_items = ""
min_items = schema.get("minItems", 0)
max_items = schema.get("maxItems")
if min_items > 0:
first_item = f"({item_rule_name})"
successive_items = list_item_operator * (min_items - 1)
min_items -= 1
else:
first_item = f"({item_rule_name})?"
max_items = schema.get("maxItems")
if max_items is not None and max_items > min_items:
successive_items += (list_item_operator + "?") * (max_items - min_items - 1)
else:
successive_items += list_item_operator + "*"
rule = f'"[" space {first_item} {successive_items} "]" space'
if min_items == 0:
rule = f'"[" space ( {item_rule_name} {successive_items} )? "]" space'
else:
rule = f'"[" space {item_rule_name} {successive_items} "]" space'
return self._add_rule(rule_name, rule)
elif schema_type in (None, 'string') and 'pattern' in schema:
return self._visit_pattern(schema['pattern'], rule_name)
elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''):
return self._add_rule(
'root' if rule_name == 'root' else schema_format,
PRIMITIVE_RULES['uuid']
)
elif schema_type in (None, 'string') and schema_format in DATE_RULES:
for t, r in DATE_RULES.items():
self._add_rule(t, r)
return schema_format + '-string'
elif (schema_type == 'object') or (len(schema) == 0):
for n in OBJECT_RULE_NAMES:
self._add_rule(n, PRIMITIVE_RULES[n])
return self._add_rule(rule_name, 'object')
else:
assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
# TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
return self._add_rule(
'root' if rule_name == 'root' else schema_type,
PRIMITIVE_RULES[schema_type]
)
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]):
prop_order = self._prop_order
# sort by position in prop_order (if specified) then by original order
sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))]
prop_kv_rule_names = {}
for prop_name, prop_schema in properties:
prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
prop_kv_rule_names[prop_name] = self._add_rule(
f'{name}{"-" if name else ""}{prop_name}-kv',
fr'{self._format_literal(prop_name)} space ":" space {prop_rule_name}'
)
required_props = [k for k in sorted_props if k in required]
optional_props = [k for k in sorted_props if k not in required]
if additional_properties == True or isinstance(additional_properties, dict):
sub_name = f'{name}{"-" if name else ""}additional'
value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value')
prop_kv_rule_names["*"] = self._add_rule(
f'{sub_name}-kv',
self._add_rule('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}'
)
optional_props.append("*")
rule = '"{" space '
rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)
if optional_props:
rule += ' ('
if required_props:
rule += ' "," space ( '
def get_recursive_refs(ks, first_is_optional):
[k, *rest] = ks
kv_rule_name = prop_kv_rule_names[k]
if k == '*':
res = self._add_rule(
f'{name}{"-" if name else ""}additional-kvs',
f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*'
)
elif first_is_optional:
res = f'( "," space {kv_rule_name} )?'
else:
res = kv_rule_name
if len(rest) > 0:
res += ' ' + self._add_rule(
f'{name}{"-" if name else ""}{k}-rest',
get_recursive_refs(rest, first_is_optional=True)
)
return res
rule += ' | '.join(
get_recursive_refs(optional_props[i:], first_is_optional=False)
for i in range(len(optional_props))
)
if required_props:
rule += ' )'
rule += ' )?'
rule += ' "}" space'
return rule
def format_grammar(self):
return '\n'.join((f'{name} ::= {rule}' for name, rule in self._rules.items()))
return '\n'.join(
f'{name} ::= {rule}'
for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
)
def main(args_in = None):
@ -129,16 +502,47 @@ def main(args_in = None):
type=lambda s: s.split(','),
help='''
comma-separated property names defining the order of precedence for object properties;
properties not specified here are given lower precedence than those that are, and are
sorted alphabetically
properties not specified here are given lower precedence than those that are, and
are kept in their original order from the schema. Required properties are always
given precedence over optional properties.
'''
)
parser.add_argument(
'--allow-fetch',
action='store_true',
default=False,
help='Whether to allow fetching referenced schemas over HTTPS')
parser.add_argument(
'--dotall',
action='store_true',
default=False,
help='Whether to treat dot (".") as matching all chars including line breaks in regular expression patterns')
parser.add_argument(
'--raw-pattern',
action='store_true',
default=False,
help='Treats string patterns as raw patterns w/o quotes (or quote escapes)')
parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)')
args = parser.parse_args(args_in)
schema = json.load(sys.stdin if args.schema == '-' else open(args.schema))
prop_order = {name: idx for idx, name in enumerate(args.prop_order)}
converter = SchemaConverter(prop_order)
if args.schema.startswith('https://'):
url = args.schema
import requests
schema = requests.get(url).json()
elif args.schema == '-':
url = 'stdin'
schema = json.load(sys.stdin)
else:
url = f'file://{args.schema}'
with open(args.schema) as f:
schema = json.load(f)
converter = SchemaConverter(
prop_order={name: idx for idx, name in enumerate(args.prop_order)},
allow_fetch=args.allow_fetch,
dotall=args.dotall,
raw_pattern=args.raw_pattern)
schema = converter.resolve_refs(schema, url)
converter.visit(schema, '')
print(converter.format_grammar())

View file

@ -249,6 +249,9 @@ static ggml_type ggml_type_from_name(const std::string & s) {
if (s == "q5_1") {
return GGML_TYPE_Q5_1;
}
if (s == "iq4_nl") {
return GGML_TYPE_IQ4_NL;
}
return GGML_TYPE_COUNT;
}

View file

@ -1,11 +1,13 @@
# MobileVLM
Currently this implementation supports [MobileVLM-v1.7](https://huggingface.co/mtgv/MobileVLM-1.7B) variants.
Currently this implementation supports [MobileVLM-1.7B](https://huggingface.co/mtgv/MobileVLM-1.7B) / [MobileVLM_V2-1.7B](https://huggingface.co/mtgv/MobileVLM_V2-1.7B) variants.
for more information, please go to [Meituan-AutoML/MobileVLM](https://github.com/Meituan-AutoML/MobileVLM)
The implementation is based on llava, and is compatible with llava and mobileVLM. The usage is basically same as llava.
Notice: The overall process of model inference for both **MobileVLM** and **MobileVLM_V2** models is the same, but the process of model conversion is a little different. Therefore, using MobiVLM as an example, the different conversion step will be shown.
## Usage
Build with cmake or run `make llava-cli` to build it.
@ -34,7 +36,7 @@ git clone https://huggingface.co/openai/clip-vit-large-patch14-336
python ./examples/llava/llava-surgery.py -m path/to/MobileVLM-1.7B
```
3. Use `convert-image-encoder-to-gguf.py` with `--projector-type ldp` to convert the LLaVA image encoder to GGUF:
3. Use `convert-image-encoder-to-gguf.py` with `--projector-type ldp` (for **V2** the arg is `--projector-type ldpv2`) to convert the LLaVA image encoder to GGUF:
```sh
python ./examples/llava/convert-image-encoder-to-gguf \
@ -44,6 +46,14 @@ python ./examples/llava/convert-image-encoder-to-gguf \
--projector-type ldp
```
```sh
python ./examples/llava/convert-image-encoder-to-gguf \
-m path/to/clip-vit-large-patch14-336 \
--llava-projector path/to/MobileVLM-1.7B_V2/llava.projector \
--output-dir path/to/MobileVLM-1.7B_V2 \
--projector-type ldpv2
```
4. Use `convert.py` to convert the LLaMA part of LLaVA to GGUF:
```sh

View file

@ -119,6 +119,7 @@ static std::string format(const char * fmt, ...) {
#define TN_LLAVA_PROJ "mm.%d.%s"
#define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s"
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
#define TN_IMAGE_NEWLINE "model.image_newline"
@ -126,12 +127,14 @@ enum projector_type {
PROJECTOR_TYPE_MLP,
PROJECTOR_TYPE_MLP_NORM,
PROJECTOR_TYPE_LDP,
PROJECTOR_TYPE_LDPV2,
PROJECTOR_TYPE_UNKNOWN,
};
static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_MLP, "mlp" },
{ PROJECTOR_TYPE_LDP, "ldp" },
{ PROJECTOR_TYPE_LDPV2, "ldpv2"},
};
@ -475,6 +478,14 @@ struct clip_vision_model {
struct ggml_tensor * mm_model_block_2_block_2_0_w;
struct ggml_tensor * mm_model_block_2_block_2_1_w;
struct ggml_tensor * mm_model_block_2_block_2_1_b;
// MobileVLM_V2 projection
struct ggml_tensor * mm_model_mlp_0_w;
struct ggml_tensor * mm_model_mlp_0_b;
struct ggml_tensor * mm_model_mlp_2_w;
struct ggml_tensor * mm_model_mlp_2_b;
struct ggml_tensor * mm_model_peg_0_w;
struct ggml_tensor * mm_model_peg_0_b;
};
struct clip_ctx {
@ -807,6 +818,29 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
}
embeddings = block_1;
}
else if (ctx->proj_type == PROJECTOR_TYPE_LDPV2)
{
int n_patch = 24;
struct ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b);
mlp_0 = ggml_gelu(ctx0, mlp_0);
struct ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0);
mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b);
// mlp_2 ne = [2048, 576, 1, 1]
// // AVG Pool Layer 2*2, strides = 2
mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 0, 2, 3));
// mlp_2 ne = [576, 2048, 1, 1]
mlp_2 = ggml_reshape_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]);
// mlp_2 ne [24, 24, 2048, 1]
mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0);
// weight ne = [3, 3, 2048, 1]
struct ggml_tensor * peg_0 = ggml_conv_depthwise_2d(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1);
peg_0 = ggml_add(ctx0, peg_0, mlp_2);
peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3));
peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b);
peg_0 = ggml_reshape_3d(ctx0, peg_0, peg_0->ne[0], peg_0->ne[1] * peg_0->ne[2], peg_0->ne[3]);
embeddings = peg_0;
}
else {
GGML_ASSERT(false);
}
@ -1177,7 +1211,18 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
vision_model.mm_model_block_2_block_2_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight"));
vision_model.mm_model_block_2_block_2_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight"));
vision_model.mm_model_block_2_block_2_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias"));
} else {
}
else if (new_clip->proj_type == PROJECTOR_TYPE_LDPV2)
{
// MobilVLM_V2 projection
vision_model.mm_model_mlp_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 0, "weight"));
vision_model.mm_model_mlp_0_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 0, "bias"));
vision_model.mm_model_mlp_2_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 2, "weight"));
vision_model.mm_model_mlp_2_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 2, "bias"));
vision_model.mm_model_peg_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_PEG, 0, "weight"));
vision_model.mm_model_peg_0_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_PEG, 0, "bias"));
}
else {
std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type];
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
}
@ -1966,6 +2011,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
if (ctx->proj_type == PROJECTOR_TYPE_LDP) {
return ctx->vision_model.mm_model_block_1_block_2_1_b->ne[0];
}
if (ctx->proj_type == PROJECTOR_TYPE_LDPV2) {
return ctx->vision_model.mm_model_peg_0_b->ne[0];
}
if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
return ctx->vision_model.mm_2_b->ne[0];
}

View file

@ -1,6 +1,7 @@
import argparse
import os
import json
import re
import torch
import numpy as np
@ -38,9 +39,11 @@ def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: b
def get_tensor_name(name: str) -> str:
if "projection" in name:
return name
if "mm_projector" in name:
return name.replace("model.mm_projector", "mm")
name = name.replace("model.mm_projector", "mm")
name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1)
name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1)
return name
return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln")
@ -83,7 +86,7 @@ ap.add_argument("--clip-model-is-vision", action="store_true", required=False,
ap.add_argument("--clip-model-is-openclip", action="store_true", required=False,
help="The clip model is from openclip (for ViT-SO400M type))")
ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp", choices=["mlp", "ldp"], default="mlp")
ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp")
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711
# Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5

View file

@ -0,0 +1,20 @@
import json, subprocess, sys, os
assert len(sys.argv) >= 2
[_, pattern, *rest] = sys.argv
print(subprocess.check_output(
[
"python",
os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"json-schema-to-grammar.py"),
*rest,
"-",
"--raw-pattern",
],
text=True,
input=json.dumps({
"type": "string",
"pattern": pattern,
}, indent=2)))

View file

@ -2,12 +2,16 @@ set(TARGET server)
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
option(LLAMA_SERVER_SSL "Build SSL support for the server" OFF)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
add_executable(${TARGET}
server.cpp
utils.hpp
httplib.h
)
install(TARGETS ${TARGET} RUNTIME)
target_compile_definitions(${TARGET} PRIVATE
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
)
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PRIVATE common json-schema-to-grammar ${CMAKE_THREAD_LIBS_INIT})
if (LLAMA_SERVER_SSL)
find_package(OpenSSL REQUIRED)
target_link_libraries(${TARGET} PRIVATE OpenSSL::SSL OpenSSL::Crypto)

View file

@ -26,8 +26,9 @@ const propOrder = grammarJsonSchemaPropOrder
let grammar = null
if (grammarJsonSchemaFile) {
const schema = JSON.parse(readFileSync(grammarJsonSchemaFile, 'utf-8'))
const converter = new SchemaConverter(propOrder)
let schema = JSON.parse(readFileSync(grammarJsonSchemaFile, 'utf-8'))
const converter = new SchemaConverter({prop_order: propOrder, allow_fetch: true})
schema = await converter.resolveRefs(schema, grammarJsonSchemaFile)
converter.visit(schema, '')
grammar = converter.formatGrammar()
}

View file

@ -483,4 +483,4 @@ unsigned char completion_js[] = {
0x20, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f,
0x73, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x3b, 0x0a, 0x7d, 0x0a
};
unsigned int completion_js_len = 5796;
size_t completion_js_len = 5796;

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -630,14 +630,16 @@
const grammarJsonSchemaPropOrder = signal('')
const updateGrammarJsonSchemaPropOrder = (el) => grammarJsonSchemaPropOrder.value = el.target.value
const convertJSONSchemaGrammar = () => {
const convertJSONSchemaGrammar = async () => {
try {
const schema = JSON.parse(params.value.grammar)
const converter = new SchemaConverter(
grammarJsonSchemaPropOrder.value
let schema = JSON.parse(params.value.grammar)
const converter = new SchemaConverter({
prop_order: grammarJsonSchemaPropOrder.value
.split(',')
.reduce((acc, cur, i) => ({ ...acc, [cur.trim()]: i }), {})
)
.reduce((acc, cur, i) => ({ ...acc, [cur.trim()]: i }), {}),
allow_fetch: true,
})
schema = await converter.resolveRefs(schema, 'input')
converter.visit(schema, '')
params.value = {
...params.value,

File diff suppressed because one or more lines are too long

View file

@ -1,25 +1,50 @@
// WARNING: This file was ported from json-schema-to-grammar.py, please fix bugs / add features there first.
const SPACE_RULE = '" "?';
const PRIMITIVE_RULES = {
boolean: '("true" | "false") space',
number: '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space',
integer: '("-"? ([0-9] | [1-9] [0-9]*)) space',
value: 'object | array | string | number | boolean',
object: '"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space',
array: '"[" space ( value ("," space value)* )? "]" space',
uuid: '"\\"" ' + [8, 4, 4, 4, 12].map(n => [...new Array(n)].map(_ => '[0-9a-fA-F]').join('')).join(' "-" ') + ' "\\"" space',
string: ` "\\"" (
[^"\\\\] |
"\\\\" (["\\\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\\"" space`,
null: '"null" space',
};
const OBJECT_RULE_NAMES = ['object', 'array', 'string', 'number', 'boolean', 'null', 'value'];
// TODO: support "uri", "email" string formats
const DATE_RULES = {
'date' : '[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )',
'time' : '([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )',
'date-time': 'date "T" time',
'date-string': '"\\"" date "\\"" space',
'time-string': '"\\"" time "\\"" space',
'date-time-string': '"\\"" date-time "\\"" space',
};
const RESERVED_NAMES = {'root': true, ...PRIMITIVE_RULES, ...DATE_RULES};
const INVALID_RULE_CHARS_RE = /[^\dA-Za-z-]+/g;
const GRAMMAR_LITERAL_ESCAPE_RE = /[\n\r"]/g;
const GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"'};
const GRAMMAR_RANGE_LITERAL_ESCAPE_RE = /[\n\r"\]\-\\]/g;
const GRAMMAR_LITERAL_ESCAPES = { '\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]' };
const NON_LITERAL_SET = new Set('|.()[]{}*+?');
const ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = new Set('[]()|{}*+?');
export class SchemaConverter {
constructor(propOrder) {
this._propOrder = propOrder || {};
this._rules = new Map();
this._rules.set('space', SPACE_RULE);
constructor(options) {
this._propOrder = options.prop_order || {};
this._allowFetch = options.allow_fetch || false;
this._dotall = options.dotall || false;
this._rules = {'space': SPACE_RULE};
this._refs = {};
this._refsBeingResolved = new Set();
}
_formatLiteral(literal) {
@ -30,83 +55,490 @@ export class SchemaConverter {
return `"${escaped}"`;
}
_formatRangeChar(literal) {
return JSON.stringify(literal).slice(1, -1).replace(
GRAMMAR_RANGE_LITERAL_ESCAPE_RE,
m => GRAMMAR_LITERAL_ESCAPES[m]
);
}
_addRule(name, rule) {
let escName = name.replace(INVALID_RULE_CHARS_RE, '-');
let key = escName;
if (this._rules.has(escName)) {
if (this._rules.get(escName) === rule) {
if (escName in this._rules) {
if (this._rules[escName] === rule) {
return key;
}
let i = 0;
while (this._rules.has(`${escName}${i}`)) {
while ((`${escName}${i}` in this._rules) && (this._rules[`${escName}${i}`] !== rule)) {
i += 1;
}
key = `${escName}${i}`;
}
this._rules.set(key, rule);
this._rules[key] = rule;
return key;
}
async resolveRefs(schema, url) {
const visit = async (n) => {
if (Array.isArray(n)) {
return Promise.all(n.map(visit));
} else if (typeof n === 'object' && n !== null) {
let ref = n.$ref;
let target;
if (ref !== undefined && !this._refs[ref]) {
if (ref.startsWith('https://')) {
if (!this._allowFetch) {
throw new Error('Fetching remote schemas is not allowed (use --allow-fetch for force)');
}
const fetch = (await import('node-fetch')).default;
const fragSplit = ref.split('#');
const baseUrl = fragSplit[0];
target = this._refs[baseUrl];
if (!target) {
target = await this.resolveRefs(await fetch(ref).then(res => res.json()), baseUrl);
this._refs[baseUrl] = target;
}
if (fragSplit.length === 1 || fragSplit[fragSplit.length - 1] === '') {
return target;
}
} else if (ref.startsWith('#/')) {
target = schema;
ref = `${url}${ref}`;
n.$ref = ref;
} else {
throw new Error(`Unsupported ref ${ref}`);
}
const selectors = ref.split('#')[1].split('/').slice(1);
for (const sel of selectors) {
if (!target || !(sel in target)) {
throw new Error(`Error resolving ref ${ref}: ${sel} not in ${JSON.stringify(target)}`);
}
target = target[sel];
}
this._refs[ref] = target;
} else {
await Promise.all(Object.values(n).map(visit));
}
}
return n;
};
return visit(schema);
}
_generateUnionRule(name, altSchemas) {
return altSchemas
.map((altSchema, i) => this.visit(altSchema, `${name ?? ''}${name ? '-' : 'alternative-'}${i}`))
.join(' | ');
}
_visitPattern(pattern, name) {
if (!pattern.startsWith('^') || !pattern.endsWith('$')) {
throw new Error('Pattern must start with "^" and end with "$"');
}
pattern = pattern.slice(1, -1);
const subRuleIds = {};
let i = 0;
const length = pattern.length;
const getDot = () => {
let rule;
if (this._dotall) {
rule = '[\\U00000000-\\U0010FFFF]';
} else {
// Accept any character... except \n and \r line break chars (\x0A and \xOD)
rule = '[\\U00000000-\\x09\\x0B\\x0C\\x0E-\\U0010FFFF]';
}
return this._addRule('dot', rule);
};
const toRule = ([s, isLiteral]) => isLiteral ? "\"" + s + "\"" : s;
const transform = () => {
const start = i;
// For each component of this sequence, store its string representation and whether it's a literal.
// We only need a flat structure here to apply repetition operators to the last item, and
// to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
// (GBNF's syntax is luckily very close to regular expressions!)
const seq = [];
const joinSeq = () => {
const ret = [];
for (const [isLiteral, g] of groupBy(seq, x => x[1])) {
if (isLiteral) {
ret.push([[...g].map(x => x[0]).join(''), true]);
} else {
ret.push(...g);
}
}
if (ret.length === 1) {
return ret[0];
}
return [ret.map(x => toRule(x)).join(' '), false];
};
while (i < length) {
const c = pattern[i];
if (c === '.') {
seq.push([getDot(), false]);
i += 1;
} else if (c === '(') {
i += 1;
if (i < length) {
if (pattern[i] === '?') {
throw new Error(`Unsupported pattern syntax "${pattern[i]}" at index ${i} of /${pattern}/`);
}
}
seq.push([`(${toRule(transform())})`, false]);
} else if (c === ')') {
i += 1;
if (start <= 0 || pattern[start - 1] !== '(') {
throw new Error(`Unbalanced parentheses; start = ${start}, i = ${i}, pattern = ${pattern}`);
}
return joinSeq();
} else if (c === '[') {
let squareBrackets = c;
i += 1;
while (i < length && pattern[i] !== ']') {
if (pattern[i] === '\\') {
squareBrackets += pattern.slice(i, i + 2);
i += 2;
} else {
squareBrackets += pattern[i];
i += 1;
}
}
if (i >= length) {
throw new Error(`Unbalanced square brackets; start = ${start}, i = ${i}, pattern = ${pattern}`);
}
squareBrackets += ']';
i += 1;
seq.push([squareBrackets, false]);
} else if (c === '|') {
seq.push(['|', false]);
i += 1;
} else if (c === '*' || c === '+' || c === '?') {
seq[seq.length - 1] = [toRule(seq[seq.length - 1]) + c, false];
i += 1;
} else if (c === '{') {
let curlyBrackets = c;
i += 1;
while (i < length && pattern[i] !== '}') {
curlyBrackets += pattern[i];
i += 1;
}
if (i >= length) {
throw new Error(`Unbalanced curly brackets; start = ${start}, i = ${i}, pattern = ${pattern}`);
}
curlyBrackets += '}';
i += 1;
const nums = curlyBrackets.slice(1, -1).split(',').map(s => s.trim());
let minTimes, maxTimes;
if (nums.length === 1) {
minTimes = parseInt(nums[0], 10);
maxTimes = minTimes;
} else {
if (nums.length !== 2) {
throw new Error(`Invalid quantifier ${curlyBrackets}`);
}
minTimes = nums[0] ? parseInt(nums[0], 10) : 0;
maxTimes = nums[1] ? parseInt(nums[1], 10) : Infinity;
}
let [sub, subIsLiteral] = seq[seq.length - 1];
if (minTimes === 0 && maxTimes === Infinity) {
seq[seq.length - 1] = [`${sub}*`, false];
} else if (minTimes === 0 && maxTimes === 1) {
seq[seq.length - 1] = [`${sub}?`, false];
} else if (minTimes === 1 && maxTimes === Infinity) {
seq[seq.length - 1] = [`${sub}+`, false];
} else {
if (!subIsLiteral) {
let id = subRuleIds[sub];
if (id === undefined) {
id = this._addRule(`${name}-${Object.keys(subRuleIds).length + 1}`, sub);
subRuleIds[sub] = id;
}
sub = id;
}
const repeatedSub = Array.from({ length: minTimes }, () => subIsLiteral ? `"${sub.slice(1, -1).repeat(minTimes)}"` : sub);
const optionalSub = maxTimes !== undefined ? Array.from({ length: maxTimes - minTimes }, () => `${sub}?`) : [`${sub}*`];
seq[seq.length - 1] = [repeatedSub.concat(optionalSub).join(' '), false];
}
} else {
let literal = '';
while (i < length) {
if (pattern[i] === '\\' && i < length - 1) {
const next = pattern[i + 1];
if (ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.has(next)) {
i += 1;
literal += pattern[i];
i += 1;
} else {
literal += pattern.slice(i, i + 2);
i += 2;
}
} else if (pattern[i] === '"') {
literal += '\\"';
i += 1;
} else if (!NON_LITERAL_SET.has(pattern[i]) &&
(i === length - 1 || literal === '' || pattern[i + 1] === '.' || !NON_LITERAL_SET.has(pattern[i+1]))) {
literal += pattern[i];
i += 1;
} else {
break;
}
}
if (literal !== '') {
seq.push([literal, true]);
}
}
}
return joinSeq();
};
return this._addRule(name, "\"\\\"\" " + toRule(transform()) + " \"\\\"\" space")
}
_resolveRef(ref) {
let refName = ref.split('/').pop();
if (!(refName in this._rules) && !this._refsBeingResolved.has(ref)) {
this._refsBeingResolved.add(ref);
const resolved = this._refs[ref];
refName = this.visit(resolved, refName);
this._refsBeingResolved.delete(ref);
}
return refName;
}
_generateConstantRule(value) {
if (typeof value !== 'string') {
throw new Error('Only string constants are supported, got ' + JSON.stringify(value));
}
return this._formatLiteral(value);
}
visit(schema, name) {
const schemaType = schema.type;
const ruleName = name || 'root';
const schemaFormat = schema.format;
const ruleName = name in RESERVED_NAMES ? name + '-' : name == '' ? 'root' : name;
if (schema.oneOf || schema.anyOf) {
const rule = (schema.oneOf || schema.anyOf).map((altSchema, i) =>
this.visit(altSchema, `${name}${name ? "-" : ""}${i}`)
).join(' | ');
return this._addRule(ruleName, rule);
const ref = schema.$ref;
if (ref !== undefined) {
return this._addRule(ruleName, this._resolveRef(ref));
} else if (schema.oneOf || schema.anyOf) {
return this._addRule(ruleName, this._generateUnionRule(name, schema.oneOf || schema.anyOf));
} else if (Array.isArray(schemaType)) {
return this._addRule(ruleName, this._generateUnionRule(name, schemaType.map(t => ({ type: t }))));
} else if ('const' in schema) {
return this._addRule(ruleName, this._formatLiteral(schema.const));
} else if ('enum' in schema) {
const rule = schema.enum.map(v => this._formatLiteral(v)).join(' | ');
return this._addRule(ruleName, rule);
} else if (schemaType === 'object' && 'properties' in schema) {
// TODO: `required` keyword (from python implementation)
const propOrder = this._propOrder;
const propPairs = Object.entries(schema.properties).sort((a, b) => {
// sort by position in prop_order (if specified) then by key
const orderA = typeof propOrder[a[0]] === 'number' ? propOrder[a[0]] : Infinity;
const orderB = typeof propOrder[b[0]] === 'number' ? propOrder[b[0]] : Infinity;
return orderA - orderB || a[0].localeCompare(b[0]);
});
let rule = '"{" space';
propPairs.forEach(([propName, propSchema], i) => {
const propRuleName = this.visit(propSchema, `${name}${name ? "-" : ""}${propName}`);
if (i > 0) {
rule += ' "," space';
if (typeof schema.const !== 'string') {
throw new Error('Only string constants are supported, got ' + JSON.stringify(schema.const));
}
return this._addRule(ruleName, this._generateConstantRule(schema.const));
} else if ('enum' in schema) {
const rule = schema.enum.map(v => this._generateConstantRule(v)).join(' | ');
return this._addRule(ruleName, rule);
} else if ((schemaType === undefined || schemaType === 'object') &&
('properties' in schema ||
('additionalProperties' in schema && schema.additionalProperties !== true))) {
const required = new Set(schema.required || []);
const properties = Object.entries(schema.properties ?? {});
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, schema.additionalProperties));
} else if ((schemaType === undefined || schemaType === 'object') && 'allOf' in schema) {
const required = new Set();
const properties = [];
const addComponent = (compSchema, isRequired) => {
const ref = compSchema.$ref;
if (ref !== undefined) {
compSchema = this._refs[ref];
}
rule += ` ${this._formatLiteral(propName)} space ":" space ${propRuleName}`;
});
rule += ' "}" space';
return this._addRule(ruleName, rule);
} else if (schemaType === 'array' && 'items' in schema) {
// TODO `prefixItems` keyword (from python implementation)
const itemRuleName = this.visit(schema.items, `${name}${name ? "-" : ""}item`);
const rule = `"[" space (${itemRuleName} ("," space ${itemRuleName})*)? "]" space`;
return this._addRule(ruleName, rule);
if ('properties' in compSchema) {
for (const [propName, propSchema] of Object.entries(compSchema.properties)) {
properties.push([propName, propSchema]);
if (isRequired) {
required.add(propName);
}
}
}
};
for (const t of schema.allOf) {
if ('anyOf' in t) {
for (const tt of t.anyOf) {
addComponent(tt, false);
}
} else {
if (!PRIMITIVE_RULES[schemaType]) {
addComponent(t, true);
}
}
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, /* additionalProperties= */ false));
} else if ((schemaType === undefined || schemaType === 'array') && ('items' in schema || 'prefixItems' in schema)) {
const items = schema.items ?? schema.prefixItems;
if (Array.isArray(items)) {
return this._addRule(
ruleName,
'"[" space ' +
items.map((item, i) => this.visit(item, `${name ?? ''}${name ? '-' : ''}tuple-${i}`)).join(' "," space ') +
' "]" space'
);
} else {
const itemRuleName = this.visit(items, `${name ?? ''}${name ? '-' : ''}item`);
const listItemOperator = `( "," space ${itemRuleName} )`;
let successiveItems = '';
let minItems = schema.minItems || 0;
const maxItems = schema.maxItems;
if (minItems > 0) {
successiveItems = listItemOperator.repeat(minItems - 1);
minItems--;
}
if (maxItems !== undefined && maxItems > minItems) {
successiveItems += `${listItemOperator}?`.repeat(maxItems - minItems - 1);
} else {
successiveItems += `${listItemOperator}*`;
}
const rule = minItems === 0
? `"[" space ( ${itemRuleName} ${successiveItems} )? "]" space`
: `"[" space ${itemRuleName} ${successiveItems} "]" space`;
return this._addRule(ruleName, rule);
}
} else if ((schemaType === undefined || schemaType === 'string') && 'pattern' in schema) {
return this._visitPattern(schema.pattern, ruleName);
} else if ((schemaType === undefined || schemaType === 'string') && /^uuid[1-5]?$/.test(schema.format || '')) {
return this._addRule(
ruleName === 'root' ? 'root' : schemaFormat,
PRIMITIVE_RULES['uuid'])
} else if ((schemaType === undefined || schemaType === 'string') && schema.format in DATE_RULES) {
for (const [t, r] of Object.entries(DATE_RULES)) {
this._addRule(t, r);
}
return schemaFormat + '-string';
} else if ((schemaType === 'object') || (Object.keys(schema).length === 0)) {
for (const n of OBJECT_RULE_NAMES) {
this._addRule(n, PRIMITIVE_RULES[n]);
}
return this._addRule(ruleName, 'object');
} else {
if (!(schemaType in PRIMITIVE_RULES)) {
throw new Error(`Unrecognized schema: ${JSON.stringify(schema)}`);
}
return this._addRule(
ruleName === 'root' ? 'root' : schemaType,
PRIMITIVE_RULES[schemaType]
// TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
return this._addRule(ruleName === 'root' ? 'root' : schemaType, PRIMITIVE_RULES[schemaType]);
}
}
_buildObjectRule(properties, required, name, additionalProperties) {
const propOrder = this._propOrder;
// sort by position in prop_order (if specified) then by original order
const sortedProps = properties.map(([k]) => k).sort((a, b) => {
const orderA = propOrder[a] || Infinity;
const orderB = propOrder[b] || Infinity;
return orderA - orderB || properties.findIndex(([k]) => k === a) - properties.findIndex(([k]) => k === b);
});
const propKvRuleNames = {};
for (const [propName, propSchema] of properties) {
const propRuleName = this.visit(propSchema, `${name ?? ''}${name ? '-' : ''}${propName}`);
propKvRuleNames[propName] = this._addRule(
`${name ?? ''}${name ? '-' : ''}${propName}-kv`,
`${this._formatLiteral(propName)} space ":" space ${propRuleName}`
);
}
const requiredProps = sortedProps.filter(k => required.has(k));
const optionalProps = sortedProps.filter(k => !required.has(k));
if (typeof additionalProperties === 'object' || additionalProperties === true) {
const subName = `${name ?? ''}${name ? '-' : ''}additional`;
const valueRule = this.visit(additionalProperties === true ? {} : additionalProperties, `${subName}-value`);
propKvRuleNames['*'] = this._addRule(
`${subName}-kv`,
`${this._addRule('string', PRIMITIVE_RULES['string'])} ":" space ${valueRule}`);
optionalProps.push('*');
}
let rule = '"{" space ';
rule += requiredProps.map(k => propKvRuleNames[k]).join(' "," space ');
if (optionalProps.length > 0) {
rule += ' (';
if (requiredProps.length > 0) {
rule += ' "," space ( ';
}
const getRecursiveRefs = (ks, firstIsOptional) => {
const [k, ...rest] = ks;
const kvRuleName = propKvRuleNames[k];
let res;
if (k === '*') {
res = this._addRule(
`${name ?? ''}${name ? '-' : ''}additional-kvs`,
`${kvRuleName} ( "," space ` + kvRuleName + ` )*`
)
} else if (firstIsOptional) {
res = `( "," space ${kvRuleName} )?`;
} else {
res = kvRuleName;
}
if (rest.length > 0) {
res += ' ' + this._addRule(
`${name ?? ''}${name ? '-' : ''}${k}-rest`,
getRecursiveRefs(rest, true)
);
}
return res;
};
rule += optionalProps.map((_, i) => getRecursiveRefs(optionalProps.slice(i), false)).join(' | ');
if (requiredProps.length > 0) {
rule += ' )';
}
rule += ' )?';
}
rule += ' "}" space';
return rule;
}
formatGrammar() {
let grammar = '';
this._rules.forEach((rule, name) => {
for (const [name, rule] of Object.entries(this._rules).sort(([a], [b]) => a.localeCompare(b))) {
grammar += `${name} ::= ${rule}\n`;
});
}
return grammar;
}
}
// Helper function to group elements by a key function
function* groupBy(iterable, keyFn) {
let lastKey = null;
let group = [];
for (const element of iterable) {
const key = keyFn(element);
if (lastKey !== null && key !== lastKey) {
yield [lastKey, group];
group = [];
}
group.push(element);
lastKey = key;
}
if (group.length > 0) {
yield [lastKey, group];
}
}

View file

@ -1,6 +1,7 @@
#include "utils.hpp"
#include "common.h"
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "grammar-parser.h"
@ -178,6 +179,7 @@ struct server_slot {
llama_token sampled;
struct llama_sampling_params sparams;
llama_sampling_context * ctx_sampling = nullptr;
json json_schema;
int32_t ga_i = 0; // group-attention state
int32_t ga_n = 1; // group-attention factor
@ -845,7 +847,17 @@ struct server_context {
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
slot.params.seed = json_value(data, "seed", default_params.seed);
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
auto schema = json_value(data, "json_schema", json::object());
slot.sparams.grammar = json_schema_to_grammar(schema);
} catch (const std::exception & e) {
send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
return false;
}
} else {
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
}
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);

View file

@ -37,6 +37,22 @@ Feature: Security
| llama.cpp | no |
| hackme | raised |
Scenario Outline: OAI Compatibility (invalid response formats)
Given a system prompt test
And a user prompt test
And a response format <response_format>
And a model test
And 2 max tokens to predict
And streaming is disabled
Given an OAI compatible chat completions request with raised api error
Examples: Prompts
| response_format |
| {"type": "sound"} |
| {"type": "json_object", "schema": 123} |
| {"type": "json_object", "schema": {"type": 123}} |
| {"type": "json_object", "schema": {"type": "hiccup"}} |
Scenario Outline: CORS Options
Given a user api key llama.cpp

View file

@ -70,6 +70,22 @@ Feature: llama.cpp server
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128 | (thanks\|happy\|bird\|Annabyear)+ | -1 | 64 | enabled | |
Scenario Outline: OAI Compatibility w/ response format
Given a model test
And a system prompt test
And a user prompt test
And a response format <response_format>
And 10 max tokens to predict
Given an OAI compatible chat completions request with no api error
Then <n_predicted> tokens are predicted matching <re_content>
Examples: Prompts
| response_format | n_predicted | re_content |
| {"type": "json_object", "schema": {"const": "42"}} | 5 | "42" |
| {"type": "json_object", "schema": {"items": [{"type": "integer"}]}} | 10 | \[ -300 \] |
| {"type": "json_object"} | 10 | \{ " Jacky. |
Scenario: Tokenize / Detokenize
When tokenizing:
"""

View file

@ -24,12 +24,16 @@ from prometheus_client import parser
def step_server_config(context, server_fqdn, server_port):
context.server_fqdn = server_fqdn
context.server_port = int(server_port)
context.n_gpu_layer = None
if 'PORT' in os.environ:
context.server_port = int(os.environ['PORT'])
print(f"$PORT set, overriding server port with to {context.server_port}")
if 'FQDN' in os.environ:
context.server_fqdn = os.environ['FQDN']
print(f"$FQDN set, overriding server fqdn with to {context.server_fqdn}")
if 'N_GPU_LAYERS' in os.environ:
context.n_gpu_layer = int(os.environ['N_GPU_LAYERS'])
print(f"$N_GPU_LAYERS set, overriding n_gpu_layer with to {context.n_gpu_layer}")
context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
@ -41,7 +45,6 @@ def step_server_config(context, server_fqdn, server_port):
context.n_ctx = None
context.n_ga = None
context.n_ga_w = None
context.n_gpu_layer = None
context.n_predict = None
context.n_prompts = 0
context.n_server_predict = None
@ -56,6 +59,7 @@ def step_server_config(context, server_fqdn, server_port):
context.seed = None
context.server_seed = None
context.user_api_key = None
context.response_format = None
context.tasks_result = []
context.concurrent_tasks = []
@ -266,6 +270,11 @@ def step_max_tokens(context, max_tokens):
context.n_predict = max_tokens
@step('a response format {response_format}')
def step_response_format(context, response_format):
context.response_format = json.loads(response_format)
@step('streaming is {enable_streaming}')
def step_streaming(context, enable_streaming):
context.enable_streaming = enable_streaming == 'enabled'
@ -381,6 +390,9 @@ async def step_oai_chat_completions(context, api_error):
enable_streaming=context.enable_streaming
if hasattr(context, 'enable_streaming') else None,
response_format=context.response_format
if hasattr(context, 'response_format') else None,
seed=await completions_seed(context),
user_api_key=context.user_api_key
@ -440,6 +452,8 @@ async def step_oai_chat_completions(context):
if hasattr(context, 'n_predict') else None,
enable_streaming=context.enable_streaming
if hasattr(context, 'enable_streaming') else None,
response_format=context.response_format
if hasattr(context, 'response_format') else None,
seed=await completions_seed(context),
user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None)
@ -460,6 +474,8 @@ async def step_oai_chat_completions(context):
if hasattr(context, 'n_predict') else None,
enable_streaming=context.enable_streaming
if hasattr(context, 'enable_streaming') else None,
response_format=context.response_format
if hasattr(context, 'response_format') else None,
seed=context.seed
if hasattr(context, 'seed') else
context.server_seed
@ -742,6 +758,7 @@ async def oai_chat_completions(user_prompt,
model=None,
n_predict=None,
enable_streaming=None,
response_format=None,
seed=None,
user_api_key=None,
expect_api_error=None):
@ -767,6 +784,8 @@ async def oai_chat_completions(user_prompt,
"stream": enable_streaming,
"seed": seed
}
if response_format is not None:
payload['response_format'] = response_format
completion_response = {
'content': '',
'timings': {
@ -827,6 +846,7 @@ async def oai_chat_completions(user_prompt,
model=model,
max_tokens=n_predict,
stream=enable_streaming,
response_format=payload.get('response_format'),
seed=seed
)
except openai.error.AuthenticationError as e:

View file

@ -373,10 +373,21 @@ static json oaicompat_completion_params_parse(
llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z);
llama_params["n_keep"] = json_value(body, "n_keep", 0);
if (body.count("grammar") != 0) {
if (body.contains("grammar")) {
llama_params["grammar"] = json_value(body, "grammar", json::object());
}
if (body.contains("response_format")) {
auto response_format = json_value(body, "response_format", json::object());
if (response_format.contains("type")) {
if (response_format["type"] == "json_object") {
llama_params["json_schema"] = json_value(response_format, "schema", json::object());
} else {
throw std::runtime_error("response_format type not supported: " + response_format["type"].dump());
}
}
}
// Handle 'stop' field
if (body.contains("stop") && body["stop"].is_string()) {
llama_params["stop"] = json::array({body["stop"].get<std::string>()});

28
examples/ts-type-to-grammar.sh Executable file
View file

@ -0,0 +1,28 @@
#!/bin/bash
#
# ./examples/ts-type-to-grammar.sh "{a:string,b:string,c?:string}"
# python examples/json-schema-to-grammar.py https://json.schemastore.org/tsconfig.json
#
set -euo pipefail
readonly type="$1"
# Create a temporary directory
TMPDIR=""
trap 'rm -fR "$TMPDIR"' EXIT
TMPDIR=$(mktemp -d)
DTS_FILE="$TMPDIR/type.d.ts"
SCHEMA_FILE="$TMPDIR/schema.json"
echo "export type MyType = $type" > "$DTS_FILE"
# This is a fork of typescript-json-schema, actively maintained as of March 2024:
# https://github.com/vega/ts-json-schema-generator
npx ts-json-schema-generator --unstable --no-top-ref --path "$DTS_FILE" --type MyType -e none > "$SCHEMA_FILE"
# Alternative, not actively maintained as of March 2024:
# https://github.com/YousefED/typescript-json-schema
# npx typescript-json-schema --defaultProps --required "$DTS_FILE" MyType | tee "$SCHEMA_FILE" >&2
./examples/json-schema-to-grammar.py "$SCHEMA_FILE"

File diff suppressed because it is too large Load diff

View file

@ -173,8 +173,9 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
//GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
//GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
GGML_METAL_KERNEL_TYPE_CONCAT,
@ -598,8 +599,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
@ -739,6 +741,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_IQ4_NL:
return true;
default:
return false;
@ -2436,8 +2441,9 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
//case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
//case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
default: GGML_ASSERT(false && "not implemented");
};
} break;

View file

@ -2388,6 +2388,242 @@ kernel void kernel_cpy_f32_q4_1(
}
}
kernel void kernel_cpy_f32_q5_0(
device const float * src0,
device void * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0;
device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < QK5_0; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
max = v;
}
}
const float d = max / -16;
const float id = d ? 1.0f/d : 0.0f;
dst_data[i00/QK5_0].d = d;
uint32_t qh = 0;
for (int j = 0; j < QK5_0/2; ++j) {
const float x0 = src[0 + j]*id;
const float x1 = src[QK5_0/2 + j]*id;
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
}
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
for (int j = 0; j < 4; ++j) {
dst_data[i00/QK5_0].qh[j] = qh8[j];
}
}
}
kernel void kernel_cpy_f32_q5_1(
device const float * src0,
device void * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
float max = src[0];
float min = src[0];
for (int j = 1; j < QK5_1; j++) {
const float v = src[j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = d ? 1.0f/d : 0.0f;
dst_data[i00/QK5_1].d = d;
dst_data[i00/QK5_1].m = min;
uint32_t qh = 0;
for (int j = 0; j < QK5_1/2; ++j) {
const float x0 = (src[0 + j] - min)*id;
const float x1 = (src[QK5_1/2 + j] - min)*id;
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
}
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
for (int j = 0; j < 4; ++j) {
dst_data[i00/QK5_1].qh[j] = qh8[j];
}
}
}
static inline int best_index_int8(int n, constant float * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
constexpr constant static float kvalues_iq4nl_f[16] = {
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
};
kernel void kernel_cpy_f32_iq4_nl(
device const float * src0,
device void * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL;
device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < QK4_0; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
max = v;
}
}
const float d = max / kvalues_iq4nl_f[0];
const float id = d ? 1.0f/d : 0.0f;
float sumqx = 0, sumq2 = 0;
for (int j = 0; j < QK4_NL/2; ++j) {
const float x0 = src[0 + j]*id;
const float x1 = src[QK4_NL/2 + j]*id;
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
const float v0 = kvalues_iq4nl_f[xi0];
const float v1 = kvalues_iq4nl_f[xi1];
const float w0 = src[0 + j]*src[0 + j];
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
}
}
kernel void kernel_concat(
device const char * src0,
device const char * src1,
@ -4220,10 +4456,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
}
}
constexpr constant static float kvalues_iq4nl_f[16] = {
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
};
void kernel_mul_mv_iq4_nl_f32_impl(
device const void * src0,
device const float * src1,

View file

@ -11705,9 +11705,8 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
float * scales, float * weight, uint8_t * L,
const int8_t * values,
const float * quant_weights) {
const int ntry = 7;
const float * quant_weights,
const int ntry) {
float sigma2 = 0;
for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j];
@ -11719,6 +11718,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
float max_scale = 0, amax_scale = 0;
for (int ib = 0; ib < super_block_size/block_size; ++ib) {
const float * xb = x + ib*block_size;
uint8_t * Lb = L + ib*block_size;
if (quant_weights) {
const float * qw = quant_weights + ib*block_size;
for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
@ -11736,12 +11736,13 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
scales[ib] = 0;
continue;
}
float d = -max/values[0];
float d = ntry > 0 ? -max/values[0] : max/values[0];
float id = 1/d;
float sumqx = 0, sumq2 = 0;
for (int j = 0; j < block_size; ++j) {
float al = id*xb[j];
int l = best_index_int8(16, values, al);
Lb[j] = l;
float q = values[l];
float w = weight[j];
sumqx += w*q*xb[j];
@ -11796,11 +11797,13 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
}
} else {
dh[0] = GGML_FP32_TO_FP16(scales[0]);
if (ntry > 0) {
float id = scales[0] ? 1/scales[0] : 0;
for (int j = 0; j < super_block_size; ++j) {
L[j] = best_index_int8(16, values, id*x[j]);
}
}
}
for (int i = 0; i < super_block_size/32; ++i) {
for (int j = 0; j < 16; ++j) {
@ -11823,7 +11826,7 @@ size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int nrow
for (int ibl = 0; ibl < nblock; ++ibl) {
const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
quantize_row_iq4_nl_impl(QK4_NL, 32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,
&scale, weight, L, kvalues_iq4nl, qw);
&scale, weight, L, kvalues_iq4nl, qw, 7);
}
src += n_per_row;
qrow += nblock*sizeof(block_iq4_nl);
@ -11832,14 +11835,23 @@ size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int nrow
}
void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) {
assert(k % QK4_NL == 0);
block_iq4_nl * restrict y = vy;
quantize_row_iq4_nl_reference(x, y, k);
GGML_ASSERT(k%QK4_NL == 0);
int nblock = k/QK4_NL;
uint8_t L[QK4_NL];
float weight[QK4_NL];
uint16_t unused_h;
uint8_t * unused_l = NULL;
float scale;
block_iq4_nl * iq4 = (block_iq4_nl *)vy;
for (int ibl = 0; ibl < nblock; ++ibl) {
quantize_row_iq4_nl_impl(QK4_NL, 32, x + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,
&scale, weight, L, kvalues_iq4nl, NULL, -1);
}
}
void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int k) {
assert(k % QK4_NL == 0);
quantize_iq4_nl(x, y, 1, k, NULL);
quantize_row_iq4_nl(x, y, k);
}
size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
@ -11857,7 +11869,7 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int nrow
for (int ibl = 0; ibl < nblock; ++ibl) {
const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL;
quantize_row_iq4_nl_impl(QK_K, 32, src + QK_K*ibl, &iq4[ibl].d, iq4[ibl].qs, &iq4[ibl].scales_h, iq4[ibl].scales_l,
scales, weight, L, kvalues_iq4nl, qw);
scales, weight, L, kvalues_iq4nl, qw, 7);
}
src += n_per_row;
qrow += nblock*sizeof(block_iq4_xs);

View file

@ -977,8 +977,10 @@ namespace dpct
static int convert_backend_index(std::string & backend) {
if (backend == "ext_oneapi_level_zero:gpu") return 0;
if (backend == "opencl:gpu") return 1;
if (backend == "opencl:cpu") return 2;
if (backend == "opencl:acc") return 3;
if (backend == "ext_oneapi_cuda:gpu") return 2;
if (backend == "ext_oneapi_hip:gpu") return 3;
if (backend == "opencl:cpu") return 4;
if (backend == "opencl:acc") return 5;
printf("convert_backend_index: can't handle backend=%s\n", backend.c_str());
GGML_ASSERT(false);
}

View file

@ -1,66 +1,75 @@
function(llama_build_executable source)
# Builds and runs a test source file.
# Optional args:
# - NAME: name of the executable & test target (defaults to the source file name without extension)
# - LABEL: label for the test (defaults to main)
# - ARGS: arguments to pass to the test executable
# - WORKING_DIRECTORY
function(llama_test source)
include(CMakeParseArguments)
set(options)
set(oneValueArgs NAME LABEL WORKING_DIRECTORY)
set(multiValueArgs ARGS)
cmake_parse_arguments(LLAMA_TEST "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if (NOT DEFINED LLAMA_TEST_LABEL)
set(LLAMA_TEST_LABEL "main")
endif()
if (NOT DEFINED LLAMA_TEST_WORKING_DIRECTORY)
set(LLAMA_TEST_WORKING_DIRECTORY .)
endif()
if (DEFINED LLAMA_TEST_NAME)
set(TEST_TARGET ${LLAMA_TEST_NAME})
else()
get_filename_component(TEST_TARGET ${source} NAME_WE)
endif()
add_executable(${TEST_TARGET} ${source} get-model.cpp)
install(TARGETS ${TEST_TARGET} RUNTIME)
target_link_libraries(${TEST_TARGET} PRIVATE common)
target_link_libraries(${TEST_TARGET} PRIVATE common json-schema-to-grammar)
add_test(
NAME ${TEST_TARGET}
WORKING_DIRECTORY ${LLAMA_TEST_WORKING_DIRECTORY}
COMMAND $<TARGET_FILE:${TEST_TARGET}>
${LLAMA_TEST_ARGS})
set_property(TEST ${TEST_TARGET} PROPERTY LABELS ${LLAMA_TEST_LABEL})
endfunction()
function(llama_test_executable name source)
get_filename_component(TEST_TARGET ${source} NAME_WE)
add_test(NAME ${name} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN})
set_property(TEST ${name} PROPERTY LABELS "main")
endfunction()
# llama_test(test-double-float.cpp) # SLOW
llama_test(test-quantize-fns.cpp)
llama_test(test-quantize-perf.cpp)
llama_test(test-sampling.cpp)
llama_test(test-chat-template.cpp)
function(llama_build_and_test_executable source)
llama_build_and_test_executable_with_label(${source} "main")
endfunction()
llama_test(test-tokenizer-0-llama.cpp NAME test-tokenizer-0-llama ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
llama_test(test-tokenizer-0-falcon.cpp NAME test-tokenizer-0-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
function(llama_build_and_test_executable_with_label source label)
get_filename_component(TEST_TARGET ${source} NAME_WE)
add_executable(${TEST_TARGET} ${source} get-model.cpp)
install(TARGETS ${TEST_TARGET} RUNTIME)
target_link_libraries(${TEST_TARGET} PRIVATE common)
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN})
set_property(TEST ${TEST_TARGET} PROPERTY LABELS ${label})
endfunction()
llama_test(test-tokenizer-1-llama.cpp NAME test-tokenizer-1-llama ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
llama_test(test-tokenizer-1-llama.cpp NAME test-tokenizer-1-baichuan ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf)
# llama_build_and_test_executable(test-double-float.cpp) # SLOW
llama_build_and_test_executable(test-quantize-fns.cpp)
llama_build_and_test_executable(test-quantize-perf.cpp)
llama_build_and_test_executable(test-sampling.cpp)
llama_build_and_test_executable(test-chat-template.cpp)
llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-aquila ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-mpt ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-stablelm-3b-4e1t ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-stablelm-3b-4e1t.gguf)
llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-gpt-neox ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-neox.gguf)
llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-gpt2 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt2.gguf)
#llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-bloom ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-bloom.gguf) # BIG
llama_build_executable(test-tokenizer-0-llama.cpp)
llama_test_executable (test-tokenizer-0-llama test-tokenizer-0-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
llama_test(test-grammar-parser.cpp)
llama_test(test-llama-grammar.cpp)
llama_test(test-grad0.cpp)
# llama_test(test-opt.cpp) # SLOW
llama_test(test-backend-ops.cpp)
llama_build_executable(test-tokenizer-0-falcon.cpp)
llama_test_executable (test-tokenizer-0-falcon test-tokenizer-0-falcon.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
llama_test(test-rope.cpp)
llama_build_executable(test-tokenizer-1-llama.cpp)
llama_test_executable (test-tokenizer-1-llama test-tokenizer-1-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
llama_test_executable (test-tokenizer-1-baichuan test-tokenizer-1-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf)
llama_test(test-model-load-cancel.cpp LABEL "model")
llama_test(test-autorelease.cpp LABEL "model")
llama_build_executable(test-tokenizer-1-bpe.cpp)
llama_test_executable (test-tokenizer-1-falcon test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
llama_test_executable (test-tokenizer-1-aquila test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
llama_test_executable (test-tokenizer-1-mpt test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
llama_test_executable (test-tokenizer-1-stablelm-3b-4e1t test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-stablelm-3b-4e1t.gguf)
llama_test_executable (test-tokenizer-1-gpt-neox test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-neox.gguf)
llama_test_executable (test-tokenizer-1-refact test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
llama_test_executable (test-tokenizer-1-starcoder test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
llama_test_executable (test-tokenizer-1-gpt2 test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt2.gguf)
# llama_test_executable (test-tokenizer-1-bloom test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-bloom.gguf) # BIG
llama_build_and_test_executable(test-grammar-parser.cpp)
llama_build_and_test_executable(test-llama-grammar.cpp)
llama_build_and_test_executable(test-grad0.cpp)
# llama_build_and_test_executable(test-opt.cpp) # SLOW
llama_build_and_test_executable(test-backend-ops.cpp)
llama_build_and_test_executable(test-rope.cpp)
llama_build_and_test_executable_with_label(test-model-load-cancel.cpp "model")
llama_build_and_test_executable_with_label(test-autorelease.cpp "model")
llama_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..)
target_include_directories(test-json-schema-to-grammar PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../examples/server)
# dummy executable - not installed
get_filename_component(TEST_TARGET test-c.c NAME_WE)

View file

@ -0,0 +1,10 @@
import { readFileSync } from "fs"
import { SchemaConverter } from "../examples/server/public/json-schema-to-grammar.mjs"
const [, , file] = process.argv
const url = `file://${file}`
let schema = JSON.parse(readFileSync(file, "utf8"));
const converter = new SchemaConverter({})
schema = await converter.resolveRefs(schema, url)
converter.visit(schema, '')
console.log(converter.formatGrammar())

View file

@ -0,0 +1,824 @@
#ifdef NDEBUG
#undef NDEBUG
#endif
#include <fstream>
#include <sstream>
#include <regex>
#include "json-schema-to-grammar.h"
#include "grammar-parser.h"
static std::string trim(const std::string & source) {
std::string s(source);
s.erase(0,s.find_first_not_of(" \n\r\t"));
s.erase(s.find_last_not_of(" \n\r\t")+1);
return std::regex_replace(s, std::regex("(^|\n)[ \t]+"), "$1");
}
enum TestCaseStatus {
SUCCESS,
FAILURE
};
struct TestCase {
TestCaseStatus expected_status;
std::string name;
std::string schema;
std::string expected_grammar;
void _print_failure_header() const {
fprintf(stderr, "#\n# Test '%s' failed.\n#\n%s\n", name.c_str(), schema.c_str());
}
void verify(const std::string & actual_grammar) const {
if (trim(actual_grammar) != trim(expected_grammar)) {
_print_failure_header();
fprintf(stderr, "# EXPECTED:\n%s\n# ACTUAL:\n%s\n", expected_grammar.c_str(), actual_grammar.c_str());
assert(false);
}
}
void verify_expectation_parseable() const {
try {
auto state = grammar_parser::parse(expected_grammar.c_str());
if (state.symbol_ids.find("root") == state.symbol_ids.end()) {
throw std::runtime_error("Grammar failed to parse:\n" + expected_grammar);
}
} catch (const std::runtime_error & ex) {
_print_failure_header();
fprintf(stderr, "# GRAMMAR ERROR: %s\n", ex.what());
assert(false);
}
}
void verify_status(TestCaseStatus status) const {
if (status != expected_status) {
_print_failure_header();
fprintf(stderr, "# EXPECTED STATUS: %s\n", expected_status == SUCCESS ? "SUCCESS" : "FAILURE");
fprintf(stderr, "# ACTUAL STATUS: %s\n", status == SUCCESS ? "SUCCESS" : "FAILURE");
assert(false);
}
}
};
static void write(const std::string & file, const std::string & content) {
std::ofstream f;
f.open(file.c_str());
f << content.c_str();
f.close();
}
static std::string read(const std::string & file) {
std::ostringstream actuals;
actuals << std::ifstream(file.c_str()).rdbuf();
return actuals.str();
}
static void test_all(const std::string & lang, std::function<void(const TestCase &)> runner) {
fprintf(stderr, "#\n# Testing JSON schema conversion (%s)\n#\n", lang.c_str());
auto test = [&](const TestCase & tc) {
fprintf(stderr, "- %s%s\n", tc.name.c_str(), tc.expected_status == FAILURE ? " (failure expected)" : "");
runner(tc);
};
test({
FAILURE,
"unknown type",
R"""({
"type": "kaboom"
})""",
""
});
test({
FAILURE,
"invalid type type",
R"""({
"type": 123
})""",
""
});
test({
SUCCESS,
"empty schema (object)",
"{}",
R"""(
array ::= "[" space ( value ("," space value)* )? "]" space
boolean ::= ("true" | "false") space
null ::= "null" space
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
root ::= object
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
value ::= object | array | string | number | boolean
)"""
});
test({
SUCCESS,
"exotic formats",
R"""({
"items": [
{ "format": "date" },
{ "format": "uuid" },
{ "format": "time" },
{ "format": "date-time" }
]
})""",
R"""(
date ::= [0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( "0" [1-9] | [1-2] [0-9] | "3" [0-1] )
date-string ::= "\"" date "\"" space
date-time ::= date "T" time
date-time-string ::= "\"" date-time "\"" space
root ::= "[" space date-string "," space uuid "," space time-string "," space date-time-string "]" space
space ::= " "?
time ::= ([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )
time-string ::= "\"" time "\"" space
uuid ::= "\"" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "-" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "-" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "-" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "-" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "\"" space
)"""
});
test({
SUCCESS,
"string",
R"""({
"type": "string"
})""",
R"""(
root ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
space ::= " "?
)"""
});
test({
SUCCESS,
"boolean",
R"""({
"type": "boolean"
})""",
R"""(
root ::= ("true" | "false") space
space ::= " "?
)"""
});
test({
SUCCESS,
"integer",
R"""({
"type": "integer"
})""",
R"""(
root ::= ("-"? ([0-9] | [1-9] [0-9]*)) space
space ::= " "?
)"""
});
test({
SUCCESS,
"string const",
R"""({
"const": "foo"
})""",
R"""(
root ::= "\"foo\""
space ::= " "?
)"""
});
test({
FAILURE,
"non-string const",
R"""({
"const": 123
})""",
""
});
test({
FAILURE,
"non-string enum",
R"""({
"enum": [123]
})""",
""
});
test({
SUCCESS,
"tuple1",
R"""({
"prefixItems": [{ "type": "string" }]
})""",
R"""(
root ::= "[" space string "]" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"tuple2",
R"""({
"prefixItems": [{ "type": "string" }, { "type": "number" }]
})""",
R"""(
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= "[" space string "," space number "]" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"number",
R"""({
"type": "number"
})""",
R"""(
root ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
space ::= " "?
)"""
});
test({
SUCCESS,
"minItems",
R"""({
"items": {
"type": "boolean"
},
"minItems": 2
})""",
R"""(
boolean ::= ("true" | "false") space
root ::= "[" space boolean ( "," space boolean )( "," space boolean )* "]" space
space ::= " "?
)"""
});
test({
SUCCESS,
"maxItems 1",
R"""({
"items": {
"type": "boolean"
},
"maxItems": 1
})""",
R"""(
boolean ::= ("true" | "false") space
root ::= "[" space ( boolean )? "]" space
space ::= " "?
)"""
});
test({
SUCCESS,
"maxItems 2",
R"""({
"items": {
"type": "boolean"
},
"maxItems": 2
})""",
R"""(
boolean ::= ("true" | "false") space
root ::= "[" space ( boolean ( "," space boolean )? )? "]" space
space ::= " "?
)"""
});
test({
SUCCESS,
"min + maxItems",
R"""({
"items": {
"type": ["number", "integer"]
},
"minItems": 3,
"maxItems": 5
})""",
R"""(
integer ::= ("-"? ([0-9] | [1-9] [0-9]*)) space
item ::= number | integer
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= "[" space item ( "," space item )( "," space item )( "," space item )?( "," space item )? "]" space
space ::= " "?
)"""
});
test({
SUCCESS,
"simple regexp",
R"""({
"type": "string",
"pattern": "^abc?d*efg+(hij)?kl$"
})""",
R"""(
root ::= "\"" "ab" "c"? "d"* "ef" "g"+ ("hij")? "kl" "\"" space
space ::= " "?
)"""
});
test({
SUCCESS,
"regexp escapes",
R"""({
"type": "string",
"pattern": "^\\[\\]\\{\\}\\(\\)\\|\\+\\*\\?$"
})""",
R"""(
root ::= "\"" "[]{}()|+*?" "\"" space
space ::= " "?
)"""
});
test({
SUCCESS,
"regexp quote",
R"""({
"type": "string",
"pattern": "^\"$"
})""",
R"""(
root ::= "\"" "\"" "\"" space
space ::= " "?
)"""
});
test({
SUCCESS,
"regexp",
R"""({
"type": "string",
"pattern": "^(\\([0-9]{1,3}\\))?[0-9]{3}-[0-9]{4} and...$"
})""",
R"""(
dot ::= [\U00000000-\x09\x0B\x0C\x0E-\U0010FFFF]
root ::= "\"" ("(" root-1 root-1? root-1? ")")? root-1 root-1 root-1 "-" root-1 root-1 root-1 root-1 " and" dot dot dot "\"" space
root-1 ::= [0-9]
space ::= " "?
)"""
});
test({
SUCCESS,
"required props",
R"""({
"type": "object",
"properties": {
"a": {
"type": "string"
},
"b": {
"type": "string"
}
},
"required": [
"a",
"b"
],
"additionalProperties": false,
"definitions": {}
})""",
R"""(
a-kv ::= "\"a\"" space ":" space string
b-kv ::= "\"b\"" space ":" space string
root ::= "{" space a-kv "," space b-kv "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"1 optional prop",
R"""({
"properties": {
"a": {
"type": "string"
}
},
"additionalProperties": false
})""",
R"""(
a-kv ::= "\"a\"" space ":" space string
root ::= "{" space (a-kv )? "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"N optional props",
R"""({
"properties": {
"a": {"type": "string"},
"b": {"type": "string"},
"c": {"type": "string"}
},
"additionalProperties": false
})""",
R"""(
a-kv ::= "\"a\"" space ":" space string
a-rest ::= ( "," space b-kv )? b-rest
b-kv ::= "\"b\"" space ":" space string
b-rest ::= ( "," space c-kv )?
c-kv ::= "\"c\"" space ":" space string
root ::= "{" space (a-kv a-rest | b-kv b-rest | c-kv )? "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"required + optional props",
R"""({
"properties": {
"a": {"type": "string"},
"b": {"type": "string"},
"c": {"type": "string"},
"d": {"type": "string"}
},
"required": ["a", "b"],
"additionalProperties": false
})""",
R"""(
a-kv ::= "\"a\"" space ":" space string
b-kv ::= "\"b\"" space ":" space string
c-kv ::= "\"c\"" space ":" space string
c-rest ::= ( "," space d-kv )?
d-kv ::= "\"d\"" space ":" space string
root ::= "{" space a-kv "," space b-kv ( "," space ( c-kv c-rest | d-kv ) )? "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"additional props",
R"""({
"type": "object",
"additionalProperties": {"type": "array", "items": {"type": "number"}}
})""",
R"""(
additional-kv ::= string ":" space additional-value
additional-kvs ::= additional-kv ( "," space additional-kv )*
additional-value ::= "[" space ( number ( "," space number )* )? "]" space
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= "{" space (additional-kvs )? "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"additional props (true)",
R"""({
"type": "object",
"additionalProperties": true
})""",
R"""(
array ::= "[" space ( value ("," space value)* )? "]" space
boolean ::= ("true" | "false") space
null ::= "null" space
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
root ::= object
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
value ::= object | array | string | number | boolean
)"""
});
test({
SUCCESS,
"additional props (implicit)",
R"""({
"type": "object"
})""",
R"""(
array ::= "[" space ( value ("," space value)* )? "]" space
boolean ::= ("true" | "false") space
null ::= "null" space
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
root ::= object
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
value ::= object | array | string | number | boolean
)"""
});
test({
SUCCESS,
"empty w/o additional props",
R"""({
"type": "object",
"additionalProperties": false
})""",
R"""(
root ::= "{" space "}" space
space ::= " "?
)"""
});
test({
SUCCESS,
"required + additional props",
R"""({
"type": "object",
"properties": {
"a": {"type": "number"}
},
"required": ["a"],
"additionalProperties": {"type": "string"}
})""",
R"""(
a-kv ::= "\"a\"" space ":" space number
additional-kv ::= string ":" space string
additional-kvs ::= additional-kv ( "," space additional-kv )*
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= "{" space a-kv ( "," space ( additional-kvs ) )? "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"optional + additional props",
R"""({
"type": "object",
"properties": {
"a": {"type": "number"}
},
"additionalProperties": {"type": "number"}
})""",
R"""(
a-kv ::= "\"a\"" space ":" space number
a-rest ::= additional-kvs
additional-kv ::= string ":" space number
additional-kvs ::= additional-kv ( "," space additional-kv )*
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= "{" space (a-kv a-rest | additional-kvs )? "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"required + optional + additional props",
R"""({
"type": "object",
"properties": {
"a": {"type": "number"},
"b": {"type": "number"}
},
"required": ["a"],
"additionalProperties": {"type": "number"}
})""",
R"""(
a-kv ::= "\"a\"" space ":" space number
additional-kv ::= string ":" space number
additional-kvs ::= additional-kv ( "," space additional-kv )*
b-kv ::= "\"b\"" space ":" space number
b-rest ::= additional-kvs
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= "{" space a-kv ( "," space ( b-kv b-rest | additional-kvs ) )? "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"top-level $ref",
R"""({
"$ref": "#/definitions/MyType",
"definitions": {
"MyType": {
"type": "object",
"properties": {
"a": {
"type": "string"
}
},
"required": [
"a"
],
"additionalProperties": false
}
}
})""",
R"""(
MyType ::= "{" space MyType-a-kv "}" space
MyType-a-kv ::= "\"a\"" space ":" space string
root ::= MyType
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)"""
});
test({
SUCCESS,
"anyOf",
R"""({
"anyOf": [
{"$ref": "#/definitions/foo"},
{"$ref": "#/definitions/bar"}
],
"definitions": {
"foo": {
"properties": {"a": {"type": "number"}}
},
"bar": {
"properties": {"b": {"type": "number"}}
}
},
"type": "object"
})""",
R"""(
alternative-0 ::= foo
alternative-1 ::= bar
bar ::= "{" space (bar-b-kv )? "}" space
bar-b-kv ::= "\"b\"" space ":" space number
foo ::= "{" space (foo-a-kv )? "}" space
foo-a-kv ::= "\"a\"" space ":" space number
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= alternative-0 | alternative-1
space ::= " "?
)"""
});
test({
SUCCESS,
"mix of allOf, anyOf and $ref (similar to https://json.schemastore.org/tsconfig.json)",
R"""({
"allOf": [
{"$ref": "#/definitions/foo"},
{"$ref": "#/definitions/bar"},
{
"anyOf": [
{"$ref": "#/definitions/baz"},
{"$ref": "#/definitions/bam"}
]
}
],
"definitions": {
"foo": {
"properties": {"a": {"type": "number"}}
},
"bar": {
"properties": {"b": {"type": "number"}}
},
"bam": {
"properties": {"c": {"type": "number"}}
},
"baz": {
"properties": {"d": {"type": "number"}}
}
},
"type": "object"
})""",
R"""(
a-kv ::= "\"a\"" space ":" space number
b-kv ::= "\"b\"" space ":" space number
c-kv ::= "\"c\"" space ":" space number
d-kv ::= "\"d\"" space ":" space number
d-rest ::= ( "," space c-kv )?
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
root ::= "{" space a-kv "," space b-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
space ::= " "?
)"""
});
test({
SUCCESS,
"conflicting names",
R"""({
"type": "object",
"properties": {
"number": {
"type": "object",
"properties": {
"number": {
"type": "object",
"properties": {
"root": {
"type": "number"
}
},
"required": [
"root"
],
"additionalProperties": false
}
},
"required": [
"number"
],
"additionalProperties": false
}
},
"required": [
"number"
],
"additionalProperties": false,
"definitions": {}
})""",
R"""(
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
number- ::= "{" space number-number-kv "}" space
number-kv ::= "\"number\"" space ":" space number-
number-number ::= "{" space number-number-root-kv "}" space
number-number-kv ::= "\"number\"" space ":" space number-number
number-number-root-kv ::= "\"root\"" space ":" space number
root ::= "{" space number-kv "}" space
space ::= " "?
)"""
});
}
int main() {
test_all("C++", [](const TestCase & tc) {
try {
tc.verify(json_schema_to_grammar(nlohmann::json::parse(tc.schema)));
tc.verify_status(SUCCESS);
} catch (const std::runtime_error & ex) {
fprintf(stderr, "Error: %s\n", ex.what());
tc.verify_status(FAILURE);
}
});
test_all("Python", [](const TestCase & tc) {
write("test-json-schema-input.tmp", tc.schema);
tc.verify_status(std::system(
"python ./examples/json-schema-to-grammar.py test-json-schema-input.tmp > test-grammar-output.tmp") == 0 ? SUCCESS : FAILURE);
tc.verify(read("test-grammar-output.tmp"));
});
test_all("JavaScript", [](const TestCase & tc) {
write("test-json-schema-input.tmp", tc.schema);
tc.verify_status(std::system(
"node ./tests/run-json-schema-to-grammar.mjs test-json-schema-input.tmp > test-grammar-output.tmp") == 0 ? SUCCESS : FAILURE);
tc.verify(read("test-grammar-output.tmp"));
});
test_all("Check Expectations Validity", [](const TestCase & tc) {
if (tc.expected_status == SUCCESS) {
tc.verify_expectation_parseable();
}
});
}