Add embeddings scale to clip_ctx to rescale final image embeddings

This commit is contained in:
Andrei Betlen 2024-10-01 06:12:31 -04:00
parent dd34db2636
commit 9aecd38a8d

View file

@ -104,6 +104,7 @@ static std::string format(const char * fmt, ...) {
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
#define KEY_IMAGE_STD "clip.vision.image_std"
#define KEY_PROJ_TYPE "clip.projector_type"
#define KEY_EMBD_SCALE "clip.embeddings_scale"
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
@ -548,6 +549,7 @@ struct clip_ctx {
float image_mean[3];
float image_std[3];
float embeddings_scale = 1.0f;
bool use_gelu = false;
int32_t ftype = 1;
@ -1021,6 +1023,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
}
}
if (ctx->embeddings_scale != 1.0f) {
embeddings = ggml_scale(ctx0, embeddings, ctx->embeddings_scale);
}
// build the graph
ggml_build_forward_expand(gf, embeddings);
@ -1322,6 +1328,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
new_clip->image_std[i] = std_data[i];
}
try {
new_clip->embeddings_scale = get_f32(ctx, KEY_EMBD_SCALE);
} catch (const std::exception& /*e*/) {
new_clip->embeddings_scale = 1.0f;
}
if (verbosity >= 2) {
LOG_INF("\n%s: vision model hparams\n", __func__);
LOG_INF("image_size %d\n", hparams.image_size);