llava : fix compile warnings
This commit is contained in:
parent
a2848854a4
commit
65ec518d41
3 changed files with 112 additions and 96 deletions
|
@ -30,6 +30,7 @@
|
|||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <cinttypes>
|
||||
#include <limits>
|
||||
|
||||
// #define CLIP_DEBUG_FUNCTIONS
|
||||
|
||||
|
@ -263,6 +264,24 @@ static projector_type clip_projector_type_from_string(const std::string & name)
|
|||
// clip layers
|
||||
//
|
||||
|
||||
struct clip_hparams {
|
||||
int32_t image_size;
|
||||
int32_t patch_size;
|
||||
int32_t hidden_size;
|
||||
int32_t n_intermediate;
|
||||
int32_t projection_dim;
|
||||
int32_t n_head;
|
||||
int32_t n_layer;
|
||||
|
||||
float eps;
|
||||
|
||||
char mm_patch_merge_type[32]="flat"; // spatial_unpad or flat (default)
|
||||
|
||||
int32_t image_grid_pinpoints[32];
|
||||
int32_t image_crop_resolution;
|
||||
|
||||
};
|
||||
|
||||
struct clip_layer {
|
||||
// attention
|
||||
struct ggml_tensor * k_w;
|
||||
|
@ -292,7 +311,7 @@ struct clip_layer {
|
|||
};
|
||||
|
||||
struct clip_vision_model {
|
||||
struct clip_vision_hparams hparams;
|
||||
struct clip_hparams hparams;
|
||||
|
||||
// embeddings
|
||||
struct ggml_tensor * class_embedding;
|
||||
|
@ -376,10 +395,6 @@ struct clip_ctx {
|
|||
ggml_allocr * compute_alloc = NULL;
|
||||
};
|
||||
|
||||
const struct clip_vision_hparams clip_get_vision_hparams(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams;
|
||||
}
|
||||
|
||||
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs) {
|
||||
if (!ctx->has_vision_encoder) {
|
||||
printf("This gguf file seems to have no vision encoder\n");
|
||||
|
@ -392,7 +407,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
const int image_size = hparams.image_size;
|
||||
const int patch_size = hparams.patch_size;
|
||||
const int num_patches = ((image_size / patch_size) * (image_size / patch_size));
|
||||
const int num_patches_per_side = image_size / patch_size;
|
||||
const int num_patches_per_side = image_size / patch_size; GGML_UNUSED(num_patches_per_side);
|
||||
const int num_positions = num_patches + 1;
|
||||
const int hidden_size = hparams.hidden_size;
|
||||
const int n_head = hparams.n_head;
|
||||
|
@ -1292,7 +1307,7 @@ inline float lerp(float s, float e, float t) {
|
|||
return s + (e - s) * t;
|
||||
}
|
||||
// Bilinear resize function
|
||||
void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) {
|
||||
static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) {
|
||||
dst.nx = target_width;
|
||||
dst.ny = target_height;
|
||||
dst.buf.resize(3 * target_width * target_height);
|
||||
|
@ -1327,7 +1342,7 @@ void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_wi
|
|||
}
|
||||
|
||||
// Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not
|
||||
void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* dst, const float mean[3], const float std[3]) {
|
||||
static void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* dst, const float mean[3], const float std[3]) {
|
||||
dst->nx = src->nx;
|
||||
dst->ny = src->ny;
|
||||
dst->buf.resize(src->buf.size());
|
||||
|
@ -1338,12 +1353,11 @@ void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* dst, co
|
|||
}
|
||||
}
|
||||
|
||||
inline float clip(float x, float lower, float upper)
|
||||
{
|
||||
inline float clip(float x, float lower, float upper) {
|
||||
return std::max(lower, std::min(x, upper));
|
||||
}
|
||||
bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_width, int target_height)
|
||||
{
|
||||
|
||||
static bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_width, int target_height) {
|
||||
const int nx = img.nx;
|
||||
const int ny = img.ny;
|
||||
|
||||
|
@ -1351,11 +1365,10 @@ bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_wid
|
|||
dst.ny = target_height;
|
||||
dst.buf.resize(3 * target_width * target_height);
|
||||
|
||||
int a, b, c, d, index;
|
||||
float Ca, Cb, Cc;
|
||||
float Cc;
|
||||
float C[5];
|
||||
float d0, d2, d3, a0, a1, a2, a3;
|
||||
int i, j, k, ii, jj;
|
||||
int i, j, k, jj;
|
||||
int x, y;
|
||||
float dx, dy;
|
||||
float tx, ty;
|
||||
|
@ -1363,31 +1376,20 @@ bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_wid
|
|||
tx = (float)nx / (float)target_width;
|
||||
ty = (float)ny / (float)target_height;
|
||||
|
||||
float scale = std::max(tx, ty);
|
||||
|
||||
// Bicubic interpolation; adapted from ViT.cpp, inspired from :
|
||||
// -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36
|
||||
// -> https://en.wikipedia.org/wiki/Bicubic_interpolation
|
||||
|
||||
for (i = 0; i < target_height; i++)
|
||||
{
|
||||
for (j = 0; j < target_width; j++)
|
||||
{
|
||||
for (i = 0; i < target_height; i++) {
|
||||
for (j = 0; j < target_width; j++) {
|
||||
x = (int)(tx * j);
|
||||
y = (int)(ty * i);
|
||||
|
||||
dx = tx * j - x;
|
||||
dy = ty * i - y;
|
||||
|
||||
index = (y * nx + x) * 3;
|
||||
a = (y * nx + (x + 1)) * 3;
|
||||
b = ((y + 1) * nx + x) * 3;
|
||||
c = ((y + 1) * nx + (x + 1)) * 3;
|
||||
|
||||
for (k = 0; k < 3; k++)
|
||||
{
|
||||
for (jj = 0; jj <= 3; jj++)
|
||||
{
|
||||
for (k = 0; k < 3; k++) {
|
||||
for (jj = 0; jj <= 3; jj++) {
|
||||
d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
|
||||
d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
|
||||
d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
|
||||
|
@ -1396,6 +1398,7 @@ bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_wid
|
|||
a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3;
|
||||
a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2;
|
||||
a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3;
|
||||
|
||||
C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx;
|
||||
|
||||
d0 = C[0] - C[1];
|
||||
|
@ -1418,7 +1421,7 @@ bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_wid
|
|||
}
|
||||
|
||||
// llava-1.6 type of resize_and_pad (black)
|
||||
void resize_and_pad_image(const clip_image_u8& image, clip_image_u8 &image_output, const std::pair<int, int>& target_resolution) {
|
||||
static void resize_and_pad_image(const clip_image_u8& image, clip_image_u8 &image_output, const std::pair<int, int>& target_resolution) {
|
||||
int target_width = target_resolution.first;
|
||||
int target_height = target_resolution.second;
|
||||
|
||||
|
@ -1494,7 +1497,7 @@ static std::pair<int, int> select_best_resolution(const std::pair<int, int>& ori
|
|||
}
|
||||
|
||||
|
||||
std::vector<clip_image_u8*> divide_to_patches_u8(const clip_image_u8& image, int patch_size) {
|
||||
static std::vector<clip_image_u8*> divide_to_patches_u8(const clip_image_u8 & image, int patch_size) {
|
||||
std::vector<clip_image_u8*> patches;
|
||||
int width = image.nx;
|
||||
int height = image.ny;
|
||||
|
@ -1710,6 +1713,30 @@ void clip_free(clip_ctx * ctx) {
|
|||
delete ctx;
|
||||
}
|
||||
|
||||
size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
|
||||
return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
|
||||
}
|
||||
|
||||
int32_t clip_image_size(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams.image_size;
|
||||
}
|
||||
|
||||
int32_t clip_patch_size(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams.patch_size;
|
||||
}
|
||||
|
||||
int32_t clip_hidden_size(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams.hidden_size;
|
||||
}
|
||||
|
||||
const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams.mm_patch_merge_type;
|
||||
}
|
||||
|
||||
const int32_t * clip_image_grid(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams.image_grid_pinpoints;
|
||||
}
|
||||
|
||||
bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) {
|
||||
if (!ctx->has_vision_encoder) {
|
||||
printf("This gguf file seems to have no vision encoder\n");
|
||||
|
@ -1973,7 +2000,3 @@ int clip_n_patches(const struct clip_ctx * ctx) {
|
|||
}
|
||||
return n_patches;
|
||||
}
|
||||
|
||||
size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
|
||||
return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
|
||||
}
|
||||
|
|
|
@ -24,25 +24,7 @@ struct clip_ctx;
|
|||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct clip_vision_hparams {
|
||||
int32_t image_size;
|
||||
int32_t patch_size;
|
||||
int32_t hidden_size;
|
||||
int32_t n_intermediate;
|
||||
int32_t projection_dim;
|
||||
int32_t n_head;
|
||||
int32_t n_layer;
|
||||
|
||||
float eps;
|
||||
|
||||
char mm_patch_merge_type[32]="flat"; // spatial_unpad or flat (default)
|
||||
int32_t image_grid_pinpoints[32];
|
||||
int32_t image_crop_resolution;
|
||||
|
||||
};
|
||||
|
||||
struct clip_ctx;
|
||||
CLIP_API const struct clip_vision_hparams clip_get_vision_hparams(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API struct clip_ctx * clip_model_load(const char * fname, int verbosity);
|
||||
CLIP_API struct clip_ctx * clip_model_load_cpu(const char * fname, int verbosity);
|
||||
|
@ -51,6 +33,15 @@ CLIP_API void clip_free(struct clip_ctx * ctx);
|
|||
|
||||
CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API int32_t clip_image_size (const struct clip_ctx * ctx);
|
||||
CLIP_API int32_t clip_patch_size (const struct clip_ctx * ctx);
|
||||
CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx);
|
||||
|
||||
// TODO: should be enum, not string
|
||||
CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
|
||||
CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx);
|
||||
|
||||
|
|
|
@ -2,14 +2,13 @@
|
|||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include "llava.h"
|
||||
#include "base64.hpp"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
|
||||
#include "base64.hpp"
|
||||
|
||||
// RGB uint8 image
|
||||
struct clip_image_u8 {
|
||||
int nx;
|
||||
|
@ -37,6 +36,7 @@ struct clip_image_f32 {
|
|||
static std::pair<int, int> select_best_resolution(const std::pair<int, int>& original_size, const std::vector<std::pair<int, int>>& possible_resolutions) {
|
||||
int original_width = original_size.first;
|
||||
int original_height = original_size.second;
|
||||
|
||||
std::pair<int, int> best_fit;
|
||||
int max_effective_resolution = 0;
|
||||
int min_wasted_resolution = std::numeric_limits<int>::max();
|
||||
|
@ -59,6 +59,7 @@ static std::pair<int, int> select_best_resolution(const std::pair<int, int>& ori
|
|||
|
||||
return best_fit;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the anyres image grid shape object
|
||||
*
|
||||
|
@ -67,7 +68,7 @@ static std::pair<int, int> select_best_resolution(const std::pair<int, int>& ori
|
|||
* @param image_patch_size
|
||||
* @return <int, int>
|
||||
*/
|
||||
struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair<int, int>& image_size, const std::vector<std::pair<int, int>>& grid_pinpoints, int image_patch_size) {
|
||||
static struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair<int, int> & image_size, const std::vector<std::pair<int, int>> & grid_pinpoints, int image_patch_size) {
|
||||
/**
|
||||
Conversion from gguf flat array to vector:
|
||||
std::vector<std::pair<int, int>> possible_resolutions;
|
||||
|
@ -79,22 +80,26 @@ struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair<int, in
|
|||
return {best_resolution.first / image_patch_size, best_resolution.second / image_patch_size};
|
||||
}
|
||||
|
||||
|
||||
// Take the image segments in a grid configuration and return the embeddings and the number of embeddings into preallocated memory (image_embd_out)
|
||||
static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *> & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out) {
|
||||
struct temp_model {
|
||||
struct {
|
||||
struct ggml_tensor * newline;
|
||||
struct ggml_context * ctx;
|
||||
} model;
|
||||
|
||||
auto & vparams = clip_get_vision_hparams(ctx_clip);
|
||||
auto num_patches_per_side = vparams.image_size / vparams.patch_size; // 336 / 14 = 24 - used for embedding-patching boxes (24*24 = 576 patches)
|
||||
const int32_t image_size = clip_image_size(ctx_clip);
|
||||
const int32_t patch_size = clip_patch_size(ctx_clip);
|
||||
|
||||
int32_t num_patches_per_side = image_size / patch_size; // 336 / 14 = 24 - used for embedding-patching boxes (24*24 = 576 patches)
|
||||
|
||||
int num_patches_width = grid_shape.first; // grid 1-4
|
||||
int num_patches_height = grid_shape.second; // grid 1-4
|
||||
|
||||
const size_t num_images = num_patches_width + num_patches_height + 1;
|
||||
|
||||
// TODO: size calculation is not calculated - it's only tens of MB
|
||||
size_t ctx_size = 0;
|
||||
|
||||
{
|
||||
ctx_size += clip_embd_nbytes(ctx_clip) * num_images * 8; // image_features
|
||||
ctx_size += 1024*1024 * ggml_type_size(GGML_TYPE_F32);
|
||||
|
@ -105,6 +110,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
|
|||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ false, // NOTE: this should be false when using the legacy API
|
||||
};
|
||||
|
||||
// Python reference code for full unpad:
|
||||
/*
|
||||
base_image_feature = image_feature[0]
|
||||
|
@ -138,7 +144,6 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
|
|||
*/
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
ggml_context *ctx_noalloc = ggml_init({2048, NULL, true});
|
||||
|
||||
ggml_tensor * newline_tmp = clip_get_newline_tensor(ctx_clip);
|
||||
model.newline = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, newline_tmp->ne[0]);
|
||||
|
@ -147,8 +152,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
|
|||
printf("newline_tmp tensor buffer is NULL\n");
|
||||
}
|
||||
ggml_backend_tensor_get(newline_tmp, model.newline->data, 0, ggml_nbytes(newline_tmp));
|
||||
} else
|
||||
{
|
||||
} else {
|
||||
model.newline->data = newline_tmp->data;
|
||||
if (model.newline->data == NULL) {
|
||||
printf("newline_tmp tensor data is NULL\n");
|
||||
|
@ -158,8 +162,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
|
|||
struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_patches(ctx_clip), num_images - 1); // example: 4096 x 576 x 4
|
||||
// ggml_tensor_printf(image_features,"image_features",__LINE__,false,false);
|
||||
// fill it with the image embeddings, ignoring the base
|
||||
for (int i = 1; i < num_images; i++)
|
||||
{
|
||||
for (size_t i = 1; i < num_images; i++) {
|
||||
size_t offset = (i-1) * clip_embd_nbytes(ctx_clip);
|
||||
memcpy((uint8_t *)(image_features->data) + offset, image_embd_v[i], clip_embd_nbytes(ctx_clip));
|
||||
}
|
||||
|
@ -222,10 +225,10 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
|||
}
|
||||
|
||||
const int64_t t_img_enc_start_us = ggml_time_us();
|
||||
auto & vparams = clip_get_vision_hparams(ctx_clip);
|
||||
|
||||
if (strcmp(vparams.mm_patch_merge_type, "spatial_unpad") != 0)
|
||||
{
|
||||
const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip);
|
||||
|
||||
if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) {
|
||||
// flat / default llava-1.5 type embedding
|
||||
*n_img_pos = clip_n_patches(ctx_clip);
|
||||
bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); // image_embd shape is 576 x 4096
|
||||
|
@ -235,41 +238,43 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
|||
|
||||
return false;
|
||||
}
|
||||
} else
|
||||
{
|
||||
} else {
|
||||
// spatial_unpad llava-1.6 type embedding
|
||||
// TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a solution to quickly get batching working
|
||||
std::vector<float *> image_embd_v;
|
||||
image_embd_v.resize(img_res_v.size);
|
||||
for (int i = 0; i < img_res_v.size; i++)
|
||||
{
|
||||
for (size_t i = 0; i < img_res_v.size; i++) {
|
||||
image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184
|
||||
bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside
|
||||
const bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside
|
||||
if (!encoded) {
|
||||
fprintf(stderr, "Unable to encode image - spatial_unpad - subimage %d of %d\n", i+1, (int)img_res_v.size);
|
||||
fprintf(stderr, "Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
const int64_t t_img_enc_batch_us = ggml_time_us();
|
||||
printf("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0);
|
||||
|
||||
const int32_t * image_grid = clip_image_grid(ctx_clip);
|
||||
|
||||
std::vector<std::pair<int, int>> grid_pinpoints;
|
||||
for (int i = 0; i < 32 && vparams.image_grid_pinpoints[i] != 0; i+=2) {
|
||||
grid_pinpoints.push_back({vparams.image_grid_pinpoints[i], vparams.image_grid_pinpoints[i+1]});
|
||||
for (int i = 0; i < 32 && image_grid[i] != 0; i += 2) {
|
||||
grid_pinpoints.push_back({image_grid[i], image_grid[i+1]});
|
||||
}
|
||||
|
||||
// free all img_res_v - not needed anymore
|
||||
delete[] img_res_v.data;
|
||||
img_res_v.size = 0;
|
||||
img_res_v.data = nullptr;
|
||||
struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, vparams.image_size);
|
||||
|
||||
const int32_t image_size = clip_image_size(ctx_clip);
|
||||
|
||||
struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, image_size);
|
||||
|
||||
int n_img_pos_out;
|
||||
clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out);
|
||||
*n_img_pos = n_img_pos_out;
|
||||
|
||||
for (int i = 0; i < image_embd_v.size(); i++)
|
||||
{
|
||||
for (size_t i = 0; i < image_embd_v.size(); i++) {
|
||||
free(image_embd_v[i]);
|
||||
}
|
||||
image_embd_v.clear();
|
||||
|
@ -278,10 +283,9 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
|||
// clip_image_u8 * tmp = clip_image_u8_init();
|
||||
// clip_image_convert_f32_to_u8(*image_feature, *tmp);
|
||||
// clip_image_save_to_bmp(*tmp, "image_feature.bmp");
|
||||
|
||||
}
|
||||
printf("%s: image embedding created: %d tokens\n", __func__, *n_img_pos);
|
||||
|
||||
printf("%s: image embedding created: %d tokens\n", __func__, *n_img_pos);
|
||||
|
||||
const int64_t t_img_enc_end_us = ggml_time_us();
|
||||
float t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0;
|
||||
|
@ -291,8 +295,6 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
|||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) {
|
||||
// make sure that the correct mmproj was used, i.e., compare apples to apples
|
||||
int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue