diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 2cbddfa75..3d24d736b 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -1055,12 +1055,16 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i return true; } -int clip_n_pos(struct clip_ctx * ctx) { +int clip_n_mmproj_embd(struct clip_ctx * ctx) { + return ctx->vision_model.mm_2_b->ne[0]; +} + +int clip_n_patches(struct clip_ctx * ctx) { auto & params = ctx->vision_model.hparams; return (params.image_size / params.patch_size) * (params.image_size / params.patch_size); } size_t clip_embd_nbytes(struct clip_ctx * ctx) { - return clip_n_pos(ctx) * ctx->vision_model.mm_2_b->ne[0] * sizeof(float); + return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float); } diff --git a/examples/llava/clip.h b/examples/llava/clip.h index c651332d6..3d7261e29 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -25,7 +25,8 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity); void clip_free(struct clip_ctx * ctx); size_t clip_embd_nbytes(struct clip_ctx * ctx); -int clip_n_pos(struct clip_ctx * ctx); +int clip_n_patches(struct clip_ctx * ctx); +int clip_n_mmproj_embd(struct clip_ctx * ctx); // RGB uint8 image struct clip_image_u8 { diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index d2716e046..7dfdc8d3b 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -49,7 +49,8 @@ int main(int argc, char ** argv) { return 1; } - int n_img_pos = clip_n_pos(ctx_clip); + int n_img_pos = clip_n_patches(ctx_clip); + int n_img_embd = clip_n_mmproj_embd(ctx_clip); float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)); if (!image_embd) { @@ -90,6 +91,19 @@ int main(int argc, char ** argv) { return 1; } + // make sure that the correct mmproj was used, i.e., compare apples to apples + int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama)); + if (n_img_embd != n_llama_embd) { + printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_img_embd, n_llama_embd); + + llama_free(ctx_llama); + llama_free_model(model); + llama_backend_free(); + free(image_embd); + + return 1; + } + // process the prompt // llava chat format is "USER: \n\nASSISTANT:"