Check if apples are compared to apples

This commit is contained in:
M. Yusuf Sarıgöz 2023-10-11 08:15:51 +03:00
parent f1564bb2eb
commit ab2158796f
3 changed files with 23 additions and 4 deletions

View file

@ -1055,12 +1055,16 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
return true; return true;
} }
int clip_n_pos(struct clip_ctx * ctx) { int clip_n_mmproj_embd(struct clip_ctx * ctx) {
return ctx->vision_model.mm_2_b->ne[0];
}
int clip_n_patches(struct clip_ctx * ctx) {
auto & params = ctx->vision_model.hparams; auto & params = ctx->vision_model.hparams;
return (params.image_size / params.patch_size) * (params.image_size / params.patch_size); return (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
} }
size_t clip_embd_nbytes(struct clip_ctx * ctx) { size_t clip_embd_nbytes(struct clip_ctx * ctx) {
return clip_n_pos(ctx) * ctx->vision_model.mm_2_b->ne[0] * sizeof(float); return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
} }

View file

@ -25,7 +25,8 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity);
void clip_free(struct clip_ctx * ctx); void clip_free(struct clip_ctx * ctx);
size_t clip_embd_nbytes(struct clip_ctx * ctx); size_t clip_embd_nbytes(struct clip_ctx * ctx);
int clip_n_pos(struct clip_ctx * ctx); int clip_n_patches(struct clip_ctx * ctx);
int clip_n_mmproj_embd(struct clip_ctx * ctx);
// RGB uint8 image // RGB uint8 image
struct clip_image_u8 { struct clip_image_u8 {

View file

@ -49,7 +49,8 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
int n_img_pos = clip_n_pos(ctx_clip); int n_img_pos = clip_n_patches(ctx_clip);
int n_img_embd = clip_n_mmproj_embd(ctx_clip);
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)); float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
if (!image_embd) { if (!image_embd) {
@ -90,6 +91,19 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
// make sure that the correct mmproj was used, i.e., compare apples to apples
int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama));
if (n_img_embd != n_llama_embd) {
printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_img_embd, n_llama_embd);
llama_free(ctx_llama);
llama_free_model(model);
llama_backend_free();
free(image_embd);
return 1;
}
// process the prompt // process the prompt
// llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:" // llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:"