From c6932085febdd3f4794bf058e39afbe5dee6d952 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sat, 14 Oct 2023 11:51:33 +0200 Subject: [PATCH] refactor image load out of llava init --- examples/llava/llava-utils.h | 2 + examples/llava/llava.cpp | 83 +++++++++++++++++------------------- examples/llava/llava.h | 6 +-- 3 files changed, 45 insertions(+), 46 deletions(-) diff --git a/examples/llava/llava-utils.h b/examples/llava/llava-utils.h index db8af6d6c..b794c39cc 100644 --- a/examples/llava/llava-utils.h +++ b/examples/llava/llava-utils.h @@ -5,6 +5,8 @@ #include "common.h" #include "llama.h" +#include "base64.hpp" + #include #include #include diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 7781d4222..cf19a2f78 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -15,7 +15,8 @@ static void show_additional_info(int /*argc*/, char ** argv) { printf(" note: a lower temperature value like 0.1 is recommended for better quality.\n"); } -static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_embd, int * n_img_pos, float * t_img_enc_ms) { +static bool encode_image_with_clip(llava_context * ctx_llava, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_embd, int * n_img_pos) { + auto ctx_clip = ctx_llava->ctx_clip; clip_image_f32 img_res; if (!clip_image_preprocess(ctx_clip, img, &img_res, /*pad2square =*/ true)) { fprintf(stderr, "%s: unable to preprocess image\n", __func__); @@ -26,6 +27,14 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli *n_img_pos = clip_n_patches(ctx_clip); *n_img_embd = clip_n_mmproj_embd(ctx_clip); + // make sure that the correct mmproj was used, i.e., compare apples to apples + int n_llama_embd = llama_n_embd(llama_get_model(ctx_llava->ctx_llama)); + if (*n_img_embd != n_llama_embd) { + printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, *n_img_embd, n_llama_embd); + + return false; + } + const int64_t t_img_enc_start_us = ggml_time_us(); if (!clip_image_encode(ctx_clip, n_threads, &img_res, image_embd)) { fprintf(stderr, "Unable to encode image\n"); @@ -33,12 +42,18 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli return false; } const int64_t t_img_enc_end_us = ggml_time_us(); - *t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0; + float t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0; + + { + printf("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / *n_img_pos); + } + return true; } -bool llava_build_img_embed(struct llava_context * ctx_llava, const clip_image_u8 * img) { +static bool llava_build_img_embed(struct llava_context * ctx_llava, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_image_pos_out) { + auto ctx_clip = ctx_llava->ctx_clip; float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)); if (!image_embd) { fprintf(stderr, "Unable to allocate memory for image embeddings\n"); @@ -46,24 +61,22 @@ bool llava_build_img_embed(struct llava_context * ctx_llava, const clip_image_u8 return false; } + int n_image_pos; int n_img_embd; - int n_img_pos; - float t_img_enc_ms; - if (!encode_image_with_clip(ctx_clip, params->n_threads, &img, image_embd, &n_img_embd, &n_img_pos, &t_img_enc_ms)) { + if (!encode_image_with_clip(ctx_llava, n_threads, img, image_embd, &n_img_embd, &n_image_pos)) { fprintf(stderr, "%s: cannot encode image, aborting\n", __func__); free(image_embd); return false; } - - ctx_llava->image_embd = image_embd; - retur true; + *image_embd_out = image_embd; + *n_image_pos_out = n_image_pos; + return true; } struct llava_context * llava_init(gpt_params * params) { const char * clip_path = params->mmproj.c_str(); - const char * img_path = params->image.c_str(); auto prompt = params->prompt; if (prompt.empty()) { @@ -94,55 +107,36 @@ struct llava_context * llava_init(gpt_params * params) { return NULL; } - // make sure that the correct mmproj was used, i.e., compare apples to apples - int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama)); - if (n_img_embd != n_llama_embd) { - printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_img_embd, n_llama_embd); - - llama_free(ctx_llama); - llama_free_model(model); - llama_backend_free(); - free(image_embd); - - return NULL; - } - - { - printf("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / n_img_pos); - } - auto ctx_llava = (struct llava_context *)malloc(sizeof(llava_context)); ctx_llava->ctx_llama = ctx_llama; ctx_llava->ctx_clip = ctx_clip; ctx_llava->model = model; - ctx_llava->image_embd = image_embd; - ctx_llava->n_img_pos = n_img_pos; return ctx_llava; } void llava_free(struct llava_context * ctx_llava) { if (ctx_llava->ctx_clip) { - clip_free(ctx_clip); + clip_free(ctx_llava->ctx_clip); ctx_llava->ctx_clip = NULL; } llama_free(ctx_llava->ctx_llama); llama_free_model(ctx_llava->model); llama_backend_free(); - free(ctx_llava->image_embd); } -void llava_process_prompt(struct llava_context * ctx_llava, gpt_params * params, const char * prompt) { +static void llava_process_prompt(struct llava_context * ctx_llava, float * image_embd, int n_img_pos, gpt_params * params, const char * prompt) { int n_past = 0; const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict; + // llava chat format is "USER: \n\nASSISTANT:" // GG: are we sure that the should be a trailing whitespace at the end of this string? eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params->n_batch, &n_past); - eval_image_embd(ctx_llava->ctx_llama, ctx_llava->image_embd, ctx_llava->n_img_pos, params->n_batch, &n_past); + eval_image_embd(ctx_llava->ctx_llama, image_embd, n_img_pos, params->n_batch, &n_past); eval_string(ctx_llava->ctx_llama, prompt, params->n_batch, &n_past); eval_string(ctx_llava->ctx_llama, "\nASSISTANT:", params->n_batch, &n_past); @@ -186,31 +180,34 @@ int main(int argc, char ** argv) { // load and preprocess the image clip_image_u8 img; + auto prompt = params.prompt; if (prompt_contains_image(prompt)) { - if (img_path) { + if (!params.image.empty()) { printf("using base64 encoded image instead of command line image path\n"); } if (!get_image_from_prompt(prompt, &img)) { fprintf(stderr, "%s: can't load image from prompt\n", __func__); - clip_free(ctx_clip); - return NULL; + llava_free(ctx_llava); + return 1; } prompt = remove_image_from_prompt(prompt); } else { - if (!clip_image_load_from_file(img_path, &img)) { - fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path); - clip_free(ctx_clip); - return NULL; + if (!clip_image_load_from_file(params.image.c_str(), &img)) { + fprintf(stderr, "%s: is %s really an image file?\n", __func__, params.image.c_str()); + llava_free(ctx_llava); + return 1; } } - llava_build_img_embed(ctx_llava, &img); + float * image_embd; + int n_image_pos; + llava_build_img_embed(ctx_llava, params.n_threads, &img, &image_embd, &n_image_pos); // process the prompt - // llava chat format is "USER: \n\nASSISTANT:" - llava_process_prompt(ctx_llava, ¶ms, params.prompt.c_str()); + llava_process_prompt(ctx_llava, image_embd, n_image_pos, ¶ms, params.prompt.c_str()); llama_print_timings(ctx_llava->ctx_llama); + free(image_embd); llava_free(ctx_llava); return 0; } diff --git a/examples/llava/llava.h b/examples/llava/llava.h index 4f229a08c..ddbcc8d43 100644 --- a/examples/llava/llava.h +++ b/examples/llava/llava.h @@ -14,14 +14,14 @@ struct llava_context { struct llama_context * ctx_llama = NULL; struct llama_model * model = NULL; - int n_img_pos = 0; - float * image_embd = NULL; +// int n_img_pos = 0; +// float * image_embd = NULL; }; struct llava_context * llava_init(gpt_params * params); void llava_free(struct llava_context * ctx_llava); -void llava_process_prompt(struct llava_context * ctx_llava, gpt_params * params, const char * prompt); +//void llava_process_prompt(struct llava_context * ctx_llava, gpt_params * params, const char * prompt); #ifdef __cplusplus