diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 87cb1a28a..d6ab40158 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -1,3 +1,7 @@ +// NOTE: This is modified from clip.cpp only for LLaVA, +// so there might be still unnecessary artifacts hanging around +// I'll gradually clean and extend it + #include #include #include @@ -13,6 +17,7 @@ #include "clip.h" #include "ggml.h" +#include "ggml-alloc.h" #define STB_IMAGE_IMPLEMENTATION #include "stb_image.h" @@ -144,21 +149,6 @@ std::string get_ftype(int ftype) { } } -// -// Vocab utils -// - -struct clip_vocab { - using id = clip_vocab_id; - using token = std::string; - - std::map token_to_id; - std::map id_to_token; - std::vector special_tokens; - - // void add_special_token(const std::string & token); -}; - // // clip layers // @@ -191,21 +181,6 @@ struct clip_layer { struct ggml_tensor * ln_2_b; }; -struct clip_text_model { - struct clip_text_hparams hparams; - - // embeddings - struct ggml_tensor * token_embeddings; - struct ggml_tensor * position_embeddings; - - std::vector layers; - - struct ggml_tensor * post_ln_w; - struct ggml_tensor * post_ln_b; - - struct ggml_tensor * projection; -}; - struct clip_vision_model { struct clip_vision_hparams hparams; @@ -249,96 +224,251 @@ struct clip_ctx { bool has_text_encoder = false; bool has_vision_encoder = false; bool has_llava_projector = false; - struct clip_text_model text_model; struct clip_vision_model vision_model; - struct clip_vocab vocab; float image_mean[3]; float image_std[3]; bool use_gelu = false; int32_t ftype = 1; struct ggml_context * ctx; struct gguf_context * ctx_gguf; - struct clip_buffer buf_compute; + //struct clip_buffer buf_compute; + + + // reusable buffer for `struct ggml_graph_plan.work_data` + std::vector work_buffer; + + // memory buffers used to evaluate the model + clip_buffer buf_compute; + + clip_buffer buf_alloc; + ggml_allocr * alloc = NULL; + }; -// -// memory allocation and management -// -// utility function for a workaround until https://github.com/ggerganov/ggml/issues/260 is resolved -// after that, remove this and use the mechanism implemented in GGML directly -size_t get_mem_req_by_size(struct clip_ctx * ctx) { - size_t mb = 1024 * 1024; - const int n_tensors = gguf_get_n_tensors(ctx->ctx_gguf); - const auto & vision_hparams = clip_get_vision_hparams(ctx); - const int n_positions = - ctx->has_vision_encoder ? vision_hparams->image_size * vision_hparams->image_size / vision_hparams->patch_size + 1 : 77; - switch (n_tensors) { - case 397: // base, two-tower - case 200: // base, vision-only - if (vision_hparams->patch_size == 32) { // patch size = 32 - return 96 * mb; - } else { // patch size = 16 - return 128 * mb; - } - case 197: // base or large, text-only - return 96 * mb; - case 589: // large, two-tower - case 392: // large, vision-only - case 377: // large, LLaVA encoder - if (vision_hparams->image_size == 224) { // input image size = 224 - return 1200 * mb; - } else { // input image size = 336 - return 2900 * mb; - } - case 909: // huge, two-tower - case 520: // huge, vision-only - return 232 * mb; - case 389: // huge, text-only - return 120 * mb; - default: - fprintf(stderr, "%s: Unrecognized number of tensors: %d. Check if you pass the correct model file\n", __func__, - n_tensors); - exit(1); +static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_image_f32_batch * imgs) { + + if (!ctx->has_vision_encoder) { + printf("This gguf file seems to have no vision encoder\n"); + return nullptr; } + + const auto & model = ctx->vision_model; + const auto & hparams = model.hparams; + + const int image_size = hparams.image_size; + const int patch_size = hparams.patch_size; + const int num_patches = ((image_size / patch_size) * (image_size / patch_size)); + const int num_positions = num_patches + 1; + const int hidden_size = hparams.hidden_size; + const int n_head = hparams.n_head; + const int d_head = hidden_size / n_head; + const int n_layer = hparams.n_layer; + const int n_intermediate = hparams.n_intermediate; + const int projection_dim = hparams.projection_dim; + const float eps = hparams.eps; + int batch_size = imgs->size; + if(ctx->has_llava_projector) { + GGML_ASSERT(batch_size == 1); + } + + auto & buf_compute = ctx->buf_compute; + + struct ggml_init_params params = { + .mem_size = buf_compute.size, + .mem_buffer = buf_compute.data, + .no_alloc = false, + }; + + params.no_alloc = true; + + struct ggml_context * ctx0 = ggml_init(params); + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + + struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, batch_size); + ggml_allocr_alloc(ctx->alloc, inp_raw); + + if (!ggml_allocr_is_measure(ctx->alloc)) { + float * data = (float *)ggml_get_data(inp_raw); + + for (int b = 0; b < imgs->size; b++) { + const int nx = imgs->data[b].nx; + const int ny = imgs->data[b].ny; + GGML_ASSERT(nx == image_size && ny == image_size); + + const int n = nx * ny; + + for (int b = 0; b < batch_size; b++) { + for (int k = 0; k < 3; k++) { + for (int y = 0; y < ny; y++) { + for (int x = 0; x < nx; x++) { + data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].data[3 * (y * nx + x) + k]; + } + } + } + } + } + } + + struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + + inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); + + // concat class_embeddings and patch_embeddings + struct ggml_tensor * embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); + ggml_allocr_alloc(ctx->alloc, embeddings); + if (!ggml_allocr_is_measure(ctx->alloc)) { + ggml_set_zero(embeddings); + } + + struct ggml_tensor * temp = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, 1, batch_size); + ggml_allocr_alloc(ctx->alloc, temp); + + embeddings = ggml_acc(ctx0, embeddings, ggml_repeat(ctx0, model.class_embedding, temp), embeddings->nb[1], + embeddings->nb[2], embeddings->nb[3], 0); + embeddings = + ggml_acc(ctx0, embeddings, inp, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); + + struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); + ggml_allocr_alloc(ctx->alloc, positions); + if (!ggml_allocr_is_measure(ctx->alloc)) { + for (int i = 0; i < num_positions; i++) { + ggml_set_i32_1d(positions, i, i); + } + } + + embeddings = + ggml_add(ctx0, embeddings, ggml_repeat(ctx0, ggml_get_rows(ctx0, model.position_embeddings, positions), embeddings)); + + // pre-layernorm + { + embeddings = ggml_norm(ctx0, embeddings, eps); + + embeddings = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.pre_ln_w, embeddings), embeddings), + ggml_repeat(ctx0, model.pre_ln_b, embeddings)); + } + +struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); +ggml_allocr_alloc(ctx->alloc, KQ_scale); +if (!ggml_allocr_is_measure(ctx->alloc)) { + ggml_set_f32(KQ_scale, 1.0f / sqrt((float)d_head)); } -size_t get_scr_buf_req_by_size(struct clip_ctx * ctx) { - size_t mb = 1024 * 1024; + // loop over layers + for (int il = 0; il < n_layer - 1; il++) { + struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states - const int n_tensors = gguf_get_n_tensors(ctx->ctx_gguf); - const auto & vision_hparams = clip_get_vision_hparams(ctx); - const int n_positions = - ctx->has_vision_encoder ? vision_hparams->image_size * vision_hparams->image_size / vision_hparams->patch_size + 1 : 77; + const size_t nb_q_w = model.layers[il].q_w->nb[0]; - switch (n_tensors) { - case 397: - case 200: - if (n_positions <= 50) { - return 32 * mb; - } else { - return 96 * mb; + // layernorm1 + { + cur = ggml_norm(ctx0, cur, eps); + + cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ln_1_w, cur), cur), + ggml_repeat(ctx0, model.layers[il].ln_1_b, cur)); } - case 197: - return 32 * mb; - case 589: - case 392: - case 377: - if (n_positions <= 257) { - return 96 * mb; - } else { - return 192 * mb; + + // self-attention + { + + struct ggml_tensor * Q = + ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].q_b, cur), ggml_mul_mat(ctx0, model.layers[il].q_w, cur)); + + Q = ggml_scale_inplace(ctx0, Q, KQ_scale); + Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size); + Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); + Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size); + + struct ggml_tensor * K = + ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].k_b, cur), ggml_mul_mat(ctx0, model.layers[il].k_w, cur)); + + K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); + K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); + K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); + + struct ggml_tensor * V = + ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].v_b, cur), ggml_mul_mat(ctx0, model.layers[il].v_w, cur)); + + V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); + V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); + V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + KQ = ggml_soft_max_inplace(ctx0, KQ); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); + KQV = ggml_cont(ctx0, ggml_permute(ctx0, KQV, 0, 2, 1, 3)); + + cur = ggml_cpy(ctx0, KQV, ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size)); } - case 909: - case 520: - return 144 * mb; - case 389: - return 60 * mb; - default: - fprintf(stderr, "%s: Unrecognized number of tensors: %d. Check if you pass the correct model file\n", __func__, - n_tensors); - exit(1); + + // attention output + cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].o_b, cur), ggml_mul_mat(ctx0, model.layers[il].o_w, cur)); + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, embeddings); + + embeddings = cur; // embeddings = residual, cur = hidden_states + + // layernorm2 + { + cur = ggml_norm(ctx0, cur, eps); + + cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ln_2_w, cur), cur), + ggml_repeat(ctx0, model.layers[il].ln_2_b, cur)); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); + cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].ff_i_b, cur), cur); + + if (ctx->use_gelu) { + cur = ggml_gelu_inplace(ctx0, cur); + } else { + cur = ggml_gelu_quick_inplace(ctx0, cur); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); + cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].ff_o_b, cur), cur); + + // residual 2 + cur = ggml_add(ctx0, embeddings, cur); + + embeddings = cur; } + + if (ctx->has_llava_projector) { + embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); + + struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches); + ggml_allocr_alloc(ctx->alloc, patches); + if (!ggml_allocr_is_measure(ctx->alloc)) { + for (int i = 0; i < num_patches; ++i) { + ggml_set_i32_1d(patches, i, i+1); + } + } + + embeddings = ggml_get_rows(ctx0, embeddings, patches); + + // mm projection 0 + embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); + embeddings = ggml_add(ctx0, ggml_repeat(ctx0, model.mm_0_b, embeddings), embeddings); + + embeddings = ggml_gelu(ctx0, embeddings); + + embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); + embeddings = ggml_add(ctx0, ggml_repeat(ctx0, model.mm_2_b, embeddings), embeddings); + + ggml_set_name(embeddings, "check"); + } + + // build the graph + ggml_build_forward_expand(gf, embeddings); + + ggml_free(ctx0); + + return gf; } // read and create ggml_context containing the tensors and their data @@ -422,6 +552,8 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->has_llava_projector = gguf_get_val_bool(ctx, idx); } + GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search + idx = get_key_idx(ctx, KEY_USE_GELU); new_clip->use_gelu = gguf_get_val_bool(ctx, idx); @@ -477,66 +609,6 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { fin.close(); } - // text model - if (new_clip->has_text_encoder) { - // load text model - auto & text_model = new_clip->text_model; - auto & hparams = text_model.hparams; - hparams.hidden_size = get_u32(ctx, format(KEY_N_EMBD, "text")); - hparams.n_head = get_u32(ctx, format(KEY_N_HEAD, "text")); - hparams.n_intermediate = get_u32(ctx, format(KEY_N_FF, "text")); - hparams.n_layer = get_u32(ctx, format(KEY_N_BLOCK, "text")); - hparams.num_positions = get_u32(ctx, KEY_N_POSITIONS); - hparams.projection_dim = get_u32(ctx, format(KEY_PROJ_DIM, "text")); - hparams.eps = get_f32(ctx, format(KEY_LAYER_NORM_EPS, "text")); - - const int idx_tokens = get_key_idx(ctx, KEY_TOKENS); - hparams.n_vocab = gguf_get_arr_n(ctx, idx_tokens); - auto & vocab = new_clip->vocab; - for (int id = 0; id < hparams.n_vocab; ++id) { - const std::string token = gguf_get_arr_str(ctx, idx_tokens, id); - vocab.id_to_token[id] = token; - vocab.token_to_id[token] = id; - } - - if (verbosity >= 2) { - printf("\n%s: text model hparams\n", __func__); - printf("n_vocab %d\n", hparams.n_vocab); - printf("num_positions %d\n", hparams.num_positions); - printf("t_hidden_size %d\n", hparams.hidden_size); - printf("t_n_intermediate %d\n", hparams.n_intermediate); - printf("t_projection_dim %d\n", hparams.projection_dim); - printf("t_n_head %d\n", hparams.n_head); - printf("t_n_layer %d\n", hparams.n_layer); - } - - text_model.token_embeddings = get_tensor(new_clip->ctx, format(TN_TOKEN_EMBD, "t")); - text_model.position_embeddings = get_tensor(new_clip->ctx, format(TN_POS_EMBD, "t")); - text_model.post_ln_w = get_tensor(new_clip->ctx, format(TN_LN_POST, "t", "weight")); - text_model.post_ln_b = get_tensor(new_clip->ctx, format(TN_LN_POST, "t", "bias")); - text_model.projection = get_tensor(new_clip->ctx, TN_TEXT_PROJ); - text_model.layers.resize(hparams.n_layer); - for (int il = 0; il < hparams.n_layer; ++il) { - auto & layer = text_model.layers[il]; - layer.k_w = get_tensor(new_clip->ctx, format(TN_ATTN_K, "t", il, "weight")); - layer.q_w = get_tensor(new_clip->ctx, format(TN_ATTN_Q, "t", il, "weight")); - layer.v_w = get_tensor(new_clip->ctx, format(TN_ATTN_V, "t", il, "weight")); - layer.o_w = get_tensor(new_clip->ctx, format(TN_ATTN_OUTPUT, "t", il, "weight")); - layer.ln_1_w = get_tensor(new_clip->ctx, format(TN_LN_1, "t", il, "weight")); - layer.ln_2_w = get_tensor(new_clip->ctx, format(TN_LN_2, "t", il, "weight")); - layer.ff_i_w = get_tensor(new_clip->ctx, format(TN_FFN_DOWN, "t", il, "weight")); - layer.ff_o_w = get_tensor(new_clip->ctx, format(TN_FFN_UP, "t", il, "weight")); - layer.k_b = get_tensor(new_clip->ctx, format(TN_ATTN_K, "t", il, "bias")); - layer.q_b = get_tensor(new_clip->ctx, format(TN_ATTN_Q, "t", il, "bias")); - layer.v_b = get_tensor(new_clip->ctx, format(TN_ATTN_V, "t", il, "bias")); - layer.o_b = get_tensor(new_clip->ctx, format(TN_ATTN_OUTPUT, "t", il, "bias")); - layer.ln_1_b = get_tensor(new_clip->ctx, format(TN_LN_1, "t", il, "bias")); - layer.ln_2_b = get_tensor(new_clip->ctx, format(TN_LN_2, "t", il, "bias")); - layer.ff_i_b = get_tensor(new_clip->ctx, format(TN_FFN_DOWN, "t", il, "bias")); - layer.ff_o_b = get_tensor(new_clip->ctx, format(TN_FFN_UP, "t", il, "bias")); - } - } - // vision model if (new_clip->has_vision_encoder) { // load vision model @@ -608,99 +680,26 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { ggml_free(meta); new_clip->ctx_gguf = ctx; + +// measure mem requirement and allocate + { + static const size_t tensor_alignment = 32; + new_clip->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); + new_clip->alloc = ggml_allocr_new_measure(tensor_alignment); + clip_image_f32_batch batch; + batch.size = 1; + ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch); + size_t alloc_size = ggml_allocr_alloc_graph(new_clip->alloc, gf) + tensor_alignment; + ggml_allocr_free(new_clip->alloc); + new_clip->buf_alloc.resize(alloc_size); + new_clip->alloc = ggml_allocr_new(new_clip->buf_alloc.data, new_clip->buf_alloc.size, tensor_alignment); - const size_t mem_req = get_mem_req_by_size(new_clip); - new_clip->buf_compute.resize(mem_req); - if (verbosity >= 1) { - printf("\n%s: %zu MB of memory allocated\n", __func__, mem_req / 1024 / 1024); + printf("%s: total allocated memory: %.2f MB\n", __func__, (new_clip->buf_compute.size + alloc_size)/1024.0/1024.0); } return new_clip; } -bool clip_tokenize(const clip_ctx * ctx, const char * text, struct clip_tokens * tokens) { - if (!ctx->has_text_encoder) { - printf("This GGUF file seems to have no text encoder\n"); - return false; - } - - std::vector words; - - // first split the text into words - { - std::string str = text; - std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; - - // Generate the subpattern from the special_tokens vector if it's not empty - if (!ctx->vocab.special_tokens.empty()) { - std::string special_tokens_subpattern; - for (const auto & token : ctx->vocab.special_tokens) { - if (!special_tokens_subpattern.empty()) { - special_tokens_subpattern += "|"; - } - special_tokens_subpattern += token; - } - - // Modify the regex pattern with the generated special tokens subpattern - pat = special_tokens_subpattern + "|" + pat; - } - - std::regex re(pat); - std::smatch m; - - while (std::regex_search(str, m, re)) { - for (auto x : m) { - words.push_back(x); - } - str = m.suffix(); - } - } - - std::vector v_tokens; - v_tokens.push_back(49406); // startoftext - - for (const auto & word : words) { - // feel lucky? let's try if it's a full word - std::string full_word = ""; - if (word.find(" ") == 0) // starts_with for C++11 - { - full_word += word.substr(1); - } else { - full_word += word; - } - full_word += ""; - auto wit = ctx->vocab.token_to_id.find(full_word); - if (wit != ctx->vocab.token_to_id.end()) { - v_tokens.push_back(wit->second); - continue; - } - - for (int i = 0; i < word.size();) { - for (int j = word.size() - 1; j >= i; j--) { - auto cand = word.substr(i, j - i + 1); - auto it = ctx->vocab.token_to_id.find(cand); - if (it != ctx->vocab.token_to_id.end()) { // word.substr(i, j-i+1) in vocab - v_tokens.push_back(it->second); - i = j + 1; - break; - } else if (j == i) { // word.substr(i, 1) has no matching - fprintf(stderr, "%s: unknown token '%s'\n", __func__, word.substr(i, 1).data()); - i++; - } - } - } - } - - v_tokens.push_back(49407); // endoftext - - tokens->size = v_tokens.size(); - - tokens->data = new int[v_tokens.size()]; - std::copy(v_tokens.begin(), v_tokens.end(), tokens->data); - - return true; -} - clip_image_u8 * make_clip_image_u8() { return new clip_image_u8(); } clip_image_f32 * make_clip_image_f32() { return new clip_image_f32(); } @@ -864,231 +863,6 @@ void clip_free(clip_ctx * ctx) { delete ctx; } -bool clip_text_encode(const clip_ctx * ctx, const int n_threads, const clip_tokens * tokens, float * vec, - const bool normalize) { - if (!ctx->has_text_encoder) { - printf("This GGUF file seems to have no text encoder\n"); - return false; - } - - const auto & model = ctx->text_model; - const auto & hparams = model.hparams; - const size_t N = tokens->size; - - const int n_vocab = hparams.n_vocab; - const int num_positions = hparams.num_positions; - const int hidden_size = hparams.hidden_size; - const int n_head = hparams.n_head; - const int d_head = hidden_size / n_head; - const int n_layer = hparams.n_layer; - const int n_intermediate = hparams.n_intermediate; - const int projection_dim = hparams.projection_dim; - const float eps = hparams.eps; - - auto & buf_compute = ctx->buf_compute; - - struct ggml_init_params params = { - .mem_size = buf_compute.size, - .mem_buffer = buf_compute.data, - .no_alloc = false, - }; - - struct ggml_context * ctx0 = ggml_init(params); - struct ggml_cgraph gf = {}; - - //static size_t scr0_size = get_scr_buf_req_by_size((struct clip_ctx *)ctx); - //static void * scr0 = malloc(scr0_size); - - struct ggml_tensor * input_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - memcpy(input_ids->data, tokens->data, N * ggml_element_size(input_ids)); - - struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - for (int i = 0; i < N; i++) { - ggml_set_i32_1d(positions, i, i); - } - - struct ggml_tensor * embeddings = ggml_get_rows(ctx0, model.token_embeddings, input_ids); - - embeddings = ggml_add(ctx0, ggml_get_rows(ctx0, model.position_embeddings, positions), embeddings); - - // loop over layers - for (int il = 0; il < n_layer; il++) { - struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states - - //ggml_set_scratch(ctx0, {0, scr0_size, scr0}); - - // layernorm1 - { - cur = ggml_norm(ctx0, cur, eps); - - cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ln_1_w, cur), cur), - ggml_repeat(ctx0, model.layers[il].ln_1_b, cur)); - } - - // self-attention - { - struct ggml_tensor * Q = - ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].q_b, cur), ggml_mul_mat(ctx0, model.layers[il].q_w, cur)); - - Q = ggml_scale_inplace(ctx0, Q, ggml_new_f32(ctx0, 1.0f / sqrt((float)d_head))); - Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, N, 1); - Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); - Q = ggml_reshape_3d(ctx0, Q, d_head, N, n_head); - - struct ggml_tensor * K = - ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].k_b, cur), ggml_mul_mat(ctx0, model.layers[il].k_w, cur)); - - K = ggml_reshape_4d(ctx0, K, d_head, n_head, N, 1); - K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); - K = ggml_reshape_3d(ctx0, K, d_head, N, n_head); - - struct ggml_tensor * V = - ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].v_b, cur), ggml_mul_mat(ctx0, model.layers[il].v_w, cur)); - V = ggml_reshape_4d(ctx0, V, d_head, n_head, N, 1); - V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); - V = ggml_reshape_3d(ctx0, V, N, d_head, n_head); - - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - KQ = ggml_diag_mask_inf_inplace(ctx0, KQ, 0); // causal masking - KQ = ggml_soft_max_inplace(ctx0, KQ); - - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); - KQV = ggml_reshape_4d(ctx0, KQV, d_head, N, n_head, 1); - KQV = ggml_cont(ctx0, ggml_permute(ctx0, KQV, 0, 2, 1, 3)); - - cur = ggml_cpy(ctx0, KQV, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hidden_size, N)); - } - - // attention output - cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].o_b, cur), ggml_mul_mat(ctx0, model.layers[il].o_w, cur)); - - // re-add the layer input, e.g., residual - cur = ggml_add(ctx0, cur, embeddings); - - embeddings = cur; // embeddings = residual, cur = hidden_states - - // layernorm2 - { - cur = ggml_norm(ctx0, cur, eps); - - cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ln_2_w, cur), cur), - ggml_repeat(ctx0, model.layers[il].ln_2_b, cur)); - } - - cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); - cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].ff_i_b, cur), cur); - - if (ctx->use_gelu) { - cur = ggml_gelu_inplace(ctx0, cur); - } else { - cur = ggml_gelu_quick_inplace(ctx0, cur); - } - - cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); - cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].ff_o_b, cur), cur); - - // residual 2 - cur = ggml_add(ctx0, embeddings, cur); - - embeddings = cur; - } - - // final -layer_norm - { - embeddings = ggml_norm(ctx0, embeddings, eps); - - embeddings = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.post_ln_w, embeddings), embeddings), - ggml_repeat(ctx0, model.post_ln_b, embeddings)); - } - - // get the output of eot token, e.g., last index - struct ggml_tensor * eot = ggml_new_i32(ctx0, N - 1); - embeddings = ggml_get_rows(ctx0, embeddings, eot); - - //ggml_set_scratch(ctx0, {0, 0, nullptr}); - - // text projection - embeddings = ggml_mul_mat(ctx0, model.projection, embeddings); - - // normalize output embeddings - if (normalize) { - ggml_tensor * length = ggml_sqrt(ctx0, ggml_sum(ctx0, ggml_sqr(ctx0, embeddings))); - embeddings = ggml_scale_inplace(ctx0, embeddings, ggml_div(ctx0, ggml_new_f32(ctx0, 1.0f), length)); - } - - ggml_set_name(embeddings, "check"); - - // run the computation - - ggml_build_forward_expand(&gf, embeddings); - /* - ggml_cplan cplan = ggml_graph_plan(&gf, n_threads); - if (cplan.work_size != 0) { - cplan.work_data = (uint8_t *)malloc(cplan.work_size); - } - ggml_graph_compute(&gf, &cplan); - */ - - ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); - -// print -#ifdef CLIP_DEBUG - { - auto print_t_f32 = [&](struct ggml_tensor * t) { - float * data = (float *)t->data; - printf("dtype: f32, dims: %jd %jd %jd %jd, nb: %jd %jd %jd %jd\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3], t->nb[0], - t->nb[1], t->nb[2], t->nb[3]); - printf("data: "); - for (int i = 0; i < std::min((int)t->ne[0], 20); i++) { - printf("%f ", data[i]); - } - - // printf("\n\n"); - double sum = 0.0; - for (int i = 0; i < ggml_nelements(t); i++) { - sum += data[i]; - } - printf("sum: %f\n", sum); - }; - - auto print_t_f16 = [&](struct ggml_tensor * t) { - ggml_fp16_t * data = (ggml_fp16_t *)t->data; - printf("dtype: f16, dims: %jd %jd %jd %jd\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]); - printf("data: "); - for (int i = 0; i < std::min((int)t->ne[0], 10); i++) { - printf("%f ", ggml_fp16_to_fp32(data[i])); - } - printf("\n\n"); - double sum = 0.0; - for (int i = 0; i < ggml_nelements(t); i++) { - sum += ggml_fp16_to_fp32(data[i]); - } - printf("sum: %f\n", sum); - }; - - auto * t = ggml_get_tensor(ctx0, "check"); - if (t->type == GGML_TYPE_F32) { - print_t_f32(t); - } else { - print_t_f16(t); - } - } - - printf("used_mem = %zu\n", ggml_used_mem(ctx0)); -#endif - memcpy(vec, ggml_get_data_f32(embeddings), sizeof(float) * projection_dim); - - /* - if (cplan.work_size != 0) { - free(cplan.work_data); - } - */ - - ggml_free(ctx0); - - return true; -} - bool clip_image_encode(const clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec, const bool normalize) { if (!ctx->has_vision_encoder) { printf("This gguf file seems to have no vision encoder\n"); @@ -1125,428 +899,31 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl const float eps = hparams.eps; int batch_size = imgs->size; if(ctx->has_llava_projector) { - GGML_ASSERT(batch_size == 1); + GGML_ASSERT(batch_size == 1); // TODO: support multiple images } - auto & buf_compute = ctx->buf_compute; + ggml_allocr_reset(ctx->alloc); + ggml_cgraph * gf = clip_image_build_graph(ctx, imgs); + ggml_allocr_alloc_graph(ctx->alloc, gf); - struct ggml_init_params params = { - .mem_size = buf_compute.size, - .mem_buffer = buf_compute.data, - .no_alloc = false, - }; - - struct ggml_context * ctx0 = ggml_init(params); - struct ggml_cgraph gf = {}; - - //static size_t scr0_size = get_scr_buf_req_by_size((struct clip_ctx *)ctx); - //static void * scr0 = malloc(scr0_size); - - struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, batch_size); - - { - float * data = (float *)ggml_get_data(inp_raw); - - for (int b = 0; b < imgs->size; b++) { - const int nx = imgs->data[b].nx; - const int ny = imgs->data[b].ny; - GGML_ASSERT(nx == image_size && ny == image_size); - - const int n = nx * ny; - - for (int b = 0; b < batch_size; b++) { - for (int k = 0; k < 3; k++) { - for (int y = 0; y < ny; y++) { - for (int x = 0; x < nx; x++) { - data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].data[3 * (y * nx + x) + k]; - } - } - } - } - } + struct ggml_cplan plan = ggml_graph_plan(gf, n_threads); + if (plan.work_size > 0) { + plan.work_data = (uint8_t *)malloc(plan.work_size); } - struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - - inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size); - inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); - - // concat class_embeddings and patch_embeddings - struct ggml_tensor * embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); - - ggml_set_zero(embeddings); - struct ggml_tensor * temp = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, 1, batch_size); - - embeddings = ggml_acc(ctx0, embeddings, ggml_repeat(ctx0, model.class_embedding, temp), embeddings->nb[1], - embeddings->nb[2], embeddings->nb[3], 0); - embeddings = - ggml_acc(ctx0, embeddings, inp, embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); - - struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); - for (int i = 0; i < num_positions; i++) { - ggml_set_i32_1d(positions, i, i); - } - - embeddings = - ggml_add(ctx0, embeddings, ggml_repeat(ctx0, ggml_get_rows(ctx0, model.position_embeddings, positions), embeddings)); - - // pre-layernorm - { - embeddings = ggml_norm(ctx0, embeddings, eps); - - embeddings = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.pre_ln_w, embeddings), embeddings), - ggml_repeat(ctx0, model.pre_ln_b, embeddings)); - } - - // loop over layers - for (int il = 0; il < n_layer - 1; il++) { - struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states - - const size_t nb_q_w = model.layers[il].q_w->nb[0]; - - //ggml_set_scratch(ctx0, {0, scr0_size, scr0}); - - // layernorm1 - { - cur = ggml_norm(ctx0, cur, eps); - - cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ln_1_w, cur), cur), - ggml_repeat(ctx0, model.layers[il].ln_1_b, cur)); - } - - // self-attention - { - - struct ggml_tensor * Q = - ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].q_b, cur), ggml_mul_mat(ctx0, model.layers[il].q_w, cur)); - - Q = ggml_scale_inplace(ctx0, Q, ggml_new_f32(ctx0, 1.0f / sqrt((float)d_head))); - Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size); - Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); - Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size); - - struct ggml_tensor * K = - ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].k_b, cur), ggml_mul_mat(ctx0, model.layers[il].k_w, cur)); - - K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); - K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); - K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); - - struct ggml_tensor * V = - ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].v_b, cur), ggml_mul_mat(ctx0, model.layers[il].v_w, cur)); - - V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); - V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); - V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); - - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - KQ = ggml_soft_max_inplace(ctx0, KQ); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); - KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); - KQV = ggml_cont(ctx0, ggml_permute(ctx0, KQV, 0, 2, 1, 3)); - - cur = ggml_cpy(ctx0, KQV, ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size)); - } - - // attention output - cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].o_b, cur), ggml_mul_mat(ctx0, model.layers[il].o_w, cur)); - - // re-add the layer input, e.g., residual - cur = ggml_add(ctx0, cur, embeddings); - - embeddings = cur; // embeddings = residual, cur = hidden_states - - // layernorm2 - { - cur = ggml_norm(ctx0, cur, eps); - - cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ln_2_w, cur), cur), - ggml_repeat(ctx0, model.layers[il].ln_2_b, cur)); - } - - cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); - cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].ff_i_b, cur), cur); - - if (ctx->use_gelu) { - cur = ggml_gelu_inplace(ctx0, cur); - } else { - cur = ggml_gelu_quick_inplace(ctx0, cur); - } - - cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); - cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].ff_o_b, cur), cur); - - // residual 2 - cur = ggml_add(ctx0, embeddings, cur); - - embeddings = cur; - } - - //ggml_set_scratch(ctx0, {0, 0, nullptr}); - - if (ctx->has_llava_projector) { - embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); - struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches); - for (int i = 0; i < num_patches; ++i) { - ggml_set_i32_1d(patches, i, i+1); - } - embeddings = ggml_get_rows(ctx0, embeddings, patches); - - // mm projection 0 - embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); - embeddings = ggml_add(ctx0, ggml_repeat(ctx0, model.mm_0_b, embeddings), embeddings); - - embeddings = ggml_gelu(ctx0, embeddings); - - embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); - embeddings = ggml_add(ctx0, ggml_repeat(ctx0, model.mm_2_b, embeddings), embeddings); - - ggml_set_name(embeddings, "check"); - } else { - // get the output of cls token, e.g., 0th index - struct ggml_tensor * cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, batch_size); - for (int b = 0; b < batch_size; b++) { - ggml_set_i32_1d(cls, b, b * num_positions); - } - embeddings = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, embeddings, hidden_size, num_positions * batch_size), cls); - - // post-layernorm - { - embeddings = ggml_norm(ctx0, embeddings, eps); - - embeddings = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.post_ln_w, embeddings), embeddings), - ggml_repeat(ctx0, model.post_ln_b, embeddings)); - } - - // final visual projection - embeddings = ggml_mul_mat(ctx0, model.projection, embeddings); - - // normalize output embeddings - struct ggml_tensor * output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, projection_dim, batch_size); - - for (int b = 0; b < batch_size; b++) { - struct ggml_tensor * embedding = ggml_get_rows(ctx0, embeddings, ggml_new_i32(ctx0, b)); - if (normalize) { - ggml_tensor * length = ggml_sqrt(ctx0, ggml_sum(ctx0, ggml_sqr(ctx0, embedding))); - embedding = ggml_scale_inplace(ctx0, embedding, ggml_div(ctx0, ggml_new_f32(ctx0, 1.0f), length)); - } - output = ggml_acc(ctx0, output, embedding, output->nb[1], output->nb[2], output->nb[3], b * ggml_nbytes(embedding)); - } - - embeddings = output; - } - //ggml_set_name(embeddings, "check"); - - // run the computation - ggml_build_forward_expand(&gf, embeddings); - - /* - ggml_cplan cplan = ggml_graph_plan(&gf, n_threads); - cplan.work_size *= batch_size; - if (cplan.work_size != 0) { - cplan.work_data = (uint8_t *)malloc(cplan.work_size); - } - ggml_graph_compute(&gf, &cplan); - */ - - ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); - -// print -#ifdef CLIP_DEBUG - { - auto print_t_f32 = [&](struct ggml_tensor * t) { - float * data = (float *)t->data; - printf("dtype: f32, dims: %jd %jd %jd %jd, nb: %jd %jd %jd %jd\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3], t->nb[0], - t->nb[1], t->nb[2], t->nb[3]); - printf("data: "); - for (int i = 0; i < std::min((int)t->ne[0], 20); i++) { - printf("%f ", data[i]); - } - - // printf("\n\n"); - double sum = 0.0; - for (int i = 0; i < ggml_nelements(t); i++) { - sum += data[i]; - } - printf("sum: %f\n", sum); - }; - - auto print_t_f16 = [&](struct ggml_tensor * t) { - ggml_fp16_t * data = (ggml_fp16_t *)t->data; - printf("dtype: f16, dims: %jd %jd %jd %jd\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]); - printf("data: "); - for (int i = 0; i < std::min((int)t->ne[0], 10); i++) { - printf("%f ", ggml_fp16_to_fp32(data[i])); - } - printf("\n\n"); - double sum = 0.0; - for (int i = 0; i < ggml_nelements(t); i++) { - sum += ggml_fp16_to_fp32(data[i]); - } - printf("sum: %f\n", sum); - }; - - auto * t = ggml_get_tensor(ctx0, "check"); - // auto t = inp_raw; - if (t->type == GGML_TYPE_F32) { - print_t_f32(t); - } else { - print_t_f16(t); - } - } - - printf("used_mem = %zu\n", ggml_used_mem(ctx0)); -#endif + ggml_graph_compute(gf, &plan); +struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 1]; memcpy(vec, ggml_get_data_f32(embeddings), ggml_nbytes(embeddings)); - /* - if (cplan.work_size != 0) { - free(cplan.work_data); + if (plan.work_size > 0) { + free(plan.work_data); } - */ - - ggml_free(ctx0); - - return true; -} - -float clip_similarity_score(const float * vec1, const float * vec2, const int vec_dim) { - float dot_product = 0.0; - for (int i = 0; i < vec_dim; i++) { - dot_product += vec1[i] * vec2[i]; - } - - return dot_product; -} - -bool clip_compare_text_and_image(const clip_ctx * ctx, const int n_threads, const char * text, const clip_image_u8 * image, - float * score) { - if (!(ctx->has_text_encoder && ctx->has_vision_encoder)) { - printf("clip_compare_text_and_image function can only be used with two-tower models\n"); - return false; - } - - // prepare image and text vectors - const int projection_dim = ctx->vision_model.hparams.projection_dim; - float img_vec[projection_dim]; - float txt_vec[projection_dim]; - - // tokenize and encode text - clip_tokens tokens; - if (!clip_tokenize(ctx, text, &tokens)) { - return false; - } - - if (!clip_text_encode(ctx, n_threads, &tokens, txt_vec, true)) { - return false; - } - - // preprocess and encode image - clip_image_f32 img_res; - - if (!clip_image_preprocess(ctx, image, &img_res)) { - return false; - } - - if (!clip_image_encode(ctx, n_threads, &img_res, img_vec, true)) { - return false; - } - - // compute similarity - *score = clip_similarity_score(img_vec, txt_vec, projection_dim); - - return true; -} - -typedef struct { - float score; - int index; -} ScoreIndexPair; - -int compare_scores(const void * a, const void * b) { - const ScoreIndexPair * pair1 = (const ScoreIndexPair *)a; - const ScoreIndexPair * pair2 = (const ScoreIndexPair *)b; - - if (pair1->score < pair2->score) { - return 1; - } else if (pair1->score > pair2->score) { - return -1; - } else { - return 0; - } -} - -bool softmax_with_sorting(float * arr, const int length, float * sorted_scores, int * indices) { - ScoreIndexPair * score_index_pairs = (ScoreIndexPair *)malloc(length * sizeof(ScoreIndexPair)); - if (!score_index_pairs) { - return false; - } - - // Calculate softmax probabilities - - double sum = 0.0; - for (int i = 0; i < length; i++) { - arr[i] = exp(arr[i]) + 1e-9; - sum += arr[i]; - } - - for (int i = 0; i < length; i++) { - arr[i] /= sum; - score_index_pairs[i].score = arr[i]; - score_index_pairs[i].index = i; - } - - // Sort scores in descending order - qsort(score_index_pairs, length, sizeof(ScoreIndexPair), compare_scores); - - // Copy sorted scores and indices to the respective arrays - for (int i = 0; i < length; i++) { - sorted_scores[i] = score_index_pairs[i].score; - indices[i] = score_index_pairs[i].index; - } - - free(score_index_pairs); - return true; -} - -bool clip_zero_shot_label_image(struct clip_ctx * ctx, const int n_threads, const struct clip_image_u8 * input_img, - const char ** labels, const size_t n_labels, float * scores, int * indices) { - if (!(ctx->has_text_encoder && ctx->has_vision_encoder)) { - printf("clip_zero_shot_label_image function can only be used with two-tower models\n"); - return false; - } - - // load the image - clip_image_f32 img_res; - - const int vec_dim = clip_get_vision_hparams(ctx)->projection_dim; - - clip_image_preprocess(ctx, input_img, &img_res); - - float img_vec[vec_dim]; - if (!clip_image_encode(ctx, n_threads, &img_res, img_vec, false)) { - return false; - } - - // encode texts and compute similarities - float txt_vec[vec_dim]; - float similarities[n_labels]; - - for (int i = 0; i < n_labels; i++) { - const auto & text = labels[i]; - clip_tokens tokens; - clip_tokenize(ctx, text, &tokens); - clip_text_encode(ctx, n_threads, &tokens, txt_vec, false); - similarities[i] = clip_similarity_score(img_vec, txt_vec, vec_dim); - } - - // apply softmax and sort scores - softmax_with_sorting(similarities, n_labels, scores, indices); - + return true; } +/* bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype) { ggml_type type = GGML_TYPE_Q4_1; @@ -1731,6 +1108,6 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i return true; } +*/ -struct clip_text_hparams * clip_get_text_hparams(struct clip_ctx * ctx) { return &ctx->text_model.hparams; } struct clip_vision_hparams * clip_get_vision_hparams(struct clip_ctx * ctx) { return &ctx->vision_model.hparams; }