From 3e3357fd77bcf5bd8cfcc53ca53d4a5532e67e1b Mon Sep 17 00:00:00 2001 From: tc-mb <157115220+tc-mb@users.noreply.github.com> Date: Wed, 22 Jan 2025 15:35:48 +0800 Subject: [PATCH 1/2] llava : support Minicpm-omni (#11289) * init * add readme * update readme * no use make * update readme * update fix code * fix editorconfig-checker * no change convert py * use clip_image_u8_free --- examples/llava/README-minicpmo2.6.md | 46 +++++++++++++++++++ examples/llava/clip.cpp | 29 +++++++++++- examples/llava/llava.cpp | 13 ++---- examples/llava/minicpmv-cli.cpp | 10 +++- .../minicpmv-convert-image-encoder-to-gguf.py | 15 ++++-- examples/llava/minicpmv-surgery.py | 2 +- 6 files changed, 100 insertions(+), 15 deletions(-) create mode 100644 examples/llava/README-minicpmo2.6.md diff --git a/examples/llava/README-minicpmo2.6.md b/examples/llava/README-minicpmo2.6.md new file mode 100644 index 000000000..8713a43d6 --- /dev/null +++ b/examples/llava/README-minicpmo2.6.md @@ -0,0 +1,46 @@ +## MiniCPM-o 2.6 +Currently, this readme only supports minicpm-omni's image capabilities, and we will update the full-mode support as soon as possible. + +### Prepare models and code + +Download [MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6) PyTorch model from huggingface to "MiniCPM-o-2_6" folder. + +Clone llama.cpp: +```bash +git clone git@github.com:OpenBMB/llama.cpp.git +cd llama.cpp +git checkout minicpm-omni +``` + +### Usage of MiniCPM-o 2.6 + +Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-o-2_6-gguf) by us) + +```bash +python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-o-2_6 +python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-o-2_6 --minicpmv-projector ../MiniCPM-o-2_6/minicpmv.projector --output-dir ../MiniCPM-o-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 4 +python ./convert_hf_to_gguf.py ../MiniCPM-o-2_6/model + +# quantize int4 version +./llama-quantize ../MiniCPM-o-2_6/model/ggml-model-f16.gguf ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M +``` + +Build llama.cpp using `CMake`: +https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md + +```bash +cmake -B build +cmake --build build --config Release +``` + +Inference on Linux or Mac +``` +# run f16 version +./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" + +# run quantized int4 version +./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" + +# or run in interactive mode +./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i +``` diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 7a8a3156b..24073c5a9 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -718,6 +718,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 else if (ctx->minicpmv_version == 3) { pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); } + else if (ctx->minicpmv_version == 4) { + pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); + } ggml_set_name(pos_embed, "pos_embed"); ggml_set_input(pos_embed); } @@ -1053,6 +1056,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 n_head = hidden_size/d_head; num_query = 64; } + else if (ctx->minicpmv_version == 4) { + hidden_size = 3584; + n_head = hidden_size/d_head; + num_query = 64; + } struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); @@ -2041,6 +2049,7 @@ static std::vector> uhd_slice_image(const clip_imag images[images.size()-1].push_back(patch); } } + clip_image_u8_free(refine_image); } return images; } @@ -2079,6 +2088,13 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli clip_image_f32_free(res); } } + for (size_t i = 0; i < imgs.size(); ++i) { + for (size_t j = 0; j < imgs[i].size(); ++j) { + if (imgs[i][j] != nullptr) { + clip_image_u8_free(imgs[i][j]); + } + } + } return true; } else if (ctx->has_qwen2vl_merger) { @@ -2335,6 +2351,9 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i else if (ctx->minicpmv_version == 3) { n_patches = 64; } + else if (ctx->minicpmv_version == 4) { + n_patches = 64; + } } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { int patch_size = params.patch_size * 2; int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); @@ -2514,8 +2533,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); int* positions_data = (int*)malloc(ggml_nbytes(positions)); - int bucket_coords_h[70]; - int bucket_coords_w[70]; + int bucket_coords_h[1024]; + int bucket_coords_w[1024]; for (int i = 0; i < pos_h; i++){ bucket_coords_h[i] = std::floor(70.0*i/pos_h); } @@ -2543,6 +2562,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima else if (ctx->minicpmv_version == 3) { embed_dim = 3584; } + else if (ctx->minicpmv_version == 4) { + embed_dim = 3584; + } auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed)); @@ -2786,6 +2808,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { else if (ctx->minicpmv_version == 3) { return 3584; } + else if (ctx->minicpmv_version == 4) { + return 3584; + } } if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { return ctx->vision_model.mm_1_b->ne[0]; diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index c598caf3d..2cac7933d 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -216,7 +216,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector return true; } -static clip_image_f32 * only_v2_5_reshape_by_patch(clip_image_f32 * image, int patch_size) { +static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size) { int width = image->nx; int height = image->ny; int num_patches = (height / patch_size) * (width / patch_size); @@ -277,13 +277,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); } else { - int has_minicpmv_projector = clip_is_minicpmv(ctx_clip); - if (has_minicpmv_projector == 2) { - encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); - } - else if (has_minicpmv_projector == 3) { - encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); - } + encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); } if (!encoded) { @@ -313,6 +307,9 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli load_image_size->height = img->ny; clip_add_load_image_size(ctx_clip, load_image_size); LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height); + delete[] img_res_v.data; + img_res_v.size = 0; + img_res_v.data = nullptr; } else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { // flat / default llava-1.5 type embedding diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index 38c44e130..53d902d61 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -140,6 +140,9 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e else if (has_minicpmv_projector == 3) { system_prompt = "<|im_start|>user\n"; } + else if (has_minicpmv_projector == 4) { + system_prompt = "<|im_start|>user\n"; + } LOG_INF("%s: image token past: %d\n", __func__, n_past); eval_string(ctx_llava->ctx_llama, (system_prompt+"").c_str(), params->n_batch, &n_past, false); process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++); @@ -227,6 +230,9 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm else if (has_minicpmv_projector == 3) { user_prompt = "<|im_start|>user\n" + prompt; } + else if (has_minicpmv_projector == 4) { + user_prompt = "<|im_start|>user\n" + prompt; + } } eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false); @@ -236,6 +242,9 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm else if (has_minicpmv_projector == 3) { eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false); } + else if (has_minicpmv_projector == 4) { + eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false); + } // generate the response @@ -308,7 +317,6 @@ int main(int argc, char ** argv) { const auto * tmp = llama_loop(ctx_llava, smpl, n_past); response += tmp; if (strcmp(tmp, "") == 0) break; - if (strstr(tmp, "###")) break; // Yi-VL behavior printf("%s", tmp);// mistral llava-1.6 if (strstr(response.c_str(), "")) break; // minicpm-v fflush(stdout); diff --git a/examples/llava/minicpmv-convert-image-encoder-to-gguf.py b/examples/llava/minicpmv-convert-image-encoder-to-gguf.py index ea773742a..9b196757f 100644 --- a/examples/llava/minicpmv-convert-image-encoder-to-gguf.py +++ b/examples/llava/minicpmv-convert-image-encoder-to-gguf.py @@ -501,7 +501,7 @@ default_image_mean = [0.48145466, 0.4578275, 0.40821073] default_image_std = [0.26862954, 0.26130258, 0.27577711] ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) -ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3', default=2) +ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3; MiniCPM-o-2.6 use 4', default=2) # with proper args = ap.parse_args() @@ -545,12 +545,19 @@ if args.use_f32: minicpmv_version = args.minicpmv_version emb_dim = 4096 +block_count = 26 if minicpmv_version == 1: emb_dim = 2304 + block_count = 26 elif minicpmv_version == 2: emb_dim = 4096 + block_count = 27 elif minicpmv_version == 3: emb_dim = 3584 + block_count = 27 +elif minicpmv_version == 4: + emb_dim = 3584 + block_count = 27 default_vision_config = { "hidden_size": 1152, @@ -567,6 +574,9 @@ model = Idefics2VisionTransformer(vision_config) if minicpmv_version == 3: vision_config = SiglipVisionConfig(**default_vision_config) model = SiglipVisionTransformer(vision_config) +elif minicpmv_version == 4: + vision_config = SiglipVisionConfig(**default_vision_config) + model = SiglipVisionTransformer(vision_config) processor = None # if model.attn_pool is not None: @@ -587,7 +597,7 @@ elif args.minicpmv_projector is not None: fname_middle = "mmproj-" has_text_encoder = False has_minicpmv_projector = True - minicpmv_version = 3 + minicpmv_version = 4 elif args.vision_only: fname_middle = "vision-" has_text_encoder = False @@ -625,7 +635,6 @@ if has_vision_encoder: fout.add_uint32("clip.vision.projection_dim", 0) fout.add_uint32(add_key_str(KEY_ATTENTION_HEAD_COUNT, VISION), 16) fout.add_float32(add_key_str(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) - block_count = 26 fout.add_uint32(add_key_str(KEY_BLOCK_COUNT, VISION), block_count) if processor is not None: diff --git a/examples/llava/minicpmv-surgery.py b/examples/llava/minicpmv-surgery.py index 748ff5c57..ba8211658 100644 --- a/examples/llava/minicpmv-surgery.py +++ b/examples/llava/minicpmv-surgery.py @@ -8,7 +8,7 @@ ap.add_argument("-m", "--model", help="Path to MiniCPM-V model") args = ap.parse_args() # find the model part that includes the the multimodal projector weights -model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True) +model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True, torch_dtype=torch.bfloat16) checkpoint = model.state_dict() # get a list of mm tensor names From a94f3b2727e97eb6c904006eb786960c069282bc Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Wed, 22 Jan 2025 09:51:44 +0000 Subject: [PATCH 2/2] `common`: utils to split / join / repeat strings (from json converter) (#11342) * Factor string_join, string_split, string_repeat into common * json: refactor to surface a versatile builder * Update common.cpp --- common/common.cpp | 42 +++++++++++++ common/common.h | 4 ++ common/json-schema-to-grammar.cpp | 98 +++++++++++-------------------- common/json-schema-to-grammar.h | 10 +++- 4 files changed, 90 insertions(+), 64 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 727ab0a10..6dea8e3d2 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -484,6 +484,48 @@ void string_replace_all(std::string & s, const std::string & search, const std:: s = std::move(builder); } +std::string string_join(const std::vector & values, const std::string & separator) { + std::ostringstream result; + for (size_t i = 0; i < values.size(); ++i) { + if (i > 0) { + result << separator; + } + result << values[i]; + } + return result.str(); +} + +std::vector string_split(const std::string & str, const std::string & delimiter) { + std::vector parts; + size_t start = 0; + size_t end = str.find(delimiter); + + while (end != std::string::npos) { + parts.push_back(str.substr(start, end - start)); + start = end + delimiter.length(); + end = str.find(delimiter, start); + } + + parts.push_back(str.substr(start)); + + return parts; +} + +std::string string_repeat(const std::string & str, size_t n) { + if (n == 0) { + return ""; + } + + std::string result; + result.reserve(str.length() * n); + + for (size_t i = 0; i < n; ++i) { + result += str; + } + + return result; +} + std::string string_from(bool value) { return value ? "true" : "false"; } diff --git a/common/common.h b/common/common.h index 7c9d73ce1..571260372 100644 --- a/common/common.h +++ b/common/common.h @@ -429,6 +429,10 @@ std::string string_format(const char * fmt, ...); std::string string_strip(const std::string & str); std::string string_get_sortable_timestamp(); +std::string string_join(const std::vector & values, const std::string & separator); +std::vector string_split(const std::string & str, const std::string & delimiter); +std::string string_repeat(const std::string & str, size_t n); + void string_replace_all(std::string & s, const std::string & search, const std::string & replace); template diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index dadc18c8b..4d426b6bd 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -1,4 +1,6 @@ #include "json-schema-to-grammar.h" +#include "common.h" + #include #include #include @@ -11,11 +13,6 @@ using json = nlohmann::ordered_json; -template -static std::string join(Iterator begin, Iterator end, const std::string & separator); - -static std::string repeat(const std::string & str, size_t n); - static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") { auto has_max = max_items != std::numeric_limits::max(); @@ -128,8 +125,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & if (sub_len > 0) { auto from_sub = from.substr(i + 1); auto to_sub = to.substr(i + 1); - auto sub_zeros = repeat("0", sub_len); - auto sub_nines = repeat("9", sub_len); + auto sub_zeros = string_repeat("0", sub_len); + auto sub_nines = string_repeat("9", sub_len); auto to_reached = false; out << "("; @@ -188,8 +185,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & auto max_digits = max_s.length(); for (auto digits = min_digits; digits < max_digits; digits++) { - uniform_range(min_s, repeat("9", digits)); - min_s = "1" + repeat("0", digits); + uniform_range(min_s, string_repeat("9", digits)); + min_s = "1" + string_repeat("0", digits); out << " | "; } uniform_range(min_s, max_s); @@ -318,49 +315,6 @@ std::unordered_map GRAMMAR_LITERAL_ESCAPES = { std::unordered_set NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; std::unordered_set ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'}; -template -std::string join(Iterator begin, Iterator end, const std::string & separator) { - std::ostringstream result; - if (begin != end) { - result << *begin; - for (Iterator it = begin + 1; it != end; ++it) { - result << separator << *it; - } - } - return result.str(); -} - -static std::vector split(const std::string & str, const std::string & delimiter) { - std::vector tokens; - size_t start = 0; - size_t end = str.find(delimiter); - - while (end != std::string::npos) { - tokens.push_back(str.substr(start, end - start)); - start = end + delimiter.length(); - end = str.find(delimiter, start); - } - - tokens.push_back(str.substr(start)); - - return tokens; -} - -static std::string repeat(const std::string & str, size_t n) { - if (n == 0) { - return ""; - } - - std::string result; - result.reserve(str.length() * n); - - for (size_t i = 0; i < n; ++i) { - result += str; - } - - return result; -} - static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function & replacement) { std::smatch match; std::string result; @@ -389,6 +343,7 @@ static std::string format_literal(const std::string & literal) { class SchemaConverter { private: + friend std::string build_grammar(const std::function & cb); std::function _fetch_json; bool _dotall; std::map _rules; @@ -418,7 +373,7 @@ private: for (size_t i = 0; i < alt_schemas.size(); i++) { rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i))); } - return join(rules.begin(), rules.end(), " | "); + return string_join(rules, " | "); } std::string _visit_pattern(const std::string & pattern, const std::string & name) { @@ -481,7 +436,7 @@ private: for (const auto & item : ret) { results.push_back(to_rule(item)); } - return std::make_pair(join(results.begin(), results.end(), " "), false); + return std::make_pair(string_join(results, " "), false); }; while (i < length) { @@ -539,7 +494,7 @@ private: } curly_brackets += '}'; i++; - auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ","); + auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ","); int min_times = 0; int max_times = std::numeric_limits::max(); try { @@ -854,7 +809,7 @@ public: return; } std::string pointer = ref.substr(ref.find('#') + 1); - std::vector tokens = split(pointer, "/"); + std::vector tokens = string_split(pointer, "/"); for (size_t i = 1; i < tokens.size(); ++i) { std::string sel = tokens[i]; if (target.is_null() || !target.contains(sel)) { @@ -905,7 +860,7 @@ public: for (const auto & v : schema["enum"]) { enum_values.push_back(_generate_constant_rule(v)); } - return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space"); + return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space"); } else if ((schema_type.is_null() || schema_type == "object") && (schema.contains("properties") || (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) { @@ -1019,10 +974,10 @@ public: void check_errors() { if (!_errors.empty()) { - throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n")); + throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n")); } if (!_warnings.empty()) { - fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str()); + fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str()); } } @@ -1036,10 +991,27 @@ public: }; std::string json_schema_to_grammar(const json & schema) { - SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false); - auto copy = schema; - converter.resolve_refs(copy, "input"); - converter.visit(copy, ""); + return build_grammar([&](const llama_grammar_builder & callbacks) { + auto copy = schema; + callbacks.resolve_refs(copy); + callbacks.add_schema("", copy); + }); +} + +std::string build_grammar(const std::function & cb) { + SchemaConverter converter([&](const std::string &) { return json(); }, /* dotall= */ false); + llama_grammar_builder builder { + /* .add_rule = */ [&](const std::string & name, const std::string & rule) { + return converter._add_rule(name, rule); + }, + /* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) { + return converter.visit(schema, name == "root" ? "" : name); + }, + /* .resolve_refs = */ [&](nlohmann::ordered_json & schema) { + converter.resolve_refs(schema, ""); + } + }; + cb(builder); converter.check_errors(); return converter.format_grammar(); } diff --git a/common/json-schema-to-grammar.h b/common/json-schema-to-grammar.h index 41623b346..4f43ab3a5 100644 --- a/common/json-schema-to-grammar.h +++ b/common/json-schema-to-grammar.h @@ -5,4 +5,12 @@ #define JSON_ASSERT GGML_ASSERT #include "json.hpp" -std::string json_schema_to_grammar(const nlohmann::ordered_json& schema); +std::string json_schema_to_grammar(const nlohmann::ordered_json & schema); + +struct llama_grammar_builder { + std::function add_rule; + std::function add_schema; + std::function resolve_refs; +}; + +std::string build_grammar(const std::function & cb);