Add embeddings scale to clip_ctx to rescale final image embeddings
This commit is contained in:
parent
dd34db2636
commit
9aecd38a8d
1 changed files with 12 additions and 0 deletions
|
@ -104,6 +104,7 @@ static std::string format(const char * fmt, ...) {
|
||||||
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
|
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
|
||||||
#define KEY_IMAGE_STD "clip.vision.image_std"
|
#define KEY_IMAGE_STD "clip.vision.image_std"
|
||||||
#define KEY_PROJ_TYPE "clip.projector_type"
|
#define KEY_PROJ_TYPE "clip.projector_type"
|
||||||
|
#define KEY_EMBD_SCALE "clip.embeddings_scale"
|
||||||
|
|
||||||
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
||||||
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
|
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
|
||||||
|
@ -548,6 +549,7 @@ struct clip_ctx {
|
||||||
|
|
||||||
float image_mean[3];
|
float image_mean[3];
|
||||||
float image_std[3];
|
float image_std[3];
|
||||||
|
float embeddings_scale = 1.0f;
|
||||||
bool use_gelu = false;
|
bool use_gelu = false;
|
||||||
int32_t ftype = 1;
|
int32_t ftype = 1;
|
||||||
|
|
||||||
|
@ -1021,6 +1023,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (ctx->embeddings_scale != 1.0f) {
|
||||||
|
embeddings = ggml_scale(ctx0, embeddings, ctx->embeddings_scale);
|
||||||
|
}
|
||||||
|
|
||||||
// build the graph
|
// build the graph
|
||||||
ggml_build_forward_expand(gf, embeddings);
|
ggml_build_forward_expand(gf, embeddings);
|
||||||
|
|
||||||
|
@ -1322,6 +1328,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
new_clip->image_std[i] = std_data[i];
|
new_clip->image_std[i] = std_data[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
new_clip->embeddings_scale = get_f32(ctx, KEY_EMBD_SCALE);
|
||||||
|
} catch (const std::exception& /*e*/) {
|
||||||
|
new_clip->embeddings_scale = 1.0f;
|
||||||
|
}
|
||||||
|
|
||||||
if (verbosity >= 2) {
|
if (verbosity >= 2) {
|
||||||
LOG_INF("\n%s: vision model hparams\n", __func__);
|
LOG_INF("\n%s: vision model hparams\n", __func__);
|
||||||
LOG_INF("image_size %d\n", hparams.image_size);
|
LOG_INF("image_size %d\n", hparams.image_size);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue