Check if apples are compared to apples
This commit is contained in:
parent
f1564bb2eb
commit
ab2158796f
3 changed files with 23 additions and 4 deletions
|
@ -1055,12 +1055,16 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
int clip_n_pos(struct clip_ctx * ctx) {
|
int clip_n_mmproj_embd(struct clip_ctx * ctx) {
|
||||||
|
return ctx->vision_model.mm_2_b->ne[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
int clip_n_patches(struct clip_ctx * ctx) {
|
||||||
auto & params = ctx->vision_model.hparams;
|
auto & params = ctx->vision_model.hparams;
|
||||||
|
|
||||||
return (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
|
return (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t clip_embd_nbytes(struct clip_ctx * ctx) {
|
size_t clip_embd_nbytes(struct clip_ctx * ctx) {
|
||||||
return clip_n_pos(ctx) * ctx->vision_model.mm_2_b->ne[0] * sizeof(float);
|
return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,8 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity);
|
||||||
void clip_free(struct clip_ctx * ctx);
|
void clip_free(struct clip_ctx * ctx);
|
||||||
|
|
||||||
size_t clip_embd_nbytes(struct clip_ctx * ctx);
|
size_t clip_embd_nbytes(struct clip_ctx * ctx);
|
||||||
int clip_n_pos(struct clip_ctx * ctx);
|
int clip_n_patches(struct clip_ctx * ctx);
|
||||||
|
int clip_n_mmproj_embd(struct clip_ctx * ctx);
|
||||||
|
|
||||||
// RGB uint8 image
|
// RGB uint8 image
|
||||||
struct clip_image_u8 {
|
struct clip_image_u8 {
|
||||||
|
|
|
@ -49,7 +49,8 @@ int main(int argc, char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
int n_img_pos = clip_n_pos(ctx_clip);
|
int n_img_pos = clip_n_patches(ctx_clip);
|
||||||
|
int n_img_embd = clip_n_mmproj_embd(ctx_clip);
|
||||||
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
|
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
|
||||||
|
|
||||||
if (!image_embd) {
|
if (!image_embd) {
|
||||||
|
@ -90,6 +91,19 @@ int main(int argc, char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// make sure that the correct mmproj was used, i.e., compare apples to apples
|
||||||
|
int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama));
|
||||||
|
if (n_img_embd != n_llama_embd) {
|
||||||
|
printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_img_embd, n_llama_embd);
|
||||||
|
|
||||||
|
llama_free(ctx_llama);
|
||||||
|
llama_free_model(model);
|
||||||
|
llama_backend_free();
|
||||||
|
free(image_embd);
|
||||||
|
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
// process the prompt
|
// process the prompt
|
||||||
// llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:"
|
// llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:"
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue