Merge branch 'master' into sycl_readme_update
This commit is contained in:
commit
a80267a110
40 changed files with 12162 additions and 6703 deletions
114
.github/workflows/build.yml
vendored
114
.github/workflows/build.yml
vendored
|
@ -21,6 +21,118 @@ env:
|
|||
GGML_N_THREADS: 1
|
||||
|
||||
jobs:
|
||||
macOS-latest-cmake-arm64:
|
||||
runs-on: macos-14
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
run: |
|
||||
brew update
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
sysctl -a
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_METAL_EMBED_LIBRARY=ON ..
|
||||
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu)
|
||||
|
||||
- name: Test
|
||||
id: cmake_test
|
||||
run: |
|
||||
cd build
|
||||
ctest -L main --verbose --timeout 900
|
||||
|
||||
- name: Determine tag name
|
||||
id: tag
|
||||
shell: bash
|
||||
run: |
|
||||
BUILD_NUMBER="$(git rev-list --count HEAD)"
|
||||
SHORT_HASH="$(git rev-parse --short=7 HEAD)"
|
||||
if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then
|
||||
echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT
|
||||
else
|
||||
SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-')
|
||||
echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Pack artifacts
|
||||
id: pack_artifacts
|
||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||
run: |
|
||||
cp LICENSE ./build/bin/
|
||||
zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip ./build/bin/*
|
||||
|
||||
- name: Upload artifacts
|
||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
path: |
|
||||
llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip
|
||||
|
||||
macOS-latest-cmake-x64:
|
||||
runs-on: macos-latest
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
run: |
|
||||
brew update
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
sysctl -a
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_METAL_EMBED_LIBRARY=ON ..
|
||||
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu)
|
||||
|
||||
- name: Test
|
||||
id: cmake_test
|
||||
run: |
|
||||
cd build
|
||||
ctest -L main --verbose --timeout 900
|
||||
|
||||
- name: Determine tag name
|
||||
id: tag
|
||||
shell: bash
|
||||
run: |
|
||||
BUILD_NUMBER="$(git rev-list --count HEAD)"
|
||||
SHORT_HASH="$(git rev-parse --short=7 HEAD)"
|
||||
if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then
|
||||
echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT
|
||||
else
|
||||
SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-')
|
||||
echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Pack artifacts
|
||||
id: pack_artifacts
|
||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||
run: |
|
||||
cp LICENSE ./build/bin/
|
||||
zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip ./build/bin/*
|
||||
|
||||
- name: Upload artifacts
|
||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
path: |
|
||||
llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip
|
||||
|
||||
ubuntu-focal-make:
|
||||
runs-on: ubuntu-20.04
|
||||
|
||||
|
@ -748,6 +860,8 @@ jobs:
|
|||
- macOS-latest-cmake
|
||||
- windows-latest-cmake
|
||||
- windows-latest-cmake-cublas
|
||||
- macOS-latest-cmake-arm64
|
||||
- macOS-latest-cmake-x64
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
|
|
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -11,7 +11,10 @@
|
|||
*.gcda
|
||||
*.dot
|
||||
*.bat
|
||||
*.tmp
|
||||
*.metallib
|
||||
*.etag
|
||||
*.lastModified
|
||||
.DS_Store
|
||||
.build/
|
||||
.cache/
|
||||
|
|
12
Makefile
12
Makefile
|
@ -9,7 +9,8 @@ TEST_TARGETS = \
|
|||
tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt \
|
||||
tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama \
|
||||
tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama tests/test-tokenizer-1-bpe tests/test-rope \
|
||||
tests/test-backend-ops tests/test-model-load-cancel tests/test-autorelease
|
||||
tests/test-backend-ops tests/test-model-load-cancel tests/test-autorelease \
|
||||
tests/test-json-schema-to-grammar
|
||||
|
||||
# Code coverage output files
|
||||
COV_TARGETS = *.gcno tests/*.gcno *.gcda tests/*.gcda *.gcov tests/*.gcov lcov-report gcovr-report
|
||||
|
@ -666,6 +667,9 @@ console.o: common/console.cpp common/console.h
|
|||
grammar-parser.o: common/grammar-parser.cpp common/grammar-parser.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
||||
json-schema-to-grammar.o: common/json-schema-to-grammar.cpp common/json-schema-to-grammar.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
||||
train.o: common/train.cpp common/train.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
||||
|
@ -745,7 +749,7 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
|
|||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
||||
server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
|
||||
server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h common/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp json-schema-to-grammar.o common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
|
||||
|
||||
|
@ -865,6 +869,10 @@ tests/test-double-float: tests/test-double-float.cpp ggml.o $(OBJS)
|
|||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
||||
tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp json-schema-to-grammar.o ggml.o llama.o grammar-parser.o $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
||||
tests/test-grad0: tests/test-grad0.cpp ggml.o $(OBJS)
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||
|
|
|
@ -122,6 +122,7 @@ pub fn build(b: *std.build.Builder) !void {
|
|||
const console = make.obj("console", "common/console.cpp");
|
||||
const sampling = make.obj("sampling", "common/sampling.cpp");
|
||||
const grammar_parser = make.obj("grammar-parser", "common/grammar-parser.cpp");
|
||||
const json_schema_to_grammar = make.obj("json-schema-to-grammar", "common/json-schema-to-grammar.cpp");
|
||||
const train = make.obj("train", "common/train.cpp");
|
||||
const clip = make.obj("clip", "examples/llava/clip.cpp");
|
||||
const llava = make.obj("llava", "examples/llava/llava.cpp");
|
||||
|
@ -133,7 +134,7 @@ pub fn build(b: *std.build.Builder) !void {
|
|||
_ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, train });
|
||||
_ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, train });
|
||||
|
||||
const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, sampling, grammar_parser, clip, llava });
|
||||
const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, common, buildinfo, sampling, grammar_parser, json_schema_to_grammar, clip, llava });
|
||||
if (server.target.isWindows()) {
|
||||
server.linkSystemLibrary("ws2_32");
|
||||
}
|
||||
|
|
|
@ -47,6 +47,8 @@ if (BUILD_SHARED_LIBS)
|
|||
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
endif()
|
||||
|
||||
set(TARGET json-schema-to-grammar)
|
||||
add_library(${TARGET} OBJECT json-schema-to-grammar.cpp json-schema-to-grammar.h)
|
||||
|
||||
set(TARGET common)
|
||||
|
||||
|
@ -60,6 +62,7 @@ add_library(${TARGET} STATIC
|
|||
console.cpp
|
||||
grammar-parser.h
|
||||
grammar-parser.cpp
|
||||
json.hpp
|
||||
train.h
|
||||
train.cpp
|
||||
)
|
||||
|
|
|
@ -1590,6 +1590,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
|
|||
if (s == "q4_1") {
|
||||
return GGML_TYPE_Q4_1;
|
||||
}
|
||||
if (s == "iq4_nl") {
|
||||
return GGML_TYPE_IQ4_NL;
|
||||
}
|
||||
if (s == "q5_0") {
|
||||
return GGML_TYPE_Q5_0;
|
||||
}
|
||||
|
|
725
common/json-schema-to-grammar.cpp
Normal file
725
common/json-schema-to-grammar.cpp
Normal file
|
@ -0,0 +1,725 @@
|
|||
#include "json-schema-to-grammar.h"
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <regex>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
const std::string SPACE_RULE = "\" \"?";
|
||||
|
||||
std::unordered_map<std::string, std::string> PRIMITIVE_RULES = {
|
||||
{"boolean", "(\"true\" | \"false\") space"},
|
||||
{"number", "(\"-\"? ([0-9] | [1-9] [0-9]*)) (\".\" [0-9]+)? ([eE] [-+]? [0-9]+)? space"},
|
||||
{"integer", "(\"-\"? ([0-9] | [1-9] [0-9]*)) space"},
|
||||
{"value", "object | array | string | number | boolean"},
|
||||
{"object", "\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space"},
|
||||
{"array", "\"[\" space ( value (\",\" space value)* )? \"]\" space"},
|
||||
{"uuid", "\"\\\"\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
|
||||
"\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
|
||||
"\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
|
||||
"\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "
|
||||
"\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] \"\\\"\" space"},
|
||||
{"string", " \"\\\"\" (\n"
|
||||
" [^\"\\\\] |\n"
|
||||
" \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])\n"
|
||||
" )* \"\\\"\" space"},
|
||||
{"null", "\"null\" space"}
|
||||
};
|
||||
std::vector<std::string> OBJECT_RULE_NAMES = {"object", "array", "string", "number", "boolean", "null", "value"};
|
||||
|
||||
std::unordered_map<std::string, std::string> DATE_RULES = {
|
||||
{"date", "[0-9] [0-9] [0-9] [0-9] \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )"},
|
||||
{"time", "([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9] [0-9] [0-9] )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )"},
|
||||
{"date-time", "date \"T\" time"},
|
||||
{"date-string", "\"\\\"\" date \"\\\"\" space"},
|
||||
{"time-string", "\"\\\"\" time \"\\\"\" space"},
|
||||
{"date-time-string", "\"\\\"\" date-time \"\\\"\" space"}
|
||||
};
|
||||
|
||||
static bool is_reserved_name(const std::string & name) {
|
||||
static std::unordered_set<std::string> RESERVED_NAMES;
|
||||
if (RESERVED_NAMES.empty()) {
|
||||
RESERVED_NAMES.insert("root");
|
||||
for (const auto &p : PRIMITIVE_RULES) RESERVED_NAMES.insert(p.first);
|
||||
for (const auto &p : DATE_RULES) RESERVED_NAMES.insert(p.first);
|
||||
}
|
||||
return RESERVED_NAMES.find(name) != RESERVED_NAMES.end();
|
||||
}
|
||||
|
||||
std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
|
||||
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]");
|
||||
std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
|
||||
std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
|
||||
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}
|
||||
};
|
||||
|
||||
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
|
||||
std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
|
||||
|
||||
template <typename Iterator>
|
||||
std::string join(Iterator begin, Iterator end, const std::string & separator) {
|
||||
std::ostringstream result;
|
||||
if (begin != end) {
|
||||
result << *begin;
|
||||
for (Iterator it = begin + 1; it != end; ++it) {
|
||||
result << separator << *it;
|
||||
}
|
||||
}
|
||||
return result.str();
|
||||
}
|
||||
|
||||
static std::vector<std::string> split(const std::string & str, const std::string & delimiter) {
|
||||
std::vector<std::string> tokens;
|
||||
size_t start = 0;
|
||||
size_t end = str.find(delimiter);
|
||||
|
||||
while (end != std::string::npos) {
|
||||
tokens.push_back(str.substr(start, end - start));
|
||||
start = end + delimiter.length();
|
||||
end = str.find(delimiter, start);
|
||||
}
|
||||
|
||||
tokens.push_back(str.substr(start));
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
static std::string repeat(const std::string & str, size_t n) {
|
||||
if (n == 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string result;
|
||||
result.reserve(str.length() * n);
|
||||
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
result += str;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
|
||||
std::smatch match;
|
||||
std::string result;
|
||||
|
||||
std::string::const_iterator searchStart(input.cbegin());
|
||||
std::string::const_iterator searchEnd(input.cend());
|
||||
|
||||
while (std::regex_search(searchStart, searchEnd, match, regex)) {
|
||||
result.append(searchStart, searchStart + match.position());
|
||||
result.append(replacement(match));
|
||||
searchStart = match.suffix().first;
|
||||
}
|
||||
|
||||
result.append(searchStart, searchEnd);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::string format_literal(const std::string & literal) {
|
||||
std::string escaped = replacePattern(json(literal).dump(), GRAMMAR_LITERAL_ESCAPE_RE, [&](const std::smatch & match) {
|
||||
char c = match.str()[0];
|
||||
return GRAMMAR_LITERAL_ESCAPES.at(c);
|
||||
});
|
||||
return "\"" + escaped + "\"";
|
||||
}
|
||||
|
||||
|
||||
class SchemaConverter {
|
||||
private:
|
||||
std::function<json(const std::string &)> _fetch_json;
|
||||
bool _dotall;
|
||||
std::map<std::string, std::string> _rules;
|
||||
std::unordered_map<std::string, nlohmann::json> _refs;
|
||||
std::unordered_set<std::string> _refs_being_resolved;
|
||||
std::vector<std::string> _errors;
|
||||
std::vector<std::string> _warnings;
|
||||
|
||||
std::string _add_rule(const std::string & name, const std::string & rule) {
|
||||
std::string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-");
|
||||
if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) {
|
||||
_rules[esc_name] = rule;
|
||||
return esc_name;
|
||||
} else {
|
||||
int i = 0;
|
||||
while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) {
|
||||
i++;
|
||||
}
|
||||
std::string key = esc_name + std::to_string(i);
|
||||
_rules[key] = rule;
|
||||
return key;
|
||||
}
|
||||
}
|
||||
|
||||
std::string _generate_union_rule(const std::string & name, const std::vector<json> & alt_schemas) {
|
||||
std::vector<std::string> rules;
|
||||
for (size_t i = 0; i < alt_schemas.size(); i++) {
|
||||
rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
|
||||
}
|
||||
return join(rules.begin(), rules.end(), " | ");
|
||||
}
|
||||
|
||||
std::string _visit_pattern(const std::string & pattern, const std::string & name) {
|
||||
if (!(pattern.front() == '^' && pattern.back() == '$')) {
|
||||
_errors.push_back("Pattern must start with '^' and end with '$'");
|
||||
return "";
|
||||
}
|
||||
std::string sub_pattern = pattern.substr(1, pattern.length() - 2);
|
||||
std::unordered_map<std::string, std::string> sub_rule_ids;
|
||||
|
||||
size_t i = 0;
|
||||
size_t length = sub_pattern.length();
|
||||
|
||||
using literal_or_rule = std::pair<std::string, bool>;
|
||||
auto to_rule = [&](const literal_or_rule & ls) {
|
||||
auto is_literal = ls.second;
|
||||
auto s = ls.first;
|
||||
return is_literal ? "\"" + s + "\"" : s;
|
||||
};
|
||||
std::function<literal_or_rule()> transform = [&]() -> literal_or_rule {
|
||||
size_t start = i;
|
||||
std::vector<literal_or_rule> seq;
|
||||
|
||||
auto get_dot = [&]() {
|
||||
std::string rule;
|
||||
if (_dotall) {
|
||||
rule = "[\\U00000000-\\U0010FFFF]";
|
||||
} else {
|
||||
rule = "[\\U00000000-\\x09\\x0B\\x0C\\x0E-\\U0010FFFF]";
|
||||
}
|
||||
return _add_rule("dot", rule);
|
||||
};
|
||||
|
||||
// Joins the sequence, merging consecutive literals together.
|
||||
auto join_seq = [&]() {
|
||||
std::vector<literal_or_rule> ret;
|
||||
|
||||
std::string literal;
|
||||
auto flush_literal = [&]() {
|
||||
if (literal.empty()) {
|
||||
return false;
|
||||
}
|
||||
ret.push_back(std::make_pair(literal, true));
|
||||
literal.clear();
|
||||
return true;
|
||||
};
|
||||
|
||||
for (const auto & item : seq) {
|
||||
auto is_literal = item.second;
|
||||
if (is_literal) {
|
||||
literal += item.first;
|
||||
} else {
|
||||
flush_literal();
|
||||
ret.push_back(item);
|
||||
}
|
||||
}
|
||||
flush_literal();
|
||||
|
||||
std::vector<std::string> results;
|
||||
for (const auto & item : ret) {
|
||||
results.push_back(to_rule(item));
|
||||
}
|
||||
return std::make_pair(join(results.begin(), results.end(), " "), false);
|
||||
};
|
||||
|
||||
while (i < length) {
|
||||
char c = sub_pattern[i];
|
||||
if (c == '.') {
|
||||
seq.push_back(std::make_pair(get_dot(), false));
|
||||
i++;
|
||||
} else if (c == '(') {
|
||||
i++;
|
||||
if (i < length) {
|
||||
if (sub_pattern[i] == '?') {
|
||||
_warnings.push_back("Unsupported pattern syntax");
|
||||
}
|
||||
}
|
||||
seq.push_back(std::make_pair("(" + to_rule(transform()) + ")", false));
|
||||
} else if (c == ')') {
|
||||
i++;
|
||||
if (start > 0 && sub_pattern[start - 1] != '(') {
|
||||
_errors.push_back("Unbalanced parentheses");
|
||||
}
|
||||
return join_seq();
|
||||
} else if (c == '[') {
|
||||
std::string square_brackets = std::string(1, c);
|
||||
i++;
|
||||
while (i < length && sub_pattern[i] != ']') {
|
||||
if (sub_pattern[i] == '\\') {
|
||||
square_brackets += sub_pattern.substr(i, 2);
|
||||
i += 2;
|
||||
} else {
|
||||
square_brackets += sub_pattern[i];
|
||||
i++;
|
||||
}
|
||||
}
|
||||
if (i >= length) {
|
||||
_errors.push_back("Unbalanced square brackets");
|
||||
}
|
||||
square_brackets += ']';
|
||||
i++;
|
||||
seq.push_back(std::make_pair(square_brackets, false));
|
||||
} else if (c == '|') {
|
||||
seq.push_back(std::make_pair("|", false));
|
||||
i++;
|
||||
} else if (c == '*' || c == '+' || c == '?') {
|
||||
seq.back() = std::make_pair(to_rule(seq.back()) + c, false);
|
||||
i++;
|
||||
} else if (c == '{') {
|
||||
std::string curly_brackets = std::string(1, c);
|
||||
i++;
|
||||
while (i < length && sub_pattern[i] != '}') {
|
||||
curly_brackets += sub_pattern[i];
|
||||
i++;
|
||||
}
|
||||
if (i >= length) {
|
||||
_errors.push_back("Unbalanced curly brackets");
|
||||
}
|
||||
curly_brackets += '}';
|
||||
i++;
|
||||
auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
|
||||
int min_times = 0;
|
||||
int max_times = std::numeric_limits<int>::max();
|
||||
try {
|
||||
if (nums.size() == 1) {
|
||||
min_times = max_times = std::stoi(nums[0]);
|
||||
} else if (nums.size() != 2) {
|
||||
_errors.push_back("Wrong number of values in curly brackets");
|
||||
} else {
|
||||
if (!nums[0].empty()) {
|
||||
min_times = std::stoi(nums[0]);
|
||||
}
|
||||
if (!nums[1].empty()) {
|
||||
max_times = std::stoi(nums[1]);
|
||||
}
|
||||
}
|
||||
} catch (const std::invalid_argument & e) {
|
||||
_errors.push_back("Invalid number in curly brackets");
|
||||
return std::make_pair("", false);
|
||||
}
|
||||
auto &last = seq.back();
|
||||
auto &sub = last.first;
|
||||
auto sub_is_literal = last.second;
|
||||
|
||||
if (min_times == 0 && max_times == std::numeric_limits<int>::max()) {
|
||||
sub += "*";
|
||||
} else if (min_times == 0 && max_times == 1) {
|
||||
sub += "?";
|
||||
} else if (min_times == 1 && max_times == std::numeric_limits<int>::max()) {
|
||||
sub += "+";
|
||||
} else {
|
||||
if (!sub_is_literal) {
|
||||
std::string & sub_id = sub_rule_ids[sub];
|
||||
if (sub_id.empty()) {
|
||||
sub_id = _add_rule(name + "-" + std::to_string(sub_rule_ids.size()), sub);
|
||||
}
|
||||
sub = sub_id;
|
||||
}
|
||||
std::string result;
|
||||
if (sub_is_literal && min_times > 0) {
|
||||
result = "\"" + repeat(sub.substr(1, sub.length() - 2), min_times) + "\"";
|
||||
} else {
|
||||
for (int j = 0; j < min_times; j++) {
|
||||
if (j > 0) {
|
||||
result += " ";
|
||||
}
|
||||
result += sub;
|
||||
}
|
||||
}
|
||||
if (min_times > 0 && min_times < max_times) {
|
||||
result += " ";
|
||||
}
|
||||
if (max_times == std::numeric_limits<int>::max()) {
|
||||
result += sub + "*";
|
||||
} else {
|
||||
for (int j = min_times; j < max_times; j++) {
|
||||
if (j > min_times) {
|
||||
result += " ";
|
||||
}
|
||||
result += sub + "?";
|
||||
}
|
||||
}
|
||||
seq.back().first = result;
|
||||
seq.back().second = false;
|
||||
}
|
||||
} else {
|
||||
std::string literal;
|
||||
auto is_non_literal = [&](char c) {
|
||||
return NON_LITERAL_SET.find(c) != NON_LITERAL_SET.end();
|
||||
};
|
||||
while (i < length) {
|
||||
if (sub_pattern[i] == '\\' && i < length - 1) {
|
||||
char next = sub_pattern[i + 1];
|
||||
if (ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.find(next) != ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.end()) {
|
||||
i++;
|
||||
literal += sub_pattern[i];
|
||||
i++;
|
||||
} else {
|
||||
literal += sub_pattern.substr(i, 2);
|
||||
i += 2;
|
||||
}
|
||||
} else if (sub_pattern[i] == '"') {
|
||||
literal += "\\\"";
|
||||
i++;
|
||||
} else if (!is_non_literal(sub_pattern[i]) &&
|
||||
(i == length - 1 || literal.empty() || sub_pattern[i + 1] == '.' || !is_non_literal(sub_pattern[i + 1]))) {
|
||||
literal += sub_pattern[i];
|
||||
i++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!literal.empty()) {
|
||||
seq.push_back(std::make_pair(literal, true));
|
||||
}
|
||||
}
|
||||
}
|
||||
return join_seq();
|
||||
};
|
||||
return _add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space");
|
||||
}
|
||||
|
||||
std::string _resolve_ref(const std::string & ref) {
|
||||
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
|
||||
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
|
||||
_refs_being_resolved.insert(ref);
|
||||
json resolved = _refs[ref];
|
||||
ref_name = visit(resolved, ref_name);
|
||||
_refs_being_resolved.erase(ref);
|
||||
}
|
||||
return ref_name;
|
||||
}
|
||||
|
||||
std::string _build_object_rule(
|
||||
const std::vector<std::pair<std::string, json>> & properties,
|
||||
const std::unordered_set<std::string> & required,
|
||||
const std::string & name,
|
||||
const json & additional_properties)
|
||||
{
|
||||
std::vector<std::string> required_props;
|
||||
std::vector<std::string> optional_props;
|
||||
std::unordered_map<std::string, std::string> prop_kv_rule_names;
|
||||
for (const auto & kv : properties) {
|
||||
const auto &prop_name = kv.first;
|
||||
const auto &prop_schema = kv.second;
|
||||
|
||||
std::string prop_rule_name = visit(prop_schema, name + (name.empty() ? "" : "-") + prop_name);
|
||||
prop_kv_rule_names[prop_name] = _add_rule(
|
||||
name + (name.empty() ? "" : "-") + prop_name + "-kv",
|
||||
format_literal(prop_name) + " space \":\" space " + prop_rule_name
|
||||
);
|
||||
if (required.find(prop_name) != required.end()) {
|
||||
required_props.push_back(prop_name);
|
||||
} else {
|
||||
optional_props.push_back(prop_name);
|
||||
}
|
||||
}
|
||||
if (additional_properties.is_object() || (additional_properties.is_boolean() && additional_properties.get<bool>())) {
|
||||
std::string sub_name = name + (name.empty() ? "" : "-") + "additional";
|
||||
std::string value_rule = visit(additional_properties.is_object() ? additional_properties : json::object(), sub_name + "-value");
|
||||
std::string kv_rule = _add_rule(sub_name + "-kv", _add_rule("string", PRIMITIVE_RULES.at("string")) + " \":\" space " + value_rule);
|
||||
prop_kv_rule_names["*"] = kv_rule;
|
||||
optional_props.push_back("*");
|
||||
}
|
||||
|
||||
std::string rule = "\"{\" space ";
|
||||
for (size_t i = 0; i < required_props.size(); i++) {
|
||||
if (i > 0) {
|
||||
rule += " \",\" space ";
|
||||
}
|
||||
rule += prop_kv_rule_names[required_props[i]];
|
||||
}
|
||||
|
||||
if (!optional_props.empty()) {
|
||||
rule += " (";
|
||||
if (!required_props.empty()) {
|
||||
rule += " \",\" space ( ";
|
||||
}
|
||||
|
||||
std::function<std::string(const std::vector<std::string> &, bool)> get_recursive_refs = [&](const std::vector<std::string> & ks, bool first_is_optional) {
|
||||
std::string res;
|
||||
if (ks.empty()) {
|
||||
return res;
|
||||
}
|
||||
std::string k = ks[0];
|
||||
std::string kv_rule_name = prop_kv_rule_names[k];
|
||||
if (k == "*") {
|
||||
res = _add_rule(
|
||||
name + (name.empty() ? "" : "-") + "additional-kvs",
|
||||
kv_rule_name + " ( \",\" space " + kv_rule_name + " )*"
|
||||
);
|
||||
} else if (first_is_optional) {
|
||||
res = "( \",\" space " + kv_rule_name + " )?";
|
||||
} else {
|
||||
res = kv_rule_name;
|
||||
}
|
||||
if (ks.size() > 1) {
|
||||
res += " " + _add_rule(
|
||||
name + (name.empty() ? "" : "-") + k + "-rest",
|
||||
get_recursive_refs(std::vector<std::string>(ks.begin() + 1, ks.end()), true)
|
||||
);
|
||||
}
|
||||
return res;
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < optional_props.size(); i++) {
|
||||
if (i > 0) {
|
||||
rule += " | ";
|
||||
}
|
||||
rule += get_recursive_refs(std::vector<std::string>(optional_props.begin() + i, optional_props.end()), false);
|
||||
}
|
||||
if (!required_props.empty()) {
|
||||
rule += " )";
|
||||
}
|
||||
rule += " )?";
|
||||
}
|
||||
|
||||
rule += " \"}\" space";
|
||||
|
||||
return rule;
|
||||
}
|
||||
|
||||
public:
|
||||
SchemaConverter(
|
||||
const std::function<json(const std::string &)> & fetch_json,
|
||||
bool dotall)
|
||||
: _fetch_json(fetch_json), _dotall(dotall)
|
||||
{
|
||||
_rules["space"] = SPACE_RULE;
|
||||
}
|
||||
|
||||
void resolve_refs(nlohmann::json & schema, const std::string & url) {
|
||||
/*
|
||||
* Resolves all $ref fields in the given schema, fetching any remote schemas,
|
||||
* replacing each $ref with absolute reference URL and populates _refs with the
|
||||
* respective referenced (sub)schema dictionaries.
|
||||
*/
|
||||
std::function<void(json &)> visit_refs = [&](json & n) {
|
||||
if (n.is_array()) {
|
||||
for (auto & x : n) {
|
||||
visit_refs(x);
|
||||
}
|
||||
} else if (n.is_object()) {
|
||||
if (n.contains("$ref")) {
|
||||
std::string ref = n["$ref"];
|
||||
if (_refs.find(ref) == _refs.end()) {
|
||||
json target;
|
||||
if (ref.find("https://") == 0) {
|
||||
std::string base_url = ref.substr(0, ref.find('#'));
|
||||
auto it = _refs.find(base_url);
|
||||
if (it != _refs.end()) {
|
||||
target = it->second;
|
||||
} else {
|
||||
// Fetch the referenced schema and resolve its refs
|
||||
auto referenced = _fetch_json(ref);
|
||||
resolve_refs(referenced, base_url);
|
||||
_refs[base_url] = referenced;
|
||||
}
|
||||
if (ref.find('#') == std::string::npos || ref.substr(ref.find('#') + 1).empty()) {
|
||||
return;
|
||||
}
|
||||
} else if (ref.find("#/") == 0) {
|
||||
target = schema;
|
||||
n["$ref"] = url + ref;
|
||||
ref = url + ref;
|
||||
} else {
|
||||
_errors.push_back("Unsupported ref: " + ref);
|
||||
return;
|
||||
}
|
||||
std::string pointer = ref.substr(ref.find('#') + 1);
|
||||
std::vector<std::string> tokens = split(pointer, "/");
|
||||
for (size_t i = 1; i < tokens.size(); ++i) {
|
||||
std::string sel = tokens[i];
|
||||
if (target.is_null() || !target.contains(sel)) {
|
||||
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
|
||||
return;
|
||||
}
|
||||
target = target[sel];
|
||||
}
|
||||
_refs[ref] = target;
|
||||
}
|
||||
} else {
|
||||
for (auto & kv : n.items()) {
|
||||
visit_refs(kv.value());
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
visit_refs(schema);
|
||||
}
|
||||
|
||||
std::string _generate_constant_rule(const json & value) {
|
||||
if (!value.is_string()) {
|
||||
_errors.push_back("Only std::string constants are supported, got " + value.dump());
|
||||
return "";
|
||||
}
|
||||
return format_literal(value.get<std::string>());
|
||||
}
|
||||
|
||||
std::string visit(const json & schema, const std::string & name) {
|
||||
json schema_type = schema.contains("type") ? schema["type"] : json();
|
||||
std::string schema_format = schema.contains("format") ? schema["format"].get<std::string>() : "";
|
||||
std::string rule_name = is_reserved_name(name) ? name + "-" : name.empty() ? "root" : name;
|
||||
|
||||
if (schema.contains("$ref")) {
|
||||
return _add_rule(rule_name, _resolve_ref(schema["$ref"]));
|
||||
} else if (schema.contains("oneOf") || schema.contains("anyOf")) {
|
||||
std::vector<json> alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get<std::vector<json>>() : schema["anyOf"].get<std::vector<json>>();
|
||||
return _add_rule(rule_name, _generate_union_rule(name, alt_schemas));
|
||||
} else if (schema_type.is_array()) {
|
||||
std::vector<json> schema_types;
|
||||
for (const auto & t : schema_type) {
|
||||
schema_types.push_back({{"type", t}});
|
||||
}
|
||||
return _add_rule(rule_name, _generate_union_rule(name, schema_types));
|
||||
} else if (schema.contains("const")) {
|
||||
return _add_rule(rule_name, _generate_constant_rule(schema["const"]));
|
||||
} else if (schema.contains("enum")) {
|
||||
std::vector<std::string> enum_values;
|
||||
for (const auto & v : schema["enum"]) {
|
||||
enum_values.push_back(_generate_constant_rule(v));
|
||||
}
|
||||
return _add_rule(rule_name, join(enum_values.begin(), enum_values.end(), " | "));
|
||||
} else if ((schema_type.is_null() || schema_type == "object")
|
||||
&& (schema.contains("properties") ||
|
||||
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
|
||||
std::unordered_set<std::string> required;
|
||||
if (schema.contains("required") && schema["required"].is_array()) {
|
||||
for (const auto & item : schema["required"]) {
|
||||
if (item.is_string()) {
|
||||
required.insert(item.get<std::string>());
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<std::pair<std::string, json>> properties;
|
||||
if (schema.contains("properties")) {
|
||||
for (const auto & prop : schema["properties"].items()) {
|
||||
properties.emplace_back(prop.key(), prop.value());
|
||||
}
|
||||
}
|
||||
return _add_rule(rule_name,
|
||||
_build_object_rule(
|
||||
properties, required, name,
|
||||
schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
|
||||
} else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) {
|
||||
std::unordered_set<std::string> required;
|
||||
std::vector<std::pair<std::string, json>> properties;
|
||||
std::string hybrid_name = name;
|
||||
std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
|
||||
if (comp_schema.contains("$ref")) {
|
||||
add_component(_refs[comp_schema["$ref"]], is_required);
|
||||
} else if (comp_schema.contains("properties")) {
|
||||
for (const auto & prop : comp_schema["properties"].items()) {
|
||||
properties.emplace_back(prop.key(), prop.value());
|
||||
if (is_required) {
|
||||
required.insert(prop.key());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// todo warning
|
||||
}
|
||||
};
|
||||
for (auto & t : schema["allOf"]) {
|
||||
if (t.contains("anyOf")) {
|
||||
for (auto & tt : t["anyOf"]) {
|
||||
add_component(tt, false);
|
||||
}
|
||||
} else {
|
||||
add_component(t, true);
|
||||
}
|
||||
}
|
||||
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
|
||||
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
|
||||
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
|
||||
if (items.is_array()) {
|
||||
std::string rule = "\"[\" space ";
|
||||
for (size_t i = 0; i < items.size(); i++) {
|
||||
if (i > 0) {
|
||||
rule += " \",\" space ";
|
||||
}
|
||||
rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" + std::to_string(i));
|
||||
}
|
||||
rule += " \"]\" space";
|
||||
return _add_rule(rule_name, rule);
|
||||
} else {
|
||||
std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
|
||||
std::string list_item_operator = "( \",\" space " + item_rule_name + " )";
|
||||
std::string successive_items;
|
||||
int min_items = schema.contains("minItems") ? schema["minItems"].get<int>() : 0;
|
||||
json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
|
||||
int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : -1;
|
||||
if (min_items > 0) {
|
||||
successive_items += repeat(list_item_operator, min_items - 1);
|
||||
min_items--;
|
||||
}
|
||||
if (max_items >= 0 && max_items > min_items) {
|
||||
successive_items += repeat(list_item_operator + "?", max_items - min_items - 1);
|
||||
} else {
|
||||
successive_items += list_item_operator + "*";
|
||||
}
|
||||
std::string rule;
|
||||
if (min_items == 0) {
|
||||
rule = "\"[\" space ( " + item_rule_name + " " + successive_items + " )? \"]\" space";
|
||||
} else {
|
||||
rule = "\"[\" space " + item_rule_name + " " + successive_items + " \"]\" space";
|
||||
}
|
||||
return _add_rule(rule_name, rule);
|
||||
}
|
||||
} else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
|
||||
return _visit_pattern(schema["pattern"], rule_name);
|
||||
} else if ((schema_type.is_null() || schema_type == "string") && std::regex_match(schema_format, std::regex("^uuid[1-5]?$"))) {
|
||||
return _add_rule(rule_name == "root" ? "root" : schema_format, PRIMITIVE_RULES.at("uuid"));
|
||||
} else if ((schema_type.is_null() || schema_type == "string") && DATE_RULES.find(schema_format) != DATE_RULES.end()) {
|
||||
for (const auto & kv : DATE_RULES) {
|
||||
_add_rule(kv.first, kv.second);
|
||||
}
|
||||
return schema_format + "-string";
|
||||
} else if (schema.empty() || schema_type == "object") {
|
||||
for (const auto & n : OBJECT_RULE_NAMES) {
|
||||
_add_rule(n, PRIMITIVE_RULES.at(n));
|
||||
}
|
||||
return _add_rule(rule_name, "object");
|
||||
} else {
|
||||
if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get<std::string>()) == PRIMITIVE_RULES.end()) {
|
||||
_errors.push_back("Unrecognized schema: " + schema.dump());
|
||||
return "";
|
||||
}
|
||||
// TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
|
||||
return _add_rule(rule_name == "root" ? "root" : schema_type.get<std::string>(), PRIMITIVE_RULES.at(schema_type.get<std::string>()));
|
||||
}
|
||||
}
|
||||
|
||||
void check_errors() {
|
||||
if (!_errors.empty()) {
|
||||
throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n"));
|
||||
}
|
||||
if (!_warnings.empty()) {
|
||||
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str());
|
||||
}
|
||||
}
|
||||
|
||||
std::string format_grammar() {
|
||||
std::stringstream ss;
|
||||
for (const auto & kv : _rules) {
|
||||
ss << kv.first << " ::= " << kv.second << std::endl;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
|
||||
std::string json_schema_to_grammar(const json & schema) {
|
||||
SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
|
||||
auto copy = schema;
|
||||
converter.resolve_refs(copy, "input");
|
||||
converter.visit(copy, "");
|
||||
converter.check_errors();
|
||||
return converter.format_grammar();
|
||||
}
|
4
common/json-schema-to-grammar.h
Normal file
4
common/json-schema-to-grammar.h
Normal file
|
@ -0,0 +1,4 @@
|
|||
#pragma once
|
||||
#include "json.hpp"
|
||||
|
||||
std::string json_schema_to_grammar(const nlohmann::json& schema);
|
File diff suppressed because it is too large
Load diff
74
examples/json-schema-pydantic-example.py
Normal file
74
examples/json-schema-pydantic-example.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
# Usage:
|
||||
#! ./server -m some-model.gguf &
|
||||
#! pip install pydantic
|
||||
#! python json-schema-pydantic-example.py
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from annotated_types import MinLen
|
||||
from typing import Annotated, List, Optional
|
||||
import json, requests
|
||||
|
||||
if True:
|
||||
|
||||
def create_completion(*, response_model=None, endpoint="http://localhost:8080/v1/chat/completions", messages, **kwargs):
|
||||
'''
|
||||
Creates a chat completion using an OpenAI-compatible endpoint w/ JSON schema support
|
||||
(llama.cpp server, llama-cpp-python, Anyscale / Together...)
|
||||
|
||||
The response_model param takes a type (+ supports Pydantic) and behaves just as w/ Instructor (see below)
|
||||
'''
|
||||
if response_model:
|
||||
type_adapter = TypeAdapter(response_model)
|
||||
schema = type_adapter.json_schema()
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": f"You respond in JSON format with the following schema: {json.dumps(schema, indent=2)}"
|
||||
}] + messages
|
||||
response_format={"type": "json_object", "schema": schema}
|
||||
|
||||
data = requests.post(endpoint, headers={"Content-Type": "application/json"},
|
||||
json=dict(messages=messages, response_format=response_format, **kwargs)).json()
|
||||
if 'error' in data:
|
||||
raise Exception(data['error']['message'])
|
||||
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
return type_adapter.validate_json(content) if type_adapter else content
|
||||
|
||||
else:
|
||||
|
||||
# This alternative branch uses Instructor + OpenAI client lib.
|
||||
# Instructor support streamed iterable responses, retry & more.
|
||||
# (see https://python.useinstructor.com/)
|
||||
#! pip install instructor openai
|
||||
import instructor, openai
|
||||
client = instructor.patch(
|
||||
openai.OpenAI(api_key="123", base_url="http://localhost:8080"),
|
||||
mode=instructor.Mode.JSON_SCHEMA)
|
||||
create_completion = client.chat.completions.create
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
class QAPair(BaseModel):
|
||||
question: str
|
||||
concise_answer: str
|
||||
justification: str
|
||||
|
||||
class PyramidalSummary(BaseModel):
|
||||
title: str
|
||||
summary: str
|
||||
question_answers: Annotated[List[QAPair], MinLen(2)]
|
||||
sub_sections: Optional[Annotated[List['PyramidalSummary'], MinLen(2)]]
|
||||
|
||||
print("# Summary\n", create_completion(
|
||||
model="...",
|
||||
response_model=PyramidalSummary,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
You are a highly efficient corporate document summarizer.
|
||||
Create a pyramidal summary of an imaginary internal document about our company processes
|
||||
(starting high-level, going down to each sub sections).
|
||||
Keep questions short, and answers even shorter (trivia / quizz style).
|
||||
"""
|
||||
}]))
|
|
@ -1,8 +1,10 @@
|
|||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import itertools
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, Dict, List, Set, Tuple, Union
|
||||
|
||||
# whitespace is constrained to a single space char to prevent model "running away" in
|
||||
# whitespace. Also maybe improves generation quality?
|
||||
|
@ -12,22 +14,50 @@ PRIMITIVE_RULES = {
|
|||
'boolean': '("true" | "false") space',
|
||||
'number': '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space',
|
||||
'integer': '("-"? ([0-9] | [1-9] [0-9]*)) space',
|
||||
'value' : 'object | array | string | number | boolean',
|
||||
'object' : '"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space',
|
||||
'array' : '"[" space ( value ("," space value)* )? "]" space',
|
||||
'uuid' : '"\\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + ' "\\"" space',
|
||||
'string': r''' "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space ''',
|
||||
)* "\"" space''',
|
||||
'null': '"null" space',
|
||||
}
|
||||
OBJECT_RULE_NAMES = ['object', 'array', 'string', 'number', 'boolean', 'null', 'value']
|
||||
|
||||
# TODO: support "uri", "email" string formats
|
||||
DATE_RULES = {
|
||||
'date' : '[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )',
|
||||
'time' : '([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )',
|
||||
'date-time': 'date "T" time',
|
||||
'date-string': '"\\"" date "\\"" space',
|
||||
'time-string': '"\\"" time "\\"" space',
|
||||
'date-time-string': '"\\"" date-time "\\"" space',
|
||||
}
|
||||
|
||||
RESERVED_NAMES = set(["root", *PRIMITIVE_RULES.keys(), *DATE_RULES.keys()])
|
||||
|
||||
INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+')
|
||||
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
|
||||
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"'}
|
||||
GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]')
|
||||
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'}
|
||||
|
||||
NON_LITERAL_SET = set('|.()[]{}*+?')
|
||||
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?')
|
||||
|
||||
DATE_PATTERN = '[0-9]{4}-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])'
|
||||
TIME_PATTERN = '([01][0-9]|2[0-3])(:[0-5][0-9]){2}(\\.[0-9]{1,3})?(Z|[+-](([01][0-9]|2[0-3]):[0-5][0-9]))' # Cap millisecond precision w/ 3 digits
|
||||
|
||||
class SchemaConverter:
|
||||
def __init__(self, prop_order):
|
||||
def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
|
||||
self._prop_order = prop_order
|
||||
self._allow_fetch = allow_fetch
|
||||
self._dotall = dotall
|
||||
self._raw_pattern = raw_pattern
|
||||
self._rules = {'space': SPACE_RULE}
|
||||
self._refs = {}
|
||||
self._refs_being_resolved = set()
|
||||
|
||||
def _format_literal(self, literal):
|
||||
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
|
||||
|
@ -41,78 +71,421 @@ class SchemaConverter:
|
|||
key = esc_name
|
||||
else:
|
||||
i = 0
|
||||
while f'{esc_name}{i}' in self._rules:
|
||||
while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule:
|
||||
i += 1
|
||||
key = f'{esc_name}{i}'
|
||||
self._rules[key] = rule
|
||||
return key
|
||||
|
||||
def resolve_refs(self, schema: dict, url: str):
|
||||
'''
|
||||
Resolves all $ref fields in the given schema, fetching any remote schemas,
|
||||
replacing $ref with absolute reference URL and populating self._refs with the
|
||||
respective referenced (sub)schema dictionaries.
|
||||
'''
|
||||
def visit(n: dict):
|
||||
if isinstance(n, list):
|
||||
return [visit(x) for x in n]
|
||||
elif isinstance(n, dict):
|
||||
ref = n.get('$ref')
|
||||
if ref is not None and ref not in self._refs:
|
||||
if ref.startswith('https://'):
|
||||
assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
|
||||
import requests
|
||||
|
||||
frag_split = ref.split('#')
|
||||
base_url = frag_split[0]
|
||||
|
||||
target = self._refs.get(base_url)
|
||||
if target is None:
|
||||
target = self.resolve_refs(requests.get(ref).json(), base_url)
|
||||
self._refs[base_url] = target
|
||||
|
||||
if len(frag_split) == 1 or frag_split[-1] == '':
|
||||
return target
|
||||
elif ref.startswith('#/'):
|
||||
target = schema
|
||||
ref = f'{url}{ref}'
|
||||
n['$ref'] = ref
|
||||
else:
|
||||
raise ValueError(f'Unsupported ref {ref}')
|
||||
|
||||
for sel in ref.split('#')[-1].split('/')[1:]:
|
||||
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
|
||||
target = target[sel]
|
||||
|
||||
self._refs[ref] = target
|
||||
else:
|
||||
for v in n.values():
|
||||
visit(v)
|
||||
|
||||
return n
|
||||
return visit(schema)
|
||||
|
||||
def _generate_union_rule(self, name, alt_schemas):
|
||||
return ' | '.join((
|
||||
self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
|
||||
for i, alt_schema in enumerate(alt_schemas)
|
||||
))
|
||||
|
||||
def _visit_pattern(self, pattern, name):
|
||||
'''
|
||||
Transforms a regular expression pattern into a GBNF rule.
|
||||
|
||||
Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
|
||||
Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
|
||||
|
||||
Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
|
||||
|
||||
Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
|
||||
we define sub-rules to keep the output lean.
|
||||
'''
|
||||
|
||||
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
|
||||
pattern = pattern[1:-1]
|
||||
sub_rule_ids = {}
|
||||
|
||||
i = 0
|
||||
length = len(pattern)
|
||||
|
||||
def to_rule(s: Tuple[str, bool]) -> str:
|
||||
(txt, is_literal) = s
|
||||
return "\"" + txt + "\"" if is_literal else txt
|
||||
|
||||
def transform() -> Tuple[str, bool]:
|
||||
'''
|
||||
Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
|
||||
'''
|
||||
nonlocal i
|
||||
nonlocal pattern
|
||||
nonlocal sub_rule_ids
|
||||
|
||||
start = i
|
||||
# For each component of this sequence, store its string representation and whether it's a literal.
|
||||
# We only need a flat structure here to apply repetition operators to the last item, and
|
||||
# to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
|
||||
# (GBNF's syntax is luckily very close to regular expressions!)
|
||||
seq: list[Tuple[str, bool]] = []
|
||||
|
||||
def get_dot():
|
||||
if self._dotall:
|
||||
rule = '[\\U00000000-\\U0010FFFF]'
|
||||
else:
|
||||
# Accept any character... except \n and \r line break chars (\x0A and \xOD)
|
||||
rule = '[\\U00000000-\\x09\\x0B\\x0C\\x0E-\\U0010FFFF]'
|
||||
return self._add_rule(f'dot', rule)
|
||||
|
||||
def join_seq():
|
||||
nonlocal seq
|
||||
ret = []
|
||||
for is_literal, g in itertools.groupby(seq, lambda x: x[1]):
|
||||
if is_literal:
|
||||
ret.append((''.join(x[0] for x in g), True))
|
||||
else:
|
||||
ret.extend(g)
|
||||
if len(ret) == 1:
|
||||
return ret[0]
|
||||
return (' '.join(to_rule(x) for x in seq), False)
|
||||
|
||||
while i < length:
|
||||
c = pattern[i]
|
||||
if c == '.':
|
||||
seq.append((get_dot(), False))
|
||||
i += 1
|
||||
elif c == '(':
|
||||
i += 1
|
||||
if i < length:
|
||||
assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
|
||||
seq.append((f'({to_rule(transform())})', False))
|
||||
elif c == ')':
|
||||
i += 1
|
||||
assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}'
|
||||
return join_seq()
|
||||
elif c == '[':
|
||||
square_brackets = c
|
||||
i += 1
|
||||
while i < length and pattern[i] != ']':
|
||||
if pattern[i] == '\\':
|
||||
square_brackets += pattern[i:i+2]
|
||||
i += 2
|
||||
else:
|
||||
square_brackets += pattern[i]
|
||||
i += 1
|
||||
assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}'
|
||||
square_brackets += ']'
|
||||
i += 1
|
||||
seq.append((square_brackets, False))
|
||||
elif c == '|':
|
||||
seq.append(('|', False))
|
||||
i += 1
|
||||
elif c in ('*', '+', '?'):
|
||||
seq[-1] = (to_rule(seq[-1]) + c, False)
|
||||
i += 1
|
||||
elif c == '{':
|
||||
curly_brackets = c
|
||||
i += 1
|
||||
while i < length and pattern[i] != '}':
|
||||
curly_brackets += pattern[i]
|
||||
i += 1
|
||||
assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}'
|
||||
curly_brackets += '}'
|
||||
i += 1
|
||||
nums = [s.strip() for s in curly_brackets[1:-1].split(',')]
|
||||
min_times = 0
|
||||
max_times = None
|
||||
try:
|
||||
if len(nums) == 1:
|
||||
min_times = int(nums[0])
|
||||
max_times = min_times
|
||||
else:
|
||||
assert len(nums) == 2
|
||||
min_times = int(nums[0]) if nums[0] else 0
|
||||
max_times = int(nums[1]) if nums[1] else None
|
||||
except ValueError:
|
||||
raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/')
|
||||
|
||||
(sub, sub_is_literal) = seq[-1]
|
||||
|
||||
if min_times == 0 and max_times is None:
|
||||
seq[-1] = (f'{sub}*', False)
|
||||
elif min_times == 0 and max_times == 1:
|
||||
seq[-1] = (f'{sub}?', False)
|
||||
elif min_times == 1 and max_times is None:
|
||||
seq[-1] = (f'{sub}+', False)
|
||||
else:
|
||||
if not sub_is_literal:
|
||||
id = sub_rule_ids.get(sub)
|
||||
if id is None:
|
||||
id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
|
||||
sub_rule_ids[sub] = id
|
||||
sub = id
|
||||
|
||||
seq[-1] = (
|
||||
' '.join(
|
||||
([f'"{sub[1:-1] * min_times}"'] if sub_is_literal else [sub] * min_times) +
|
||||
([f'{sub}?'] * (max_times - min_times) if max_times is not None else [f'{sub}*'])),
|
||||
False
|
||||
)
|
||||
else:
|
||||
literal = ''
|
||||
while i < length:
|
||||
if pattern[i] == '\\' and i < length - 1:
|
||||
next = pattern[i + 1]
|
||||
if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
|
||||
i += 1
|
||||
literal += pattern[i]
|
||||
i += 1
|
||||
else:
|
||||
literal += pattern[i:i+2]
|
||||
i += 2
|
||||
elif pattern[i] == '"' and not self._raw_pattern:
|
||||
literal += '\\"'
|
||||
i += 1
|
||||
elif pattern[i] not in NON_LITERAL_SET and \
|
||||
(i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET):
|
||||
literal += pattern[i]
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
if literal:
|
||||
seq.append((literal, True))
|
||||
|
||||
return join_seq()
|
||||
|
||||
return self._add_rule(
|
||||
name,
|
||||
to_rule(transform()) if self._raw_pattern \
|
||||
else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space")
|
||||
|
||||
|
||||
def _resolve_ref(self, ref):
|
||||
ref_name = ref.split('/')[-1]
|
||||
if ref_name not in self._rules and ref not in self._refs_being_resolved:
|
||||
self._refs_being_resolved.add(ref)
|
||||
resolved = self._refs[ref]
|
||||
ref_name = self.visit(resolved, ref_name)
|
||||
self._refs_being_resolved.remove(ref)
|
||||
return ref_name
|
||||
|
||||
def _generate_constant_rule(self, value):
|
||||
assert isinstance(value, str), f'Only string constants are supported, got {value}'
|
||||
return self._format_literal(value)
|
||||
|
||||
def visit(self, schema, name):
|
||||
schema_type = schema.get('type')
|
||||
rule_name = name or 'root'
|
||||
schema_format = schema.get('format')
|
||||
rule_name = name + '-' if name in RESERVED_NAMES else name or 'root'
|
||||
|
||||
if 'oneOf' in schema or 'anyOf' in schema:
|
||||
rule = ' | '.join((
|
||||
self.visit(alt_schema, f'{name}{"-" if name else ""}{i}')
|
||||
for i, alt_schema in enumerate(schema.get('oneOf') or schema['anyOf'])
|
||||
))
|
||||
return self._add_rule(rule_name, rule)
|
||||
if (ref := schema.get('$ref')) is not None:
|
||||
return self._add_rule(rule_name, self._resolve_ref(ref))
|
||||
|
||||
elif 'oneOf' in schema or 'anyOf' in schema:
|
||||
return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf']))
|
||||
|
||||
elif isinstance(schema_type, list):
|
||||
return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type]))
|
||||
|
||||
elif 'const' in schema:
|
||||
return self._add_rule(rule_name, self._format_literal(schema['const']))
|
||||
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']))
|
||||
|
||||
elif 'enum' in schema:
|
||||
rule = ' | '.join((self._format_literal(v) for v in schema['enum']))
|
||||
rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum']))
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
elif schema_type == 'object' and 'properties' in schema:
|
||||
# TODO: `required` keyword
|
||||
prop_order = self._prop_order
|
||||
prop_pairs = sorted(
|
||||
schema['properties'].items(),
|
||||
# sort by position in prop_order (if specified) then by key
|
||||
key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]),
|
||||
)
|
||||
elif schema_type in (None, 'object') and \
|
||||
('properties' in schema or \
|
||||
('additionalProperties' in schema and schema['additionalProperties'] is not True)):
|
||||
required = set(schema.get('required', []))
|
||||
properties = list(schema.get('properties', {}).items())
|
||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
|
||||
|
||||
rule = '"{" space'
|
||||
for i, (prop_name, prop_schema) in enumerate(prop_pairs):
|
||||
prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
|
||||
if i > 0:
|
||||
rule += ' "," space'
|
||||
rule += fr' {self._format_literal(prop_name)} space ":" space {prop_rule_name}'
|
||||
rule += ' "}" space'
|
||||
elif schema_type in (None, 'object') and 'allOf' in schema:
|
||||
required = set()
|
||||
properties = []
|
||||
hybrid_name = name
|
||||
def add_component(comp_schema, is_required):
|
||||
if (ref := comp_schema.get('$ref')) is not None:
|
||||
comp_schema = self._refs[ref]
|
||||
|
||||
return self._add_rule(rule_name, rule)
|
||||
if 'properties' in comp_schema:
|
||||
for prop_name, prop_schema in comp_schema['properties'].items():
|
||||
properties.append((prop_name, prop_schema))
|
||||
if is_required:
|
||||
required.add(prop_name)
|
||||
|
||||
elif schema_type == 'array' and 'items' in schema:
|
||||
# TODO `prefixItems` keyword
|
||||
item_rule_name = self.visit(schema['items'], f'{name}{"-" if name else ""}item')
|
||||
list_item_operator = f'("," space {item_rule_name})'
|
||||
for t in schema['allOf']:
|
||||
if 'anyOf' in t:
|
||||
for tt in t['anyOf']:
|
||||
add_component(tt, is_required=False)
|
||||
else:
|
||||
add_component(t, is_required=True)
|
||||
|
||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=[]))
|
||||
|
||||
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
|
||||
items = schema.get('items') or schema['prefixItems']
|
||||
if isinstance(items, list):
|
||||
return self._add_rule(
|
||||
rule_name,
|
||||
'"[" space ' +
|
||||
' "," space '.join(
|
||||
self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
|
||||
for i, item in enumerate(items)) +
|
||||
' "]" space')
|
||||
else:
|
||||
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
|
||||
list_item_operator = f'( "," space {item_rule_name} )'
|
||||
successive_items = ""
|
||||
min_items = schema.get("minItems", 0)
|
||||
max_items = schema.get("maxItems")
|
||||
if min_items > 0:
|
||||
first_item = f"({item_rule_name})"
|
||||
successive_items = list_item_operator * (min_items - 1)
|
||||
min_items -= 1
|
||||
else:
|
||||
first_item = f"({item_rule_name})?"
|
||||
max_items = schema.get("maxItems")
|
||||
if max_items is not None and max_items > min_items:
|
||||
successive_items += (list_item_operator + "?") * (max_items - min_items - 1)
|
||||
else:
|
||||
successive_items += list_item_operator + "*"
|
||||
rule = f'"[" space {first_item} {successive_items} "]" space'
|
||||
if min_items == 0:
|
||||
rule = f'"[" space ( {item_rule_name} {successive_items} )? "]" space'
|
||||
else:
|
||||
rule = f'"[" space {item_rule_name} {successive_items} "]" space'
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
elif schema_type in (None, 'string') and 'pattern' in schema:
|
||||
return self._visit_pattern(schema['pattern'], rule_name)
|
||||
|
||||
elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''):
|
||||
return self._add_rule(
|
||||
'root' if rule_name == 'root' else schema_format,
|
||||
PRIMITIVE_RULES['uuid']
|
||||
)
|
||||
|
||||
elif schema_type in (None, 'string') and schema_format in DATE_RULES:
|
||||
for t, r in DATE_RULES.items():
|
||||
self._add_rule(t, r)
|
||||
return schema_format + '-string'
|
||||
|
||||
elif (schema_type == 'object') or (len(schema) == 0):
|
||||
for n in OBJECT_RULE_NAMES:
|
||||
self._add_rule(n, PRIMITIVE_RULES[n])
|
||||
return self._add_rule(rule_name, 'object')
|
||||
|
||||
else:
|
||||
assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
|
||||
# TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
|
||||
return self._add_rule(
|
||||
'root' if rule_name == 'root' else schema_type,
|
||||
PRIMITIVE_RULES[schema_type]
|
||||
)
|
||||
|
||||
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]):
|
||||
prop_order = self._prop_order
|
||||
# sort by position in prop_order (if specified) then by original order
|
||||
sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))]
|
||||
|
||||
prop_kv_rule_names = {}
|
||||
for prop_name, prop_schema in properties:
|
||||
prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
|
||||
prop_kv_rule_names[prop_name] = self._add_rule(
|
||||
f'{name}{"-" if name else ""}{prop_name}-kv',
|
||||
fr'{self._format_literal(prop_name)} space ":" space {prop_rule_name}'
|
||||
)
|
||||
required_props = [k for k in sorted_props if k in required]
|
||||
optional_props = [k for k in sorted_props if k not in required]
|
||||
|
||||
if additional_properties == True or isinstance(additional_properties, dict):
|
||||
sub_name = f'{name}{"-" if name else ""}additional'
|
||||
value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value')
|
||||
prop_kv_rule_names["*"] = self._add_rule(
|
||||
f'{sub_name}-kv',
|
||||
self._add_rule('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}'
|
||||
)
|
||||
optional_props.append("*")
|
||||
|
||||
rule = '"{" space '
|
||||
rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)
|
||||
|
||||
if optional_props:
|
||||
rule += ' ('
|
||||
if required_props:
|
||||
rule += ' "," space ( '
|
||||
|
||||
def get_recursive_refs(ks, first_is_optional):
|
||||
[k, *rest] = ks
|
||||
kv_rule_name = prop_kv_rule_names[k]
|
||||
if k == '*':
|
||||
res = self._add_rule(
|
||||
f'{name}{"-" if name else ""}additional-kvs',
|
||||
f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*'
|
||||
)
|
||||
elif first_is_optional:
|
||||
res = f'( "," space {kv_rule_name} )?'
|
||||
else:
|
||||
res = kv_rule_name
|
||||
if len(rest) > 0:
|
||||
res += ' ' + self._add_rule(
|
||||
f'{name}{"-" if name else ""}{k}-rest',
|
||||
get_recursive_refs(rest, first_is_optional=True)
|
||||
)
|
||||
return res
|
||||
|
||||
rule += ' | '.join(
|
||||
get_recursive_refs(optional_props[i:], first_is_optional=False)
|
||||
for i in range(len(optional_props))
|
||||
)
|
||||
if required_props:
|
||||
rule += ' )'
|
||||
rule += ' )?'
|
||||
|
||||
rule += ' "}" space'
|
||||
|
||||
return rule
|
||||
|
||||
def format_grammar(self):
|
||||
return '\n'.join((f'{name} ::= {rule}' for name, rule in self._rules.items()))
|
||||
return '\n'.join(
|
||||
f'{name} ::= {rule}'
|
||||
for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
|
||||
)
|
||||
|
||||
|
||||
def main(args_in = None):
|
||||
|
@ -129,16 +502,47 @@ def main(args_in = None):
|
|||
type=lambda s: s.split(','),
|
||||
help='''
|
||||
comma-separated property names defining the order of precedence for object properties;
|
||||
properties not specified here are given lower precedence than those that are, and are
|
||||
sorted alphabetically
|
||||
properties not specified here are given lower precedence than those that are, and
|
||||
are kept in their original order from the schema. Required properties are always
|
||||
given precedence over optional properties.
|
||||
'''
|
||||
)
|
||||
parser.add_argument(
|
||||
'--allow-fetch',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Whether to allow fetching referenced schemas over HTTPS')
|
||||
parser.add_argument(
|
||||
'--dotall',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Whether to treat dot (".") as matching all chars including line breaks in regular expression patterns')
|
||||
parser.add_argument(
|
||||
'--raw-pattern',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Treats string patterns as raw patterns w/o quotes (or quote escapes)')
|
||||
|
||||
parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)')
|
||||
args = parser.parse_args(args_in)
|
||||
|
||||
schema = json.load(sys.stdin if args.schema == '-' else open(args.schema))
|
||||
prop_order = {name: idx for idx, name in enumerate(args.prop_order)}
|
||||
converter = SchemaConverter(prop_order)
|
||||
if args.schema.startswith('https://'):
|
||||
url = args.schema
|
||||
import requests
|
||||
schema = requests.get(url).json()
|
||||
elif args.schema == '-':
|
||||
url = 'stdin'
|
||||
schema = json.load(sys.stdin)
|
||||
else:
|
||||
url = f'file://{args.schema}'
|
||||
with open(args.schema) as f:
|
||||
schema = json.load(f)
|
||||
converter = SchemaConverter(
|
||||
prop_order={name: idx for idx, name in enumerate(args.prop_order)},
|
||||
allow_fetch=args.allow_fetch,
|
||||
dotall=args.dotall,
|
||||
raw_pattern=args.raw_pattern)
|
||||
schema = converter.resolve_refs(schema, url)
|
||||
converter.visit(schema, '')
|
||||
print(converter.format_grammar())
|
||||
|
||||
|
|
|
@ -249,6 +249,9 @@ static ggml_type ggml_type_from_name(const std::string & s) {
|
|||
if (s == "q5_1") {
|
||||
return GGML_TYPE_Q5_1;
|
||||
}
|
||||
if (s == "iq4_nl") {
|
||||
return GGML_TYPE_IQ4_NL;
|
||||
}
|
||||
|
||||
return GGML_TYPE_COUNT;
|
||||
}
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
# MobileVLM
|
||||
|
||||
Currently this implementation supports [MobileVLM-v1.7](https://huggingface.co/mtgv/MobileVLM-1.7B) variants.
|
||||
Currently this implementation supports [MobileVLM-1.7B](https://huggingface.co/mtgv/MobileVLM-1.7B) / [MobileVLM_V2-1.7B](https://huggingface.co/mtgv/MobileVLM_V2-1.7B) variants.
|
||||
|
||||
for more information, please go to [Meituan-AutoML/MobileVLM](https://github.com/Meituan-AutoML/MobileVLM)
|
||||
|
||||
The implementation is based on llava, and is compatible with llava and mobileVLM. The usage is basically same as llava.
|
||||
|
||||
Notice: The overall process of model inference for both **MobileVLM** and **MobileVLM_V2** models is the same, but the process of model conversion is a little different. Therefore, using MobiVLM as an example, the different conversion step will be shown.
|
||||
|
||||
## Usage
|
||||
Build with cmake or run `make llava-cli` to build it.
|
||||
|
||||
|
@ -34,7 +36,7 @@ git clone https://huggingface.co/openai/clip-vit-large-patch14-336
|
|||
python ./examples/llava/llava-surgery.py -m path/to/MobileVLM-1.7B
|
||||
```
|
||||
|
||||
3. Use `convert-image-encoder-to-gguf.py` with `--projector-type ldp` to convert the LLaVA image encoder to GGUF:
|
||||
3. Use `convert-image-encoder-to-gguf.py` with `--projector-type ldp` (for **V2** the arg is `--projector-type ldpv2`) to convert the LLaVA image encoder to GGUF:
|
||||
|
||||
```sh
|
||||
python ./examples/llava/convert-image-encoder-to-gguf \
|
||||
|
@ -44,6 +46,14 @@ python ./examples/llava/convert-image-encoder-to-gguf \
|
|||
--projector-type ldp
|
||||
```
|
||||
|
||||
```sh
|
||||
python ./examples/llava/convert-image-encoder-to-gguf \
|
||||
-m path/to/clip-vit-large-patch14-336 \
|
||||
--llava-projector path/to/MobileVLM-1.7B_V2/llava.projector \
|
||||
--output-dir path/to/MobileVLM-1.7B_V2 \
|
||||
--projector-type ldpv2
|
||||
```
|
||||
|
||||
4. Use `convert.py` to convert the LLaMA part of LLaVA to GGUF:
|
||||
|
||||
```sh
|
||||
|
|
|
@ -119,6 +119,7 @@ static std::string format(const char * fmt, ...) {
|
|||
#define TN_LLAVA_PROJ "mm.%d.%s"
|
||||
#define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s"
|
||||
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
|
||||
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
|
||||
#define TN_IMAGE_NEWLINE "model.image_newline"
|
||||
|
||||
|
||||
|
@ -126,12 +127,14 @@ enum projector_type {
|
|||
PROJECTOR_TYPE_MLP,
|
||||
PROJECTOR_TYPE_MLP_NORM,
|
||||
PROJECTOR_TYPE_LDP,
|
||||
PROJECTOR_TYPE_LDPV2,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||
{ PROJECTOR_TYPE_MLP, "mlp" },
|
||||
{ PROJECTOR_TYPE_LDP, "ldp" },
|
||||
{ PROJECTOR_TYPE_LDPV2, "ldpv2"},
|
||||
};
|
||||
|
||||
|
||||
|
@ -475,6 +478,14 @@ struct clip_vision_model {
|
|||
struct ggml_tensor * mm_model_block_2_block_2_0_w;
|
||||
struct ggml_tensor * mm_model_block_2_block_2_1_w;
|
||||
struct ggml_tensor * mm_model_block_2_block_2_1_b;
|
||||
|
||||
// MobileVLM_V2 projection
|
||||
struct ggml_tensor * mm_model_mlp_0_w;
|
||||
struct ggml_tensor * mm_model_mlp_0_b;
|
||||
struct ggml_tensor * mm_model_mlp_2_w;
|
||||
struct ggml_tensor * mm_model_mlp_2_b;
|
||||
struct ggml_tensor * mm_model_peg_0_w;
|
||||
struct ggml_tensor * mm_model_peg_0_b;
|
||||
};
|
||||
|
||||
struct clip_ctx {
|
||||
|
@ -807,6 +818,29 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
}
|
||||
embeddings = block_1;
|
||||
}
|
||||
else if (ctx->proj_type == PROJECTOR_TYPE_LDPV2)
|
||||
{
|
||||
int n_patch = 24;
|
||||
struct ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
|
||||
mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b);
|
||||
mlp_0 = ggml_gelu(ctx0, mlp_0);
|
||||
struct ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0);
|
||||
mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b);
|
||||
// mlp_2 ne = [2048, 576, 1, 1]
|
||||
// // AVG Pool Layer 2*2, strides = 2
|
||||
mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 0, 2, 3));
|
||||
// mlp_2 ne = [576, 2048, 1, 1]
|
||||
mlp_2 = ggml_reshape_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]);
|
||||
// mlp_2 ne [24, 24, 2048, 1]
|
||||
mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0);
|
||||
// weight ne = [3, 3, 2048, 1]
|
||||
struct ggml_tensor * peg_0 = ggml_conv_depthwise_2d(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1);
|
||||
peg_0 = ggml_add(ctx0, peg_0, mlp_2);
|
||||
peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3));
|
||||
peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b);
|
||||
peg_0 = ggml_reshape_3d(ctx0, peg_0, peg_0->ne[0], peg_0->ne[1] * peg_0->ne[2], peg_0->ne[3]);
|
||||
embeddings = peg_0;
|
||||
}
|
||||
else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
@ -1177,7 +1211,18 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
vision_model.mm_model_block_2_block_2_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight"));
|
||||
vision_model.mm_model_block_2_block_2_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight"));
|
||||
vision_model.mm_model_block_2_block_2_1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias"));
|
||||
} else {
|
||||
}
|
||||
else if (new_clip->proj_type == PROJECTOR_TYPE_LDPV2)
|
||||
{
|
||||
// MobilVLM_V2 projection
|
||||
vision_model.mm_model_mlp_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 0, "weight"));
|
||||
vision_model.mm_model_mlp_0_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 0, "bias"));
|
||||
vision_model.mm_model_mlp_2_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 2, "weight"));
|
||||
vision_model.mm_model_mlp_2_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 2, "bias"));
|
||||
vision_model.mm_model_peg_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_PEG, 0, "weight"));
|
||||
vision_model.mm_model_peg_0_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_PEG, 0, "bias"));
|
||||
}
|
||||
else {
|
||||
std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type];
|
||||
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
|
||||
}
|
||||
|
@ -1966,6 +2011,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
|||
if (ctx->proj_type == PROJECTOR_TYPE_LDP) {
|
||||
return ctx->vision_model.mm_model_block_1_block_2_1_b->ne[0];
|
||||
}
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_LDPV2) {
|
||||
return ctx->vision_model.mm_model_peg_0_b->ne[0];
|
||||
}
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
|
||||
return ctx->vision_model.mm_2_b->ne[0];
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import argparse
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
@ -38,9 +39,11 @@ def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: b
|
|||
def get_tensor_name(name: str) -> str:
|
||||
if "projection" in name:
|
||||
return name
|
||||
|
||||
if "mm_projector" in name:
|
||||
return name.replace("model.mm_projector", "mm")
|
||||
name = name.replace("model.mm_projector", "mm")
|
||||
name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1)
|
||||
name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1)
|
||||
return name
|
||||
|
||||
return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln")
|
||||
|
||||
|
@ -83,7 +86,7 @@ ap.add_argument("--clip-model-is-vision", action="store_true", required=False,
|
|||
ap.add_argument("--clip-model-is-openclip", action="store_true", required=False,
|
||||
help="The clip model is from openclip (for ViT-SO400M type))")
|
||||
ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
|
||||
ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp", choices=["mlp", "ldp"], default="mlp")
|
||||
ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp")
|
||||
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
|
||||
# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711
|
||||
# Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5
|
||||
|
|
20
examples/regex-to-grammar.py
Normal file
20
examples/regex-to-grammar.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
import json, subprocess, sys, os
|
||||
|
||||
assert len(sys.argv) >= 2
|
||||
[_, pattern, *rest] = sys.argv
|
||||
|
||||
print(subprocess.check_output(
|
||||
[
|
||||
"python",
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)),
|
||||
"json-schema-to-grammar.py"),
|
||||
*rest,
|
||||
"-",
|
||||
"--raw-pattern",
|
||||
],
|
||||
text=True,
|
||||
input=json.dumps({
|
||||
"type": "string",
|
||||
"pattern": pattern,
|
||||
}, indent=2)))
|
|
@ -2,12 +2,16 @@ set(TARGET server)
|
|||
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
|
||||
option(LLAMA_SERVER_SSL "Build SSL support for the server" OFF)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
|
||||
add_executable(${TARGET}
|
||||
server.cpp
|
||||
utils.hpp
|
||||
httplib.h
|
||||
)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_compile_definitions(${TARGET} PRIVATE
|
||||
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
|
||||
)
|
||||
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PRIVATE common json-schema-to-grammar ${CMAKE_THREAD_LIBS_INIT})
|
||||
if (LLAMA_SERVER_SSL)
|
||||
find_package(OpenSSL REQUIRED)
|
||||
target_link_libraries(${TARGET} PRIVATE OpenSSL::SSL OpenSSL::Crypto)
|
||||
|
|
|
@ -26,8 +26,9 @@ const propOrder = grammarJsonSchemaPropOrder
|
|||
|
||||
let grammar = null
|
||||
if (grammarJsonSchemaFile) {
|
||||
const schema = JSON.parse(readFileSync(grammarJsonSchemaFile, 'utf-8'))
|
||||
const converter = new SchemaConverter(propOrder)
|
||||
let schema = JSON.parse(readFileSync(grammarJsonSchemaFile, 'utf-8'))
|
||||
const converter = new SchemaConverter({prop_order: propOrder, allow_fetch: true})
|
||||
schema = await converter.resolveRefs(schema, grammarJsonSchemaFile)
|
||||
converter.visit(schema, '')
|
||||
grammar = converter.formatGrammar()
|
||||
}
|
||||
|
|
|
@ -483,4 +483,4 @@ unsigned char completion_js[] = {
|
|||
0x20, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f,
|
||||
0x73, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x3b, 0x0a, 0x7d, 0x0a
|
||||
};
|
||||
unsigned int completion_js_len = 5796;
|
||||
size_t completion_js_len = 5796;
|
||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -630,14 +630,16 @@
|
|||
|
||||
const grammarJsonSchemaPropOrder = signal('')
|
||||
const updateGrammarJsonSchemaPropOrder = (el) => grammarJsonSchemaPropOrder.value = el.target.value
|
||||
const convertJSONSchemaGrammar = () => {
|
||||
const convertJSONSchemaGrammar = async () => {
|
||||
try {
|
||||
const schema = JSON.parse(params.value.grammar)
|
||||
const converter = new SchemaConverter(
|
||||
grammarJsonSchemaPropOrder.value
|
||||
let schema = JSON.parse(params.value.grammar)
|
||||
const converter = new SchemaConverter({
|
||||
prop_order: grammarJsonSchemaPropOrder.value
|
||||
.split(',')
|
||||
.reduce((acc, cur, i) => ({ ...acc, [cur.trim()]: i }), {})
|
||||
)
|
||||
.reduce((acc, cur, i) => ({ ...acc, [cur.trim()]: i }), {}),
|
||||
allow_fetch: true,
|
||||
})
|
||||
schema = await converter.resolveRefs(schema, 'input')
|
||||
converter.visit(schema, '')
|
||||
params.value = {
|
||||
...params.value,
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -1,25 +1,50 @@
|
|||
// WARNING: This file was ported from json-schema-to-grammar.py, please fix bugs / add features there first.
|
||||
const SPACE_RULE = '" "?';
|
||||
|
||||
const PRIMITIVE_RULES = {
|
||||
boolean: '("true" | "false") space',
|
||||
number: '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space',
|
||||
integer: '("-"? ([0-9] | [1-9] [0-9]*)) space',
|
||||
value: 'object | array | string | number | boolean',
|
||||
object: '"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space',
|
||||
array: '"[" space ( value ("," space value)* )? "]" space',
|
||||
uuid: '"\\"" ' + [8, 4, 4, 4, 12].map(n => [...new Array(n)].map(_ => '[0-9a-fA-F]').join('')).join(' "-" ') + ' "\\"" space',
|
||||
string: ` "\\"" (
|
||||
[^"\\\\] |
|
||||
"\\\\" (["\\\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\\"" space`,
|
||||
null: '"null" space',
|
||||
};
|
||||
const OBJECT_RULE_NAMES = ['object', 'array', 'string', 'number', 'boolean', 'null', 'value'];
|
||||
|
||||
// TODO: support "uri", "email" string formats
|
||||
const DATE_RULES = {
|
||||
'date' : '[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )',
|
||||
'time' : '([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )',
|
||||
'date-time': 'date "T" time',
|
||||
'date-string': '"\\"" date "\\"" space',
|
||||
'time-string': '"\\"" time "\\"" space',
|
||||
'date-time-string': '"\\"" date-time "\\"" space',
|
||||
};
|
||||
|
||||
const RESERVED_NAMES = {'root': true, ...PRIMITIVE_RULES, ...DATE_RULES};
|
||||
|
||||
const INVALID_RULE_CHARS_RE = /[^\dA-Za-z-]+/g;
|
||||
const GRAMMAR_LITERAL_ESCAPE_RE = /[\n\r"]/g;
|
||||
const GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"'};
|
||||
const GRAMMAR_RANGE_LITERAL_ESCAPE_RE = /[\n\r"\]\-\\]/g;
|
||||
const GRAMMAR_LITERAL_ESCAPES = { '\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]' };
|
||||
|
||||
const NON_LITERAL_SET = new Set('|.()[]{}*+?');
|
||||
const ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = new Set('[]()|{}*+?');
|
||||
|
||||
export class SchemaConverter {
|
||||
constructor(propOrder) {
|
||||
this._propOrder = propOrder || {};
|
||||
this._rules = new Map();
|
||||
this._rules.set('space', SPACE_RULE);
|
||||
constructor(options) {
|
||||
this._propOrder = options.prop_order || {};
|
||||
this._allowFetch = options.allow_fetch || false;
|
||||
this._dotall = options.dotall || false;
|
||||
this._rules = {'space': SPACE_RULE};
|
||||
this._refs = {};
|
||||
this._refsBeingResolved = new Set();
|
||||
}
|
||||
|
||||
_formatLiteral(literal) {
|
||||
|
@ -30,83 +55,490 @@ export class SchemaConverter {
|
|||
return `"${escaped}"`;
|
||||
}
|
||||
|
||||
_formatRangeChar(literal) {
|
||||
return JSON.stringify(literal).slice(1, -1).replace(
|
||||
GRAMMAR_RANGE_LITERAL_ESCAPE_RE,
|
||||
m => GRAMMAR_LITERAL_ESCAPES[m]
|
||||
);
|
||||
}
|
||||
|
||||
_addRule(name, rule) {
|
||||
let escName = name.replace(INVALID_RULE_CHARS_RE, '-');
|
||||
let key = escName;
|
||||
|
||||
if (this._rules.has(escName)) {
|
||||
if (this._rules.get(escName) === rule) {
|
||||
if (escName in this._rules) {
|
||||
if (this._rules[escName] === rule) {
|
||||
return key;
|
||||
}
|
||||
|
||||
let i = 0;
|
||||
while (this._rules.has(`${escName}${i}`)) {
|
||||
while ((`${escName}${i}` in this._rules) && (this._rules[`${escName}${i}`] !== rule)) {
|
||||
i += 1;
|
||||
}
|
||||
key = `${escName}${i}`;
|
||||
}
|
||||
|
||||
this._rules.set(key, rule);
|
||||
this._rules[key] = rule;
|
||||
return key;
|
||||
}
|
||||
|
||||
async resolveRefs(schema, url) {
|
||||
const visit = async (n) => {
|
||||
if (Array.isArray(n)) {
|
||||
return Promise.all(n.map(visit));
|
||||
} else if (typeof n === 'object' && n !== null) {
|
||||
let ref = n.$ref;
|
||||
let target;
|
||||
if (ref !== undefined && !this._refs[ref]) {
|
||||
if (ref.startsWith('https://')) {
|
||||
if (!this._allowFetch) {
|
||||
throw new Error('Fetching remote schemas is not allowed (use --allow-fetch for force)');
|
||||
}
|
||||
const fetch = (await import('node-fetch')).default;
|
||||
|
||||
const fragSplit = ref.split('#');
|
||||
const baseUrl = fragSplit[0];
|
||||
|
||||
target = this._refs[baseUrl];
|
||||
if (!target) {
|
||||
target = await this.resolveRefs(await fetch(ref).then(res => res.json()), baseUrl);
|
||||
this._refs[baseUrl] = target;
|
||||
}
|
||||
|
||||
if (fragSplit.length === 1 || fragSplit[fragSplit.length - 1] === '') {
|
||||
return target;
|
||||
}
|
||||
} else if (ref.startsWith('#/')) {
|
||||
target = schema;
|
||||
ref = `${url}${ref}`;
|
||||
n.$ref = ref;
|
||||
} else {
|
||||
throw new Error(`Unsupported ref ${ref}`);
|
||||
}
|
||||
|
||||
const selectors = ref.split('#')[1].split('/').slice(1);
|
||||
for (const sel of selectors) {
|
||||
if (!target || !(sel in target)) {
|
||||
throw new Error(`Error resolving ref ${ref}: ${sel} not in ${JSON.stringify(target)}`);
|
||||
}
|
||||
target = target[sel];
|
||||
}
|
||||
|
||||
this._refs[ref] = target;
|
||||
} else {
|
||||
await Promise.all(Object.values(n).map(visit));
|
||||
}
|
||||
}
|
||||
|
||||
return n;
|
||||
};
|
||||
|
||||
return visit(schema);
|
||||
}
|
||||
|
||||
_generateUnionRule(name, altSchemas) {
|
||||
return altSchemas
|
||||
.map((altSchema, i) => this.visit(altSchema, `${name ?? ''}${name ? '-' : 'alternative-'}${i}`))
|
||||
.join(' | ');
|
||||
}
|
||||
|
||||
_visitPattern(pattern, name) {
|
||||
if (!pattern.startsWith('^') || !pattern.endsWith('$')) {
|
||||
throw new Error('Pattern must start with "^" and end with "$"');
|
||||
}
|
||||
pattern = pattern.slice(1, -1);
|
||||
const subRuleIds = {};
|
||||
|
||||
let i = 0;
|
||||
const length = pattern.length;
|
||||
|
||||
const getDot = () => {
|
||||
let rule;
|
||||
if (this._dotall) {
|
||||
rule = '[\\U00000000-\\U0010FFFF]';
|
||||
} else {
|
||||
// Accept any character... except \n and \r line break chars (\x0A and \xOD)
|
||||
rule = '[\\U00000000-\\x09\\x0B\\x0C\\x0E-\\U0010FFFF]';
|
||||
}
|
||||
return this._addRule('dot', rule);
|
||||
};
|
||||
|
||||
|
||||
const toRule = ([s, isLiteral]) => isLiteral ? "\"" + s + "\"" : s;
|
||||
|
||||
const transform = () => {
|
||||
const start = i;
|
||||
// For each component of this sequence, store its string representation and whether it's a literal.
|
||||
// We only need a flat structure here to apply repetition operators to the last item, and
|
||||
// to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
|
||||
// (GBNF's syntax is luckily very close to regular expressions!)
|
||||
const seq = [];
|
||||
|
||||
const joinSeq = () => {
|
||||
const ret = [];
|
||||
for (const [isLiteral, g] of groupBy(seq, x => x[1])) {
|
||||
if (isLiteral) {
|
||||
ret.push([[...g].map(x => x[0]).join(''), true]);
|
||||
} else {
|
||||
ret.push(...g);
|
||||
}
|
||||
}
|
||||
if (ret.length === 1) {
|
||||
return ret[0];
|
||||
}
|
||||
return [ret.map(x => toRule(x)).join(' '), false];
|
||||
};
|
||||
|
||||
while (i < length) {
|
||||
const c = pattern[i];
|
||||
if (c === '.') {
|
||||
seq.push([getDot(), false]);
|
||||
i += 1;
|
||||
} else if (c === '(') {
|
||||
i += 1;
|
||||
if (i < length) {
|
||||
if (pattern[i] === '?') {
|
||||
throw new Error(`Unsupported pattern syntax "${pattern[i]}" at index ${i} of /${pattern}/`);
|
||||
}
|
||||
}
|
||||
seq.push([`(${toRule(transform())})`, false]);
|
||||
} else if (c === ')') {
|
||||
i += 1;
|
||||
if (start <= 0 || pattern[start - 1] !== '(') {
|
||||
throw new Error(`Unbalanced parentheses; start = ${start}, i = ${i}, pattern = ${pattern}`);
|
||||
}
|
||||
return joinSeq();
|
||||
} else if (c === '[') {
|
||||
let squareBrackets = c;
|
||||
i += 1;
|
||||
while (i < length && pattern[i] !== ']') {
|
||||
if (pattern[i] === '\\') {
|
||||
squareBrackets += pattern.slice(i, i + 2);
|
||||
i += 2;
|
||||
} else {
|
||||
squareBrackets += pattern[i];
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
if (i >= length) {
|
||||
throw new Error(`Unbalanced square brackets; start = ${start}, i = ${i}, pattern = ${pattern}`);
|
||||
}
|
||||
squareBrackets += ']';
|
||||
i += 1;
|
||||
seq.push([squareBrackets, false]);
|
||||
} else if (c === '|') {
|
||||
seq.push(['|', false]);
|
||||
i += 1;
|
||||
} else if (c === '*' || c === '+' || c === '?') {
|
||||
seq[seq.length - 1] = [toRule(seq[seq.length - 1]) + c, false];
|
||||
i += 1;
|
||||
} else if (c === '{') {
|
||||
let curlyBrackets = c;
|
||||
i += 1;
|
||||
while (i < length && pattern[i] !== '}') {
|
||||
curlyBrackets += pattern[i];
|
||||
i += 1;
|
||||
}
|
||||
if (i >= length) {
|
||||
throw new Error(`Unbalanced curly brackets; start = ${start}, i = ${i}, pattern = ${pattern}`);
|
||||
}
|
||||
curlyBrackets += '}';
|
||||
i += 1;
|
||||
const nums = curlyBrackets.slice(1, -1).split(',').map(s => s.trim());
|
||||
let minTimes, maxTimes;
|
||||
if (nums.length === 1) {
|
||||
minTimes = parseInt(nums[0], 10);
|
||||
maxTimes = minTimes;
|
||||
} else {
|
||||
if (nums.length !== 2) {
|
||||
throw new Error(`Invalid quantifier ${curlyBrackets}`);
|
||||
}
|
||||
minTimes = nums[0] ? parseInt(nums[0], 10) : 0;
|
||||
maxTimes = nums[1] ? parseInt(nums[1], 10) : Infinity;
|
||||
}
|
||||
|
||||
let [sub, subIsLiteral] = seq[seq.length - 1];
|
||||
|
||||
if (minTimes === 0 && maxTimes === Infinity) {
|
||||
seq[seq.length - 1] = [`${sub}*`, false];
|
||||
} else if (minTimes === 0 && maxTimes === 1) {
|
||||
seq[seq.length - 1] = [`${sub}?`, false];
|
||||
} else if (minTimes === 1 && maxTimes === Infinity) {
|
||||
seq[seq.length - 1] = [`${sub}+`, false];
|
||||
} else {
|
||||
if (!subIsLiteral) {
|
||||
let id = subRuleIds[sub];
|
||||
if (id === undefined) {
|
||||
id = this._addRule(`${name}-${Object.keys(subRuleIds).length + 1}`, sub);
|
||||
subRuleIds[sub] = id;
|
||||
}
|
||||
sub = id;
|
||||
}
|
||||
|
||||
const repeatedSub = Array.from({ length: minTimes }, () => subIsLiteral ? `"${sub.slice(1, -1).repeat(minTimes)}"` : sub);
|
||||
const optionalSub = maxTimes !== undefined ? Array.from({ length: maxTimes - minTimes }, () => `${sub}?`) : [`${sub}*`];
|
||||
seq[seq.length - 1] = [repeatedSub.concat(optionalSub).join(' '), false];
|
||||
}
|
||||
} else {
|
||||
let literal = '';
|
||||
while (i < length) {
|
||||
if (pattern[i] === '\\' && i < length - 1) {
|
||||
const next = pattern[i + 1];
|
||||
if (ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.has(next)) {
|
||||
i += 1;
|
||||
literal += pattern[i];
|
||||
i += 1;
|
||||
} else {
|
||||
literal += pattern.slice(i, i + 2);
|
||||
i += 2;
|
||||
}
|
||||
} else if (pattern[i] === '"') {
|
||||
literal += '\\"';
|
||||
i += 1;
|
||||
} else if (!NON_LITERAL_SET.has(pattern[i]) &&
|
||||
(i === length - 1 || literal === '' || pattern[i + 1] === '.' || !NON_LITERAL_SET.has(pattern[i+1]))) {
|
||||
literal += pattern[i];
|
||||
i += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (literal !== '') {
|
||||
seq.push([literal, true]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return joinSeq();
|
||||
};
|
||||
|
||||
return this._addRule(name, "\"\\\"\" " + toRule(transform()) + " \"\\\"\" space")
|
||||
}
|
||||
|
||||
_resolveRef(ref) {
|
||||
let refName = ref.split('/').pop();
|
||||
if (!(refName in this._rules) && !this._refsBeingResolved.has(ref)) {
|
||||
this._refsBeingResolved.add(ref);
|
||||
const resolved = this._refs[ref];
|
||||
refName = this.visit(resolved, refName);
|
||||
this._refsBeingResolved.delete(ref);
|
||||
}
|
||||
return refName;
|
||||
}
|
||||
|
||||
_generateConstantRule(value) {
|
||||
if (typeof value !== 'string') {
|
||||
throw new Error('Only string constants are supported, got ' + JSON.stringify(value));
|
||||
}
|
||||
return this._formatLiteral(value);
|
||||
}
|
||||
|
||||
visit(schema, name) {
|
||||
const schemaType = schema.type;
|
||||
const ruleName = name || 'root';
|
||||
const schemaFormat = schema.format;
|
||||
const ruleName = name in RESERVED_NAMES ? name + '-' : name == '' ? 'root' : name;
|
||||
|
||||
if (schema.oneOf || schema.anyOf) {
|
||||
const rule = (schema.oneOf || schema.anyOf).map((altSchema, i) =>
|
||||
this.visit(altSchema, `${name}${name ? "-" : ""}${i}`)
|
||||
).join(' | ');
|
||||
|
||||
return this._addRule(ruleName, rule);
|
||||
const ref = schema.$ref;
|
||||
if (ref !== undefined) {
|
||||
return this._addRule(ruleName, this._resolveRef(ref));
|
||||
} else if (schema.oneOf || schema.anyOf) {
|
||||
return this._addRule(ruleName, this._generateUnionRule(name, schema.oneOf || schema.anyOf));
|
||||
} else if (Array.isArray(schemaType)) {
|
||||
return this._addRule(ruleName, this._generateUnionRule(name, schemaType.map(t => ({ type: t }))));
|
||||
} else if ('const' in schema) {
|
||||
return this._addRule(ruleName, this._formatLiteral(schema.const));
|
||||
} else if ('enum' in schema) {
|
||||
const rule = schema.enum.map(v => this._formatLiteral(v)).join(' | ');
|
||||
return this._addRule(ruleName, rule);
|
||||
} else if (schemaType === 'object' && 'properties' in schema) {
|
||||
// TODO: `required` keyword (from python implementation)
|
||||
const propOrder = this._propOrder;
|
||||
const propPairs = Object.entries(schema.properties).sort((a, b) => {
|
||||
// sort by position in prop_order (if specified) then by key
|
||||
const orderA = typeof propOrder[a[0]] === 'number' ? propOrder[a[0]] : Infinity;
|
||||
const orderB = typeof propOrder[b[0]] === 'number' ? propOrder[b[0]] : Infinity;
|
||||
return orderA - orderB || a[0].localeCompare(b[0]);
|
||||
});
|
||||
|
||||
let rule = '"{" space';
|
||||
propPairs.forEach(([propName, propSchema], i) => {
|
||||
const propRuleName = this.visit(propSchema, `${name}${name ? "-" : ""}${propName}`);
|
||||
if (i > 0) {
|
||||
rule += ' "," space';
|
||||
if (typeof schema.const !== 'string') {
|
||||
throw new Error('Only string constants are supported, got ' + JSON.stringify(schema.const));
|
||||
}
|
||||
return this._addRule(ruleName, this._generateConstantRule(schema.const));
|
||||
} else if ('enum' in schema) {
|
||||
const rule = schema.enum.map(v => this._generateConstantRule(v)).join(' | ');
|
||||
return this._addRule(ruleName, rule);
|
||||
} else if ((schemaType === undefined || schemaType === 'object') &&
|
||||
('properties' in schema ||
|
||||
('additionalProperties' in schema && schema.additionalProperties !== true))) {
|
||||
const required = new Set(schema.required || []);
|
||||
const properties = Object.entries(schema.properties ?? {});
|
||||
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, schema.additionalProperties));
|
||||
} else if ((schemaType === undefined || schemaType === 'object') && 'allOf' in schema) {
|
||||
const required = new Set();
|
||||
const properties = [];
|
||||
const addComponent = (compSchema, isRequired) => {
|
||||
const ref = compSchema.$ref;
|
||||
if (ref !== undefined) {
|
||||
compSchema = this._refs[ref];
|
||||
}
|
||||
rule += ` ${this._formatLiteral(propName)} space ":" space ${propRuleName}`;
|
||||
});
|
||||
rule += ' "}" space';
|
||||
|
||||
return this._addRule(ruleName, rule);
|
||||
} else if (schemaType === 'array' && 'items' in schema) {
|
||||
// TODO `prefixItems` keyword (from python implementation)
|
||||
const itemRuleName = this.visit(schema.items, `${name}${name ? "-" : ""}item`);
|
||||
const rule = `"[" space (${itemRuleName} ("," space ${itemRuleName})*)? "]" space`;
|
||||
return this._addRule(ruleName, rule);
|
||||
if ('properties' in compSchema) {
|
||||
for (const [propName, propSchema] of Object.entries(compSchema.properties)) {
|
||||
properties.push([propName, propSchema]);
|
||||
if (isRequired) {
|
||||
required.add(propName);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (const t of schema.allOf) {
|
||||
if ('anyOf' in t) {
|
||||
for (const tt of t.anyOf) {
|
||||
addComponent(tt, false);
|
||||
}
|
||||
} else {
|
||||
if (!PRIMITIVE_RULES[schemaType]) {
|
||||
addComponent(t, true);
|
||||
}
|
||||
}
|
||||
|
||||
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, /* additionalProperties= */ false));
|
||||
} else if ((schemaType === undefined || schemaType === 'array') && ('items' in schema || 'prefixItems' in schema)) {
|
||||
const items = schema.items ?? schema.prefixItems;
|
||||
if (Array.isArray(items)) {
|
||||
return this._addRule(
|
||||
ruleName,
|
||||
'"[" space ' +
|
||||
items.map((item, i) => this.visit(item, `${name ?? ''}${name ? '-' : ''}tuple-${i}`)).join(' "," space ') +
|
||||
' "]" space'
|
||||
);
|
||||
} else {
|
||||
const itemRuleName = this.visit(items, `${name ?? ''}${name ? '-' : ''}item`);
|
||||
const listItemOperator = `( "," space ${itemRuleName} )`;
|
||||
let successiveItems = '';
|
||||
let minItems = schema.minItems || 0;
|
||||
const maxItems = schema.maxItems;
|
||||
if (minItems > 0) {
|
||||
successiveItems = listItemOperator.repeat(minItems - 1);
|
||||
minItems--;
|
||||
}
|
||||
if (maxItems !== undefined && maxItems > minItems) {
|
||||
successiveItems += `${listItemOperator}?`.repeat(maxItems - minItems - 1);
|
||||
} else {
|
||||
successiveItems += `${listItemOperator}*`;
|
||||
}
|
||||
const rule = minItems === 0
|
||||
? `"[" space ( ${itemRuleName} ${successiveItems} )? "]" space`
|
||||
: `"[" space ${itemRuleName} ${successiveItems} "]" space`;
|
||||
return this._addRule(ruleName, rule);
|
||||
}
|
||||
} else if ((schemaType === undefined || schemaType === 'string') && 'pattern' in schema) {
|
||||
return this._visitPattern(schema.pattern, ruleName);
|
||||
} else if ((schemaType === undefined || schemaType === 'string') && /^uuid[1-5]?$/.test(schema.format || '')) {
|
||||
return this._addRule(
|
||||
ruleName === 'root' ? 'root' : schemaFormat,
|
||||
PRIMITIVE_RULES['uuid'])
|
||||
} else if ((schemaType === undefined || schemaType === 'string') && schema.format in DATE_RULES) {
|
||||
for (const [t, r] of Object.entries(DATE_RULES)) {
|
||||
this._addRule(t, r);
|
||||
}
|
||||
return schemaFormat + '-string';
|
||||
} else if ((schemaType === 'object') || (Object.keys(schema).length === 0)) {
|
||||
for (const n of OBJECT_RULE_NAMES) {
|
||||
this._addRule(n, PRIMITIVE_RULES[n]);
|
||||
}
|
||||
return this._addRule(ruleName, 'object');
|
||||
} else {
|
||||
if (!(schemaType in PRIMITIVE_RULES)) {
|
||||
throw new Error(`Unrecognized schema: ${JSON.stringify(schema)}`);
|
||||
}
|
||||
return this._addRule(
|
||||
ruleName === 'root' ? 'root' : schemaType,
|
||||
PRIMITIVE_RULES[schemaType]
|
||||
// TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
|
||||
return this._addRule(ruleName === 'root' ? 'root' : schemaType, PRIMITIVE_RULES[schemaType]);
|
||||
}
|
||||
}
|
||||
|
||||
_buildObjectRule(properties, required, name, additionalProperties) {
|
||||
const propOrder = this._propOrder;
|
||||
// sort by position in prop_order (if specified) then by original order
|
||||
const sortedProps = properties.map(([k]) => k).sort((a, b) => {
|
||||
const orderA = propOrder[a] || Infinity;
|
||||
const orderB = propOrder[b] || Infinity;
|
||||
return orderA - orderB || properties.findIndex(([k]) => k === a) - properties.findIndex(([k]) => k === b);
|
||||
});
|
||||
|
||||
const propKvRuleNames = {};
|
||||
for (const [propName, propSchema] of properties) {
|
||||
const propRuleName = this.visit(propSchema, `${name ?? ''}${name ? '-' : ''}${propName}`);
|
||||
propKvRuleNames[propName] = this._addRule(
|
||||
`${name ?? ''}${name ? '-' : ''}${propName}-kv`,
|
||||
`${this._formatLiteral(propName)} space ":" space ${propRuleName}`
|
||||
);
|
||||
}
|
||||
const requiredProps = sortedProps.filter(k => required.has(k));
|
||||
const optionalProps = sortedProps.filter(k => !required.has(k));
|
||||
|
||||
if (typeof additionalProperties === 'object' || additionalProperties === true) {
|
||||
const subName = `${name ?? ''}${name ? '-' : ''}additional`;
|
||||
const valueRule = this.visit(additionalProperties === true ? {} : additionalProperties, `${subName}-value`);
|
||||
propKvRuleNames['*'] = this._addRule(
|
||||
`${subName}-kv`,
|
||||
`${this._addRule('string', PRIMITIVE_RULES['string'])} ":" space ${valueRule}`);
|
||||
optionalProps.push('*');
|
||||
}
|
||||
|
||||
let rule = '"{" space ';
|
||||
rule += requiredProps.map(k => propKvRuleNames[k]).join(' "," space ');
|
||||
|
||||
if (optionalProps.length > 0) {
|
||||
rule += ' (';
|
||||
if (requiredProps.length > 0) {
|
||||
rule += ' "," space ( ';
|
||||
}
|
||||
|
||||
const getRecursiveRefs = (ks, firstIsOptional) => {
|
||||
const [k, ...rest] = ks;
|
||||
const kvRuleName = propKvRuleNames[k];
|
||||
let res;
|
||||
if (k === '*') {
|
||||
res = this._addRule(
|
||||
`${name ?? ''}${name ? '-' : ''}additional-kvs`,
|
||||
`${kvRuleName} ( "," space ` + kvRuleName + ` )*`
|
||||
)
|
||||
} else if (firstIsOptional) {
|
||||
res = `( "," space ${kvRuleName} )?`;
|
||||
} else {
|
||||
res = kvRuleName;
|
||||
}
|
||||
if (rest.length > 0) {
|
||||
res += ' ' + this._addRule(
|
||||
`${name ?? ''}${name ? '-' : ''}${k}-rest`,
|
||||
getRecursiveRefs(rest, true)
|
||||
);
|
||||
}
|
||||
return res;
|
||||
};
|
||||
|
||||
rule += optionalProps.map((_, i) => getRecursiveRefs(optionalProps.slice(i), false)).join(' | ');
|
||||
if (requiredProps.length > 0) {
|
||||
rule += ' )';
|
||||
}
|
||||
rule += ' )?';
|
||||
}
|
||||
|
||||
rule += ' "}" space';
|
||||
|
||||
return rule;
|
||||
}
|
||||
|
||||
formatGrammar() {
|
||||
let grammar = '';
|
||||
this._rules.forEach((rule, name) => {
|
||||
for (const [name, rule] of Object.entries(this._rules).sort(([a], [b]) => a.localeCompare(b))) {
|
||||
grammar += `${name} ::= ${rule}\n`;
|
||||
});
|
||||
}
|
||||
return grammar;
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to group elements by a key function
|
||||
function* groupBy(iterable, keyFn) {
|
||||
let lastKey = null;
|
||||
let group = [];
|
||||
for (const element of iterable) {
|
||||
const key = keyFn(element);
|
||||
if (lastKey !== null && key !== lastKey) {
|
||||
yield [lastKey, group];
|
||||
group = [];
|
||||
}
|
||||
group.push(element);
|
||||
lastKey = key;
|
||||
}
|
||||
if (group.length > 0) {
|
||||
yield [lastKey, group];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "utils.hpp"
|
||||
|
||||
#include "common.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "llama.h"
|
||||
#include "grammar-parser.h"
|
||||
|
||||
|
@ -178,6 +179,7 @@ struct server_slot {
|
|||
llama_token sampled;
|
||||
struct llama_sampling_params sparams;
|
||||
llama_sampling_context * ctx_sampling = nullptr;
|
||||
json json_schema;
|
||||
|
||||
int32_t ga_i = 0; // group-attention state
|
||||
int32_t ga_n = 1; // group-attention factor
|
||||
|
@ -845,7 +847,17 @@ struct server_context {
|
|||
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
||||
slot.params.seed = json_value(data, "seed", default_params.seed);
|
||||
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||
try {
|
||||
auto schema = json_value(data, "json_schema", json::object());
|
||||
slot.sparams.grammar = json_schema_to_grammar(schema);
|
||||
} catch (const std::exception & e) {
|
||||
send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
||||
}
|
||||
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||
|
||||
|
|
|
@ -37,6 +37,22 @@ Feature: Security
|
|||
| llama.cpp | no |
|
||||
| hackme | raised |
|
||||
|
||||
Scenario Outline: OAI Compatibility (invalid response formats)
|
||||
Given a system prompt test
|
||||
And a user prompt test
|
||||
And a response format <response_format>
|
||||
And a model test
|
||||
And 2 max tokens to predict
|
||||
And streaming is disabled
|
||||
Given an OAI compatible chat completions request with raised api error
|
||||
|
||||
Examples: Prompts
|
||||
| response_format |
|
||||
| {"type": "sound"} |
|
||||
| {"type": "json_object", "schema": 123} |
|
||||
| {"type": "json_object", "schema": {"type": 123}} |
|
||||
| {"type": "json_object", "schema": {"type": "hiccup"}} |
|
||||
|
||||
|
||||
Scenario Outline: CORS Options
|
||||
Given a user api key llama.cpp
|
||||
|
|
|
@ -70,6 +70,22 @@ Feature: llama.cpp server
|
|||
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128 | (thanks\|happy\|bird\|Annabyear)+ | -1 | 64 | enabled | |
|
||||
|
||||
|
||||
Scenario Outline: OAI Compatibility w/ response format
|
||||
Given a model test
|
||||
And a system prompt test
|
||||
And a user prompt test
|
||||
And a response format <response_format>
|
||||
And 10 max tokens to predict
|
||||
Given an OAI compatible chat completions request with no api error
|
||||
Then <n_predicted> tokens are predicted matching <re_content>
|
||||
|
||||
Examples: Prompts
|
||||
| response_format | n_predicted | re_content |
|
||||
| {"type": "json_object", "schema": {"const": "42"}} | 5 | "42" |
|
||||
| {"type": "json_object", "schema": {"items": [{"type": "integer"}]}} | 10 | \[ -300 \] |
|
||||
| {"type": "json_object"} | 10 | \{ " Jacky. |
|
||||
|
||||
|
||||
Scenario: Tokenize / Detokenize
|
||||
When tokenizing:
|
||||
"""
|
||||
|
|
|
@ -24,12 +24,16 @@ from prometheus_client import parser
|
|||
def step_server_config(context, server_fqdn, server_port):
|
||||
context.server_fqdn = server_fqdn
|
||||
context.server_port = int(server_port)
|
||||
context.n_gpu_layer = None
|
||||
if 'PORT' in os.environ:
|
||||
context.server_port = int(os.environ['PORT'])
|
||||
print(f"$PORT set, overriding server port with to {context.server_port}")
|
||||
if 'FQDN' in os.environ:
|
||||
context.server_fqdn = os.environ['FQDN']
|
||||
print(f"$FQDN set, overriding server fqdn with to {context.server_fqdn}")
|
||||
if 'N_GPU_LAYERS' in os.environ:
|
||||
context.n_gpu_layer = int(os.environ['N_GPU_LAYERS'])
|
||||
print(f"$N_GPU_LAYERS set, overriding n_gpu_layer with to {context.n_gpu_layer}")
|
||||
|
||||
context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
|
||||
|
||||
|
@ -41,7 +45,6 @@ def step_server_config(context, server_fqdn, server_port):
|
|||
context.n_ctx = None
|
||||
context.n_ga = None
|
||||
context.n_ga_w = None
|
||||
context.n_gpu_layer = None
|
||||
context.n_predict = None
|
||||
context.n_prompts = 0
|
||||
context.n_server_predict = None
|
||||
|
@ -56,6 +59,7 @@ def step_server_config(context, server_fqdn, server_port):
|
|||
context.seed = None
|
||||
context.server_seed = None
|
||||
context.user_api_key = None
|
||||
context.response_format = None
|
||||
|
||||
context.tasks_result = []
|
||||
context.concurrent_tasks = []
|
||||
|
@ -266,6 +270,11 @@ def step_max_tokens(context, max_tokens):
|
|||
context.n_predict = max_tokens
|
||||
|
||||
|
||||
@step('a response format {response_format}')
|
||||
def step_response_format(context, response_format):
|
||||
context.response_format = json.loads(response_format)
|
||||
|
||||
|
||||
@step('streaming is {enable_streaming}')
|
||||
def step_streaming(context, enable_streaming):
|
||||
context.enable_streaming = enable_streaming == 'enabled'
|
||||
|
@ -381,6 +390,9 @@ async def step_oai_chat_completions(context, api_error):
|
|||
enable_streaming=context.enable_streaming
|
||||
if hasattr(context, 'enable_streaming') else None,
|
||||
|
||||
response_format=context.response_format
|
||||
if hasattr(context, 'response_format') else None,
|
||||
|
||||
seed=await completions_seed(context),
|
||||
|
||||
user_api_key=context.user_api_key
|
||||
|
@ -440,6 +452,8 @@ async def step_oai_chat_completions(context):
|
|||
if hasattr(context, 'n_predict') else None,
|
||||
enable_streaming=context.enable_streaming
|
||||
if hasattr(context, 'enable_streaming') else None,
|
||||
response_format=context.response_format
|
||||
if hasattr(context, 'response_format') else None,
|
||||
seed=await completions_seed(context),
|
||||
user_api_key=context.user_api_key
|
||||
if hasattr(context, 'user_api_key') else None)
|
||||
|
@ -460,6 +474,8 @@ async def step_oai_chat_completions(context):
|
|||
if hasattr(context, 'n_predict') else None,
|
||||
enable_streaming=context.enable_streaming
|
||||
if hasattr(context, 'enable_streaming') else None,
|
||||
response_format=context.response_format
|
||||
if hasattr(context, 'response_format') else None,
|
||||
seed=context.seed
|
||||
if hasattr(context, 'seed') else
|
||||
context.server_seed
|
||||
|
@ -742,6 +758,7 @@ async def oai_chat_completions(user_prompt,
|
|||
model=None,
|
||||
n_predict=None,
|
||||
enable_streaming=None,
|
||||
response_format=None,
|
||||
seed=None,
|
||||
user_api_key=None,
|
||||
expect_api_error=None):
|
||||
|
@ -767,6 +784,8 @@ async def oai_chat_completions(user_prompt,
|
|||
"stream": enable_streaming,
|
||||
"seed": seed
|
||||
}
|
||||
if response_format is not None:
|
||||
payload['response_format'] = response_format
|
||||
completion_response = {
|
||||
'content': '',
|
||||
'timings': {
|
||||
|
@ -827,6 +846,7 @@ async def oai_chat_completions(user_prompt,
|
|||
model=model,
|
||||
max_tokens=n_predict,
|
||||
stream=enable_streaming,
|
||||
response_format=payload.get('response_format'),
|
||||
seed=seed
|
||||
)
|
||||
except openai.error.AuthenticationError as e:
|
||||
|
|
|
@ -373,10 +373,21 @@ static json oaicompat_completion_params_parse(
|
|||
llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z);
|
||||
llama_params["n_keep"] = json_value(body, "n_keep", 0);
|
||||
|
||||
if (body.count("grammar") != 0) {
|
||||
if (body.contains("grammar")) {
|
||||
llama_params["grammar"] = json_value(body, "grammar", json::object());
|
||||
}
|
||||
|
||||
if (body.contains("response_format")) {
|
||||
auto response_format = json_value(body, "response_format", json::object());
|
||||
if (response_format.contains("type")) {
|
||||
if (response_format["type"] == "json_object") {
|
||||
llama_params["json_schema"] = json_value(response_format, "schema", json::object());
|
||||
} else {
|
||||
throw std::runtime_error("response_format type not supported: " + response_format["type"].dump());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle 'stop' field
|
||||
if (body.contains("stop") && body["stop"].is_string()) {
|
||||
llama_params["stop"] = json::array({body["stop"].get<std::string>()});
|
||||
|
|
28
examples/ts-type-to-grammar.sh
Executable file
28
examples/ts-type-to-grammar.sh
Executable file
|
@ -0,0 +1,28 @@
|
|||
#!/bin/bash
|
||||
#
|
||||
# ./examples/ts-type-to-grammar.sh "{a:string,b:string,c?:string}"
|
||||
# python examples/json-schema-to-grammar.py https://json.schemastore.org/tsconfig.json
|
||||
#
|
||||
set -euo pipefail
|
||||
|
||||
readonly type="$1"
|
||||
|
||||
# Create a temporary directory
|
||||
TMPDIR=""
|
||||
trap 'rm -fR "$TMPDIR"' EXIT
|
||||
TMPDIR=$(mktemp -d)
|
||||
|
||||
DTS_FILE="$TMPDIR/type.d.ts"
|
||||
SCHEMA_FILE="$TMPDIR/schema.json"
|
||||
|
||||
echo "export type MyType = $type" > "$DTS_FILE"
|
||||
|
||||
# This is a fork of typescript-json-schema, actively maintained as of March 2024:
|
||||
# https://github.com/vega/ts-json-schema-generator
|
||||
npx ts-json-schema-generator --unstable --no-top-ref --path "$DTS_FILE" --type MyType -e none > "$SCHEMA_FILE"
|
||||
|
||||
# Alternative, not actively maintained as of March 2024:
|
||||
# https://github.com/YousefED/typescript-json-schema
|
||||
# npx typescript-json-schema --defaultProps --required "$DTS_FILE" MyType | tee "$SCHEMA_FILE" >&2
|
||||
|
||||
./examples/json-schema-to-grammar.py "$SCHEMA_FILE"
|
3638
ggml-cuda.cu
3638
ggml-cuda.cu
File diff suppressed because it is too large
Load diff
18
ggml-metal.m
18
ggml-metal.m
|
@ -173,8 +173,9 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
||||
//GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
|
||||
//GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CONCAT,
|
||||
|
@ -598,8 +599,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
||||
|
@ -739,6 +741,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
@ -2436,8 +2441,9 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
|
||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
|
||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
||||
//case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
|
||||
//case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
|
||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
|
||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
|
||||
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
|
||||
default: GGML_ASSERT(false && "not implemented");
|
||||
};
|
||||
} break;
|
||||
|
|
240
ggml-metal.metal
240
ggml-metal.metal
|
@ -2388,6 +2388,242 @@ kernel void kernel_cpy_f32_q4_1(
|
|||
}
|
||||
}
|
||||
|
||||
kernel void kernel_cpy_f32_q5_0(
|
||||
device const float * src0,
|
||||
device void * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig[2];
|
||||
const int64_t i02 = tgpig[1];
|
||||
const int64_t i01 = tgpig[0];
|
||||
|
||||
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0;
|
||||
|
||||
device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
|
||||
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK5_0; j++) {
|
||||
const float v = src[j];
|
||||
if (amax < fabs(v)) {
|
||||
amax = fabs(v);
|
||||
max = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = max / -16;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst_data[i00/QK5_0].d = d;
|
||||
|
||||
uint32_t qh = 0;
|
||||
for (int j = 0; j < QK5_0/2; ++j) {
|
||||
const float x0 = src[0 + j]*id;
|
||||
const float x1 = src[QK5_0/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
|
||||
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
|
||||
|
||||
dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
||||
}
|
||||
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
dst_data[i00/QK5_0].qh[j] = qh8[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_cpy_f32_q5_1(
|
||||
device const float * src0,
|
||||
device void * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig[2];
|
||||
const int64_t i02 = tgpig[1];
|
||||
const int64_t i01 = tgpig[0];
|
||||
|
||||
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
|
||||
|
||||
device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) {
|
||||
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
|
||||
float max = src[0];
|
||||
float min = src[0];
|
||||
|
||||
for (int j = 1; j < QK5_1; j++) {
|
||||
const float v = src[j];
|
||||
min = v < min ? v : min;
|
||||
max = v > max ? v : max;
|
||||
}
|
||||
|
||||
const float d = (max - min) / 31;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst_data[i00/QK5_1].d = d;
|
||||
dst_data[i00/QK5_1].m = min;
|
||||
|
||||
uint32_t qh = 0;
|
||||
for (int j = 0; j < QK5_1/2; ++j) {
|
||||
const float x0 = (src[0 + j] - min)*id;
|
||||
const float x1 = (src[QK5_1/2 + j] - min)*id;
|
||||
|
||||
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
||||
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
||||
|
||||
dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
||||
}
|
||||
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
dst_data[i00/QK5_1].qh[j] = qh8[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static inline int best_index_int8(int n, constant float * val, float x) {
|
||||
if (x <= val[0]) return 0;
|
||||
if (x >= val[n-1]) return n-1;
|
||||
int ml = 0, mu = n-1;
|
||||
while (mu-ml > 1) {
|
||||
int mav = (ml+mu)/2;
|
||||
if (x < val[mav]) mu = mav; else ml = mav;
|
||||
}
|
||||
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
||||
}
|
||||
|
||||
constexpr constant static float kvalues_iq4nl_f[16] = {
|
||||
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
||||
};
|
||||
|
||||
kernel void kernel_cpy_f32_iq4_nl(
|
||||
device const float * src0,
|
||||
device void * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig[2];
|
||||
const int64_t i02 = tgpig[1];
|
||||
const int64_t i01 = tgpig[0];
|
||||
|
||||
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL;
|
||||
|
||||
device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) {
|
||||
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK4_0; j++) {
|
||||
const float v = src[j];
|
||||
if (amax < fabs(v)) {
|
||||
amax = fabs(v);
|
||||
max = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = max / kvalues_iq4nl_f[0];
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
float sumqx = 0, sumq2 = 0;
|
||||
for (int j = 0; j < QK4_NL/2; ++j) {
|
||||
const float x0 = src[0 + j]*id;
|
||||
const float x1 = src[QK4_NL/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
|
||||
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
|
||||
|
||||
dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
|
||||
|
||||
const float v0 = kvalues_iq4nl_f[xi0];
|
||||
const float v1 = kvalues_iq4nl_f[xi1];
|
||||
const float w0 = src[0 + j]*src[0 + j];
|
||||
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
|
||||
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
|
||||
sumq2 += w0*v0*v0 + w1*v1*v1;
|
||||
|
||||
}
|
||||
|
||||
dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_concat(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
|
@ -4220,10 +4456,6 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|||
}
|
||||
}
|
||||
|
||||
constexpr constant static float kvalues_iq4nl_f[16] = {
|
||||
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
||||
};
|
||||
|
||||
void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
|
|
|
@ -11705,9 +11705,8 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
|
|||
ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
|
||||
float * scales, float * weight, uint8_t * L,
|
||||
const int8_t * values,
|
||||
const float * quant_weights) {
|
||||
|
||||
const int ntry = 7;
|
||||
const float * quant_weights,
|
||||
const int ntry) {
|
||||
|
||||
float sigma2 = 0;
|
||||
for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j];
|
||||
|
@ -11719,6 +11718,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
|
|||
float max_scale = 0, amax_scale = 0;
|
||||
for (int ib = 0; ib < super_block_size/block_size; ++ib) {
|
||||
const float * xb = x + ib*block_size;
|
||||
uint8_t * Lb = L + ib*block_size;
|
||||
if (quant_weights) {
|
||||
const float * qw = quant_weights + ib*block_size;
|
||||
for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
|
||||
|
@ -11736,12 +11736,13 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
|
|||
scales[ib] = 0;
|
||||
continue;
|
||||
}
|
||||
float d = -max/values[0];
|
||||
float d = ntry > 0 ? -max/values[0] : max/values[0];
|
||||
float id = 1/d;
|
||||
float sumqx = 0, sumq2 = 0;
|
||||
for (int j = 0; j < block_size; ++j) {
|
||||
float al = id*xb[j];
|
||||
int l = best_index_int8(16, values, al);
|
||||
Lb[j] = l;
|
||||
float q = values[l];
|
||||
float w = weight[j];
|
||||
sumqx += w*q*xb[j];
|
||||
|
@ -11796,11 +11797,13 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
|
|||
}
|
||||
} else {
|
||||
dh[0] = GGML_FP32_TO_FP16(scales[0]);
|
||||
if (ntry > 0) {
|
||||
float id = scales[0] ? 1/scales[0] : 0;
|
||||
for (int j = 0; j < super_block_size; ++j) {
|
||||
L[j] = best_index_int8(16, values, id*x[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < super_block_size/32; ++i) {
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
|
@ -11823,7 +11826,7 @@ size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int nrow
|
|||
for (int ibl = 0; ibl < nblock; ++ibl) {
|
||||
const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
|
||||
quantize_row_iq4_nl_impl(QK4_NL, 32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,
|
||||
&scale, weight, L, kvalues_iq4nl, qw);
|
||||
&scale, weight, L, kvalues_iq4nl, qw, 7);
|
||||
}
|
||||
src += n_per_row;
|
||||
qrow += nblock*sizeof(block_iq4_nl);
|
||||
|
@ -11832,14 +11835,23 @@ size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int nrow
|
|||
}
|
||||
|
||||
void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) {
|
||||
assert(k % QK4_NL == 0);
|
||||
block_iq4_nl * restrict y = vy;
|
||||
quantize_row_iq4_nl_reference(x, y, k);
|
||||
GGML_ASSERT(k%QK4_NL == 0);
|
||||
int nblock = k/QK4_NL;
|
||||
uint8_t L[QK4_NL];
|
||||
float weight[QK4_NL];
|
||||
uint16_t unused_h;
|
||||
uint8_t * unused_l = NULL;
|
||||
float scale;
|
||||
block_iq4_nl * iq4 = (block_iq4_nl *)vy;
|
||||
for (int ibl = 0; ibl < nblock; ++ibl) {
|
||||
quantize_row_iq4_nl_impl(QK4_NL, 32, x + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,
|
||||
&scale, weight, L, kvalues_iq4nl, NULL, -1);
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int k) {
|
||||
assert(k % QK4_NL == 0);
|
||||
quantize_iq4_nl(x, y, 1, k, NULL);
|
||||
quantize_row_iq4_nl(x, y, k);
|
||||
}
|
||||
|
||||
size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int nrow, int n_per_row, const float * quant_weights) {
|
||||
|
@ -11857,7 +11869,7 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int nrow
|
|||
for (int ibl = 0; ibl < nblock; ++ibl) {
|
||||
const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL;
|
||||
quantize_row_iq4_nl_impl(QK_K, 32, src + QK_K*ibl, &iq4[ibl].d, iq4[ibl].qs, &iq4[ibl].scales_h, iq4[ibl].scales_l,
|
||||
scales, weight, L, kvalues_iq4nl, qw);
|
||||
scales, weight, L, kvalues_iq4nl, qw, 7);
|
||||
}
|
||||
src += n_per_row;
|
||||
qrow += nblock*sizeof(block_iq4_xs);
|
||||
|
|
|
@ -977,8 +977,10 @@ namespace dpct
|
|||
static int convert_backend_index(std::string & backend) {
|
||||
if (backend == "ext_oneapi_level_zero:gpu") return 0;
|
||||
if (backend == "opencl:gpu") return 1;
|
||||
if (backend == "opencl:cpu") return 2;
|
||||
if (backend == "opencl:acc") return 3;
|
||||
if (backend == "ext_oneapi_cuda:gpu") return 2;
|
||||
if (backend == "ext_oneapi_hip:gpu") return 3;
|
||||
if (backend == "opencl:cpu") return 4;
|
||||
if (backend == "opencl:acc") return 5;
|
||||
printf("convert_backend_index: can't handle backend=%s\n", backend.c_str());
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
|
|
@ -1,66 +1,75 @@
|
|||
function(llama_build_executable source)
|
||||
# Builds and runs a test source file.
|
||||
# Optional args:
|
||||
# - NAME: name of the executable & test target (defaults to the source file name without extension)
|
||||
# - LABEL: label for the test (defaults to main)
|
||||
# - ARGS: arguments to pass to the test executable
|
||||
# - WORKING_DIRECTORY
|
||||
function(llama_test source)
|
||||
include(CMakeParseArguments)
|
||||
set(options)
|
||||
set(oneValueArgs NAME LABEL WORKING_DIRECTORY)
|
||||
set(multiValueArgs ARGS)
|
||||
cmake_parse_arguments(LLAMA_TEST "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
if (NOT DEFINED LLAMA_TEST_LABEL)
|
||||
set(LLAMA_TEST_LABEL "main")
|
||||
endif()
|
||||
if (NOT DEFINED LLAMA_TEST_WORKING_DIRECTORY)
|
||||
set(LLAMA_TEST_WORKING_DIRECTORY .)
|
||||
endif()
|
||||
if (DEFINED LLAMA_TEST_NAME)
|
||||
set(TEST_TARGET ${LLAMA_TEST_NAME})
|
||||
else()
|
||||
get_filename_component(TEST_TARGET ${source} NAME_WE)
|
||||
endif()
|
||||
|
||||
add_executable(${TEST_TARGET} ${source} get-model.cpp)
|
||||
install(TARGETS ${TEST_TARGET} RUNTIME)
|
||||
target_link_libraries(${TEST_TARGET} PRIVATE common)
|
||||
target_link_libraries(${TEST_TARGET} PRIVATE common json-schema-to-grammar)
|
||||
add_test(
|
||||
NAME ${TEST_TARGET}
|
||||
WORKING_DIRECTORY ${LLAMA_TEST_WORKING_DIRECTORY}
|
||||
COMMAND $<TARGET_FILE:${TEST_TARGET}>
|
||||
${LLAMA_TEST_ARGS})
|
||||
|
||||
set_property(TEST ${TEST_TARGET} PROPERTY LABELS ${LLAMA_TEST_LABEL})
|
||||
endfunction()
|
||||
|
||||
function(llama_test_executable name source)
|
||||
get_filename_component(TEST_TARGET ${source} NAME_WE)
|
||||
add_test(NAME ${name} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN})
|
||||
set_property(TEST ${name} PROPERTY LABELS "main")
|
||||
endfunction()
|
||||
# llama_test(test-double-float.cpp) # SLOW
|
||||
llama_test(test-quantize-fns.cpp)
|
||||
llama_test(test-quantize-perf.cpp)
|
||||
llama_test(test-sampling.cpp)
|
||||
llama_test(test-chat-template.cpp)
|
||||
|
||||
function(llama_build_and_test_executable source)
|
||||
llama_build_and_test_executable_with_label(${source} "main")
|
||||
endfunction()
|
||||
llama_test(test-tokenizer-0-llama.cpp NAME test-tokenizer-0-llama ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
|
||||
llama_test(test-tokenizer-0-falcon.cpp NAME test-tokenizer-0-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
|
||||
|
||||
function(llama_build_and_test_executable_with_label source label)
|
||||
get_filename_component(TEST_TARGET ${source} NAME_WE)
|
||||
add_executable(${TEST_TARGET} ${source} get-model.cpp)
|
||||
install(TARGETS ${TEST_TARGET} RUNTIME)
|
||||
target_link_libraries(${TEST_TARGET} PRIVATE common)
|
||||
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN})
|
||||
set_property(TEST ${TEST_TARGET} PROPERTY LABELS ${label})
|
||||
endfunction()
|
||||
llama_test(test-tokenizer-1-llama.cpp NAME test-tokenizer-1-llama ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
|
||||
llama_test(test-tokenizer-1-llama.cpp NAME test-tokenizer-1-baichuan ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf)
|
||||
|
||||
# llama_build_and_test_executable(test-double-float.cpp) # SLOW
|
||||
llama_build_and_test_executable(test-quantize-fns.cpp)
|
||||
llama_build_and_test_executable(test-quantize-perf.cpp)
|
||||
llama_build_and_test_executable(test-sampling.cpp)
|
||||
llama_build_and_test_executable(test-chat-template.cpp)
|
||||
llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
|
||||
llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-aquila ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
|
||||
llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-mpt ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
|
||||
llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-stablelm-3b-4e1t ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-stablelm-3b-4e1t.gguf)
|
||||
llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-gpt-neox ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-neox.gguf)
|
||||
llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
|
||||
llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
|
||||
llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-gpt2 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt2.gguf)
|
||||
#llama_test(test-tokenizer-1-bpe.cpp NAME test-tokenizer-1-bloom ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-bloom.gguf) # BIG
|
||||
|
||||
llama_build_executable(test-tokenizer-0-llama.cpp)
|
||||
llama_test_executable (test-tokenizer-0-llama test-tokenizer-0-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
|
||||
llama_test(test-grammar-parser.cpp)
|
||||
llama_test(test-llama-grammar.cpp)
|
||||
llama_test(test-grad0.cpp)
|
||||
# llama_test(test-opt.cpp) # SLOW
|
||||
llama_test(test-backend-ops.cpp)
|
||||
|
||||
llama_build_executable(test-tokenizer-0-falcon.cpp)
|
||||
llama_test_executable (test-tokenizer-0-falcon test-tokenizer-0-falcon.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
|
||||
llama_test(test-rope.cpp)
|
||||
|
||||
llama_build_executable(test-tokenizer-1-llama.cpp)
|
||||
llama_test_executable (test-tokenizer-1-llama test-tokenizer-1-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
|
||||
llama_test_executable (test-tokenizer-1-baichuan test-tokenizer-1-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf)
|
||||
llama_test(test-model-load-cancel.cpp LABEL "model")
|
||||
llama_test(test-autorelease.cpp LABEL "model")
|
||||
|
||||
llama_build_executable(test-tokenizer-1-bpe.cpp)
|
||||
llama_test_executable (test-tokenizer-1-falcon test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
|
||||
llama_test_executable (test-tokenizer-1-aquila test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
|
||||
llama_test_executable (test-tokenizer-1-mpt test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
|
||||
llama_test_executable (test-tokenizer-1-stablelm-3b-4e1t test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-stablelm-3b-4e1t.gguf)
|
||||
llama_test_executable (test-tokenizer-1-gpt-neox test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-neox.gguf)
|
||||
llama_test_executable (test-tokenizer-1-refact test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
|
||||
llama_test_executable (test-tokenizer-1-starcoder test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
|
||||
llama_test_executable (test-tokenizer-1-gpt2 test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt2.gguf)
|
||||
# llama_test_executable (test-tokenizer-1-bloom test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-bloom.gguf) # BIG
|
||||
|
||||
llama_build_and_test_executable(test-grammar-parser.cpp)
|
||||
llama_build_and_test_executable(test-llama-grammar.cpp)
|
||||
llama_build_and_test_executable(test-grad0.cpp)
|
||||
# llama_build_and_test_executable(test-opt.cpp) # SLOW
|
||||
llama_build_and_test_executable(test-backend-ops.cpp)
|
||||
|
||||
llama_build_and_test_executable(test-rope.cpp)
|
||||
|
||||
llama_build_and_test_executable_with_label(test-model-load-cancel.cpp "model")
|
||||
llama_build_and_test_executable_with_label(test-autorelease.cpp "model")
|
||||
llama_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..)
|
||||
target_include_directories(test-json-schema-to-grammar PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../examples/server)
|
||||
|
||||
# dummy executable - not installed
|
||||
get_filename_component(TEST_TARGET test-c.c NAME_WE)
|
||||
|
|
10
tests/run-json-schema-to-grammar.mjs
Normal file
10
tests/run-json-schema-to-grammar.mjs
Normal file
|
@ -0,0 +1,10 @@
|
|||
import { readFileSync } from "fs"
|
||||
import { SchemaConverter } from "../examples/server/public/json-schema-to-grammar.mjs"
|
||||
|
||||
const [, , file] = process.argv
|
||||
const url = `file://${file}`
|
||||
let schema = JSON.parse(readFileSync(file, "utf8"));
|
||||
const converter = new SchemaConverter({})
|
||||
schema = await converter.resolveRefs(schema, url)
|
||||
converter.visit(schema, '')
|
||||
console.log(converter.formatGrammar())
|
824
tests/test-json-schema-to-grammar.cpp
Executable file
824
tests/test-json-schema-to-grammar.cpp
Executable file
|
@ -0,0 +1,824 @@
|
|||
#ifdef NDEBUG
|
||||
#undef NDEBUG
|
||||
#endif
|
||||
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <regex>
|
||||
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "grammar-parser.h"
|
||||
|
||||
static std::string trim(const std::string & source) {
|
||||
std::string s(source);
|
||||
s.erase(0,s.find_first_not_of(" \n\r\t"));
|
||||
s.erase(s.find_last_not_of(" \n\r\t")+1);
|
||||
return std::regex_replace(s, std::regex("(^|\n)[ \t]+"), "$1");
|
||||
}
|
||||
|
||||
enum TestCaseStatus {
|
||||
SUCCESS,
|
||||
FAILURE
|
||||
};
|
||||
|
||||
struct TestCase {
|
||||
TestCaseStatus expected_status;
|
||||
std::string name;
|
||||
std::string schema;
|
||||
std::string expected_grammar;
|
||||
|
||||
void _print_failure_header() const {
|
||||
fprintf(stderr, "#\n# Test '%s' failed.\n#\n%s\n", name.c_str(), schema.c_str());
|
||||
}
|
||||
void verify(const std::string & actual_grammar) const {
|
||||
if (trim(actual_grammar) != trim(expected_grammar)) {
|
||||
_print_failure_header();
|
||||
fprintf(stderr, "# EXPECTED:\n%s\n# ACTUAL:\n%s\n", expected_grammar.c_str(), actual_grammar.c_str());
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
void verify_expectation_parseable() const {
|
||||
try {
|
||||
auto state = grammar_parser::parse(expected_grammar.c_str());
|
||||
if (state.symbol_ids.find("root") == state.symbol_ids.end()) {
|
||||
throw std::runtime_error("Grammar failed to parse:\n" + expected_grammar);
|
||||
}
|
||||
} catch (const std::runtime_error & ex) {
|
||||
_print_failure_header();
|
||||
fprintf(stderr, "# GRAMMAR ERROR: %s\n", ex.what());
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
void verify_status(TestCaseStatus status) const {
|
||||
if (status != expected_status) {
|
||||
_print_failure_header();
|
||||
fprintf(stderr, "# EXPECTED STATUS: %s\n", expected_status == SUCCESS ? "SUCCESS" : "FAILURE");
|
||||
fprintf(stderr, "# ACTUAL STATUS: %s\n", status == SUCCESS ? "SUCCESS" : "FAILURE");
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static void write(const std::string & file, const std::string & content) {
|
||||
std::ofstream f;
|
||||
f.open(file.c_str());
|
||||
f << content.c_str();
|
||||
f.close();
|
||||
}
|
||||
|
||||
static std::string read(const std::string & file) {
|
||||
std::ostringstream actuals;
|
||||
actuals << std::ifstream(file.c_str()).rdbuf();
|
||||
return actuals.str();
|
||||
}
|
||||
|
||||
static void test_all(const std::string & lang, std::function<void(const TestCase &)> runner) {
|
||||
fprintf(stderr, "#\n# Testing JSON schema conversion (%s)\n#\n", lang.c_str());
|
||||
auto test = [&](const TestCase & tc) {
|
||||
fprintf(stderr, "- %s%s\n", tc.name.c_str(), tc.expected_status == FAILURE ? " (failure expected)" : "");
|
||||
runner(tc);
|
||||
};
|
||||
|
||||
test({
|
||||
FAILURE,
|
||||
"unknown type",
|
||||
R"""({
|
||||
"type": "kaboom"
|
||||
})""",
|
||||
""
|
||||
});
|
||||
|
||||
test({
|
||||
FAILURE,
|
||||
"invalid type type",
|
||||
R"""({
|
||||
"type": 123
|
||||
})""",
|
||||
""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"empty schema (object)",
|
||||
"{}",
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
null ::= "null" space
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= object
|
||||
space ::= " "?
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
value ::= object | array | string | number | boolean
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"exotic formats",
|
||||
R"""({
|
||||
"items": [
|
||||
{ "format": "date" },
|
||||
{ "format": "uuid" },
|
||||
{ "format": "time" },
|
||||
{ "format": "date-time" }
|
||||
]
|
||||
})""",
|
||||
R"""(
|
||||
date ::= [0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( "0" [1-9] | [1-2] [0-9] | "3" [0-1] )
|
||||
date-string ::= "\"" date "\"" space
|
||||
date-time ::= date "T" time
|
||||
date-time-string ::= "\"" date-time "\"" space
|
||||
root ::= "[" space date-string "," space uuid "," space time-string "," space date-time-string "]" space
|
||||
space ::= " "?
|
||||
time ::= ([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )
|
||||
time-string ::= "\"" time "\"" space
|
||||
uuid ::= "\"" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "-" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "-" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "-" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "-" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] "\"" space
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"string",
|
||||
R"""({
|
||||
"type": "string"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
space ::= " "?
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"boolean",
|
||||
R"""({
|
||||
"type": "boolean"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("true" | "false") space
|
||||
space ::= " "?
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"integer",
|
||||
R"""({
|
||||
"type": "integer"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-"? ([0-9] | [1-9] [0-9]*)) space
|
||||
space ::= " "?
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"string const",
|
||||
R"""({
|
||||
"const": "foo"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"foo\""
|
||||
space ::= " "?
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
FAILURE,
|
||||
"non-string const",
|
||||
R"""({
|
||||
"const": 123
|
||||
})""",
|
||||
""
|
||||
});
|
||||
|
||||
test({
|
||||
FAILURE,
|
||||
"non-string enum",
|
||||
R"""({
|
||||
"enum": [123]
|
||||
})""",
|
||||
""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"tuple1",
|
||||
R"""({
|
||||
"prefixItems": [{ "type": "string" }]
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "[" space string "]" space
|
||||
space ::= " "?
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"tuple2",
|
||||
R"""({
|
||||
"prefixItems": [{ "type": "string" }, { "type": "number" }]
|
||||
})""",
|
||||
R"""(
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
|
||||
root ::= "[" space string "," space number "]" space
|
||||
space ::= " "?
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"number",
|
||||
R"""({
|
||||
"type": "number"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
|
||||
space ::= " "?
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"minItems",
|
||||
R"""({
|
||||
"items": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"minItems": 2
|
||||
})""",
|
||||
R"""(
|
||||
boolean ::= ("true" | "false") space
|
||||
root ::= "[" space boolean ( "," space boolean )( "," space boolean )* "]" space
|
||||
space ::= " "?
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"maxItems 1",
|
||||
R"""({
|
||||
"items": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"maxItems": 1
|
||||
})""",
|
||||
R"""(
|
||||
boolean ::= ("true" | "false") space
|
||||
root ::= "[" space ( boolean )? "]" space
|
||||
space ::= " "?
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"maxItems 2",
|
||||
R"""({
|
||||
"items": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"maxItems": 2
|
||||
})""",
|
||||
R"""(
|
||||
boolean ::= ("true" | "false") space
|
||||
root ::= "[" space ( boolean ( "," space boolean )? )? "]" space
|
||||
space ::= " "?
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"min + maxItems",
|
||||
R"""({
|
||||
"items": {
|
||||
"type": ["number", "integer"]
|
||||
},
|
||||
"minItems": 3,
|
||||
"maxItems": 5
|
||||
})""",
|
||||
R"""(
|
||||
integer ::= ("-"? ([0-9] | [1-9] [0-9]*)) space
|
||||
item ::= number | integer
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
|
||||
root ::= "[" space item ( "," space item )( "," space item )( "," space item )?( "," space item )? "]" space
|
||||
space ::= " "?
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"simple regexp",
|
||||
R"""({
|
||||
"type": "string",
|
||||
"pattern": "^abc?d*efg+(hij)?kl$"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" "ab" "c"? "d"* "ef" "g"+ ("hij")? "kl" "\"" space
|
||||
space ::= " "?
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"regexp escapes",
|
||||
R"""({
|
||||
"type": "string",
|
||||
"pattern": "^\\[\\]\\{\\}\\(\\)\\|\\+\\*\\?$"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" "[]{}()|+*?" "\"" space
|
||||
space ::= " "?
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"regexp quote",
|
||||
R"""({
|
||||
"type": "string",
|
||||
"pattern": "^\"$"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" "\"" "\"" space
|
||||
space ::= " "?
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"regexp",
|
||||
R"""({
|
||||
"type": "string",
|
||||
"pattern": "^(\\([0-9]{1,3}\\))?[0-9]{3}-[0-9]{4} and...$"
|
||||
})""",
|
||||
R"""(
|
||||
dot ::= [\U00000000-\x09\x0B\x0C\x0E-\U0010FFFF]
|
||||
root ::= "\"" ("(" root-1 root-1? root-1? ")")? root-1 root-1 root-1 "-" root-1 root-1 root-1 root-1 " and" dot dot dot "\"" space
|
||||
root-1 ::= [0-9]
|
||||
space ::= " "?
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"required props",
|
||||
R"""({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {
|
||||
"type": "string"
|
||||
},
|
||||
"b": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"a",
|
||||
"b"
|
||||
],
|
||||
"additionalProperties": false,
|
||||
"definitions": {}
|
||||
})""",
|
||||
R"""(
|
||||
a-kv ::= "\"a\"" space ":" space string
|
||||
b-kv ::= "\"b\"" space ":" space string
|
||||
root ::= "{" space a-kv "," space b-kv "}" space
|
||||
space ::= " "?
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"1 optional prop",
|
||||
R"""({
|
||||
"properties": {
|
||||
"a": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
})""",
|
||||
R"""(
|
||||
a-kv ::= "\"a\"" space ":" space string
|
||||
root ::= "{" space (a-kv )? "}" space
|
||||
space ::= " "?
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"N optional props",
|
||||
R"""({
|
||||
"properties": {
|
||||
"a": {"type": "string"},
|
||||
"b": {"type": "string"},
|
||||
"c": {"type": "string"}
|
||||
},
|
||||
"additionalProperties": false
|
||||
})""",
|
||||
R"""(
|
||||
a-kv ::= "\"a\"" space ":" space string
|
||||
a-rest ::= ( "," space b-kv )? b-rest
|
||||
b-kv ::= "\"b\"" space ":" space string
|
||||
b-rest ::= ( "," space c-kv )?
|
||||
c-kv ::= "\"c\"" space ":" space string
|
||||
root ::= "{" space (a-kv a-rest | b-kv b-rest | c-kv )? "}" space
|
||||
space ::= " "?
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"required + optional props",
|
||||
R"""({
|
||||
"properties": {
|
||||
"a": {"type": "string"},
|
||||
"b": {"type": "string"},
|
||||
"c": {"type": "string"},
|
||||
"d": {"type": "string"}
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
"additionalProperties": false
|
||||
})""",
|
||||
R"""(
|
||||
a-kv ::= "\"a\"" space ":" space string
|
||||
b-kv ::= "\"b\"" space ":" space string
|
||||
c-kv ::= "\"c\"" space ":" space string
|
||||
c-rest ::= ( "," space d-kv )?
|
||||
d-kv ::= "\"d\"" space ":" space string
|
||||
root ::= "{" space a-kv "," space b-kv ( "," space ( c-kv c-rest | d-kv ) )? "}" space
|
||||
space ::= " "?
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"additional props",
|
||||
R"""({
|
||||
"type": "object",
|
||||
"additionalProperties": {"type": "array", "items": {"type": "number"}}
|
||||
})""",
|
||||
R"""(
|
||||
additional-kv ::= string ":" space additional-value
|
||||
additional-kvs ::= additional-kv ( "," space additional-kv )*
|
||||
additional-value ::= "[" space ( number ( "," space number )* )? "]" space
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
|
||||
root ::= "{" space (additional-kvs )? "}" space
|
||||
space ::= " "?
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"additional props (true)",
|
||||
R"""({
|
||||
"type": "object",
|
||||
"additionalProperties": true
|
||||
})""",
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
null ::= "null" space
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= object
|
||||
space ::= " "?
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
value ::= object | array | string | number | boolean
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"additional props (implicit)",
|
||||
R"""({
|
||||
"type": "object"
|
||||
})""",
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
null ::= "null" space
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= object
|
||||
space ::= " "?
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
value ::= object | array | string | number | boolean
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"empty w/o additional props",
|
||||
R"""({
|
||||
"type": "object",
|
||||
"additionalProperties": false
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "{" space "}" space
|
||||
space ::= " "?
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"required + additional props",
|
||||
R"""({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number"}
|
||||
},
|
||||
"required": ["a"],
|
||||
"additionalProperties": {"type": "string"}
|
||||
})""",
|
||||
R"""(
|
||||
a-kv ::= "\"a\"" space ":" space number
|
||||
additional-kv ::= string ":" space string
|
||||
additional-kvs ::= additional-kv ( "," space additional-kv )*
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
|
||||
root ::= "{" space a-kv ( "," space ( additional-kvs ) )? "}" space
|
||||
space ::= " "?
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"optional + additional props",
|
||||
R"""({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number"}
|
||||
},
|
||||
"additionalProperties": {"type": "number"}
|
||||
})""",
|
||||
R"""(
|
||||
a-kv ::= "\"a\"" space ":" space number
|
||||
a-rest ::= additional-kvs
|
||||
additional-kv ::= string ":" space number
|
||||
additional-kvs ::= additional-kv ( "," space additional-kv )*
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
|
||||
root ::= "{" space (a-kv a-rest | additional-kvs )? "}" space
|
||||
space ::= " "?
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"required + optional + additional props",
|
||||
R"""({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number"},
|
||||
"b": {"type": "number"}
|
||||
},
|
||||
"required": ["a"],
|
||||
"additionalProperties": {"type": "number"}
|
||||
})""",
|
||||
R"""(
|
||||
a-kv ::= "\"a\"" space ":" space number
|
||||
additional-kv ::= string ":" space number
|
||||
additional-kvs ::= additional-kv ( "," space additional-kv )*
|
||||
b-kv ::= "\"b\"" space ":" space number
|
||||
b-rest ::= additional-kvs
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
|
||||
root ::= "{" space a-kv ( "," space ( b-kv b-rest | additional-kvs ) )? "}" space
|
||||
space ::= " "?
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"top-level $ref",
|
||||
R"""({
|
||||
"$ref": "#/definitions/MyType",
|
||||
"definitions": {
|
||||
"MyType": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"a"
|
||||
],
|
||||
"additionalProperties": false
|
||||
}
|
||||
}
|
||||
})""",
|
||||
R"""(
|
||||
MyType ::= "{" space MyType-a-kv "}" space
|
||||
MyType-a-kv ::= "\"a\"" space ":" space string
|
||||
root ::= MyType
|
||||
space ::= " "?
|
||||
string ::= "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"anyOf",
|
||||
R"""({
|
||||
"anyOf": [
|
||||
{"$ref": "#/definitions/foo"},
|
||||
{"$ref": "#/definitions/bar"}
|
||||
],
|
||||
"definitions": {
|
||||
"foo": {
|
||||
"properties": {"a": {"type": "number"}}
|
||||
},
|
||||
"bar": {
|
||||
"properties": {"b": {"type": "number"}}
|
||||
}
|
||||
},
|
||||
"type": "object"
|
||||
})""",
|
||||
R"""(
|
||||
alternative-0 ::= foo
|
||||
alternative-1 ::= bar
|
||||
bar ::= "{" space (bar-b-kv )? "}" space
|
||||
bar-b-kv ::= "\"b\"" space ":" space number
|
||||
foo ::= "{" space (foo-a-kv )? "}" space
|
||||
foo-a-kv ::= "\"a\"" space ":" space number
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
|
||||
root ::= alternative-0 | alternative-1
|
||||
space ::= " "?
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"mix of allOf, anyOf and $ref (similar to https://json.schemastore.org/tsconfig.json)",
|
||||
R"""({
|
||||
"allOf": [
|
||||
{"$ref": "#/definitions/foo"},
|
||||
{"$ref": "#/definitions/bar"},
|
||||
{
|
||||
"anyOf": [
|
||||
{"$ref": "#/definitions/baz"},
|
||||
{"$ref": "#/definitions/bam"}
|
||||
]
|
||||
}
|
||||
],
|
||||
"definitions": {
|
||||
"foo": {
|
||||
"properties": {"a": {"type": "number"}}
|
||||
},
|
||||
"bar": {
|
||||
"properties": {"b": {"type": "number"}}
|
||||
},
|
||||
"bam": {
|
||||
"properties": {"c": {"type": "number"}}
|
||||
},
|
||||
"baz": {
|
||||
"properties": {"d": {"type": "number"}}
|
||||
}
|
||||
},
|
||||
"type": "object"
|
||||
})""",
|
||||
R"""(
|
||||
a-kv ::= "\"a\"" space ":" space number
|
||||
b-kv ::= "\"b\"" space ":" space number
|
||||
c-kv ::= "\"c\"" space ":" space number
|
||||
d-kv ::= "\"d\"" space ":" space number
|
||||
d-rest ::= ( "," space c-kv )?
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
|
||||
root ::= "{" space a-kv "," space b-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
|
||||
space ::= " "?
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"conflicting names",
|
||||
R"""({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"number": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"number": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"root": {
|
||||
"type": "number"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"root"
|
||||
],
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"number"
|
||||
],
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"number"
|
||||
],
|
||||
"additionalProperties": false,
|
||||
"definitions": {}
|
||||
})""",
|
||||
R"""(
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
|
||||
number- ::= "{" space number-number-kv "}" space
|
||||
number-kv ::= "\"number\"" space ":" space number-
|
||||
number-number ::= "{" space number-number-root-kv "}" space
|
||||
number-number-kv ::= "\"number\"" space ":" space number-number
|
||||
number-number-root-kv ::= "\"root\"" space ":" space number
|
||||
root ::= "{" space number-kv "}" space
|
||||
space ::= " "?
|
||||
)"""
|
||||
});
|
||||
}
|
||||
|
||||
int main() {
|
||||
test_all("C++", [](const TestCase & tc) {
|
||||
try {
|
||||
tc.verify(json_schema_to_grammar(nlohmann::json::parse(tc.schema)));
|
||||
tc.verify_status(SUCCESS);
|
||||
} catch (const std::runtime_error & ex) {
|
||||
fprintf(stderr, "Error: %s\n", ex.what());
|
||||
tc.verify_status(FAILURE);
|
||||
}
|
||||
});
|
||||
test_all("Python", [](const TestCase & tc) {
|
||||
write("test-json-schema-input.tmp", tc.schema);
|
||||
tc.verify_status(std::system(
|
||||
"python ./examples/json-schema-to-grammar.py test-json-schema-input.tmp > test-grammar-output.tmp") == 0 ? SUCCESS : FAILURE);
|
||||
tc.verify(read("test-grammar-output.tmp"));
|
||||
});
|
||||
test_all("JavaScript", [](const TestCase & tc) {
|
||||
write("test-json-schema-input.tmp", tc.schema);
|
||||
tc.verify_status(std::system(
|
||||
"node ./tests/run-json-schema-to-grammar.mjs test-json-schema-input.tmp > test-grammar-output.tmp") == 0 ? SUCCESS : FAILURE);
|
||||
tc.verify(read("test-grammar-output.tmp"));
|
||||
});
|
||||
|
||||
test_all("Check Expectations Validity", [](const TestCase & tc) {
|
||||
if (tc.expected_status == SUCCESS) {
|
||||
tc.verify_expectation_parseable();
|
||||
}
|
||||
});
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue