From 6854ad4057e682fbcc747c75a1d2670a7110ef51 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 30 Sep 2024 17:35:04 +0200 Subject: [PATCH] img pre processing --- convert_hf_to_gguf.py | 5 + gguf-py/gguf/constants.py | 5 + gguf-py/gguf/gguf_writer.py | 10 + src/llama-vision.cpp | 491 ++++++++++++++++++++++++++++++++++++ src/llama-vision.h | 18 +- src/llama.cpp | 61 +++-- 6 files changed, 564 insertions(+), 26 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index e0880511a..0340c138a 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1584,8 +1584,13 @@ class LlamaModel(Model): self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"]) self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"]) # TODO: should not hardcode these, but they are currently missing from config.json + self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP) self.gguf_writer.add_vision_clip_max_position_embeddings(577) self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05) + default_image_mean = [0.48145466, 0.4578275, 0.40821073] + default_image_std = [0.26862954, 0.26130258, 0.27577711] + self.gguf_writer.add_vision_clip_image_mean(default_image_mean) + self.gguf_writer.add_vision_clip_image_std(default_image_std) @staticmethod def permute(weights: Tensor, n_head: int, n_head_kv: int | None): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index f4ebd8f90..b83dc311a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -196,6 +196,7 @@ class Keys: PROJECTION_DIM = "vision.clip.projection_dim" USE_GELU = "vision.clip.use_gelu" MAX_POS_EMBEDDING = "vision.clip.max_position_embeddings" + PROJECTOR_TYPE = "vision.clip.projector_type" HEAD_COUNT = "vision.clip.attention.head_count" LAYERNORM_EPS = "vision.clip.attention.layer_norm_epsilon" @@ -1425,6 +1426,10 @@ class PoolingType(IntEnum): CLS = 2 +class CLIPProjectorType(Enum): + MLP = 'mlp' + + class GGMLQuantizationType(IntEnum): F32 = 0 F16 = 1 diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 2828f0a80..e44ef9a1d 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -26,6 +26,7 @@ from .constants import ( RopeScalingType, PoolingType, TokenType, + CLIPProjectorType, ) from .quants import quant_shape_from_byte_shape @@ -844,9 +845,18 @@ class GGUFWriter: def add_vision_clip_max_position_embeddings(self, value: int) -> None: self.add_uint32(Keys.Vision.Clip.MAX_POS_EMBEDDING, value) + def add_vision_clip_projector_type(self, value: CLIPProjectorType) -> None: + self.add_string(Keys.Vision.Clip.PROJECTOR_TYPE, value.value) + def add_vision_clip_layer_norm_epsilon(self, value: float) -> None: self.add_float32(Keys.Vision.Clip.LAYERNORM_EPS, value) + def add_vision_clip_image_mean(self, value: Sequence[float]) -> None: + self.add_array(Keys.Vision.IMAGE_MEAN, value) + + def add_vision_clip_image_std(self, value: Sequence[float]) -> None: + self.add_array(Keys.Vision.IMAGE_STD, value) + def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: if not isinstance(value, str): template_default = None diff --git a/src/llama-vision.cpp b/src/llama-vision.cpp index ac507babb..75fdfc703 100644 --- a/src/llama-vision.cpp +++ b/src/llama-vision.cpp @@ -1,5 +1,496 @@ #include "llama.h" #include "llama-vision.h" +#include "llama-impl.h" + +struct clip_image_size { + int width; + int height; +}; + +// RGB uint8 image +// Memory layout: RGBRGBRGB... +struct clip_image_u8 { + int nx; + int ny; + std::vector buf; + clip_image_u8() {} + clip_image_u8(const llama_img img) { + nx = img.nx; + ny = img.ny; + buf.resize(nx*ny*3); + memcpy(buf.data(), img.data, buf.size()); + } +}; + +struct clip_image_u8_batch { + struct clip_image_u8 * data; + size_t size; +}; + +// RGB float32 image (NHWC) +// Memory layout: RGBRGBRGB... +struct clip_image_f32 { + int nx; + int ny; + std::vector buf; +}; + +using clip_image_f32_batch = std::vector; +using clip_image_f8_batch = std::vector; + +int32_t clip_image_encode (const clip_context & ctx, const clip_image_f32 & img, std::vector & output); +int32_t clip_image_batch_encode(const clip_context & ctx, const clip_image_f32_batch & imgs, std::vector & output); + +/** + * Selects the best resolution from a list of possible resolutions based on the original size. + * + * @param original_size The original size of the image in the format (width, height). + * @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. + * @return The best fit resolution in the format (width, height). + */ +static clip_image_size select_best_resolution(const clip_image_size & original_size, const std::vector& possible_resolutions) { + int original_width = original_size.width; + int original_height = original_size.height; + + clip_image_size best_fit; + int max_effective_resolution = 0; + int min_wasted_resolution = std::numeric_limits::max(); + + for (const auto& resolution : possible_resolutions) { + int width = resolution.width; + int height = resolution.height; + float scale = std::min(static_cast(width) / original_width, static_cast(height) / original_height); + int downscaled_width = static_cast(original_width * scale); + int downscaled_height = static_cast(original_height * scale); + int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height); + int wasted_resolution = (width * height) - effective_resolution; + // LOG_DBG("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution); + if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) { + max_effective_resolution = effective_resolution; + min_wasted_resolution = wasted_resolution; + best_fit = resolution; + } + } + + return best_fit; +} + +static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) { + auto clip = [](int x, int lower, int upper) -> int { + return std::max(lower, std::min(x, upper)); + }; + + const int nx = img.nx; + const int ny = img.ny; + + dst.nx = target_width; + dst.ny = target_height; + dst.buf.resize(3 * target_width * target_height); + + float Cc; + float C[5]; + float d0, d2, d3, a0, a1, a2, a3; + int i, j, k, jj; + int x, y; + float dx, dy; + float tx, ty; + + tx = (float)nx / (float)target_width; + ty = (float)ny / (float)target_height; + + // Bicubic interpolation; adapted from ViT.cpp, inspired from : + // -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36 + // -> https://en.wikipedia.org/wiki/Bicubic_interpolation + + for (i = 0; i < target_height; i++) { + for (j = 0; j < target_width; j++) { + x = (int)(tx * j); + y = (int)(ty * i); + + dx = tx * j - x; + dy = ty * i - y; + + for (k = 0; k < 3; k++) { + for (jj = 0; jj <= 3; jj++) { + d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + + a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; + a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; + a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; + + C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx; + + d0 = C[0] - C[1]; + d2 = C[2] - C[1]; + d3 = C[3] - C[1]; + a0 = C[1]; + a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; + a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; + a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; + Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy; + + const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f); + dst.buf[(i * target_width + j) * 3 + k] = float(Cc2); + } + } + } + } + + return true; +} + +static std::vector divide_to_patches_u8(const clip_image_u8 & image, int patch_size) { + std::vector patches; + int width = image.nx; + int height = image.ny; + for (int i = 0; i < height; i += patch_size) { + for (int j = 0; j < width; j += patch_size) { + clip_image_u8 patch; + patch.nx = std::min(patch_size, width - j); + patch.ny = std::min(patch_size, height - i); + patch.buf.resize(3 * patch.nx * patch.ny); + for (int y = 0; y < patch.ny; ++y) { + for (int x = 0; x < patch.nx; ++x) { + for (int c = 0; c < 3; ++c) { + patch.buf[3 * (y * patch.nx + x) + c] = image.buf[3 * ((i + y) * width + (j + x)) + c]; + } + } + } + patches.push_back(patch); + } + } + return patches; +} + +// llava-1.6 type of resize_and_pad (black) +static void resize_and_pad_image(const clip_image_u8 & image, clip_image_u8 & image_output, const clip_image_size & target_resolution) { + int target_width = target_resolution.width; + int target_height = target_resolution.height; + + float scale_w = static_cast(target_width) / image.nx; + float scale_h = static_cast(target_height) / image.ny; + + int new_width, new_height; + + if (scale_w < scale_h) { + new_width = target_width; + new_height = std::min(static_cast(std::ceil(image.ny * scale_w)), target_height); + } else { + new_height = target_height; + new_width = std::min(static_cast(std::ceil(image.nx * scale_h)), target_width); + } + + clip_image_u8 resized_image; + // bilinear_resize(image, resized_image, new_width, new_height); + bicubic_resize(image, resized_image, new_width, new_height); + + clip_image_u8 padded_image; + padded_image.nx = target_width; + padded_image.ny = target_height; + padded_image.buf.resize(3 * target_width * target_height, 0); // Initialize with black + + // Calculate padding offsets + int pad_x = (target_width - new_width) / 2; + int pad_y = (target_height - new_height) / 2; + + // Copy the resized image into the center of the padded buffer + for (int y = 0; y < new_height; ++y) { + for (int x = 0; x < new_width; ++x) { + for (int c = 0; c < 3; ++c) { + padded_image.buf[3 * ((y + pad_y) * target_width + (x + pad_x)) + c] = resized_image.buf[3 * (y * new_width + x) + c]; + } + } + } + image_output = std::move(padded_image); +} + +static void normalize_image_u8_to_f32(const clip_image_u8 src, clip_image_f32 dst, const std::array & mean, const std::array & std) { + dst.nx = src.nx; + dst.ny = src.ny; + dst.buf.resize(src.buf.size()); + + for (size_t i = 0; i < src.buf.size(); ++i) { + int c = i % 3; // rgb + dst.buf[i] = (static_cast(src.buf[i]) / 255.0f - mean[c]) / std[c]; + } +} + +// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector +// res_imgs memory is being allocated here, previous allocations will be freed if found +bool clip_image_preprocess(const clip_context & ctx, const clip_image_u8 & img, clip_image_f32_batch & output_imgs) { + bool pad_to_square = true; + auto & params = ctx.model.hparams; + // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing + if (params.mm_patch_merge_type == MM_PATCH_MERGE_SPATIAL_UNPAD) { + pad_to_square = false; + } + + // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) + // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 + + clip_image_u8 temp; + if (pad_to_square && img.nx != img.ny) { + int longer_side = std::max(img.nx, img.ny); + temp.nx = longer_side; + temp.ny = longer_side; + temp.buf.resize(3 * longer_side * longer_side); + const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA (this is the mean rgb color * 255) + + // fill with background color + for (size_t i = 0; i < temp.buf.size(); i++) { + temp.buf[i] = bc[i % 3]; + } + + // copy from the input image + for (int y = 0; y < img.ny; y++) { + for (int x = 0; x < img.nx; x++) { + const int i = 3 * (y * img.nx + x); + const int j = 3 * (y * temp.nx + x); + temp.buf[j] = img.buf[i]; + temp.buf[j+1] = img.buf[i+1]; + temp.buf[j+2] = img.buf[i+2]; + } + } + } else { + if (params.image_grid_pinpoints[0] != 0) { + // "spatial_unpad" with "anyres" processing for llava-1.6 + std::vector possible_resolutions; + for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i += 2) { + clip_image_size s; + s.width = params.image_grid_pinpoints[i]; + s.height = params.image_grid_pinpoints[i+1]; + possible_resolutions.push_back(s); + } + clip_image_size best_resolution = select_best_resolution({img.nx, img.ny}, possible_resolutions); + // clip_image_save_to_bmp(*img, "input.bmp"); + resize_and_pad_image(img, temp, best_resolution); // we do not pad with mean-bg color anymore in llava-1.6 + // clip_image_save_to_bmp(*temp, "resized.bmp"); + + std::vector patches = divide_to_patches_u8(temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6) + + clip_image_u8 image_original_resize; + // bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square + bicubic_resize(img, image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square + patches.insert(patches.begin(), image_original_resize); + // clip_image_f32_batch_init(patches.size()); + output_imgs.resize(patches.size()); + int num = 0; + for (auto & patch : patches) { + normalize_image_u8_to_f32(patch, output_imgs[num], params.image_mean, params.image_std); + num++; + } + return true; + } else { + temp.nx = img.nx; + temp.ny = img.ny; + temp.buf.resize(img.buf.size()); + memcpy(temp.buf.data(), img.buf.data(), temp.buf.size()); + } + } + + const int nx = temp.nx; + const int ny = temp.ny; + // clip_image_save_to_bmp(*temp, "resized_vanilla.bmp"); + + const int nx2 = params.image_size; + const int ny2 = params.image_size; + clip_image_f32 res; + res.nx = nx2; + res.ny = ny2; + res.buf.resize(3 * nx2 * ny2); + + const float scale = std::max(nx, ny) / (float)params.image_size; + + const int nx3 = int(nx / scale + 0.5f); + const int ny3 = int(ny / scale + 0.5f); + + const auto & m3 = params.image_mean; // {0.48145466f, 0.4578275f, 0.40821073f}; + const auto & s3 = params.image_std; // {0.26862954f, 0.26130258f, 0.27577711f}; + + for (int y = 0; y < ny3; y++) { + for (int x = 0; x < nx3; x++) { + for (int c = 0; c < 3; c++) { + // linear interpolation + const float sx = (x + 0.5f) * scale - 0.5f; + const float sy = (y + 0.5f) * scale - 0.5f; + + const int x0 = std::max(0, (int)std::floor(sx)); + const int y0 = std::max(0, (int)std::floor(sy)); + + const int x1 = std::min(x0 + 1, nx - 1); + const int y1 = std::min(y0 + 1, ny - 1); + + const float dx = sx - x0; + const float dy = sy - y0; + + const int j00 = 3 * (y0 * nx + x0) + c; + const int j01 = 3 * (y0 * nx + x1) + c; + const int j10 = 3 * (y1 * nx + x0) + c; + const int j11 = 3 * (y1 * nx + x1) + c; + + const float v00 = temp.buf[j00]; + const float v01 = temp.buf[j01]; + const float v10 = temp.buf[j10]; + const float v11 = temp.buf[j11]; + + const float v0 = v00 * (1.0f - dx) + v01 * dx; + const float v1 = v10 * (1.0f - dx) + v11 * dx; + + const float v = v0 * (1.0f - dy) + v1 * dy; + + const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f); + + const int i = 3 * (y * nx3 + x) + c; + + res.buf[i] = ((float(v2) / 255.0f) - m3[c]) / s3[c]; + } + } + } + + output_imgs.resize(1); + output_imgs[0] = std::move(res); + + return true; +} + +int clip_n_patches(const clip_context & ctx) { + auto & hparams = ctx.model.hparams; + int n_patches = (hparams.image_size / hparams.patch_size) * (hparams.image_size / hparams.patch_size); + return n_patches; +} + +static bool encode_image_with_clip(clip_context & ctx_clip, const llama_img img) { + clip_image_u8 img_u8(img); + clip_image_f32_batch img_res_v; + std::vector image_embd; // output vectors + auto & hparams = ctx_clip.model.hparams; + int n_output; + + if (!clip_image_preprocess(ctx_clip, img_u8, img_res_v)) { + LLAMA_LOG_ERROR("%s: unable to preprocess image\n", __func__); + return false; + } + + if (hparams.mm_patch_merge_type != MM_PATCH_MERGE_SPATIAL_UNPAD) { + // flat / default llava-1.5 type embedding + n_output = clip_n_patches(ctx_clip); + bool encoded = clip_image_encode(ctx_clip, img_res_v[0], image_embd); + if (!encoded) { + LLAMA_LOG_ERROR("Unable to encode image\n"); + return false; + } + } +} + +int32_t clip_image_encode(const clip_context & ctx, const clip_image_f32 & img, std::vector & output) { + clip_image_f32_batch imgs{img}; + return clip_image_batch_encode(ctx, imgs, output); +} + +int32_t clip_image_batch_encode(const clip_context & ctx, const clip_image_f32_batch & imgs, std::vector & output) { + int batch_size = imgs.size(); +} +//////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////// +// for debugging +#ifndef NDEBUG + +#include +#include +#include +#include + +// export clip_image_u8 to bmp file for debugging +// https://codereview.stackexchange.com/questions/195121/writing-a-bitmap-image-from-c +inline int bmp_export(const clip_image_u8 &img, const std::string &location) { + const uint32_t width = img.nx; + const uint32_t height = img.ny; + const std::vector &buffer = img.buf; + const bool hasAlphaChannel = false; + + std::ofstream fout(location, std::ios::out | std::ios::binary); + + if (fout.fail()) { + return 0; + } + + //Padding + const uint8_t padding = hasAlphaChannel ? 0 : (4 - (width * 3) % 4) % 4; + + //Bitmap file header. + const char signature[2] = { 'B', 'M' }; + const uint32_t fileSize = buffer.size() * sizeof(uint8_t) + padding * (height - 1) + 14 + 124; + const uint32_t offset = 14 + 124; + + //Bitmap information header file + const uint32_t DIBSize = 124; + const int32_t bitmapWidth = width; + const int32_t bitmapHeight = height; + const uint16_t numPlanes = 1; + const uint16_t bitsPerPixel = (hasAlphaChannel) ? 32 : 24; + const uint32_t compressionMethod = (hasAlphaChannel) ? 3 : 0; //BI_RGB = 0, BI_BITFIELDS = 3 + const uint32_t bitmapSize = buffer.size() * sizeof(uint8_t); + const int32_t horizontalResolution = 2834; + const int32_t verticalResolution = 2834; + const uint32_t numColors = 0; + const uint32_t impColorCount = 0; + const uint32_t redBitmask = (hasAlphaChannel) ? 0x0000FF00 : 0; //ARGB32 pixel format + const uint32_t greenBitmask = (hasAlphaChannel) ? 0x00FF0000 : 0; + const uint32_t blueBitmask = (hasAlphaChannel) ? 0xFF000000 : 0; + const uint32_t alphaBitmask = (hasAlphaChannel) ? 0x000000FF : 0; + + //Writing the file header and information header to the file + std::vector header(offset, 0); + header[0] = signature[0]; + header[1] = signature[1]; + +#define BMP_HEADERS(i, variableName) header[i] = variableName; header[i+1] = variableName >> 8; header[i+2] = variableName >> 16; header[i+3] = variableName >> 24; + + BMP_HEADERS(2, fileSize); + BMP_HEADERS(6, 0); + BMP_HEADERS(10, offset); + BMP_HEADERS(14, DIBSize); + BMP_HEADERS(18, bitmapWidth); + BMP_HEADERS(22, bitmapHeight); + + header[26] = (uint8_t)numPlanes; + header[27] = (uint8_t)(numPlanes >> 8); + header[28] = (uint8_t)bitsPerPixel; + header[29] = (uint8_t)(bitsPerPixel >> 8); + + BMP_HEADERS(30, compressionMethod); + BMP_HEADERS(34, (unsigned char)bitmapSize); + BMP_HEADERS(38, (unsigned char)horizontalResolution); + BMP_HEADERS(42, (unsigned char)verticalResolution); + BMP_HEADERS(46, (unsigned char)numColors); + BMP_HEADERS(50, (unsigned char)impColorCount); + BMP_HEADERS(54, (unsigned char)redBitmask); + BMP_HEADERS(58, (unsigned char)greenBitmask); + BMP_HEADERS(62, (unsigned char)blueBitmask); + BMP_HEADERS(66, alphaBitmask); + +#undef BMP_HEADERS + + fout.write((char *)header.data(), sizeof(uint8_t) * header.size()); + + //Writing the pixel array + const uint32_t bWidth = bitsPerPixel / 8 * width; + + for (int i = height - 1; i >= 0; i--) { + std::vector row(buffer.begin() + i * bWidth, buffer.begin() + i * bWidth + bWidth); + fout.write((char *)row.data(), row.size() * sizeof(uint8_t)); + fout.seekp(padding * sizeof(uint8_t), std::ios::cur); + } + + fout.close(); + return 1; +} + +#endif + diff --git a/src/llama-vision.h b/src/llama-vision.h index 5bf1673e5..e7404ea18 100644 --- a/src/llama-vision.h +++ b/src/llama-vision.h @@ -9,6 +9,10 @@ enum vision_arch { VISION_ARCH_UNKNOWN, }; +enum clip_projector_type { + CLIP_PROJECTOR_TYPE_MLP, +}; + enum mm_patch_merge { MM_PATCH_MERGE_FLAT, MM_PATCH_MERGE_SPATIAL_UNPAD, @@ -28,9 +32,13 @@ struct clip_hparams { float eps; + clip_projector_type proj_type = CLIP_PROJECTOR_TYPE_MLP; mm_patch_merge mm_patch_merge_type = MM_PATCH_MERGE_FLAT; - int32_t image_grid_pinpoints[32]; + std::array image_mean; + std::array image_std; + + std::array image_grid_pinpoints; int32_t image_crop_resolution; }; @@ -89,3 +97,11 @@ struct clip_vision_model { struct ggml_tensor * image_newline = NULL; }; + +struct clip_context { + struct ggml_context * ctx_ggml; + clip_vision_model model; + + int32_t n_output; + float * output; +}; diff --git a/src/llama.cpp b/src/llama.cpp index 0eac03a51..2860c7094 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -400,8 +400,9 @@ enum llm_kv { LLM_KV_VISION_CLIP_PROJECTION_TYPE, LLM_KV_VISION_CLIP_PROJECTION_DIM, LLM_KV_VISION_CLIP_USE_GELU, - LLM_KV_VISION_CLIP_HEAD_COUNT, LLM_KV_VISION_CLIP_MAX_POS_EMBD, + LLM_KV_VISION_CLIP_PROJECTOR_TYPE, + LLM_KV_VISION_CLIP_HEAD_COUNT, LLM_KV_VISION_CLIP_LAYERNORM_EPS, }; @@ -526,6 +527,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_VISION_CLIP_PROJECTION_DIM, "vision.clip.projection_dim" }, { LLM_KV_VISION_CLIP_USE_GELU, "vision.clip.use_gelu" }, { LLM_KV_VISION_CLIP_MAX_POS_EMBD, "vision.clip.max_position_embeddings" }, + { LLM_KV_VISION_CLIP_PROJECTOR_TYPE, "vision.clip.projector_type" }, { LLM_KV_VISION_CLIP_HEAD_COUNT, "vision.clip.attention.head_count" }, { LLM_KV_VISION_CLIP_LAYERNORM_EPS, "vision.clip.attention.layer_norm_epsilon" }, }; @@ -5573,30 +5575,6 @@ static void llm_load_hparams( hparams.n_embd_head_v = 0; } - std::string vision_type; - ml.get_key(LLM_KV_VISION_TYPE, vision_type, false); - if (vision_type == "clip") { - hparams.has_vision = true; - ml.get_key(LLM_KV_VISION_IMAGE_SIZE, hparams.clip.image_size, true); - ml.get_key(LLM_KV_VISION_PATCH_SIZE, hparams.clip.patch_size, true); - ml.get_key(LLM_KV_VISION_CLIP_EMBEDDING_LENGTH, hparams.clip.hidden_size, true); - ml.get_key(LLM_KV_VISION_CLIP_BLOCK_COUNT, hparams.clip.n_layer, true); - ml.get_key(LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, hparams.clip.n_intermediate, true); - ml.get_key(LLM_KV_VISION_CLIP_HEAD_COUNT, hparams.clip.n_head, true); - ml.get_key(LLM_KV_VISION_CLIP_LAYERNORM_EPS, hparams.clip.eps, true); - // TODO: add image_std - std::string arch; - ml.get_key(LLM_KV_VISION_CLIP_ARCHITECTURE, arch, true); - for (auto & it : VISION_ARCH_NAMES) { - if (arch == it.second) { - hparams.clip.arch = it.first; - break; - } - } - } else if (!vision_type.empty()) { - throw std::runtime_error(format("unsupported vision type: %s", vision_type.c_str())); - } - // arch-specific KVs switch (model.arch) { case LLM_ARCH_LLAMA: @@ -6244,6 +6222,39 @@ static void llm_load_hparams( default: (void)0; } + // vision model + std::string vision_type; + ml.get_key(LLM_KV_VISION_TYPE, vision_type, false); + if (vision_type == "clip") { + hparams.has_vision = true; + std::string proj_type; + ml.get_key(LLM_KV_VISION_IMAGE_SIZE, hparams.clip.image_size, true); + ml.get_key(LLM_KV_VISION_PATCH_SIZE, hparams.clip.patch_size, true); + ml.get_key_or_arr(LLM_KV_VISION_IMAGE_MEAN, hparams.clip.image_mean, 3, true); + ml.get_key_or_arr(LLM_KV_VISION_IMAGE_STD, hparams.clip.image_std, 3, true); + ml.get_key(LLM_KV_VISION_CLIP_EMBEDDING_LENGTH, hparams.clip.hidden_size, true); + ml.get_key(LLM_KV_VISION_CLIP_BLOCK_COUNT, hparams.clip.n_layer, true); + ml.get_key(LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, hparams.clip.n_intermediate, true); + ml.get_key(LLM_KV_VISION_CLIP_HEAD_COUNT, hparams.clip.n_head, true); + ml.get_key(LLM_KV_VISION_CLIP_LAYERNORM_EPS, hparams.clip.eps, true); + ml.get_key(LLM_KV_VISION_CLIP_PROJECTOR_TYPE, proj_type, true); + if (proj_type == "mlp") { + hparams.clip.proj_type = CLIP_PROJECTOR_TYPE_MLP; + } else { + throw std::runtime_error(format("unsupported clip projector type: %s", proj_type.c_str())); + } + std::string arch; + ml.get_key(LLM_KV_VISION_CLIP_ARCHITECTURE, arch, true); + for (auto & it : VISION_ARCH_NAMES) { + if (arch == it.second) { + hparams.clip.arch = it.first; + break; + } + } + } else if (!vision_type.empty()) { + throw std::runtime_error(format("unsupported vision type: %s", vision_type.c_str())); + } + // arch-specific CLIP hparams switch (hparams.clip.arch) { case VISION_ARCH_LLAVA: