diff --git a/include/llama.h b/include/llama.h index 298b8d1bc..6049d2382 100644 --- a/include/llama.h +++ b/include/llama.h @@ -229,6 +229,16 @@ extern "C" { bool sorted; } llama_token_data_array; + // represent an RGB image + // size of data must be equal to 3*nx*ny + typedef struct llama_vision_bitmap { + uint32_t nx; + uint32_t ny; + unsigned char * data; + } llama_vision_bitmap; + + struct llama_vision_patches; + typedef bool (*llama_progress_callback)(float progress, void * user_data); // Input data for llama_decode diff --git a/src/llama-vision.cpp b/src/llama-vision.cpp new file mode 100644 index 000000000..87a33c181 --- /dev/null +++ b/src/llama-vision.cpp @@ -0,0 +1,841 @@ +#include "llama.h" +#include "llama-vision.h" +#include "llama-impl.h" + +#include // memcpy +#include +#include + +#ifndef NDEBUG +// for debugging +#include +#include +#include + +// export clip_image_u8 to bmp file for debugging +// https://codereview.stackexchange.com/questions/195121/writing-a-bitmap-image-from-c +struct clip_image_size; +static int bmp_export(const struct clip_image_u8 &img, const std::string &location); +#endif + +struct clip_image_size { + int width; + int height; +}; + +// RGB uint8 image +// Memory layout: RGBRGBRGB... +struct clip_image_u8 { + int nx; + int ny; + std::vector buf; + clip_image_u8() {} + clip_image_u8(const llama_vision_bitmap & bmp) { + nx = bmp.nx; + ny = bmp.ny; + buf.resize(nx*ny*3); + memcpy(buf.data(), bmp.data, buf.size()); + } +}; + +struct clip_image_u8_batch { + struct clip_image_u8 * data; + size_t size; +}; + +static int clip_n_patches(const clip_context & ctx) { + auto & hparams = ctx.model->hparams; + int n_patches = (hparams.image_size / hparams.patch_size) * (hparams.image_size / hparams.patch_size); + return n_patches; +} + +int clip_n_mmproj_embd(const clip_context & ctx) { + if (ctx.model->hparams.proj_type == CLIP_PROJECTOR_TYPE_MLP) { + return ctx.model->mm_2_b->ne[0]; + } else { + GGML_ASSERT(false && "invalid proj type"); + } +} + +/** + * Selects the best resolution from a list of possible resolutions based on the original size. + * + * @param original_size The original size of the image in the format (width, height). + * @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. + * @return The best fit resolution in the format (width, height). + */ +static clip_image_size select_best_resolution(const clip_image_size & original_size, const std::vector& possible_resolutions) { + int original_width = original_size.width; + int original_height = original_size.height; + + clip_image_size best_fit; + int max_effective_resolution = 0; + int min_wasted_resolution = std::numeric_limits::max(); + + for (const auto& resolution : possible_resolutions) { + int width = resolution.width; + int height = resolution.height; + float scale = std::min(static_cast(width) / original_width, static_cast(height) / original_height); + int downscaled_width = static_cast(original_width * scale); + int downscaled_height = static_cast(original_height * scale); + int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height); + int wasted_resolution = (width * height) - effective_resolution; + // LOG_DBG("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution); + if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) { + max_effective_resolution = effective_resolution; + min_wasted_resolution = wasted_resolution; + best_fit = resolution; + } + } + + return best_fit; +} + +static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) { + auto clip = [](int x, int lower, int upper) -> int { + return std::max(lower, std::min(x, upper)); + }; + + const int nx = img.nx; + const int ny = img.ny; + + dst.nx = target_width; + dst.ny = target_height; + dst.buf.resize(3 * target_width * target_height); + + float Cc; + float C[5]; + float d0, d2, d3, a0, a1, a2, a3; + int i, j, k, jj; + int x, y; + float dx, dy; + float tx, ty; + + tx = (float)nx / (float)target_width; + ty = (float)ny / (float)target_height; + + // Bicubic interpolation; adapted from ViT.cpp, inspired from : + // -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36 + // -> https://en.wikipedia.org/wiki/Bicubic_interpolation + + for (i = 0; i < target_height; i++) { + for (j = 0; j < target_width; j++) { + x = (int)(tx * j); + y = (int)(ty * i); + + dx = tx * j - x; + dy = ty * i - y; + + for (k = 0; k < 3; k++) { + for (jj = 0; jj <= 3; jj++) { + d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + + a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; + a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; + a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; + + C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx; + + d0 = C[0] - C[1]; + d2 = C[2] - C[1]; + d3 = C[3] - C[1]; + a0 = C[1]; + a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; + a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; + a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; + Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy; + + const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f); + dst.buf[(i * target_width + j) * 3 + k] = float(Cc2); + } + } + } + } + + return true; +} + +static std::vector divide_to_patches_u8(const clip_image_u8 & image, int patch_size) { + std::vector patches; + int width = image.nx; + int height = image.ny; + for (int i = 0; i < height; i += patch_size) { + for (int j = 0; j < width; j += patch_size) { + clip_image_u8 patch; + patch.nx = std::min(patch_size, width - j); + patch.ny = std::min(patch_size, height - i); + patch.buf.resize(3 * patch.nx * patch.ny); + for (int y = 0; y < patch.ny; ++y) { + for (int x = 0; x < patch.nx; ++x) { + for (int c = 0; c < 3; ++c) { + patch.buf[3 * (y * patch.nx + x) + c] = image.buf[3 * ((i + y) * width + (j + x)) + c]; + } + } + } + patches.push_back(patch); + } + } + return patches; +} + +// llava-1.6 type of resize_and_pad (black) +static clip_image_u8 resize_and_pad_image(const clip_image_u8 & image, const clip_image_size & target_resolution) { + int target_width = target_resolution.width; + int target_height = target_resolution.height; + + float scale_w = static_cast(target_width) / image.nx; + float scale_h = static_cast(target_height) / image.ny; + + int new_width, new_height; + + if (scale_w < scale_h) { + new_width = target_width; + new_height = std::min(static_cast(std::ceil(image.ny * scale_w)), target_height); + } else { + new_height = target_height; + new_width = std::min(static_cast(std::ceil(image.nx * scale_h)), target_width); + } + + clip_image_u8 resized_image; + // bilinear_resize(image, resized_image, new_width, new_height); + bicubic_resize(image, resized_image, new_width, new_height); + + clip_image_u8 padded_image; + padded_image.nx = target_width; + padded_image.ny = target_height; + padded_image.buf.resize(3 * target_width * target_height, 0); // Initialize with black + + // Calculate padding offsets + int pad_x = (target_width - new_width) / 2; + int pad_y = (target_height - new_height) / 2; + + // Copy the resized image into the center of the padded buffer + for (int y = 0; y < new_height; ++y) { + for (int x = 0; x < new_width; ++x) { + for (int c = 0; c < 3; ++c) { + padded_image.buf[3 * ((y + pad_y) * target_width + (x + pad_x)) + c] = resized_image.buf[3 * (y * new_width + x) + c]; + } + } + } + return padded_image; +} + +static void normalize_image_u8_to_f32(const clip_image_u8 & src, std::vector & dst, const std::array & mean, const std::array & std) { + dst.resize(src.buf.size()); + + for (size_t i = 0; i < src.buf.size(); ++i) { + int c = i % 3; // rgb + dst[i] = (static_cast(src.buf[i]) / 255.0f - mean[c]) / std[c]; + } +} + +// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector +// res_imgs memory is being allocated here, previous allocations will be freed if found +static llama_vision_patches clip_image_preprocess(const clip_context & ctx, const clip_image_u8 & img) { + bool pad_to_square = true; + auto & params = ctx.model->hparams; + // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing + if (params.mm_patch_merge_type == MM_PATCH_MERGE_SPATIAL_UNPAD) { + pad_to_square = false; + } + + llama_vision_patches output_imgs; + output_imgs.px = clip_n_patches(ctx); + output_imgs.py = clip_n_patches(ctx); + output_imgs.n_px = params.image_size / output_imgs.px; + output_imgs.n_py = params.image_size / output_imgs.py; + + // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) + // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 + + clip_image_u8 temp; + if (pad_to_square && img.nx != img.ny) { + // if the image is not square, pad it to a square + int longer_side = std::max(img.nx, img.ny); + temp.nx = longer_side; + temp.ny = longer_side; + temp.buf.resize(3 * longer_side * longer_side); + const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA (this is the mean rgb color * 255) + + // fill with background color + for (size_t i = 0; i < temp.buf.size(); i++) { + temp.buf[i] = bc[i % 3]; + } + + // copy from the input image + for (int y = 0; y < img.ny; y++) { + for (int x = 0; x < img.nx; x++) { + const int i = 3 * (y * img.nx + x); + const int j = 3 * (y * temp.nx + x); + temp.buf[j] = img.buf[i]; + temp.buf[j+1] = img.buf[i+1]; + temp.buf[j+2] = img.buf[i+2]; + } + } + } else if (params.image_grid_pinpoints[0] != 0) { + // "spatial_unpad" with "anyres" processing for llava-1.6 + std::vector possible_resolutions; + for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i += 2) { + clip_image_size s; + s.width = params.image_grid_pinpoints[i]; + s.height = params.image_grid_pinpoints[i+1]; + possible_resolutions.push_back(s); + } + clip_image_size best_resolution = select_best_resolution({img.nx, img.ny}, possible_resolutions); + // clip_image_save_to_bmp(*img, "input.bmp"); + temp = resize_and_pad_image(img, best_resolution); // we do not pad with mean-bg color anymore in llava-1.6 + // clip_image_save_to_bmp(*temp, "resized.bmp"); + + std::vector patches = divide_to_patches_u8(temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6) + + clip_image_u8 image_original_resize; + // bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square + bicubic_resize(img, image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square + patches.insert(patches.begin(), image_original_resize); + // clip_image_f32_batch_init(patches.size()); + output_imgs.buf.resize(patches.size()); + int num = 0; + for (auto & patch : patches) { + normalize_image_u8_to_f32(patch, output_imgs.buf[num], params.image_mean, params.image_std); + num++; + } + return output_imgs; + } else { + temp.nx = img.nx; + temp.ny = img.ny; + temp.buf.resize(img.buf.size()); + memcpy(temp.buf.data(), img.buf.data(), temp.buf.size()); + } + + const int nx = temp.nx; + const int ny = temp.ny; + // bmp_export(temp, "resized_vanilla.bmp"); + + const int nx2 = params.image_size; + const int ny2 = params.image_size; + std::vector res; + res.resize(3 * nx2 * ny2); + + const float scale = std::max(nx, ny) / (float)params.image_size; + + const int nx3 = int(nx / scale + 0.5f); + const int ny3 = int(ny / scale + 0.5f); + + const auto & m3 = params.image_mean; // {0.48145466f, 0.4578275f, 0.40821073f}; + const auto & s3 = params.image_std; // {0.26862954f, 0.26130258f, 0.27577711f}; + + for (int y = 0; y < ny3; y++) { + for (int x = 0; x < nx3; x++) { + for (int c = 0; c < 3; c++) { + // linear interpolation + const float sx = (x + 0.5f) * scale - 0.5f; + const float sy = (y + 0.5f) * scale - 0.5f; + + const int x0 = std::max(0, (int)std::floor(sx)); + const int y0 = std::max(0, (int)std::floor(sy)); + + const int x1 = std::min(x0 + 1, nx - 1); + const int y1 = std::min(y0 + 1, ny - 1); + + const float dx = sx - x0; + const float dy = sy - y0; + + const int j00 = 3 * (y0 * nx + x0) + c; + const int j01 = 3 * (y0 * nx + x1) + c; + const int j10 = 3 * (y1 * nx + x0) + c; + const int j11 = 3 * (y1 * nx + x1) + c; + + const float v00 = temp.buf[j00]; + const float v01 = temp.buf[j01]; + const float v10 = temp.buf[j10]; + const float v11 = temp.buf[j11]; + + const float v0 = v00 * (1.0f - dx) + v01 * dx; + const float v1 = v10 * (1.0f - dx) + v11 * dx; + + const float v = v0 * (1.0f - dy) + v1 * dy; + + const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f); + + const int i = 3 * (y * nx3 + x) + c; + + res[i] = ((float(v2) / 255.0f) - m3[c]) / s3[c]; + } + } + } + + output_imgs.buf.resize(1); + output_imgs.buf[0] = std::move(res); + + return output_imgs; +} + +static ggml_cgraph * clip_image_build_graph(clip_context & ctx, int batch_size, clip_image_size & image_size) { + auto & model = *ctx.model; + auto & hparams = ctx.model->hparams; + + const int hidden_size = hparams.hidden_size; + const int n_head = hparams.n_head; + const int d_head = hidden_size / n_head; + const int patch_size = hparams.patch_size; + const float eps = hparams.eps; + const int num_patches = ((image_size.width / patch_size) * (image_size.height / patch_size)); + const int num_positions = num_patches + (model.class_embedding ? 1 : 0); + + LLAMA_LOG_DEBUG("%s: num_patches = %d\n", __func__, num_patches); + + struct ggml_init_params params = { + /*.mem_size =*/ ctx.buf_compute_meta.size(), + /*.mem_buffer =*/ ctx.buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + // input + struct ggml_tensor * embeddings; + { + struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size.width, image_size.height, 3, batch_size); + ggml_set_name(inp_raw, "inp_raw"); + ggml_set_input(inp_raw); + + struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + + inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); + + if (model.patch_bias) { + inp = ggml_add(ctx0, inp, model.patch_bias); + } + // auto * ne = inp->ne; printf("%d %d %d %d\n", ne[0], ne[1], ne[2], ne[3]); + + embeddings = inp; + if (model.class_embedding) { + embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); + ggml_set_name(embeddings, "embeddings"); + ggml_set_input(embeddings); + embeddings = ggml_acc(ctx0, embeddings, model.class_embedding, + embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0); + embeddings = ggml_acc(ctx0, embeddings, inp, + embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); + } + + struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + embeddings = ggml_add(ctx0, + embeddings, + ggml_get_rows(ctx0, model.position_embeddings, positions)); + } + + // pre-layernorm + if (model.pre_norm_w) { + embeddings = ggml_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "pre_ln"); + + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_norm_w), model.pre_norm_b); + } + + // loop over layers + for (int il = 0; il < (int)hparams.n_layer + hparams.select_layer; il++) { + struct ggml_tensor * cur = embeddings; + + // layernorm1 + { + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, model.layers[il].norm_in_w), + model.layers[il].norm_in_b); + } + + // self-attention + { + + struct ggml_tensor * Q = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.layers[il].q_w, cur), + model.layers[il].q_b); + + Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); + Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size); + Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); + Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size); + + struct ggml_tensor * K = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.layers[il].k_w, cur), + model.layers[il].k_b); + + K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); + K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); + K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); + + struct ggml_tensor * V = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.layers[il].v_w, cur), + model.layers[il].v_b); + + V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); + V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); + V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + KQ = ggml_soft_max_inplace(ctx0, KQ); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size); + } + + // attention output + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].output_w, cur), model.layers[il].output_b); + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, embeddings); + + embeddings = cur; // embeddings = residual, cur = hidden_states + + // layernorm2 + { + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, model.layers[il].norm_out_w), + model.layers[il].norm_out_b); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_up_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ffn_up_b); + + if (hparams.use_gelu) { + cur = ggml_gelu_inplace(ctx0, cur); + } else { + cur = ggml_gelu_quick_inplace(ctx0, cur); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ffn_down_b); + + // residual 2 + cur = ggml_add(ctx0, embeddings, cur); + + embeddings = cur; + } + + // post-layernorm + if (model.post_norm_w) { + embeddings = ggml_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "post_ln"); + + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_norm_w), model.post_norm_b); + } + + // llava projector + { + embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); + + struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches); + ggml_set_name(patches, "patches"); + ggml_set_input(patches); + + // shape [1, 576, 1024] + // ne is whcn, ne = [1024, 576, 1, 1] + embeddings = ggml_get_rows(ctx0, embeddings, patches); + + if (hparams.proj_type == CLIP_PROJECTOR_TYPE_MLP) { + embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); + + embeddings = ggml_gelu(ctx0, embeddings); + embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); + } else { + GGML_ASSERT(false && "unsupported proj type"); + } + } + + // build the graph + ggml_build_forward_expand(gf, embeddings); + ggml_free(ctx0); + return gf; +} + +static int32_t clip_image_batch_encode(clip_context & ctx, const clip_image_f32_batch & imgs, std::vector & output) { + int batch_size = imgs.size(); + auto & model = *ctx.model; + auto & hparams = ctx.model->hparams; + + if (hparams.arch == VISION_ARCH_LLAVA) { + GGML_ASSERT(batch_size == 1); // TODO: support multiple images + } + + clip_image_size image_size{(int)hparams.image_size, (int)hparams.image_size}; + const int patch_size = hparams.patch_size; + const int num_patches = ((image_size.width / patch_size) * (image_size.height / patch_size)); + const int num_positions = num_patches + (model.class_embedding ? 1 : 0); + + LLAMA_LOG_DEBUG("%s: image_size = %d\n", __func__, hparams.image_size); + LLAMA_LOG_DEBUG("%s: num_positions = %d\n", __func__, num_positions); + + // build the inference graph + ggml_cgraph * gf = clip_image_build_graph(ctx, batch_size, image_size); + + // alloc memory for graph + bool ok = ggml_backend_sched_alloc_graph(ctx.sched, gf); + if (!ok) { + LLAMA_LOG_ERROR("failed to alloc memory for graph\n"); + return -1; + } + + // set raw input + { + struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw"); + float * data = (float *)malloc(ggml_nbytes(inp_raw)); + + for (int i = 0; i < batch_size; i++) { + const int nx = imgs[i].nx; + const int ny = imgs[i].ny; + const int n = nx * ny; + + for (int b = 0; b < batch_size; b++) { + for (int k = 0; k < 3; k++) { + for (int y = 0; y < ny; y++) { + for (int x = 0; x < nx; x++) { + data[(b * 3 * n) + k * n + y * nx + x] = imgs[b].buf[3 * (y * nx + x) + k]; + } + } + } + } + } + ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw)); + free(data); + } + + if (model.class_embedding) { + struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings"); + + void* zero_mem = malloc(ggml_nbytes(embeddings)); + memset(zero_mem, 0, ggml_nbytes(embeddings)); + ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings)); + free(zero_mem); + } + + { + struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); + + int* positions_data = (int*)malloc(ggml_nbytes(positions)); + for (int i = 0; i < num_positions; i++) { + positions_data[i] = i; + } + ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); + free(positions_data); + } + + { + struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); + int* patches_data = (int*)malloc(ggml_nbytes(patches)); + for (int i = 0; i < num_patches; i++) { + patches_data[i] = i + 1; + } + ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); + free(patches_data); + } + + // compute + ggml_backend_sched_graph_compute_async(ctx.sched, gf); + + // the last node is the embedding tensor + struct ggml_tensor * embeddings = ggml_graph_node(gf, -1); + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(ctx.sched, embeddings); + + // copy the embeddings to the location passed by the user + size_t out_nbytes = clip_n_patches(ctx)*clip_n_mmproj_embd(ctx)*sizeof(float); + GGML_ASSERT(out_nbytes == ggml_nbytes(embeddings)); + output.resize(out_nbytes); + ggml_backend_tensor_get_async(backend_embd, embeddings, output.data(), 0, ggml_nbytes(embeddings)); + + ggml_backend_sched_synchronize(ctx.sched); + + return 0; +} + +static int32_t clip_image_encode(clip_context & ctx, const clip_image_f32 & img, std::vector & output) { + clip_image_f32_batch imgs{img}; + return clip_image_batch_encode(ctx, imgs, output); +} + +static int32_t encode_image_with_clip(clip_context & ctx, const llama_img img, std::vector & output_embd) { + clip_image_u8 img_u8(img); + clip_image_f32_batch img_res_v; + auto & hparams = ctx.model->hparams; + // bmp_export(img_u8, "test_inp.bmp"); + + if (!clip_image_preprocess(ctx, img_u8, img_res_v)) { + LLAMA_LOG_ERROR("%s: unable to preprocess image\n", __func__); + return -2; + } + + switch (hparams.mm_patch_merge_type) { + case MM_PATCH_MERGE_FLAT: + { + // flat / default llava-1.5 type embedding + // n_output = clip_n_patches(ctx); + int32_t encoded = clip_image_encode(ctx, img_res_v[0], output_embd); + if (encoded != 0) { + LLAMA_LOG_ERROR("Unable to encode image\n"); + return encoded; + } + } break; + case MM_PATCH_MERGE_SPATIAL_UNPAD: + { + // TODO: support llava-1.6 + (void)0; + } break; + default: + GGML_ASSERT(false && "unsupported mm_patch_merge_type"); + } + + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////// +// public API + +int32_t llama_encode_vision_internal(clip_context & ctx, llama_batch_img * batch) { + if (batch->n_imgs == 0) { + return 0; + } + + // TODO: batching is not working atm, should be fixed later + const int n_embd = clip_n_mmproj_embd(ctx); + const int n_tokens_per_img = clip_n_patches(ctx); + const int n_pos = n_tokens_per_img*batch->n_imgs; + + ctx.out_embd.resize(n_embd*n_pos); + ctx.out_pos.resize(n_pos); + + for (int i = 0; i < batch->n_imgs; i++) { + std::vector output_single; + int32_t status = encode_image_with_clip(ctx, *batch->imgs[i], output_single); + if (status != 0) { + return status; + } + // copy output embeddings to result + for (int k = 0; k < n_embd*n_tokens_per_img; k++) { + ctx.out_embd[n_embd*n_tokens_per_img*i + k] = output_single[k]; + } + // fill position for all output tokens + for (int p = 0; p < n_tokens_per_img; p++) { + ctx.out_pos[n_tokens_per_img*i + p] = batch->pos[i] + p; + } + } + + return 0; +} + +void llama_vision_clear_output(clip_context & ctx) { + ctx.out_embd.clear(); + ctx.out_pos.clear(); +} + +//////////////////////////////////////////////////////////////////////////////////////// +// for debugging +#ifndef NDEBUG + +static int bmp_export(const struct clip_image_u8 &img, const std::string &location) { + const uint32_t width = img.nx; + const uint32_t height = img.ny; + // swap red and blue channel + std::vector buffer(width*height*3); + for (uint32_t y = 0; y < height; y++) { + for (uint32_t x = 0; x < width; x++) { + size_t base = x*3 + y*3*width; + buffer[base+2] = img.buf[base]; + buffer[base+1] = img.buf[base+1]; + buffer[base] = img.buf[base+2]; + } + } + const bool hasAlphaChannel = false; + + std::ofstream fout(location, std::ios::out | std::ios::binary); + + if (fout.fail()) { + return 0; + } + + //Padding + const uint8_t padding = hasAlphaChannel ? 0 : (4 - (width * 3) % 4) % 4; + + //Bitmap file header. + const char signature[2] = { 'B', 'M' }; + const uint32_t fileSize = buffer.size() * sizeof(uint8_t) + padding * (height - 1) + 14 + 124; + const uint32_t offset = 14 + 124; + + //Bitmap information header file + const uint32_t DIBSize = 124; + const int32_t bitmapWidth = width; + const int32_t bitmapHeight = height; + const uint16_t numPlanes = 1; + const uint16_t bitsPerPixel = (hasAlphaChannel) ? 32 : 24; + const uint32_t compressionMethod = (hasAlphaChannel) ? 3 : 0; //BI_RGB = 0, BI_BITFIELDS = 3 + const uint32_t bitmapSize = buffer.size() * sizeof(uint8_t); + const int32_t horizontalResolution = 2834; + const int32_t verticalResolution = 2834; + const uint32_t numColors = 0; + const uint32_t impColorCount = 0; + const uint32_t redBitmask = (hasAlphaChannel) ? 0x0000FF00 : 0; //ARGB32 pixel format + const uint32_t greenBitmask = (hasAlphaChannel) ? 0x00FF0000 : 0; + const uint32_t blueBitmask = (hasAlphaChannel) ? 0xFF000000 : 0; + const uint32_t alphaBitmask = (hasAlphaChannel) ? 0x000000FF : 0; + + //Writing the file header and information header to the file + std::vector header(offset, 0); + header[0] = signature[0]; + header[1] = signature[1]; + +#define BMP_HEADERS(i, variableName) header[i] = variableName; header[i+1] = variableName >> 8; header[i+2] = variableName >> 16; header[i+3] = variableName >> 24; + + BMP_HEADERS(2, fileSize); + BMP_HEADERS(6, 0); + BMP_HEADERS(10, offset); + BMP_HEADERS(14, DIBSize); + BMP_HEADERS(18, bitmapWidth); + BMP_HEADERS(22, bitmapHeight); + + header[26] = (uint8_t)numPlanes; + header[27] = (uint8_t)(numPlanes >> 8); + header[28] = (uint8_t)bitsPerPixel; + header[29] = (uint8_t)(bitsPerPixel >> 8); + + BMP_HEADERS(30, compressionMethod); + BMP_HEADERS(34, (unsigned char)bitmapSize); + BMP_HEADERS(38, (unsigned char)horizontalResolution); + BMP_HEADERS(42, (unsigned char)verticalResolution); + BMP_HEADERS(46, (unsigned char)numColors); + BMP_HEADERS(50, (unsigned char)impColorCount); + BMP_HEADERS(54, (unsigned char)redBitmask); + BMP_HEADERS(58, (unsigned char)greenBitmask); + BMP_HEADERS(62, (unsigned char)blueBitmask); + BMP_HEADERS(66, alphaBitmask); + +#undef BMP_HEADERS + + fout.write((char *)header.data(), sizeof(uint8_t) * header.size()); + + //Writing the pixel array + const uint32_t bWidth = bitsPerPixel / 8 * width; + + for (int i = height - 1; i >= 0; i--) { + std::vector row(buffer.begin() + i * bWidth, buffer.begin() + i * bWidth + bWidth); + fout.write((char *)row.data(), row.size() * sizeof(uint8_t)); + fout.seekp(padding * sizeof(uint8_t), std::ios::cur); + } + + fout.close(); + return 1; +} + +#endif + diff --git a/src/llama-vision.h b/src/llama-vision.h new file mode 100644 index 000000000..d7c922d99 --- /dev/null +++ b/src/llama-vision.h @@ -0,0 +1,151 @@ +#pragma once + +#include "ggml.h" +#include "llama.h" + +#include +#include + +enum vision_arch { + VISION_ARCH_UNKNOWN, + VISION_ARCH_LLAVA, +}; + +enum clip_projector_type { + CLIP_PROJECTOR_TYPE_UNKNOWN, + CLIP_PROJECTOR_TYPE_MLP, +}; + +enum mm_patch_merge { + MM_PATCH_MERGE_UNKNOWN, + MM_PATCH_MERGE_FLAT, + MM_PATCH_MERGE_SPATIAL_UNPAD, +}; + +struct clip_hparams { + vision_arch arch = VISION_ARCH_UNKNOWN; + + uint32_t image_size; + uint32_t patch_size; + uint32_t hidden_size; + uint32_t n_intermediate; + uint32_t projection_dim; + uint32_t n_head; + uint32_t n_layer; + uint32_t max_pos_embd; + int32_t select_layer = 0; + bool use_gelu = false; + + float eps; + + clip_projector_type proj_type = CLIP_PROJECTOR_TYPE_UNKNOWN; + mm_patch_merge mm_patch_merge_type = MM_PATCH_MERGE_FLAT; + + std::array image_mean; + std::array image_std; + + std::array image_grid_pinpoints; + int32_t image_crop_resolution; +}; + +struct clip_layer { + // attention + struct ggml_tensor * k_w = NULL; + struct ggml_tensor * k_b = NULL; + struct ggml_tensor * q_w = NULL; + struct ggml_tensor * q_b = NULL; + struct ggml_tensor * v_w = NULL; + struct ggml_tensor * v_b = NULL; + + struct ggml_tensor * output_w = NULL; + struct ggml_tensor * output_b = NULL; + + // layernorm 1 + struct ggml_tensor * norm_in_w = NULL; + struct ggml_tensor * norm_in_b = NULL; + + // ff + struct ggml_tensor * ffn_up_w = NULL; + struct ggml_tensor * ffn_up_b = NULL; + + struct ggml_tensor * ffn_down_w = NULL; + struct ggml_tensor * ffn_down_b = NULL; + + // layernorm 2 + struct ggml_tensor * norm_out_w = NULL; + struct ggml_tensor * norm_out_b = NULL; +}; + +struct clip_vision_model { + struct clip_hparams hparams; + + // embeddings + struct ggml_tensor * class_embedding = NULL; + struct ggml_tensor * patch_embeddings = NULL; + struct ggml_tensor * patch_bias = NULL; + struct ggml_tensor * position_embeddings = NULL; + + struct ggml_tensor * pre_norm_w = NULL; + struct ggml_tensor * pre_norm_b = NULL; + + std::vector layers; + + struct ggml_tensor * post_norm_w = NULL; + struct ggml_tensor * post_norm_b = NULL; + + struct ggml_tensor * projection = NULL; + + // LLaVA projection + struct ggml_tensor * mm_1_w = NULL; + struct ggml_tensor * mm_1_b = NULL; + struct ggml_tensor * mm_2_w = NULL; + struct ggml_tensor * mm_2_b = NULL; + + struct ggml_tensor * image_newline = NULL; +}; + +struct clip_context { + // memory buffers used to evaluate the model + std::vector buf_compute_meta; + ggml_backend_sched_t sched = nullptr; + + const clip_vision_model * model; + + // temporary output data, to be picked up by llama_decode() + std::vector out_embd; // size == n_tokens * n_embd + std::vector out_pos; // position of each token +}; + +struct llama_vision_patches { + uint32_t px; // size of patch + uint32_t py; // size of patch + size_t n_px; // number of patches in x direction + size_t n_py; // number of patches in y direction + // RGB float32 image (NHWC) + // Memory layout: RGBRGBRGB... + std::vector> buf; // preprocessed image data +}; + +mm_patch_merge mm_patch_merge_from_name(std::string & name) { + if (name == "flat") { + return MM_PATCH_MERGE_FLAT; + } else if (name == "spatial_unpad") { + return MM_PATCH_MERGE_SPATIAL_UNPAD; + } + return MM_PATCH_MERGE_UNKNOWN; +} + +clip_projector_type clip_projector_type_from_name(std::string & name) { + if (name == "mlp") { + return CLIP_PROJECTOR_TYPE_MLP; + } + return CLIP_PROJECTOR_TYPE_UNKNOWN; +} + +llama_vision_patches * llama_vision_patches_init(llama_vision_bitmap * bmp); +void llama_vision_patches_free(llama_vision_patches * p); + +int32_t llama_vision_encode_impl(clip_context & ctx, llama_vision_patches * p); + +// dimension of the output embeddings, must be equal to n_embd of language model +int clip_n_mmproj_embd(const clip_context & ctx);