From 65ec518d4120bc25425204d5834991ab9bca0639 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 13 Feb 2024 20:22:28 +0200 Subject: [PATCH] llava : fix compile warnings --- examples/llava/clip.cpp | 109 ++++++++++++++++++++++++--------------- examples/llava/clip.h | 27 ++++------ examples/llava/llava.cpp | 72 +++++++++++++------------- 3 files changed, 112 insertions(+), 96 deletions(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index ad12bd8c5..2baceda5d 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -30,6 +30,7 @@ #include #include #include +#include // #define CLIP_DEBUG_FUNCTIONS @@ -242,7 +243,7 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { } } -static void print_tensor_info(const ggml_tensor* tensor, const char* prefix = "") { +static void print_tensor_info(const ggml_tensor * tensor, const char * prefix = "") { size_t tensor_size = ggml_nbytes(tensor); printf("%s: n_dims = %d, name = %s, tensor_size=%zu, shape:[%" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "], type = %s\n", prefix, ggml_n_dims(tensor), tensor->name, tensor_size, @@ -263,6 +264,24 @@ static projector_type clip_projector_type_from_string(const std::string & name) // clip layers // +struct clip_hparams { + int32_t image_size; + int32_t patch_size; + int32_t hidden_size; + int32_t n_intermediate; + int32_t projection_dim; + int32_t n_head; + int32_t n_layer; + + float eps; + + char mm_patch_merge_type[32]="flat"; // spatial_unpad or flat (default) + + int32_t image_grid_pinpoints[32]; + int32_t image_crop_resolution; + +}; + struct clip_layer { // attention struct ggml_tensor * k_w; @@ -292,7 +311,7 @@ struct clip_layer { }; struct clip_vision_model { - struct clip_vision_hparams hparams; + struct clip_hparams hparams; // embeddings struct ggml_tensor * class_embedding; @@ -376,10 +395,6 @@ struct clip_ctx { ggml_allocr * compute_alloc = NULL; }; -const struct clip_vision_hparams clip_get_vision_hparams(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams; -} - static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs) { if (!ctx->has_vision_encoder) { printf("This gguf file seems to have no vision encoder\n"); @@ -392,7 +407,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 const int image_size = hparams.image_size; const int patch_size = hparams.patch_size; const int num_patches = ((image_size / patch_size) * (image_size / patch_size)); - const int num_patches_per_side = image_size / patch_size; + const int num_patches_per_side = image_size / patch_size; GGML_UNUSED(num_patches_per_side); const int num_positions = num_patches + 1; const int hidden_size = hparams.hidden_size; const int n_head = hparams.n_head; @@ -1292,7 +1307,7 @@ inline float lerp(float s, float e, float t) { return s + (e - s) * t; } // Bilinear resize function -void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) { +static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) { dst.nx = target_width; dst.ny = target_height; dst.buf.resize(3 * target_width * target_height); @@ -1327,7 +1342,7 @@ void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_wi } // Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not -void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* dst, const float mean[3], const float std[3]) { +static void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* dst, const float mean[3], const float std[3]) { dst->nx = src->nx; dst->ny = src->ny; dst->buf.resize(src->buf.size()); @@ -1338,12 +1353,11 @@ void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* dst, co } } -inline float clip(float x, float lower, float upper) -{ +inline float clip(float x, float lower, float upper) { return std::max(lower, std::min(x, upper)); } -bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_width, int target_height) -{ + +static bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_width, int target_height) { const int nx = img.nx; const int ny = img.ny; @@ -1351,11 +1365,10 @@ bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_wid dst.ny = target_height; dst.buf.resize(3 * target_width * target_height); - int a, b, c, d, index; - float Ca, Cb, Cc; + float Cc; float C[5]; float d0, d2, d3, a0, a1, a2, a3; - int i, j, k, ii, jj; + int i, j, k, jj; int x, y; float dx, dy; float tx, ty; @@ -1363,39 +1376,29 @@ bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_wid tx = (float)nx / (float)target_width; ty = (float)ny / (float)target_height; - float scale = std::max(tx, ty); - // Bicubic interpolation; adapted from ViT.cpp, inspired from : // -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36 // -> https://en.wikipedia.org/wiki/Bicubic_interpolation - for (i = 0; i < target_height; i++) - { - for (j = 0; j < target_width; j++) - { + for (i = 0; i < target_height; i++) { + for (j = 0; j < target_width; j++) { x = (int)(tx * j); y = (int)(ty * i); dx = tx * j - x; dy = ty * i - y; - index = (y * nx + x) * 3; - a = (y * nx + (x + 1)) * 3; - b = ((y + 1) * nx + x) * 3; - c = ((y + 1) * nx + (x + 1)) * 3; - - for (k = 0; k < 3; k++) - { - for (jj = 0; jj <= 3; jj++) - { + for (k = 0; k < 3; k++) { + for (jj = 0; jj <= 3; jj++) { d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; - a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; - a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; + a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; + a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; + C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx; d0 = C[0] - C[1]; @@ -1403,8 +1406,8 @@ bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_wid d3 = C[3] - C[1]; a0 = C[1]; a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; - a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; - a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; + a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; + a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy; const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f); @@ -1418,7 +1421,7 @@ bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_wid } // llava-1.6 type of resize_and_pad (black) -void resize_and_pad_image(const clip_image_u8& image, clip_image_u8 &image_output, const std::pair& target_resolution) { +static void resize_and_pad_image(const clip_image_u8& image, clip_image_u8 &image_output, const std::pair& target_resolution) { int target_width = target_resolution.first; int target_height = target_resolution.second; @@ -1467,7 +1470,7 @@ void resize_and_pad_image(const clip_image_u8& image, clip_image_u8 &image_outpu * @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. * @return The best fit resolution in the format (width, height). */ -static std::pair select_best_resolution(const std::pair& original_size, const std::vector>& possible_resolutions) { +static std::pair select_best_resolution(const std::pair & original_size, const std::vector> & possible_resolutions) { int original_width = original_size.first; int original_height = original_size.second; std::pair best_fit; @@ -1494,7 +1497,7 @@ static std::pair select_best_resolution(const std::pair& ori } -std::vector divide_to_patches_u8(const clip_image_u8& image, int patch_size) { +static std::vector divide_to_patches_u8(const clip_image_u8 & image, int patch_size) { std::vector patches; int width = image.nx; int height = image.ny; @@ -1531,7 +1534,7 @@ void clip_image_convert_f32_to_u8(const clip_image_f32& src, clip_image_u8& dst) // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // res_imgs memory is being allocated here, previous allocations will be freed if found -bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch& res_imgs ) { +bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch & res_imgs ) { bool pad_to_square = true; if (!ctx->has_vision_encoder) { printf("This gguf file seems to have no vision encoder\n"); @@ -1710,6 +1713,30 @@ void clip_free(clip_ctx * ctx) { delete ctx; } +size_t clip_embd_nbytes(const struct clip_ctx * ctx) { + return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float); +} + +int32_t clip_image_size(const struct clip_ctx * ctx) { + return ctx->vision_model.hparams.image_size; +} + +int32_t clip_patch_size(const struct clip_ctx * ctx) { + return ctx->vision_model.hparams.patch_size; +} + +int32_t clip_hidden_size(const struct clip_ctx * ctx) { + return ctx->vision_model.hparams.hidden_size; +} + +const char * clip_patch_merge_type(const struct clip_ctx * ctx) { + return ctx->vision_model.hparams.mm_patch_merge_type; +} + +const int32_t * clip_image_grid(const struct clip_ctx * ctx) { + return ctx->vision_model.hparams.image_grid_pinpoints; +} + bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) { if (!ctx->has_vision_encoder) { printf("This gguf file seems to have no vision encoder\n"); @@ -1973,7 +2000,3 @@ int clip_n_patches(const struct clip_ctx * ctx) { } return n_patches; } - -size_t clip_embd_nbytes(const struct clip_ctx * ctx) { - return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float); -} diff --git a/examples/llava/clip.h b/examples/llava/clip.h index 2d1858bbd..5e0b5c64b 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -24,25 +24,7 @@ struct clip_ctx; extern "C" { #endif -struct clip_vision_hparams { - int32_t image_size; - int32_t patch_size; - int32_t hidden_size; - int32_t n_intermediate; - int32_t projection_dim; - int32_t n_head; - int32_t n_layer; - - float eps; - - char mm_patch_merge_type[32]="flat"; // spatial_unpad or flat (default) - int32_t image_grid_pinpoints[32]; - int32_t image_crop_resolution; - -}; - struct clip_ctx; -CLIP_API const struct clip_vision_hparams clip_get_vision_hparams(const struct clip_ctx * ctx); CLIP_API struct clip_ctx * clip_model_load(const char * fname, int verbosity); CLIP_API struct clip_ctx * clip_model_load_cpu(const char * fname, int verbosity); @@ -51,6 +33,15 @@ CLIP_API void clip_free(struct clip_ctx * ctx); CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx); +CLIP_API int32_t clip_image_size (const struct clip_ctx * ctx); +CLIP_API int32_t clip_patch_size (const struct clip_ctx * ctx); +CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx); + +// TODO: should be enum, not string +CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx); + +CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx); + CLIP_API int clip_n_patches (const struct clip_ctx * ctx); CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx); diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 9f955e2ae..ea956ac00 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -2,14 +2,13 @@ #include "common.h" #include "llama.h" #include "llava.h" +#include "base64.hpp" #include #include #include #include -#include "base64.hpp" - // RGB uint8 image struct clip_image_u8 { int nx; @@ -35,8 +34,9 @@ struct clip_image_f32 { * @return The best fit resolution in the format (width, height). */ static std::pair select_best_resolution(const std::pair& original_size, const std::vector>& possible_resolutions) { - int original_width = original_size.first; + int original_width = original_size.first; int original_height = original_size.second; + std::pair best_fit; int max_effective_resolution = 0; int min_wasted_resolution = std::numeric_limits::max(); @@ -45,7 +45,7 @@ static std::pair select_best_resolution(const std::pair& ori int width = resolution.first; int height = resolution.second; float scale = std::min(static_cast(width) / original_width, static_cast(height) / original_height); - int downscaled_width = static_cast(original_width * scale); + int downscaled_width = static_cast(original_width * scale); int downscaled_height = static_cast(original_height * scale); int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height); int wasted_resolution = (width * height) - effective_resolution; @@ -59,6 +59,7 @@ static std::pair select_best_resolution(const std::pair& ori return best_fit; } + /** * @brief Get the anyres image grid shape object * @@ -67,7 +68,7 @@ static std::pair select_best_resolution(const std::pair& ori * @param image_patch_size * @return */ -struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair& image_size, const std::vector>& grid_pinpoints, int image_patch_size) { +static struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair & image_size, const std::vector> & grid_pinpoints, int image_patch_size) { /** Conversion from gguf flat array to vector: std::vector> possible_resolutions; @@ -79,22 +80,26 @@ struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out) { - struct temp_model { - struct ggml_tensor *newline; + struct { + struct ggml_tensor * newline; struct ggml_context * ctx; } model; - auto & vparams = clip_get_vision_hparams(ctx_clip); - auto num_patches_per_side = vparams.image_size / vparams.patch_size; // 336 / 14 = 24 - used for embedding-patching boxes (24*24 = 576 patches) - int num_patches_width = grid_shape.first; // grid 1-4 + const int32_t image_size = clip_image_size(ctx_clip); + const int32_t patch_size = clip_patch_size(ctx_clip); + + int32_t num_patches_per_side = image_size / patch_size; // 336 / 14 = 24 - used for embedding-patching boxes (24*24 = 576 patches) + + int num_patches_width = grid_shape.first; // grid 1-4 int num_patches_height = grid_shape.second; // grid 1-4 + const size_t num_images = num_patches_width + num_patches_height + 1; // TODO: size calculation is not calculated - it's only tens of MB size_t ctx_size = 0; + { ctx_size += clip_embd_nbytes(ctx_clip) * num_images * 8; // image_features ctx_size += 1024*1024 * ggml_type_size(GGML_TYPE_F32); @@ -105,6 +110,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector /*.mem_buffer =*/ NULL, /*.no_alloc =*/ false, // NOTE: this should be false when using the legacy API }; + // Python reference code for full unpad: /* base_image_feature = image_feature[0] @@ -138,17 +144,15 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector */ model.ctx = ggml_init(params); - ggml_context *ctx_noalloc = ggml_init({2048, NULL, true}); - ggml_tensor *newline_tmp = clip_get_newline_tensor(ctx_clip); + ggml_tensor * newline_tmp = clip_get_newline_tensor(ctx_clip); model.newline = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, newline_tmp->ne[0]); if (newline_tmp->backend != GGML_BACKEND_CPU) { if (newline_tmp->buffer == NULL) { printf("newline_tmp tensor buffer is NULL\n"); } ggml_backend_tensor_get(newline_tmp, model.newline->data, 0, ggml_nbytes(newline_tmp)); - } else - { + } else { model.newline->data = newline_tmp->data; if (model.newline->data == NULL) { printf("newline_tmp tensor data is NULL\n"); @@ -158,8 +162,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_patches(ctx_clip), num_images - 1); // example: 4096 x 576 x 4 // ggml_tensor_printf(image_features,"image_features",__LINE__,false,false); // fill it with the image embeddings, ignoring the base - for (int i = 1; i < num_images; i++) - { + for (size_t i = 1; i < num_images; i++) { size_t offset = (i-1) * clip_embd_nbytes(ctx_clip); memcpy((uint8_t *)(image_features->data) + offset, image_embd_v[i], clip_embd_nbytes(ctx_clip)); } @@ -222,10 +225,10 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli } const int64_t t_img_enc_start_us = ggml_time_us(); - auto & vparams = clip_get_vision_hparams(ctx_clip); - if (strcmp(vparams.mm_patch_merge_type, "spatial_unpad") != 0) - { + const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip); + + if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { // flat / default llava-1.5 type embedding *n_img_pos = clip_n_patches(ctx_clip); bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); // image_embd shape is 576 x 4096 @@ -235,41 +238,43 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli return false; } - } else - { + } else { // spatial_unpad llava-1.6 type embedding // TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a solution to quickly get batching working std::vector image_embd_v; image_embd_v.resize(img_res_v.size); - for (int i = 0; i < img_res_v.size; i++) - { + for (size_t i = 0; i < img_res_v.size; i++) { image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184 - bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside + const bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside if (!encoded) { - fprintf(stderr, "Unable to encode image - spatial_unpad - subimage %d of %d\n", i+1, (int)img_res_v.size); + fprintf(stderr, "Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size); return false; } } const int64_t t_img_enc_batch_us = ggml_time_us(); printf("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); + const int32_t * image_grid = clip_image_grid(ctx_clip); std::vector> grid_pinpoints; - for (int i = 0; i < 32 && vparams.image_grid_pinpoints[i] != 0; i+=2) { - grid_pinpoints.push_back({vparams.image_grid_pinpoints[i], vparams.image_grid_pinpoints[i+1]}); + for (int i = 0; i < 32 && image_grid[i] != 0; i += 2) { + grid_pinpoints.push_back({image_grid[i], image_grid[i+1]}); } + // free all img_res_v - not needed anymore delete[] img_res_v.data; img_res_v.size = 0; img_res_v.data = nullptr; - struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, vparams.image_size); + + const int32_t image_size = clip_image_size(ctx_clip); + + struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, image_size); int n_img_pos_out; clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out); *n_img_pos = n_img_pos_out; - for (int i = 0; i < image_embd_v.size(); i++) - { + for (size_t i = 0; i < image_embd_v.size(); i++) { free(image_embd_v[i]); } image_embd_v.clear(); @@ -278,10 +283,9 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli // clip_image_u8 * tmp = clip_image_u8_init(); // clip_image_convert_f32_to_u8(*image_feature, *tmp); // clip_image_save_to_bmp(*tmp, "image_feature.bmp"); - } - printf("%s: image embedding created: %d tokens\n", __func__, *n_img_pos); + printf("%s: image embedding created: %d tokens\n", __func__, *n_img_pos); const int64_t t_img_enc_end_us = ggml_time_us(); float t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0; @@ -291,8 +295,6 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli return true; } - - bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) { // make sure that the correct mmproj was used, i.e., compare apples to apples int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama));