From 0209d39526db18501c5141b109b431f42b359b89 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Fri, 13 Oct 2023 10:33:07 +0200 Subject: [PATCH] wip llava python bindings compatibility --- examples/llava/CMakeLists.txt | 3 +++ examples/llava/clip.cpp | 32 ++++++++++++++++------- examples/llava/clip.h | 1 + examples/llava/llava.cpp | 49 ++++++++++++++++++++--------------- examples/llava/test-llava.cpp | 6 +++++ 5 files changed, 61 insertions(+), 30 deletions(-) create mode 100644 examples/llava/test-llava.cpp diff --git a/examples/llava/CMakeLists.txt b/examples/llava/CMakeLists.txt index d02e6ab46..d04dcc5c5 100644 --- a/examples/llava/CMakeLists.txt +++ b/examples/llava/CMakeLists.txt @@ -18,3 +18,6 @@ target_compile_features(${TARGET} PRIVATE cxx_std_11) if(TARGET BUILD_INFO) add_dependencies(${TARGET} BUILD_INFO) endif() + +unset(TARGET) +llama_build_and_test_executable(test-llava.cpp) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index f4258b34d..5bb2e4c37 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -682,25 +682,39 @@ clip_image_u8 * make_clip_image_u8() { return new clip_image_u8(); } clip_image_f32 * make_clip_image_f32() { return new clip_image_f32(); } -bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) { - int nx, ny, nc; - auto data = stbi_load(fname, &nx, &ny, &nc, 3); - if (!data) { - fprintf(stderr, "%s: failed to load '%s'\n", __func__, fname); - return false; - } - +static void build_clip_img_from_data(const stbi_uc * data, int nx, int ny, clip_image_u8 * img) { img->nx = nx; img->ny = ny; img->size = nx * ny * 3; img->data = new uint8_t[img->size](); memcpy(img->data, data, img->size); +} +bool clip_image_load_from_bytes(const unsigned char * bytes, int bytes_length, clip_image_u8 * img) { + int nx, ny, nc; + auto data = stbi_load_from_memory(bytes, bytes_length, &nx, &ny, &nc, 3); + if (!data) { + fprintf(stderr, "%s: failed to decode image bytes\n", __func__); + return false; + } + build_clip_img_from_data(data, nx, ny, img); stbi_image_free(data); - return true; } +bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) { + int nx, ny, nc; + auto data = stbi_load(fname, &nx, &ny, &nc, 3); + if (!data) { + fprintf(stderr, "%s: failed to load image '%s'\n", __func__, fname); + return false; + } + build_clip_img_from_data(data, nx, ny, img); + stbi_image_free(data); + return true; +} + + // normalize: x = (x - mean) / std // TODO: implement bicubic interpolation instead of linear. bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32 * res, const bool pad2square) { diff --git a/examples/llava/clip.h b/examples/llava/clip.h index 3d7261e29..c0b53d0b8 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -58,6 +58,7 @@ struct clip_image_f32_batch { struct clip_image_u8 * make_clip_image_u8(); struct clip_image_f32 * make_clip_image_f32(); bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img); +bool clip_image_load_from_bytes(const unsigned char * bytes, int bytes_length, clip_image_u8 * img); bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res, const bool pad2square); bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec); diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 14dacc780..c55d4f165 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -12,6 +12,28 @@ static void show_additional_info(int /*argc*/, char ** argv) { printf(" note: a lower temperature value like 0.1 is recommended for better quality.\n"); } +static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_embd, int * n_img_pos, float * t_img_enc_ms) { + clip_image_f32 img_res; + if (!clip_image_preprocess(ctx_clip, img, &img_res, /*pad2square =*/ true)) { + fprintf(stderr, "%s: unable to preprocess image\n", __func__); + + return false; + } + + *n_img_pos = clip_n_patches(ctx_clip); + *n_img_embd = clip_n_mmproj_embd(ctx_clip); + + const int64_t t_img_enc_start_us = ggml_time_us(); + if (!clip_image_encode(ctx_clip, n_threads, &img_res, image_embd)) { + fprintf(stderr, "Unable to encode image\n"); + + return false; + } + const int64_t t_img_enc_end_us = ggml_time_us(); + *t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0; + return true; +} + int main(int argc, char ** argv) { ggml_time_init(); @@ -39,40 +61,27 @@ int main(int argc, char ** argv) { // load and preprocess the image clip_image_u8 img; - clip_image_f32 img_res; if (!clip_image_load_from_file(img_path, &img)) { fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path); - clip_free(ctx_clip); return 1; } - if (!clip_image_preprocess(ctx_clip, &img, &img_res, /*pad2square =*/ true)) { - fprintf(stderr, "%s: unable to preprocess %s\n", __func__, img_path); - - clip_free(ctx_clip); - return 1; - } - - int n_img_pos = clip_n_patches(ctx_clip); - int n_img_embd = clip_n_mmproj_embd(ctx_clip); - float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)); - if (!image_embd) { fprintf(stderr, "Unable to allocate memory for image embeddings\n"); - return 1; } - const int64_t t_img_enc_start_us = ggml_time_us(); - if (!clip_image_encode(ctx_clip, params.n_threads, &img_res, image_embd)) { - fprintf(stderr, "Unable to encode image\n"); - + int n_img_embd; + int n_img_pos; + float t_img_enc_ms; + if (!encode_image_with_clip(ctx_clip, params.n_threads, &img, image_embd, &n_img_embd, &n_img_pos, &t_img_enc_ms)) { + fprintf(stderr, "%s: cannot encode image, aborting\n", __func__); + clip_free(ctx_clip); return 1; } - const int64_t t_img_enc_end_us = ggml_time_us(); // we get the embeddings, free up the memory required for CLIP clip_free(ctx_clip); @@ -140,8 +149,6 @@ int main(int argc, char ** argv) { printf("\n"); { - const float t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0; - printf("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / n_img_pos); } diff --git a/examples/llava/test-llava.cpp b/examples/llava/test-llava.cpp new file mode 100644 index 000000000..6e8a01367 --- /dev/null +++ b/examples/llava/test-llava.cpp @@ -0,0 +1,6 @@ +#include + +int main(int argc, char ** argv) { + printf("dummy llava test\n"); + return 0; +}