diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 73bca4600..fcbc8f16c 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -89,7 +89,7 @@ static std::string format(const char * fmt, ...) { // utilities to get data from a gguf file // -int get_key_idx(const gguf_context * ctx, const char * key) { +static int get_key_idx(const gguf_context * ctx, const char * key) { int i = gguf_find_key(ctx, key); if (i == -1) { fprintf(stderr, "key %s not found in file\n", key); @@ -99,19 +99,19 @@ int get_key_idx(const gguf_context * ctx, const char * key) { return i; } -const uint32_t get_u32(const gguf_context * ctx, std::string key) { +static const uint32_t get_u32(const gguf_context * ctx, std::string key) { const int i = get_key_idx(ctx, key.c_str()); return gguf_get_val_u32(ctx, i); } -const float get_f32(const gguf_context * ctx, std::string key) { +static const float get_f32(const gguf_context * ctx, std::string key) { const int i = get_key_idx(ctx, key.c_str()); return gguf_get_val_f32(ctx, i); } -struct ggml_tensor * get_tensor(struct ggml_context * ctx, std::string name) { +static struct ggml_tensor * get_tensor(struct ggml_context * ctx, std::string name) { struct ggml_tensor * cur = ggml_get_tensor(ctx, name.c_str()); if (!cur) { printf("unable to find tensor %s\n", name.c_str()); @@ -121,7 +121,7 @@ struct ggml_tensor * get_tensor(struct ggml_context * ctx, std::string name) { return cur; } -std::string get_ftype(int ftype) { +static std::string get_ftype(int ftype) { switch (ftype) { case 0: return "f32"; @@ -231,20 +231,13 @@ struct clip_ctx { int32_t ftype = 1; struct ggml_context * ctx; struct gguf_context * ctx_gguf; - //struct clip_buffer buf_compute; - // reusable buffer for `struct ggml_graph_plan.work_data` - std::vector work_buffer; - - // memory buffers used to evaluate the model + // memory buffers to evaluate the model clip_buffer buf_compute; - clip_buffer buf_alloc; ggml_allocr * alloc = NULL; - }; - static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_image_f32_batch * imgs) { if (!ctx->has_vision_encoder) { @@ -436,7 +429,8 @@ if (!ggml_allocr_is_measure(ctx->alloc)) { embeddings = cur; } - if (ctx->has_llava_projector) { + // llava projector + { embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches); @@ -457,8 +451,6 @@ if (!ggml_allocr_is_measure(ctx->alloc)) { embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); embeddings = ggml_add(ctx0, ggml_repeat(ctx0, model.mm_2_b, embeddings), embeddings); - - ggml_set_name(embeddings, "check"); } // build the graph @@ -551,6 +543,8 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { } GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search + GGML_ASSERT(new_clip->has_vision_encoder); + GGML_ASSERT(!new_clip->has_text_encoder); idx = get_key_idx(ctx, KEY_USE_GELU); new_clip->use_gelu = gguf_get_val_bool(ctx, idx); @@ -643,16 +637,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { vision_model.class_embedding = get_tensor(new_clip->ctx, TN_CLASS_EMBD); vision_model.position_embeddings = get_tensor(new_clip->ctx, format(TN_POS_EMBD, "v")); vision_model.pre_ln_w = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "weight")); - vision_model.pre_ln_b = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "bias"));if (new_clip->has_llava_projector) { - vision_model.mm_0_w = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, 0, "weight")); - vision_model.mm_0_b = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, 0, "bias")); - vision_model.mm_2_w = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, 2, "weight")); - vision_model.mm_2_b = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, 2, "bias")); - } else { - vision_model.post_ln_w = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "weight")); - vision_model.post_ln_b = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "bias")); - vision_model.projection = get_tensor(new_clip->ctx, TN_VIS_PROJ); - } + vision_model.pre_ln_b = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "bias")); + vision_model.mm_0_w = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, 0, "weight")); + vision_model.mm_0_b = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, 0, "bias")); + vision_model.mm_2_w = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, 2, "weight")); + vision_model.mm_2_b = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, 2, "bias")); + vision_model.layers.resize(hparams.n_layer); for (int il = 0; il < hparams.n_layer; ++il) { auto & layer = vision_model.layers[il]; @@ -861,7 +851,7 @@ void clip_free(clip_ctx * ctx) { delete ctx; } -bool clip_image_encode(const clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec, const bool normalize) { +bool clip_image_encode(const clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) { if (!ctx->has_vision_encoder) { printf("This gguf file seems to have no vision encoder\n"); return false; @@ -870,37 +860,25 @@ bool clip_image_encode(const clip_ctx * ctx, const int n_threads, clip_image_f32 clip_image_f32_batch imgs{}; imgs.size = 1; imgs.data = img; - return clip_image_batch_encode(ctx, n_threads, &imgs, vec, normalize); + return clip_image_batch_encode(ctx, n_threads, &imgs, vec); } -bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs, float * vec, - const bool normalize) { +bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs, float * vec) { if (!ctx->has_vision_encoder) { printf("This gguf file seems to have no vision encoder\n"); return false; } - const auto & model = ctx->vision_model; - const auto & hparams = model.hparams; - - const int image_size = hparams.image_size; - const int patch_size = hparams.patch_size; - const int num_patches = ((image_size / patch_size) * (image_size / patch_size)); - const int num_positions = num_patches + 1; - const int hidden_size = hparams.hidden_size; - const int n_head = hparams.n_head; - const int d_head = hidden_size / n_head; - const int n_layer = hparams.n_layer; - const int n_intermediate = hparams.n_intermediate; - const int projection_dim = hparams.projection_dim; - const float eps = hparams.eps; int batch_size = imgs->size; if(ctx->has_llava_projector) { GGML_ASSERT(batch_size == 1); // TODO: support multiple images } + // reset alloc buffer to clean the memory from previous invocations ggml_allocr_reset(ctx->alloc); + + // build the inference graph ggml_cgraph * gf = clip_image_build_graph(ctx, imgs); ggml_allocr_alloc_graph(ctx->alloc, gf); @@ -911,7 +889,10 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl ggml_graph_compute(gf, &plan); + // the last node is the embedding tensor struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 1]; + + // copy the embeddings to the location passed by the user memcpy(vec, ggml_get_data_f32(embeddings), ggml_nbytes(embeddings)); if (plan.work_size > 0) { @@ -921,7 +902,6 @@ struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 1]; return true; } -/* bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype) { ggml_type type = GGML_TYPE_Q4_1; @@ -1106,6 +1086,9 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i return true; } -*/ -struct clip_vision_hparams * clip_get_vision_hparams(struct clip_ctx * ctx) { return &ctx->vision_model.hparams; } +size_t clip_embd_nbytes(struct clip_ctx * ctx) { + auto & params = ctx->vision_model.hparams; + + return (params.image_size / params.patch_size) * (params.image_size / params.patch_size) * 4096 * sizeof(float); +} diff --git a/examples/llava/clip.h b/examples/llava/clip.h index 18fe3da83..ea93f19e7 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -9,17 +9,6 @@ struct clip_ctx; extern "C" { #endif -struct clip_text_hparams { - int32_t n_vocab; - int32_t num_positions; - int32_t hidden_size; - int32_t n_intermediate; - int32_t projection_dim; - int32_t n_head; - int32_t n_layer; - float eps; -}; - struct clip_vision_hparams { int32_t image_size; int32_t patch_size; @@ -31,18 +20,11 @@ struct clip_vision_hparams { float eps; }; -typedef int32_t clip_vocab_id; -struct clip_tokens { - clip_vocab_id * data; - size_t size; -}; - struct clip_ctx * clip_model_load(const char * fname, const int verbosity); void clip_free(struct clip_ctx * ctx); -struct clip_text_hparams * clip_get_text_hparams(struct clip_ctx * ctx); -struct clip_vision_hparams * clip_get_vision_hparams(struct clip_ctx * ctx); +size_t clip_embd_nbytes(struct clip_ctx * ctx); // RGB uint8 image struct clip_image_u8 { @@ -71,31 +53,16 @@ struct clip_image_f32_batch { size_t size; }; -bool clip_tokenize(const struct clip_ctx * ctx, const char * text, struct clip_tokens * tokens); - struct clip_image_u8 * make_clip_image_u8(); struct clip_image_f32 * make_clip_image_f32(); bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img); bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res); - -bool clip_text_encode(const struct clip_ctx * ctx, const int n_threads, const struct clip_tokens * tokens, float * vec, - const bool normalize); -bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec, - const bool normalize); +bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec); void clip_image_batch_preprocess(const struct clip_ctx * ctx, const int n_threads, const struct clip_image_u8_batch * img_inputs, struct clip_image_f32_batch * imgs_resized); bool clip_image_batch_encode(const struct clip_ctx * ctx, const int n_threads, const struct clip_image_f32_batch * imgs, - float * vec, const bool normalize); - -// bool image_normalize(const clip_image_u8 *img, clip_image_f32 *res); - -bool clip_compare_text_and_image(const struct clip_ctx * ctx, const int n_threads, const char * text, - const struct clip_image_u8 * image, float * score); -float clip_similarity_score(const float * vec1, const float * vec2, const int vec_dim); -bool softmax_with_sorting(float * arr, const int length, float * sorted_scores, int * indices); -bool clip_zero_shot_label_image(struct clip_ctx * ctx, const int n_threads, const struct clip_image_u8 * input_img, - const char ** labels, const size_t n_labels, float * scores, int * indices); + float * vec); bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype); diff --git a/examples/llava/llava-utils.h b/examples/llava/llava-utils.h new file mode 100644 index 000000000..3434b528e --- /dev/null +++ b/examples/llava/llava-utils.h @@ -0,0 +1,141 @@ +// this one and clip lib will be eventually merged to a single lib, let's keep it this way for now +#include +#include +#include + +#include "common.h" +#include "llama.h" + +bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int n_batch, int * n_past) { + int n_embd = llama_n_embd(llama_get_model(ctx_llama)); + + for (int i = 0; i < N; i += n_batch) { + int n_eval = N - i; + if (n_eval > n_batch) { + n_eval = n_batch; + } + llama_batch batch = {int32_t(n_eval), nullptr, (embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, }; + if (llama_decode(ctx_llama, batch)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return false; + } + *n_past += n_eval; + } + return true; +} + +bool eval_tokens(struct llama_context * ctx_llama, std::vector tokens, int n_batch, int * n_past) { + int N = (int) tokens.size(); + for (int i = 0; i < N; i += n_batch) { + int n_eval = (int) tokens.size() - i; + if (n_eval > n_batch) { + n_eval = n_batch; + } + if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return false; + } + *n_past += n_eval; + } + return true; +} + +bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) { + std::vector tokens; + tokens.push_back(id); + return eval_tokens(ctx_llama, tokens, 1, n_past); +} + +bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past){ + std::string str2 = str; + std::vector embd_inp = ::llama_tokenize(ctx_llama, str2, true); + eval_tokens(ctx_llama, embd_inp, n_batch, n_past); + return true; +} + +llama_token sample_id(llama_context * ctx_llama, gpt_params & params) { + // out of user input, sample next token + const float temp = params.temp; + const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx_llama)) : params.top_k; + const float top_p = params.top_p; + const float tfs_z = params.tfs_z; + const float typical_p = params.typical_p; + // const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; + // const float repeat_penalty = params.repeat_penalty; + // const float alpha_presence = params.presence_penalty; + // const float alpha_frequency = params.frequency_penalty; + const int mirostat = params.mirostat; + const float mirostat_tau = params.mirostat_tau; + const float mirostat_eta = params.mirostat_eta; + // const bool penalize_nl = params.penalize_nl; + + llama_token id = 0; + { + auto logits = llama_get_logits(ctx_llama); + auto n_vocab = llama_n_vocab(llama_get_model(ctx_llama)); + + // Apply params.logit_bias map + for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { + logits[it->first] += it->second; + } + + std::vector candidates; + candidates.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + + // TODO: Apply penalties + // float nl_logit = logits[llama_token_nl(ctx)]; + // auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); + // llama_sample_repetition_penalty(ctx, &candidates_p, + // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + // last_n_repeat, repeat_penalty); + // llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, + // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + // last_n_repeat, alpha_frequency, alpha_presence); + // if (!penalize_nl) { + // logits[llama_token_nl(ctx)] = nl_logit; + // } + + if (temp <= 0) { + // Greedy sampling + id = llama_sample_token_greedy(ctx_llama, &candidates_p); + } else { + if (mirostat == 1) { + static float mirostat_mu = 2.0f * mirostat_tau; + const int mirostat_m = 100; + llama_sample_temp(ctx_llama, &candidates_p, temp); + id = llama_sample_token_mirostat(ctx_llama, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); + } else if (mirostat == 2) { + static float mirostat_mu = 2.0f * mirostat_tau; + llama_sample_temp(ctx_llama, &candidates_p, temp); + id = llama_sample_token_mirostat_v2(ctx_llama, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); + } else { + // Temperature sampling + llama_sample_top_k(ctx_llama, &candidates_p, top_k, 1); + llama_sample_tail_free(ctx_llama, &candidates_p, tfs_z, 1); + llama_sample_typical(ctx_llama, &candidates_p, typical_p, 1); + llama_sample_top_p(ctx_llama, &candidates_p, top_p, 1); + llama_sample_temp(ctx_llama, &candidates_p, temp); + id = llama_sample_token(ctx_llama, &candidates_p); + } + } + } + + return id; +} + +const char * sample(struct llama_context * ctx_llama, gpt_params & params, int * n_past) { + int id = sample_id(ctx_llama, params); + static std::string ret; + if (id == llama_token_eos(ctx_llama)) { + ret = ""; + } else { + ret = llama_token_to_piece(ctx_llama, id); + } + eval_id(ctx_llama, id, n_past); + return ret.c_str(); +} diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 37a3734b3..9dc0c4ee7 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -3,149 +3,17 @@ #include #include "clip.h" +#include "llava-utils.h" #include "common.h" #include "llama.h" -static bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int n_batch, int * n_past) { - int n_embd = llama_n_embd(llama_get_model(ctx_llama)); - - for (int i = 0; i < N; i += n_batch) { - int n_eval = N - i; - if (n_eval > n_batch) { - n_eval = n_batch; - } - llama_batch batch = {int32_t(n_eval), nullptr, (embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, }; - if (llama_decode(ctx_llama, batch)) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return false; - } - *n_past += n_eval; - } - return true; -} - -static bool eval_tokens(struct llama_context * ctx_llama, std::vector tokens, int N, int * n_past) { - int n_batch = N; - for (int i = 0; i < (int) tokens.size(); i += n_batch) { - int n_eval = (int) tokens.size() - i; - if (n_eval > n_batch) { - n_eval = n_batch; - } - if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return false; - } - *n_past += n_eval; - } - return true; -} - -static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) { - std::vector tokens; - tokens.push_back(id); - return eval_tokens(ctx_llama, tokens, 1, n_past); -} - -static bool eval_string(struct llama_context * ctx_llama, const char* str, int N, int * n_past){ - std::string str2 = str; - std::vector embd_inp = ::llama_tokenize(ctx_llama, str2, true); - eval_tokens(ctx_llama, embd_inp, N, n_past); - return true; -} - -static llama_token sample_id(llama_context * ctx_llama, gpt_params & params) { - // out of user input, sample next token - const float temp = params.temp; - const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx_llama)) : params.top_k; - const float top_p = params.top_p; - const float tfs_z = params.tfs_z; - const float typical_p = params.typical_p; - // const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; - // const float repeat_penalty = params.repeat_penalty; - // const float alpha_presence = params.presence_penalty; - // const float alpha_frequency = params.frequency_penalty; - const int mirostat = params.mirostat; - const float mirostat_tau = params.mirostat_tau; - const float mirostat_eta = params.mirostat_eta; - // const bool penalize_nl = params.penalize_nl; - - llama_token id = 0; - { - auto logits = llama_get_logits(ctx_llama); - auto n_vocab = llama_n_vocab(llama_get_model(ctx_llama)); - - // Apply params.logit_bias map - for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { - logits[it->first] += it->second; - } - - std::vector candidates; - candidates.reserve(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); - } - - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - - // TODO: Apply penalties - // float nl_logit = logits[llama_token_nl(ctx)]; - // auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); - // llama_sample_repetition_penalty(ctx, &candidates_p, - // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, - // last_n_repeat, repeat_penalty); - // llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, - // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, - // last_n_repeat, alpha_frequency, alpha_presence); - // if (!penalize_nl) { - // logits[llama_token_nl(ctx)] = nl_logit; - // } - - if (temp <= 0) { - // Greedy sampling - id = llama_sample_token_greedy(ctx_llama, &candidates_p); - } else { - if (mirostat == 1) { - static float mirostat_mu = 2.0f * mirostat_tau; - const int mirostat_m = 100; - llama_sample_temp(ctx_llama, &candidates_p, temp); - id = llama_sample_token_mirostat(ctx_llama, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); - } else if (mirostat == 2) { - static float mirostat_mu = 2.0f * mirostat_tau; - llama_sample_temp(ctx_llama, &candidates_p, temp); - id = llama_sample_token_mirostat_v2(ctx_llama, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); - } else { - // Temperature sampling - llama_sample_top_k(ctx_llama, &candidates_p, top_k, 1); - llama_sample_tail_free(ctx_llama, &candidates_p, tfs_z, 1); - llama_sample_typical(ctx_llama, &candidates_p, typical_p, 1); - llama_sample_top_p(ctx_llama, &candidates_p, top_p, 1); - llama_sample_temp(ctx_llama, &candidates_p, temp); - id = llama_sample_token(ctx_llama, &candidates_p); - } - } - } - - return id; -} - -const char * sample(struct llama_context * ctx_llama, gpt_params & params, int * n_past) { - int id = sample_id(ctx_llama, params); - static std::string ret; - if (id == llama_token_eos(ctx_llama)) { - ret = ""; - } else { - ret = llama_token_to_piece(ctx_llama, id); - } - eval_id(ctx_llama, id, n_past); - return ret.c_str(); -} - int main(int argc, char ** argv) { gpt_params params; if (argc < 4) { - printf("usage: %s [a text prompt]\n", argv[0]); + printf("usage: %s [a text prompt]\n", argv[0]); + return 1; } params.model = argv[1]; @@ -160,13 +28,28 @@ int main(int argc, char ** argv) { params.prompt = "describe the image in detail."; } - auto ctx_clip = clip_model_load(clip_path, 1); + auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); + + // load and preprocess the iamge clip_image_u8 img; clip_image_f32 img_res; clip_image_load_from_file(img_path, &img); clip_image_preprocess(ctx_clip, &img, &img_res); - float * vec = (float *)malloc(4096 * 576 * sizeof(float)); - clip_image_encode(ctx_clip, params.n_threads, &img_res, vec, false); + + float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)); + if (!image_embd) { + fprintf(stderr, "Unable to allocate memory for CLIP embeddings\n"); + + return 1; + } + + if (!clip_image_encode(ctx_clip, params.n_threads, &img_res, image_embd)) { + fprintf(stderr, "Unable to encode image\n"); + + return 1; + } + + // we get the embeddings, free up the memory required for CLIP clip_free(ctx_clip); llama_backend_init(params.numa); @@ -191,13 +74,17 @@ int main(int argc, char ** argv) { return 1; } + // process the prompt + // llava chat format is "user: \n\nassistant:" + int n_past = 0; int max_tgt_len = 256; eval_string(ctx_llama, "user: ", params.n_batch, &n_past); - eval_image_embd(ctx_llama, vec, 576, params.n_batch, &n_past); + eval_image_embd(ctx_llama, image_embd, /*n_pos_image=*/ 576, params.n_batch, &n_past); eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past); eval_string(ctx_llama, "\nassistant:", params.n_batch, &n_past); -printf("n_past = %d\n", n_past); + + // generate the response const char* tmp; for (int i=0; i