From a81ba7519305ea10c0b8f087cc61adb017bd3905 Mon Sep 17 00:00:00 2001 From: Yutong Dai Date: Tue, 3 Sep 2024 18:39:46 +0000 Subject: [PATCH] patch mergeing + masking done --- .gitignore | 2 + Makefile | 9 + examples/xgenmm/CMakeLists.txt | 6 + examples/xgenmm/clip.cpp | 2 +- examples/xgenmm/debug.py | 39 +- examples/xgenmm/playground.ipynb | 21 + examples/xgenmm/run_cli.sh | 9 + .../xgenmm/test_anyres_handle_patches.cpp | 55 +-- examples/xgenmm/test_patch_ops.cpp | 401 ++++++++++++++++++ examples/xgenmm/xgenmm.cpp | 121 +++++- 10 files changed, 610 insertions(+), 55 deletions(-) create mode 100644 examples/xgenmm/run_cli.sh create mode 100644 examples/xgenmm/test_patch_ops.cpp diff --git a/.gitignore b/.gitignore index 99a22c4a1..68ac89b3c 100644 --- a/.gitignore +++ b/.gitignore @@ -153,3 +153,5 @@ examples/xgenmm copy/imgs/image_res_3.csv examples/xgenmm copy/imgs/image_res_4.csv examples/xgenmm copy/imgs/image-1d100e9-1.jpg examples/xgenmm copy/imgs/image-1d100e9.jpg +examples/xgenmm/imgs/4patches_embeddings.pt +examples/xgenmm/imgs/attention_mask_4patchhes.pt diff --git a/Makefile b/Makefile index e13e7fb5c..48d85f755 100644 --- a/Makefile +++ b/Makefile @@ -21,6 +21,7 @@ BUILD_TARGETS = \ llama-llava-cli \ llama-minicpmv-cli\ xgenmm-cli\ + test_anyres_handle_patches\ llama-lookahead \ llama-lookup \ llama-lookup-create \ @@ -1482,6 +1483,14 @@ xgenmm-cli: examples/xgenmm/xgenmm-cli.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) $< $(filter-out %.h $<,$^) -o $@ $(LDFLAGS) -Wno-cast-qual +test_anyres_handle_patches: examples/xgenmm/test_anyres_handle_patches.cpp \ + examples/xgenmm/xgenmm.cpp \ + examples/xgenmm/xgenmm.h \ + examples/xgenmm/clip.cpp \ + examples/xgenmm/clip.h \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) $< $(filter-out %.h $<,$^) -o $@ $(LDFLAGS) -Wno-cast-qual + ifeq ($(UNAME_S),Darwin) swift: examples/batched.swift (cd examples/batched.swift; make build) diff --git a/examples/xgenmm/CMakeLists.txt b/examples/xgenmm/CMakeLists.txt index 2d7d81588..a9229e7e2 100644 --- a/examples/xgenmm/CMakeLists.txt +++ b/examples/xgenmm/CMakeLists.txt @@ -44,6 +44,12 @@ install(TARGETS test_anyres_handle_patches RUNTIME) target_link_libraries(test_anyres_handle_patches PRIVATE common xgenmm ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(xgenmm PRIVATE cxx_std_11) +set(TARGET test_patch_ops) +add_executable(test_patch_ops test_patch_ops.cpp) +install(TARGETS test_patch_ops RUNTIME) +target_link_libraries(test_patch_ops PRIVATE common xgenmm ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(xgenmm PRIVATE cxx_std_11) + # not implemented yet # set(TARGET xgenmm-cli) diff --git a/examples/xgenmm/clip.cpp b/examples/xgenmm/clip.cpp index 6a806963b..0680f4f41 100644 --- a/examples/xgenmm/clip.cpp +++ b/examples/xgenmm/clip.cpp @@ -485,7 +485,7 @@ struct clip_vision_model { struct ggml_tensor * projection; - // LLaVA projection + // LLaVA projecclip_image_encodeion struct ggml_tensor * mm_0_w = NULL; struct ggml_tensor * mm_0_b = NULL; struct ggml_tensor * mm_2_w = NULL; diff --git a/examples/xgenmm/debug.py b/examples/xgenmm/debug.py index 9a503a42c..72bc99b3f 100644 --- a/examples/xgenmm/debug.py +++ b/examples/xgenmm/debug.py @@ -1,15 +1,32 @@ -from torchvision.transforms import Resize -from torchvision.transforms import InterpolationMode -from PIL import Image +# from torchvision.transforms import Resize +# from torchvision.transforms import InterpolationMode +# from PIL import Image +# import numpy as np + +# n_px = 384 +# resize_func = Resize((n_px, n_px), interpolation=InterpolationMode.BICUBIC, antialias=True) + +# img_dir = "./imgs" +# image_path_1 = f'{img_dir}/image-1d100e9-1.jpg' +# image_path_2 = f'{img_dir}/image-1d100e9.jpg' +# image_1 = Image.open(image_path_1).convert('RGB') +# image_2 = Image.open(image_path_2).convert('RGB') + +# print(np.asarray(resize_func(image_2))[:5, :10, 0]) + + +import gguf import numpy as np +import torch -n_px = 384 -resize_func = Resize((n_px, n_px), interpolation=InterpolationMode.BICUBIC, antialias=True) +patches_embeddings = torch.load('./imgs/4patches_embeddings.pt').numpy() +print(f'4patches_embeddings:{patches_embeddings.shape}\n') +print(patches_embeddings[1:,:,:]) -img_dir = "./imgs" -image_path_1 = f'{img_dir}/image-1d100e9-1.jpg' -image_path_2 = f'{img_dir}/image-1d100e9.jpg' -image_1 = Image.open(image_path_1).convert('RGB') -image_2 = Image.open(image_path_2).convert('RGB') -print(np.asarray(resize_func(image_2))[:5, :10, 0]) \ No newline at end of file +# gguf_writer = gguf.GGUFWriter(path='./imgs/4patches_embeddings.gguf', arch='4patches_embeddings') +# gguf_writer.add_tensor("data", patches_embeddings) +# gguf_writer.write_header_to_file() +# gguf_writer.write_kv_data_to_file() +# gguf_writer.write_tensors_to_file() +# gguf_writer.close() \ No newline at end of file diff --git a/examples/xgenmm/playground.ipynb b/examples/xgenmm/playground.ipynb index 38feb1e95..09c6a58cf 100644 --- a/examples/xgenmm/playground.ipynb +++ b/examples/xgenmm/playground.ipynb @@ -5,6 +5,27 @@ "metadata": {}, "source": [] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# check mask" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def csv_to_tensor(filename, axis=0):\n", + " matrix = np.loadtxt(filename, delimiter=',')\n", + " return tensor\n", + "\n", + "filename = 'imgs/attention_mask_4patchhes.csv'\n", + "pacthes = csv_to_tensor(filename)" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/examples/xgenmm/run_cli.sh b/examples/xgenmm/run_cli.sh new file mode 100644 index 000000000..cbfe57740 --- /dev/null +++ b/examples/xgenmm/run_cli.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +make xgenmm-cli + +./xgenmm-cli -m /export/share/llamacpp_models/MiniCPM-Llama3-V-2_5/ggml-model-Q4_K_M.gguf \ + --mmproj /export/share/yutong/xgenmm/llamacpp_wd/siglip_kosmos_phi3_4k_instruct/gguf_test/mmproj-model-f32.gguf \ + -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 \ + --image /export/home/llama.cpp/examples/xgenmm/imgs/image-1d100e9-1.jpg \ + -p "What is in the image?" \ No newline at end of file diff --git a/examples/xgenmm/test_anyres_handle_patches.cpp b/examples/xgenmm/test_anyres_handle_patches.cpp index d6bbad46a..b778d356e 100644 --- a/examples/xgenmm/test_anyres_handle_patches.cpp +++ b/examples/xgenmm/test_anyres_handle_patches.cpp @@ -535,8 +535,8 @@ int main(){ part of: llava_image_embed_make_with_filename */ - const char* image_path = "/export/home/llama.cpp/examples/xgenmm/imgs/image-1d100e9.jpg"; // Porcelain - // const char* image_path = "/export/home/llama.cpp/examples/xgenmm/imgs/image-1d100e9-1.jpg"; + // const char* image_path = "/export/home/llama.cpp/examples/xgenmm/imgs/image-1d100e9.jpg"; // Porcelain + const char* image_path = "/export/home/llama.cpp/examples/xgenmm/imgs/image-1d100e9-1.jpg"; unsigned char* image_bytes; long image_bytes_length; auto loaded = load_file_to_bytes(image_path, &image_bytes, &image_bytes_length); @@ -618,31 +618,36 @@ int main(){ std::vector image_embd_v; image_embd_v.resize(img_res_v.size); printf("image_embd_v.size():%d\n", image_embd_v.size()); - for (size_t i = 0; i < img_res_v.size; i++) - { - printf("encode patch %d\n", i); - const int nx = img_res_v.data[i].nx; - const int ny = img_res_v.data[i].ny; - const int vec_len = img_res_v.data[i].buf.size(); - printf(" i:%d | nx:%d | ny:%d | vec len:%d\n", i, nx, ny, vec_len); // 384^2 * 3(channel) = 442368 - auto start = std::chrono::high_resolution_clock::now(); - image_embd_v[i] = - (float*)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184 - const bool encoded = clip_image_encode( - ctx_clip, 1, &img_res_v.data[i], - image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside - if (!encoded) - { - LOG_TEE("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int)i + 1, (int)img_res_v.size); - return false; - } - auto end = std::chrono::high_resolution_clock::now(); - std::chrono::duration duration = end - start; - std::cout << " Wall time: " << duration.count() << " seconds" << std::endl; - } + // for (size_t i = 0; i < img_res_v.size; i++) + // { + // printf("encode patch %d\n", i); + // const int nx = img_res_v.data[i].nx; + // const int ny = img_res_v.data[i].ny; + // const int vec_len = img_res_v.data[i].buf.size(); + // printf(" i:%d | nx:%d | ny:%d | vec len:%d\n", i, nx, ny, vec_len); // 384^2 * 3(channel) = 442368 + // auto start = std::chrono::high_resolution_clock::now(); + // image_embd_v[i] = + // (float*)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184 + // const bool encoded = clip_image_encode( + // ctx_clip, 1, &img_res_v.data[i], + // image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside + // if (!encoded) + // { + // LOG_TEE("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int)i + 1, (int)img_res_v.size); + // return false; + // } + // auto end = std::chrono::high_resolution_clock::now(); + // std::chrono::duration duration = end - start; + // std::cout << " Wall time: " << duration.count() << " seconds" << std::endl; + // for (int j = 0; j < 5; j++) + // { + // printf(" %.4f ", image_embd_v[i][j]); + // } + // printf("\n"); + // } // handle patches goes here - + return 0; } diff --git a/examples/xgenmm/test_patch_ops.cpp b/examples/xgenmm/test_patch_ops.cpp new file mode 100644 index 000000000..eb20236c7 --- /dev/null +++ b/examples/xgenmm/test_patch_ops.cpp @@ -0,0 +1,401 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ggml.h" + + + +void print_tensor(ggml_tensor* tensor, const char* name = "", int verbosity = 0) +{ + if (tensor->ne[2] == 1) + { + printf("---> %s: (%ld, %ld)\n", name, tensor->ne[0], tensor->ne[1]); + } + else if (ggml_is_3d(tensor)) + { + printf("---> %s: (%ld, %ld, %ld)\n", name, tensor->ne[0], tensor->ne[1], tensor->ne[2]); + } + else + { + printf("---> %s: (%ld, %ld, %ld, %ld)\n", name, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + } + if (verbosity == 1) + { + printf("*********************************************************************\n"); + if (tensor->ne[2] == 1) + { + const float* mat = (float*)tensor->data; + int dim0 = tensor->ne[1]; + int dim1 = tensor->ne[0]; + if (dim0 < 6 && dim1 < 6) + { + for (int i = 0; i < dim0; i++) + { + for (int j = 0; j < dim1; j++) + { + printf("%+.4f ", mat[i * dim1 + j]); + } + printf("\n"); + } + printf("\n"); + } + else + { + for (int i = 0; i < std::min(dim0, 3); i++) + { + for (int j = 0; j < std::min(dim1, 3); j++) + { + printf("%+.4f ", mat[i * dim1 + j]); + } + printf("... "); + for (int j = dim1 - 3; j < dim1; j++) + { + printf("%+.4f ", mat[i * dim1 + j]); + } + printf("\n"); + } + if (dim0 > 3) + { + printf("...................... omit ......................\n"); + for (int i = dim0 - 3; i < dim0; i++) + { + for (int j = 0; j < std::min(dim1, 3); j++) + { + printf("%+.4f ", mat[i * dim1 + j]); + } + printf("... "); + for (int j = dim1 - 3; j < dim1; j++) + { + printf("%+.4f ", mat[i * dim1 + j]); + } + printf("\n"); + } + } + } + } + else if (ggml_is_3d(tensor)) + { + const float* data = (float*)tensor->data; + int dim0 = tensor->ne[2]; + int dim1 = tensor->ne[1]; + int dim2 = tensor->ne[0]; + if (dim0 < 6 && dim1 < 6 && dim2 < 6) + { + for (int i = 0; i < dim0; i++) + { + printf("dim0 = %d\n", i); + for (int j = 0; j < dim1; j++) + { + for (int k = 0; k < dim2; k++) + { + printf("%+.4f ", data[i * dim1 * dim2 + j * dim2 + k]); + } + printf("\n"); + } + printf("\n"); + } + printf("\n"); + } + else + { + for (int i = 0; i < std::min(dim0, 4); i++) + { + printf("dim0 = %d\n", i); + for (int j = 0; j < std::min(dim1, 3); j++) + { + for (int k = 0; k < std::min(dim2, 3); k++) + { + printf("%+.4f ", data[i * dim1 * dim2 + j * dim2 + k]); + } + printf("... "); + for (int k = dim2 - 3; k < dim2; k++) + { + printf("%+.4f ", data[i * dim1 * dim2 + j * dim2 + k]); + } + printf("\n"); + } + printf("........................ omit .....................\n"); + for (int j = dim1 - 3; j < dim1; j++) + { + for (int k = 0; k < std::min(dim2, 3); k++) + { + printf("%+.4f ", data[i * dim1 * dim2 + j * dim2 + k]); + } + printf("... "); + for (int k = dim2 - 3; k < dim2; k++) + { + printf("%+.4f ", data[i * dim1 * dim2 + j * dim2 + k]); + } + printf("\n"); + } + printf("---------------------------------------------------\n"); + } + printf("\n"); + } + } + } + printf("*********************************************************************\n"); + printf("\n"); +} + +void tensor_to_csv(ggml_tensor* tensor, const char* filename) +{ + std::ofstream outFile(filename); + if (!outFile.is_open()) + { + std::cerr << "Error opening file!" << std::endl; + } + + const float* mat = (float*)tensor->data; + int dim0 = tensor->ne[1]; + int dim1 = tensor->ne[0]; + + { + for (int i = 0; i < dim0; i++) + { + for (int j = 0; j < dim1; j++) + { + outFile << float(mat[i * dim1 + j]); + if (j < dim1 - 1) + { + outFile << ","; + } + } + outFile << std::endl; + } + } + outFile.close(); + printf("file saved to %s\n", filename); +} + +struct tensor_from_gguf +{ + struct ggml_tensor* data; + struct ggml_context* ctx; +}; + +bool load_tensor_from_file(const char* filename, tensor_from_gguf& tensor) +{ + struct gguf_init_params params = { + /*.no_alloc =*/false, + /*.ctx =*/&tensor.ctx, + }; + gguf_context* ctx = gguf_init_from_file(filename, params); + if (!ctx) + { + fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__); + return false; + } + tensor.data = ggml_get_tensor(tensor.ctx, "data"); + + return true; +} + +int main(){ + tensor_from_gguf tensor; + std::string filename = "../examples/xgenmm/imgs/4patches_embeddings.gguf"; + bool is_successful = load_tensor_from_file(filename.c_str(), tensor); + if (!is_successful) + { + fprintf(stderr, "%s: load_tensor_from_file() failed\n", __func__); + return 1; + } + + ggml_tensor* patch_embeds = tensor.data; + // print_tensor(patch_embeds, "patch_embeds", 1); + + /* + hardcoded values + */ + int original_width = 955; + int original_height = 289; + int num_images = 4; // 3 patches + 1 base + int32_t num_patches_per_side = 384 / 14; + int num_patches_width = 3; //grid_shape.first + int num_patches_height = 1; // grid_shape.second + + + + size_t size_ele = ggml_type_size(GGML_TYPE_F32); + + struct + { + struct ggml_context* ctx; + } model; + + + // TODO: size calculation is not calculated - it's only tens of MB + size_t ctx_size = 0; + + { + ctx_size += + num_patches_per_side * num_patches_per_side * 1152 * sizeof(float) * num_images * 8; // image_features + ctx_size += 1024 * 1024 * ggml_type_size(GGML_TYPE_F32); + } + + struct ggml_init_params params + { + /*.mem_size =*/ctx_size, + /*.mem_buffer =*/NULL, + /*.no_alloc =*/false, // NOTE: this should be false when using the legacy API + }; + + model.ctx = ggml_init(params); + + + + // FIXME: hardcoded for the patch size and vit embedding size + struct ggml_tensor* image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 1152, 729, num_images - 1); + struct ggml_tensor* base_image_feature = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 1152, 729, 1); + // ggml_tensor_printf(image_features,"image_features",__LINE__,false,false); + // fill it with the image embeddings, ignoring the base + // for (size_t i = 1; i < num_images; i++) + // { + // size_t offset = (i - 1) * 729 * 1152 * sizeof(float); + // // size_t offset = (i - 1) * clip_embd_nbytes(ctx_clip); + // // memcpy((uint8_t*)(image_features->data) + offset, image_embd_v[i], clip_embd_nbytes(ctx_clip)); + // } + + int dim0 = num_images - 1; + int dim1 = num_patches_per_side * num_patches_per_side; + int dim2 = 1152; + float* patch_embeds_data = (float*)patch_embeds->data; + float* image_features_data = (float*)image_features->data; + float* base_image_feature_data = (float*)base_image_feature->data; + for (int i=0; i < dim0; i++) + { + for (int j=0; j < dim1; j++) + { + for (int k=0; k < dim2; k++) + { + image_features_data[i * dim1 * dim2 + j * dim2 + k] = + patch_embeds_data[(i + 1) * dim1 * dim2 + j * dim2 + k]; + if (i == 0) + { + base_image_feature_data[j * dim2 + k] = patch_embeds_data[j * dim2 + k]; + } + } + } + } + // print_tensor(image_features, "image_features", 1); + + + struct ggml_tensor* image_features_patchview = ggml_view_4d( + model.ctx, image_features, num_patches_per_side * 1152, num_patches_per_side, + num_patches_width, num_patches_height, size_ele * num_patches_per_side * 1152, + size_ele * num_patches_per_side * 1152 * num_patches_per_side, + size_ele * num_patches_per_side * 1152 * num_patches_per_side * num_patches_width, 0); + print_tensor(image_features_patchview, "image_features_patchview", 0); // (27 * 1152, 27, 3, 1) + struct ggml_tensor* permuted_cont = + ggml_cont(model.ctx, ggml_permute(model.ctx, image_features_patchview, 0, 2, 1, 3)); + + print_tensor(permuted_cont, "permuted_cont", 0); // (27 * 1152, 3, 27, 1) + struct ggml_tensor* flatten = + ggml_view_2d(model.ctx, permuted_cont, 1152, + num_patches_height * num_patches_width * num_patches_per_side * num_patches_per_side, + size_ele * 1152, 0); + + print_tensor(flatten, "flatten", 0); // (1152, 27 * 27 * 3) + + // struct ggml_tensor* tensor_3d = + // ggml_view_3d(model.ctx, flatten, + // 1152, // ne0 + // num_patches_per_side * num_patches_per_side, // ne1 + // num_patches_width * num_patches_height, // ne2 = num_patches_width * num_patches_height, + // size_ele * num_patches_width * num_patches_height, // nb1 = sizeof(float) × ne2, + // size_ele * num_patches_width * num_patches_height * num_patches_per_side * + // num_patches_per_side, // nb2 = sizeof(float)×ne1×ne2 + // 0); + struct ggml_tensor* tensor_3d = + ggml_reshape_3d(model.ctx, flatten, + 1152, + num_patches_per_side * num_patches_per_side, + num_patches_width * num_patches_height); + tensor_3d = ggml_cont(model.ctx, tensor_3d); + tensor_3d = ggml_concat(model.ctx, base_image_feature, tensor_3d, 2); + struct ggml_cgraph* gf = ggml_new_graph(model.ctx); + ggml_build_forward_expand(gf, tensor_3d); + ggml_graph_compute_with_ctx(model.ctx, gf, 1); + struct ggml_tensor* result = gf->nodes[gf->n_nodes - 1]; + + print_tensor(result, "result", 1); // (1152, 27 * 27, 3) + + struct + { + struct ggml_context* ctx; + } mask; + + // TODO: size calculation is not calculated - it's only tens of MB + ctx_size = 0; + + { + ctx_size += + num_patches_per_side * num_patches_width * num_patches_per_side * num_patches_height * sizeof(float) * 2; + } + + params = + { + /*.mem_size =*/ctx_size, + /*.mem_buffer =*/NULL, + /*.no_alloc =*/false, // NOTE: this should be false when using the legacy API + }; + mask.ctx = ggml_init(params); + int current_height = num_patches_per_side * num_patches_height; + int current_width = num_patches_per_side * num_patches_width; + float original_aspect_ratio = (float)original_width / (float)original_height; + float current_aspect_ratio = (float)current_width / (float)current_height; + printf("original_height: %d, original_width: %d, original_aspect_ratio: %.2f\n", original_height, original_width, + original_aspect_ratio); + printf("current_height: %d, current_width: %d, current_aspect_ratio: %.2f\n", current_height, current_width, + current_aspect_ratio); + + float scale_factor = 1.0; + struct ggml_tensor* attention_mask = ggml_new_tensor_2d(mask.ctx, GGML_TYPE_F32, current_width, current_height); + if (original_aspect_ratio > current_aspect_ratio){ + scale_factor = (float)current_width / (float)original_width; + int new_height = int(original_height * scale_factor); + int padding = (current_height - new_height) / 2; + // printf("new_height: %d, padding: %d\n", new_height, padding); + float* attention_mask_data = (float*)attention_mask->data; + for (int i = 0; i < current_height; i++){ + for (int j = 0; j < current_width; j++){ + if (i < padding || i > padding + new_height){ + attention_mask_data[i * current_width + j] = 0.0; + } else { + attention_mask_data[i * current_width + j] = 1.0; + } + } + } + }else{ + scale_factor = current_height / original_height; + int new_width = int(original_width * scale_factor); + int padding = (current_width - new_width) / 2; + float* attention_mask_data = (float*)attention_mask->data; + for (int i = 0; i < current_height; i++){ + for (int j = 0; j < current_width; j++){ + if (j < padding || j > padding + new_width){ + attention_mask_data[i * current_width + j] = 0.0; + } else { + attention_mask_data[i * current_width + j] = 1.0; + } + } + } + } + + print_tensor(attention_mask, "attention_mask", 1); + tensor_to_csv(attention_mask, "/export/home/llama.cpp/examples/xgenmm/imgs/attention_mask_4patchhes.csv"); + ggml_free(model.ctx); + ggml_free(mask.ctx); + ggml_free(tensor.ctx); + return 0; +} + + +// make test_patch_ops && ./bin/test_patch_ops \ No newline at end of file diff --git a/examples/xgenmm/xgenmm.cpp b/examples/xgenmm/xgenmm.cpp index 04805c450..00878d974 100644 --- a/examples/xgenmm/xgenmm.cpp +++ b/examples/xgenmm/xgenmm.cpp @@ -343,31 +343,116 @@ static bool encode_image_with_clip(clip_ctx *ctx_clip, int n_threads, const clip } else if (clip_is_xgenmm(ctx_clip)) { - // xgenmm embedding - *n_img_pos = clip_n_patches(ctx_clip); - bool encoded = - clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); // image_embd shape is 729 x - delete[] img_res_v.data; - if (!encoded) + // spatial_unpad llava-1.6 type embedding + // TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a + // solution to quickly get batching working + std::vector image_embd_v; + image_embd_v.resize(img_res_v.size); + for (size_t i = 0; i < img_res_v.size; i++) { - LOG_TEE("Unable to encode image\n"); - - return false; + image_embd_v[i] = + (float *)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184 + const bool encoded = clip_image_encode( + ctx_clip, n_threads, &img_res_v.data[i], + image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside + if (!encoded) + { + LOG_TEE("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int)i + 1, + (int)img_res_v.size); + return false; + } + for (int j = 0; j < 5; j++) + { + printf(" %.4f ", image_embd_v[i][j]); + } + printf("\n"); } + const int64_t t_img_enc_batch_us = ggml_time_us(); + LOG_TEE("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, + (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); + + const int32_t *image_grid = clip_image_grid(ctx_clip); + + std::vector> grid_pinpoints; + for (int i = 0; i < 32 && image_grid[i] != 0; i += 2) + { + grid_pinpoints.push_back({image_grid[i], image_grid[i + 1]}); + } + + // free all img_res_v - not needed anymore + delete[] img_res_v.data; + img_res_v.size = 0; + img_res_v.data = nullptr; + + const int32_t image_size = clip_image_size(ctx_clip); + + struct clip_image_grid_shape grid_shape = + get_anyres_image_grid_shape({img->nx, img->ny}, grid_pinpoints, image_size); + + int n_img_pos_out; + clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out); + *n_img_pos = n_img_pos_out; + + for (size_t i = 0; i < image_embd_v.size(); i++) + { + free(image_embd_v[i]); + } + image_embd_v.clear(); + + // debug image/segment/normalization content: + // clip_image_u8 * tmp = clip_image_u8_init(); + // clip_image_convert_f32_to_u8(*image_feature, *tmp); + // clip_image_save_to_bmp(*tmp, "image_feature.bmp"); } else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { - // flat / default llava-1.5 type embedding - *n_img_pos = clip_n_patches(ctx_clip); - bool encoded = - clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); // image_embd shape is 576 x 4096 - delete[] img_res_v.data; - if (!encoded) + std::vector image_embd_v; + image_embd_v.resize(img_res_v.size); + for (size_t i = 0; i < img_res_v.size; i++) { - LOG_TEE("Unable to encode image\n"); - - return false; + image_embd_v[i] = + (float *)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184 + const bool encoded = clip_image_encode( + ctx_clip, n_threads, &img_res_v.data[i], + image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside + if (!encoded) + { + LOG_TEE("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int)i + 1, + (int)img_res_v.size); + return false; + } } + const int64_t t_img_enc_batch_us = ggml_time_us(); + LOG_TEE("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, + (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); + + const int32_t *image_grid = clip_image_grid(ctx_clip); + + std::vector> grid_pinpoints; + for (int i = 0; i < 32 && image_grid[i] != 0; i += 2) + { + grid_pinpoints.push_back({image_grid[i], image_grid[i + 1]}); + } + + // free all img_res_v - not needed anymore + delete[] img_res_v.data; + img_res_v.size = 0; + img_res_v.data = nullptr; + + const int32_t image_size = clip_image_size(ctx_clip); + + struct clip_image_grid_shape grid_shape = + get_anyres_image_grid_shape({img->nx, img->ny}, grid_pinpoints, image_size); + + int n_img_pos_out; + clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out); + *n_img_pos = n_img_pos_out; + + for (size_t i = 0; i < image_embd_v.size(); i++) + { + free(image_embd_v[i]); + } + image_embd_v.clear(); } else {