diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index b3de2d73b..9539b0c75 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -104,6 +104,7 @@ static std::string format(const char * fmt, ...) { #define KEY_IMAGE_MEAN "clip.vision.image_mean" #define KEY_IMAGE_STD "clip.vision.image_std" #define KEY_PROJ_TYPE "clip.projector_type" +#define KEY_EMBD_SCALE "clip.embeddings_scale" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" @@ -548,6 +549,7 @@ struct clip_ctx { float image_mean[3]; float image_std[3]; + float embeddings_scale = 1.0f; bool use_gelu = false; int32_t ftype = 1; @@ -1021,6 +1023,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 } } + if (ctx->embeddings_scale != 1.0f) { + embeddings = ggml_scale(ctx0, embeddings, ctx->embeddings_scale); + } + // build the graph ggml_build_forward_expand(gf, embeddings); @@ -1322,6 +1328,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->image_std[i] = std_data[i]; } + try { + new_clip->embeddings_scale = get_f32(ctx, KEY_EMBD_SCALE); + } catch (const std::exception& /*e*/) { + new_clip->embeddings_scale = 1.0f; + } + if (verbosity >= 2) { LOG_INF("\n%s: vision model hparams\n", __func__); LOG_INF("image_size %d\n", hparams.image_size);