From 3a722678690952ed922a4dde8693a00882f1890a Mon Sep 17 00:00:00 2001 From: John Date: Tue, 13 Feb 2024 00:29:17 +0100 Subject: [PATCH] moved llava functions to llava.cpp, made clip.h C compatible API, replaced vector style functions with pointers, added a debug define to remove functions from compilation while not needed --- examples/llava/clip.cpp | 141 ++++++++++++++----------------------- examples/llava/clip.h | 27 +------ examples/llava/llava.cpp | 136 ++++++++++++++++++++++++++--------- examples/server/server.cpp | 29 ++++++-- 4 files changed, 184 insertions(+), 149 deletions(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 60d8e8e80..ad12bd8c5 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -31,6 +31,25 @@ #include #include +// #define CLIP_DEBUG_FUNCTIONS + +// RGB uint8 image +struct clip_image_u8 { + int nx; + int ny; + + std::vector buf; +}; + +// RGB float32 image (NHWC) +// Memory layout: RGBRGBRGB... +struct clip_image_f32 { + int nx; + int ny; + + std::vector buf; +}; + static std::string format(const char * fmt, ...) { va_list ap; va_list ap2; @@ -961,10 +980,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS); int n = gguf_get_arr_n(ctx, idx); const int32_t * pinpoints = (const int32_t *)gguf_get_arr_data(ctx, idx); - for (int i = 0; i < 32 && pinpoints[i] != 0; ++i) { + for (int i = 0; i < 32 && i < n && pinpoints[i] != 0; ++i) { hparams.image_grid_pinpoints[i] = pinpoints[i]; } - hparams.image_grid_pinpoints[n] = 0; + if (n < 32) + hparams.image_grid_pinpoints[n] = 0; } catch (std::runtime_error & e) { hparams.image_grid_pinpoints[0]=0; } @@ -1170,7 +1190,7 @@ bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length return true; } - +#ifdef CLIP_DEBUG_FUNCTIONS void clip_image_write_image_to_ppm(const clip_image_u8& img, const std::string& filename) { std::ofstream file(filename, std::ios::binary); if (!file.is_open()) { @@ -1265,6 +1285,7 @@ void clip_image_save_to_bmp(const clip_image_u8& img, const std::string& filenam file.close(); } +#endif // Linear interpolation between two points inline float lerp(float s, float e, float t) { @@ -1305,41 +1326,8 @@ void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_wi } } -// for replication purposes `.to(model.device, dtype=torch.float16)` -// converts a float to half precision and back to float -float simulateFloat16Precision(float value) { - // Convert float32 to float16 - uint32_t f32 = *reinterpret_cast(&value); - uint32_t sign = (f32 >> 16) & 0x8000; // Top bit (sign bit) - uint32_t exponent = ((f32 >> 23) & 0xFF) - 112; // Adjust bias (112 is bias of float16, 127 is bias of float32) - uint32_t mantissa = (f32 >> 13) & 0x3FF; // Keep top 10 bits (10 bits of precision in float16, 23 in float32) - - // Handle overflow/underflow - if ((f32 & 0x7FFFFFFF) > 0x477FE000) { // Not representable - exponent = 0x1F; - mantissa = 0; - } else if ((f32 & 0x7FFFFFFF) < 0x38800000) { // Too small for normal half precision - exponent = 0; - mantissa = 0; - } - - uint16_t f16 = sign | (exponent << 10) | mantissa; - - // Convert back to float32 - uint32_t sign32 = (f16 & 0x8000) << 16; - uint32_t exponent32 = ((f16 >> 10) & 0x1F); - uint32_t mantissa32 = (f16 & 0x3FF) << 13; - - // Adjust bias back - exponent32 = exponent32 == 0 ? 0 : exponent32 + 112; - - uint32_t f32Result = sign32 | (exponent32 << 23) | mantissa32; - float result = *reinterpret_cast(&f32Result); - - return result; -} -// Normalize image to float32 - supports float16 replication as in pytorch .to(model.device, dtype=torch.float16) -void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* dst, const float mean[3], const float std[3], bool replicate_float16) { +// Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not +void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* dst, const float mean[3], const float std[3]) { dst->nx = src->nx; dst->ny = src->ny; dst->buf.resize(src->buf.size()); @@ -1347,12 +1335,9 @@ void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* dst, co for (size_t i = 0; i < src->buf.size(); ++i) { int c = i % 3; // rgb dst->buf[i] = (static_cast(src->buf[i]) / 255.0f - mean[c]) / std[c]; - - if (replicate_float16) { - dst->buf[i] = simulateFloat16Precision(dst->buf[i]); - } } } + inline float clip(float x, float lower, float upper) { return std::max(lower, std::min(x, upper)); @@ -1471,7 +1456,6 @@ void resize_and_pad_image(const clip_image_u8& image, clip_image_u8 &image_outpu } } } - image_output = std::move(padded_image); } @@ -1533,7 +1517,7 @@ std::vector divide_to_patches_u8(const clip_image_u8& image, int return patches; } - +#ifdef CLIP_DEBUG_FUNCTIONS // debug function to convert f32 to u8 void clip_image_convert_f32_to_u8(const clip_image_f32& src, clip_image_u8& dst) { dst.nx = src.nx; @@ -1543,32 +1527,12 @@ void clip_image_convert_f32_to_u8(const clip_image_f32& src, clip_image_u8& dst) dst.buf[i] = static_cast(std::min(std::max(int(src.buf[i] * 255.0f), 0), 255)); } } +#endif -/** - * @brief Get the anyres image grid shape object - * - * @param image_size - * @param grid_pinpoints - * @param image_patch_size - * @return - */ -struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair& image_size, const std::vector>& grid_pinpoints, int image_patch_size) { - /** - Conversion from gguf flat array to vector: - std::vector> possible_resolutions; - for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) { - possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]}); - } - */ - auto best_resolution = select_best_resolution(image_size, grid_pinpoints); - return {best_resolution.first / image_patch_size, best_resolution.second / image_patch_size}; -} - - -// normalize: x = (x - mean) / std -// TODO: implement bicubic interpolation instead of linear. -// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patche tensors as a vector -bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, std::vector& res_tensor, bool pad2square) { +// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector +// res_imgs memory is being allocated here, previous allocations will be freed if found +bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch& res_imgs ) { + bool pad_to_square = true; if (!ctx->has_vision_encoder) { printf("This gguf file seems to have no vision encoder\n"); return false; @@ -1576,23 +1540,23 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, std auto & params = ctx->vision_model.hparams; // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing if (strcmp(params.mm_patch_merge_type, "spatial_unpad") == 0) { - pad2square = false; - } else { - // pad2square = true; // todo: consider automatic decisions on that options for all models + pad_to_square = false; } - // free the previous res_tensor - if (res_tensor.size() > 0) { - for (size_t i = 0; i < res_tensor.size(); i++) { - clip_image_f32_free(res_tensor[i]); + // free the previous res_imgs if any set + if (res_imgs.size > 0 && res_imgs.size < 100) { + for (size_t i = 0; i < res_imgs.size; i++) { + clip_image_f32_free(&(res_imgs.data[i])); } - res_tensor.clear(); + delete[] res_imgs.data; } + res_imgs.data = nullptr; + res_imgs.size = 0; // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 clip_image_u8 * temp = clip_image_u8_init(); // we will keep the input image data here temporarily - if (pad2square && img->nx != img->ny) { + if (pad_to_square && img->nx != img->ny) { int longer_side = std::max(img->nx, img->ny); temp->nx = longer_side; temp->ny = longer_side; @@ -1636,18 +1600,18 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, std // } std::vector patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6) - // fprintf(stderr, "patches: %d, %d\n", patches.size(), params.image_size); clip_image_u8 *image_original_resize = clip_image_u8_init(); - // bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square ? - bicubic_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square ? + // bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square + bicubic_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square patches.insert(patches.begin(), image_original_resize); - - res_tensor.clear(); + // clip_image_f32_batch_init(patches.size()); + res_imgs.size = patches.size(); + res_imgs.data = new clip_image_f32[res_imgs.size]; + int num=0; for (auto& patch : patches) { - clip_image_f32 *temp_image_f32 = clip_image_f32_init(); - normalize_image_u8_to_f32(patch, temp_image_f32, ctx->image_mean, ctx->image_std, false); // set to true for pytorch fp16 value replication - res_tensor.push_back(temp_image_f32); + normalize_image_u8_to_f32(patch, &res_imgs.data[num], ctx->image_mean, ctx->image_std); + num++; } for (size_t i = 0; i < patches.size(); i++) { @@ -1732,7 +1696,10 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, std // clip_image_save_to_bmp(*temp2, "resized_normalized_f32_vanilla.bmp"); // clip_image_u8_free(temp2); // } - res_tensor.push_back(res); + // res_imgs.push_back(res); + res_imgs.size = 1; + res_imgs.data = new clip_image_f32[res_imgs.size]; + res_imgs.data[0] = std::move(*res); return true; } diff --git a/examples/llava/clip.h b/examples/llava/clip.h index c1981bb5d..2d1858bbd 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -3,8 +3,6 @@ #include #include -#include -#include #ifdef LLAMA_SHARED # if defined(_WIN32) && !defined(__MINGW32__) @@ -56,24 +54,6 @@ CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx); CLIP_API int clip_n_patches (const struct clip_ctx * ctx); CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx); -// RGB uint8 image -CLIP_API struct clip_image_u8 { - int nx; - int ny; - - std::vector buf; -}; - -// RGB float32 image (NHWC) -// Memory layout: RGBRGBRGB... - CLIP_API struct clip_image_f32 { - int nx; - int ny; - - std::vector buf; -}; - - struct clip_image_u8_batch { struct clip_image_u8 * data; size_t size; @@ -95,14 +75,11 @@ CLIP_API void clip_image_u8_free (struct clip_image_u8 * img); CLIP_API void clip_image_f32_free(struct clip_image_f32 * img); CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img); -CLIP_API void clip_image_save_to_bmp(const clip_image_u8& img, const std::string& filename); -CLIP_API void clip_image_convert_f32_to_u8(const clip_image_f32& src, clip_image_u8& dst); /** interpret bytes as an image file with length bytes_length, and use the result to populate img */ CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img); -/** preprocess img and store the result in res_tensor, pad2square may be overriden to false depending on model configuration */ -CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, std::vector& res_tensor, bool pad2square); -CLIP_API struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair& image_size, const std::vector>& grid_pinpoints, int image_patch_size); +/** preprocess img and store the result in res_imgs, pad_to_square may be overriden to false depending on model configuration */ +CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch& res_imgs ); CLIP_API struct ggml_tensor *clip_get_newline_tensor(const struct clip_ctx * ctx); CLIP_API bool clip_image_encode (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec); diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index ff99a688e..699fd256a 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -10,8 +10,78 @@ #include "base64.hpp" +// RGB uint8 image +struct clip_image_u8 { + int nx; + int ny; + + std::vector buf; +}; + +// RGB float32 image (NHWC) +// Memory layout: RGBRGBRGB... +struct clip_image_f32 { + int nx; + int ny; + + std::vector buf; +}; + +/** + * Selects the best resolution from a list of possible resolutions based on the original size. + * + * @param original_size The original size of the image in the format (width, height). + * @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. + * @return The best fit resolution in the format (width, height). + */ +static std::pair select_best_resolution(const std::pair& original_size, const std::vector>& possible_resolutions) { + int original_width = original_size.first; + int original_height = original_size.second; + std::pair best_fit; + int max_effective_resolution = 0; + int min_wasted_resolution = std::numeric_limits::max(); + + for (const auto& resolution : possible_resolutions) { + int width = resolution.first; + int height = resolution.second; + float scale = std::min(static_cast(width) / original_width, static_cast(height) / original_height); + int downscaled_width = static_cast(original_width * scale); + int downscaled_height = static_cast(original_height * scale); + int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height); + int wasted_resolution = (width * height) - effective_resolution; + // fprintf(stderr, "resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution); + if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) { + max_effective_resolution = effective_resolution; + min_wasted_resolution = wasted_resolution; + best_fit = resolution; + } + } + + return best_fit; +} +/** + * @brief Get the anyres image grid shape object + * + * @param image_size + * @param grid_pinpoints + * @param image_patch_size + * @return + */ +struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair& image_size, const std::vector>& grid_pinpoints, int image_patch_size) { + /** + Conversion from gguf flat array to vector: + std::vector> possible_resolutions; + for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) { + possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]}); + } + */ + auto best_resolution = select_best_resolution(image_size, grid_pinpoints); + return {best_resolution.first / image_patch_size, best_resolution.second / image_patch_size}; +} + + // Take the image segments in a grid configuration and return the embeddings and the number of embeddings into preallocated memory (image_embd_out) -static bool handle_patches(clip_ctx * ctx_clip, std::vector & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out) { +static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out) { struct temp_model { struct ggml_tensor *newline; struct ggml_context * ctx; @@ -21,11 +91,12 @@ static bool handle_patches(clip_ctx * ctx_clip, std::vector & image_emb auto num_patches_per_side = vparams.image_size / vparams.patch_size; // 336 / 14 = 24 - used for embedding-patching boxes (24*24 = 576 patches) int num_patches_width = grid_shape.first; // grid 1-4 int num_patches_height = grid_shape.second; // grid 1-4 + const size_t num_images = num_patches_width + num_patches_height + 1; // TODO: size calculation is not calculated - it's only tens of MB size_t ctx_size = 0; { - ctx_size += clip_embd_nbytes(ctx_clip) * image_embd_v.size() * 8; // image_features + ctx_size += clip_embd_nbytes(ctx_clip) * num_images * 8; // image_features ctx_size += 1024*1024 * ggml_type_size(GGML_TYPE_F32); } @@ -84,10 +155,10 @@ static bool handle_patches(clip_ctx * ctx_clip, std::vector & image_emb } } - struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_patches(ctx_clip), image_embd_v.size() - 1); // example: 4096 x 576 x 4 + struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_patches(ctx_clip), num_images - 1); // example: 4096 x 576 x 4 // ggml_tensor_printf(image_features,"image_features",__LINE__,false,false); // fill it with the image embeddings, ignoring the base - for (int i = 1; i < image_embd_v.size(); i++) + for (int i = 1; i < num_images; i++) { size_t offset = (i-1) * clip_embd_nbytes(ctx_clip); memcpy((uint8_t *)(image_features->data) + offset, image_embd_v[i], clip_embd_nbytes(ctx_clip)); @@ -106,6 +177,15 @@ static bool handle_patches(clip_ctx * ctx_clip, std::vector & image_emb size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip) * num_patches_per_side * num_patches_width, 0); // ggml_tensor_printf(image_features_patchview,"image_features_patchview",__LINE__,false,false); struct ggml_tensor *permuted_cont = ggml_cont(model.ctx, ggml_permute(model.ctx, image_features_patchview, 0, 2, 1, 3)); + /** + At the end of each row we have to add the row_end embeddings, which are the same as the newline embeddings + image_feature = torch.cat(( + image_feature, + self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device) + ), dim=-1) + * + */ + // ggml_tensor_printf(permuted_cont,"permuted_cont",__LINE__,false,false); struct ggml_tensor *flatten = ggml_view_2d(model.ctx, permuted_cont, clip_n_mmproj_embd(ctx_clip), num_patches_height * num_patches_width * num_patches_per_side * num_patches_per_side, size_ele * clip_n_mmproj_embd(ctx_clip), 0); // ggml_tensor_printf(flatten,"flatten",__LINE__,false,false); @@ -115,7 +195,7 @@ static bool handle_patches(clip_ctx * ctx_clip, std::vector & image_emb memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as global context // append without newline tokens (default behavior in llava_arch when not using unpad ): - memcpy(image_embd_out + clip_n_patches(ctx_clip) * clip_n_mmproj_embd(ctx_clip), (float*)result->data, clip_embd_nbytes(ctx_clip) * (image_embd_v.size()-1)); // grid patches + memcpy(image_embd_out + clip_n_patches(ctx_clip) * clip_n_mmproj_embd(ctx_clip), (float*)result->data, clip_embd_nbytes(ctx_clip) * (num_images-1)); // grid patches *n_img_pos_out = static_cast(result->ne[1]+clip_n_patches(ctx_clip)); // Debug: Test single segments @@ -131,37 +211,25 @@ static bool handle_patches(clip_ctx * ctx_clip, std::vector & image_emb static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) { - std::vector img_res_v; // format VectN x H x W x RGB (N x 336 x 336 x 3), so interleaved RGB - different to the python implementation which is N x 3 x 336 x 336 - if (!clip_image_preprocess(ctx_clip, img, img_res_v, /*pad2square =*/ true)) { + // std::vector img_res_v; // format VectN x H x W x RGB (N x 336 x 336 x 3), so interleaved RGB - different to the python implementation which is N x 3 x 336 x 336 + clip_image_f32_batch img_res_v; + img_res_v.size = 0; + img_res_v.data = nullptr; + if (!clip_image_preprocess(ctx_clip, img, img_res_v)) { fprintf(stderr, "%s: unable to preprocess image\n", __func__); - for (auto img_res : img_res_v) { - clip_image_f32_free(img_res); - } + delete[] img_res_v.data; return false; } const int64_t t_img_enc_start_us = ggml_time_us(); auto & vparams = clip_get_vision_hparams(ctx_clip); - // DEBUG print the "shape" and the first 10 rows and 10 cols of img_res_v in exp format - // for (int i = 0; i < img_res_v.size(); i++) - // { - // printf("img_res_v[%d] shape: %d x %d\n", i, img_res_v[i]->nx, img_res_v[i]->ny); - // for (int j = 0; j < 10; j++) - // { - // for (int k = 0; k < 10; k++) - // { - // printf("%e ", img_res_v[i]->buf[j*img_res_v[i]->ny + k]); - // } - // printf("\n"); - // } - // } if (strcmp(vparams.mm_patch_merge_type, "spatial_unpad") != 0) { // flat / default llava-1.5 type embedding *n_img_pos = clip_n_patches(ctx_clip); - bool encoded = clip_image_encode(ctx_clip, n_threads, img_res_v[0], image_embd); // image_embd shape is 576 x 4096 - clip_image_f32_free(img_res_v[0]); + bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); // image_embd shape is 576 x 4096 + delete[] img_res_v.data; if (!encoded) { fprintf(stderr, "Unable to encode image\n"); @@ -172,30 +240,32 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli // spatial_unpad llava-1.6 type embedding // TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a solution to quickly get batching working std::vector image_embd_v; - image_embd_v.resize(img_res_v.size()); - for (int i = 0; i < img_res_v.size(); i++) + image_embd_v.resize(img_res_v.size); + for (int i = 0; i < img_res_v.size; i++) { image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184 - bool encoded = clip_image_encode(ctx_clip, n_threads, img_res_v[i], image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside - clip_image_f32_free(img_res_v[i]); + bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside if (!encoded) { - fprintf(stderr, "Unable to encode image - spatial_unpad - subimage %d of %d\n", i+1, (int)img_res_v.size()); + fprintf(stderr, "Unable to encode image - spatial_unpad - subimage %d of %d\n", i+1, (int)img_res_v.size); return false; } } const int64_t t_img_enc_batch_us = ggml_time_us(); - printf("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size(), (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); + printf("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); std::vector> grid_pinpoints; for (int i = 0; i < 32 && vparams.image_grid_pinpoints[i] != 0; i+=2) { grid_pinpoints.push_back({vparams.image_grid_pinpoints[i], vparams.image_grid_pinpoints[i+1]}); } - img_res_v.clear(); + // free all img_res_v - not needed anymore + delete[] img_res_v.data; + img_res_v.size = 0; + img_res_v.data = nullptr; struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, vparams.image_size); int n_img_pos_out; - handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out); + clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out); *n_img_pos = n_img_pos_out; for (int i = 0; i < image_embd_v.size(); i++) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 353bd8976..9148f6ca2 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -31,6 +31,23 @@ using json = nlohmann::json; +// RGB uint8 image +struct clip_image_u8 { + int nx; + int ny; + + std::vector buf; +}; + +// RGB float32 image (NHWC) +// Memory layout: RGBRGBRGB... +struct clip_image_f32 { + int nx; + int ny; + + std::vector buf; +}; + struct server_params { std::string hostname = "127.0.0.1"; @@ -943,14 +960,17 @@ struct llama_server_context { continue; } - std::vector img_res_v; - if (!clip_image_preprocess(clp_ctx, img.img_data, img_res_v, /*pad2square =*/ true)) + clip_image_f32_batch img_res_v; + img_res_v.size = 0; + img_res_v.data = nullptr; + if (!clip_image_preprocess(clp_ctx, img.img_data, img_res_v)) { LOG_TEE("Error processing the given image"); clip_free(clp_ctx); + delete[] img_res_v.data; return false; } - clip_image_f32 * img_res = img_res_v[0]; + clip_image_f32 * img_res = &img_res_v.data[0]; img.image_tokens = clip_n_patches(clp_ctx); img.image_embedding = (float *)malloc(clip_embd_nbytes(clp_ctx)); if (!img.image_embedding) @@ -965,7 +985,8 @@ struct llama_server_context LOG_TEE("Unable to encode image\n"); return false; } - clip_image_f32_free(img_res); + // clip_image_f32_free(img_res); + delete[] img_res_v.data; img.request_encode_image = false; }