LLaVA image encoder is working. will combine with llama
This commit is contained in:
parent
0f0e7c6480
commit
7e9120f7b1
3 changed files with 105 additions and 63 deletions
|
@ -1,5 +1,6 @@
|
||||||
#include "clip.h"
|
#include "clip.h"
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
const char * model_path = argv[1];
|
const char * model_path = argv[1];
|
||||||
|
@ -8,14 +9,20 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
auto ctx_clip = clip_model_load(model_path, 1);
|
auto ctx_clip = clip_model_load(model_path, 1);
|
||||||
clip_image_u8 img;
|
clip_image_u8 img;
|
||||||
//clip_tokens tokens;
|
clip_image_f32 img_res;
|
||||||
//clip_tokenize(ctx_clip, text, &tokens);
|
|
||||||
//float vec[512];
|
|
||||||
//clip_text_encode(ctx_clip, 4, &tokens, vec, false);
|
|
||||||
clip_image_load_from_file(img_path, &img);
|
clip_image_load_from_file(img_path, &img);
|
||||||
|
clip_image_preprocess(ctx_clip, &img, &img_res);
|
||||||
|
float * vec = (float *)malloc(4096 * 257 * sizeof(float));
|
||||||
|
clip_image_encode(ctx_clip, 4, &img_res, vec, false);
|
||||||
|
|
||||||
|
/*
|
||||||
float score;
|
float score;
|
||||||
clip_compare_text_and_image(ctx_clip, 4, text, &img, &score);
|
clip_compare_text_and_image(ctx_clip, 4, text, &img, &score);
|
||||||
printf("score: %f\n", score);
|
printf("score: %f\n", score);
|
||||||
|
*/
|
||||||
|
|
||||||
|
clip_free(ctx_clip);
|
||||||
|
free(vec);
|
||||||
|
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
|
|
|
@ -43,6 +43,7 @@ static std::string format(const char * fmt, ...) {
|
||||||
#define KEY_DESCRIPTION "general.description"
|
#define KEY_DESCRIPTION "general.description"
|
||||||
#define KEY_HAS_TEXT_ENC "clip.has_text_encoder"
|
#define KEY_HAS_TEXT_ENC "clip.has_text_encoder"
|
||||||
#define KEY_HAS_VIS_ENC "clip.has_vision_encoder"
|
#define KEY_HAS_VIS_ENC "clip.has_vision_encoder"
|
||||||
|
#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector"
|
||||||
#define KEY_USE_GELU "clip.use_gelu"
|
#define KEY_USE_GELU "clip.use_gelu"
|
||||||
#define KEY_N_EMBD "clip.%s.embedding_length"
|
#define KEY_N_EMBD "clip.%s.embedding_length"
|
||||||
#define KEY_N_FF "clip.%s.feed_forward_length"
|
#define KEY_N_FF "clip.%s.feed_forward_length"
|
||||||
|
@ -77,6 +78,7 @@ static std::string format(const char * fmt, ...) {
|
||||||
#define TN_LN_POST "%s.post_ln.%s"
|
#define TN_LN_POST "%s.post_ln.%s"
|
||||||
#define TN_TEXT_PROJ "text_projection.weight"
|
#define TN_TEXT_PROJ "text_projection.weight"
|
||||||
#define TN_VIS_PROJ "visual_projection.weight"
|
#define TN_VIS_PROJ "visual_projection.weight"
|
||||||
|
#define TN_LLAVA_PROJ "llava_projector.%s"
|
||||||
|
|
||||||
//
|
//
|
||||||
// utilities to get data from a gguf file
|
// utilities to get data from a gguf file
|
||||||
|
@ -221,6 +223,10 @@ struct clip_vision_model {
|
||||||
struct ggml_tensor * post_ln_b;
|
struct ggml_tensor * post_ln_b;
|
||||||
|
|
||||||
struct ggml_tensor * projection;
|
struct ggml_tensor * projection;
|
||||||
|
|
||||||
|
// LLaVA projection
|
||||||
|
struct ggml_tensor * llava_proj_w;
|
||||||
|
struct ggml_tensor * llava_proj_b;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Replacement for std::vector<uint8_t> that doesn't require zero-initialization.
|
// Replacement for std::vector<uint8_t> that doesn't require zero-initialization.
|
||||||
|
@ -240,6 +246,7 @@ struct clip_buffer {
|
||||||
struct clip_ctx {
|
struct clip_ctx {
|
||||||
bool has_text_encoder = false;
|
bool has_text_encoder = false;
|
||||||
bool has_vision_encoder = false;
|
bool has_vision_encoder = false;
|
||||||
|
bool has_llava_projector = false;
|
||||||
struct clip_text_model text_model;
|
struct clip_text_model text_model;
|
||||||
struct clip_vision_model vision_model;
|
struct clip_vision_model vision_model;
|
||||||
struct clip_vocab vocab;
|
struct clip_vocab vocab;
|
||||||
|
@ -270,16 +277,17 @@ size_t get_mem_req_by_size(struct clip_ctx * ctx) {
|
||||||
if (vision_hparams->patch_size == 32) { // patch size = 32
|
if (vision_hparams->patch_size == 32) { // patch size = 32
|
||||||
return 96 * mb;
|
return 96 * mb;
|
||||||
} else { // patch size = 16
|
} else { // patch size = 16
|
||||||
return 256 * mb;
|
return 128 * mb;
|
||||||
}
|
}
|
||||||
case 197: // base or large, text-only
|
case 197: // base or large, text-only
|
||||||
return 16 * mb;
|
return 96 * mb;
|
||||||
case 589: // large, two-tower
|
case 589: // large, two-tower
|
||||||
case 392: // large, vision-only
|
case 392: // large, vision-only
|
||||||
if (n_positions == 257) { // input image size = 224
|
case 375: // large, LLaVA encoder
|
||||||
return 60 * mb;
|
if (vision_hparams->image_size == 224) { // input image size = 224
|
||||||
|
return 1200 * mb;
|
||||||
} else { // input image size = 336
|
} else { // input image size = 336
|
||||||
return 96 * mb;
|
return 1800 * mb;
|
||||||
}
|
}
|
||||||
case 909: // huge, two-tower
|
case 909: // huge, two-tower
|
||||||
case 520: // huge, vision-only
|
case 520: // huge, vision-only
|
||||||
|
@ -313,6 +321,7 @@ size_t get_scr_buf_req_by_size(struct clip_ctx * ctx) {
|
||||||
return 32 * mb;
|
return 32 * mb;
|
||||||
case 589:
|
case 589:
|
||||||
case 392:
|
case 392:
|
||||||
|
case 377:
|
||||||
if (n_positions <= 257) {
|
if (n_positions <= 257) {
|
||||||
return 96 * mb;
|
return 96 * mb;
|
||||||
} else {
|
} else {
|
||||||
|
@ -406,12 +415,18 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
idx = get_key_idx(ctx, KEY_HAS_VIS_ENC);
|
idx = get_key_idx(ctx, KEY_HAS_VIS_ENC);
|
||||||
new_clip->has_vision_encoder = gguf_get_val_bool(ctx, idx);
|
new_clip->has_vision_encoder = gguf_get_val_bool(ctx, idx);
|
||||||
|
|
||||||
|
idx = gguf_find_key(ctx, KEY_HAS_LLAVA_PROJ);
|
||||||
|
if (idx != -1) {
|
||||||
|
new_clip->has_llava_projector = gguf_get_val_bool(ctx, idx);
|
||||||
|
}
|
||||||
|
|
||||||
idx = get_key_idx(ctx, KEY_USE_GELU);
|
idx = get_key_idx(ctx, KEY_USE_GELU);
|
||||||
new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
|
new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
|
||||||
|
|
||||||
if (verbosity >= 1) {
|
if (verbosity >= 1) {
|
||||||
printf("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder);
|
printf("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder);
|
||||||
printf("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
|
printf("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
|
||||||
|
printf("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector);
|
||||||
printf("%s: model size: %.2f MB\n", __func__, (ctx_size / 1024.0 / 1024.0));
|
printf("%s: model size: %.2f MB\n", __func__, (ctx_size / 1024.0 / 1024.0));
|
||||||
printf("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0);
|
printf("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0);
|
||||||
}
|
}
|
||||||
|
@ -556,10 +571,14 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
vision_model.class_embedding = get_tensor(new_clip->ctx, TN_CLASS_EMBD);
|
vision_model.class_embedding = get_tensor(new_clip->ctx, TN_CLASS_EMBD);
|
||||||
vision_model.position_embeddings = get_tensor(new_clip->ctx, format(TN_POS_EMBD, "v"));
|
vision_model.position_embeddings = get_tensor(new_clip->ctx, format(TN_POS_EMBD, "v"));
|
||||||
vision_model.pre_ln_w = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "weight"));
|
vision_model.pre_ln_w = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "weight"));
|
||||||
vision_model.pre_ln_b = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "bias"));
|
vision_model.pre_ln_b = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "bias"));if (new_clip->has_llava_projector) {
|
||||||
|
vision_model.llava_proj_w = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, "weight"));
|
||||||
|
vision_model.llava_proj_b = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, "bias"));
|
||||||
|
} else {
|
||||||
vision_model.post_ln_w = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "weight"));
|
vision_model.post_ln_w = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "weight"));
|
||||||
vision_model.post_ln_b = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "bias"));
|
vision_model.post_ln_b = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "bias"));
|
||||||
vision_model.projection = get_tensor(new_clip->ctx, TN_VIS_PROJ);
|
vision_model.projection = get_tensor(new_clip->ctx, TN_VIS_PROJ);
|
||||||
|
}
|
||||||
vision_model.layers.resize(hparams.n_layer);
|
vision_model.layers.resize(hparams.n_layer);
|
||||||
for (int il = 0; il < hparams.n_layer; ++il) {
|
for (int il = 0; il < hparams.n_layer; ++il) {
|
||||||
auto & layer = vision_model.layers[il];
|
auto & layer = vision_model.layers[il];
|
||||||
|
@ -1004,8 +1023,9 @@ bool clip_text_encode(const clip_ctx * ctx, const int n_threads, const clip_toke
|
||||||
cplan.work_data = (uint8_t *)malloc(cplan.work_size);
|
cplan.work_data = (uint8_t *)malloc(cplan.work_size);
|
||||||
}
|
}
|
||||||
ggml_graph_compute(&gf, &cplan);
|
ggml_graph_compute(&gf, &cplan);
|
||||||
*/
|
*/
|
||||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
|
||||||
|
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||||
|
|
||||||
// print
|
// print
|
||||||
#ifdef CLIP_DEBUG
|
#ifdef CLIP_DEBUG
|
||||||
|
@ -1053,11 +1073,12 @@ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||||
printf("used_mem = %zu\n", ggml_used_mem(ctx0));
|
printf("used_mem = %zu\n", ggml_used_mem(ctx0));
|
||||||
#endif
|
#endif
|
||||||
memcpy(vec, ggml_get_data_f32(embeddings), sizeof(float) * projection_dim);
|
memcpy(vec, ggml_get_data_f32(embeddings), sizeof(float) * projection_dim);
|
||||||
/*
|
|
||||||
|
/*
|
||||||
if (cplan.work_size != 0) {
|
if (cplan.work_size != 0) {
|
||||||
free(cplan.work_data);
|
free(cplan.work_data);
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
ggml_free(ctx0);
|
ggml_free(ctx0);
|
||||||
|
|
||||||
|
@ -1254,6 +1275,15 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl
|
||||||
embeddings = cur;
|
embeddings = cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//ggml_set_scratch(ctx0, {0, 0, nullptr});
|
||||||
|
|
||||||
|
struct ggml_tensor * output = NULL;
|
||||||
|
if (ctx->has_llava_projector) {
|
||||||
|
output = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
|
||||||
|
embeddings = ggml_mul_mat(ctx0, model.llava_proj_w, embeddings);
|
||||||
|
output = ggml_add(ctx0, ggml_repeat(ctx0, model.llava_proj_b, embeddings), embeddings);
|
||||||
|
} else {
|
||||||
// get the output of cls token, e.g., 0th index
|
// get the output of cls token, e.g., 0th index
|
||||||
struct ggml_tensor * cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, batch_size);
|
struct ggml_tensor * cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, batch_size);
|
||||||
for (int b = 0; b < batch_size; b++) {
|
for (int b = 0; b < batch_size; b++) {
|
||||||
|
@ -1269,13 +1299,11 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl
|
||||||
ggml_repeat(ctx0, model.post_ln_b, embeddings));
|
ggml_repeat(ctx0, model.post_ln_b, embeddings));
|
||||||
}
|
}
|
||||||
|
|
||||||
//ggml_set_scratch(ctx0, {0, 0, nullptr});
|
|
||||||
|
|
||||||
// final visual projection
|
// final visual projection
|
||||||
embeddings = ggml_mul_mat(ctx0, model.projection, embeddings);
|
embeddings = ggml_mul_mat(ctx0, model.projection, embeddings);
|
||||||
|
|
||||||
// normalize output embeddings
|
// normalize output embeddings
|
||||||
struct ggml_tensor * output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, projection_dim, batch_size);
|
output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, projection_dim, batch_size);
|
||||||
|
|
||||||
for (int b = 0; b < batch_size; b++) {
|
for (int b = 0; b < batch_size; b++) {
|
||||||
struct ggml_tensor * embedding = ggml_get_rows(ctx0, embeddings, ggml_new_i32(ctx0, b));
|
struct ggml_tensor * embedding = ggml_get_rows(ctx0, embeddings, ggml_new_i32(ctx0, b));
|
||||||
|
@ -1285,10 +1313,12 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl
|
||||||
}
|
}
|
||||||
output = ggml_acc(ctx0, output, embedding, output->nb[1], output->nb[2], output->nb[3], b * ggml_nbytes(embedding));
|
output = ggml_acc(ctx0, output, embedding, output->nb[1], output->nb[2], output->nb[3], b * ggml_nbytes(embedding));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
ggml_set_name(output, "check");
|
ggml_set_name(output, "check");
|
||||||
|
|
||||||
// run the computation
|
// run the computation
|
||||||
ggml_build_forward_expand(&gf, output);
|
ggml_build_forward_expand(&gf, output);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
ggml_cplan cplan = ggml_graph_plan(&gf, n_threads);
|
ggml_cplan cplan = ggml_graph_plan(&gf, n_threads);
|
||||||
cplan.work_size *= batch_size;
|
cplan.work_size *= batch_size;
|
||||||
|
@ -1296,8 +1326,9 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl
|
||||||
cplan.work_data = (uint8_t *)malloc(cplan.work_size);
|
cplan.work_data = (uint8_t *)malloc(cplan.work_size);
|
||||||
}
|
}
|
||||||
ggml_graph_compute(&gf, &cplan);
|
ggml_graph_compute(&gf, &cplan);
|
||||||
*/
|
*/
|
||||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
|
||||||
|
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||||
|
|
||||||
// print
|
// print
|
||||||
#ifdef CLIP_DEBUG
|
#ifdef CLIP_DEBUG
|
||||||
|
@ -1347,11 +1378,12 @@ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
memcpy(vec, ggml_get_data_f32(output), sizeof(float) * projection_dim * batch_size);
|
memcpy(vec, ggml_get_data_f32(output), sizeof(float) * projection_dim * batch_size);
|
||||||
/*
|
|
||||||
|
/*
|
||||||
if (cplan.work_size != 0) {
|
if (cplan.work_size != 0) {
|
||||||
free(cplan.work_data);
|
free(cplan.work_data);
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
ggml_free(ctx0);
|
ggml_free(ctx0);
|
||||||
|
|
||||||
|
|
|
@ -10,9 +10,11 @@ from transformers import CLIPModel, CLIPProcessor
|
||||||
TEXT = "clip.text"
|
TEXT = "clip.text"
|
||||||
VISION = "clip.vision"
|
VISION = "clip.vision"
|
||||||
|
|
||||||
|
|
||||||
def k(raw_key: str, arch: str) -> str:
|
def k(raw_key: str, arch: str) -> str:
|
||||||
return raw_key.format(arch=arch)
|
return raw_key.format(arch=arch)
|
||||||
|
|
||||||
|
|
||||||
def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: bool) -> bool:
|
def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: bool) -> bool:
|
||||||
if name in (
|
if name in (
|
||||||
"logit_scale",
|
"logit_scale",
|
||||||
|
@ -21,7 +23,7 @@ def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: b
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if name == "visual_projection.weight" and has_llava:
|
if has_llava and name in ["visual_projection.weight", "vision_model.post_layernorm.weight", "vision_model.post_layernorm.bias"]:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if name.startswith("v") and not has_vision:
|
if name.startswith("v") and not has_vision:
|
||||||
|
@ -32,6 +34,7 @@ def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: b
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_name(name: str) -> str:
|
def get_tensor_name(name: str) -> str:
|
||||||
if "projection" in name:
|
if "projection" in name:
|
||||||
return name
|
return name
|
||||||
|
@ -64,11 +67,14 @@ def bytes_to_unicode():
|
||||||
cs = [chr(n) for n in cs]
|
cs = [chr(n) for n in cs]
|
||||||
return dict(zip(bs, cs))
|
return dict(zip(bs, cs))
|
||||||
|
|
||||||
|
|
||||||
ap = argparse.ArgumentParser(prog="convert_hf_to_gguf.py")
|
ap = argparse.ArgumentParser(prog="convert_hf_to_gguf.py")
|
||||||
ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True)
|
ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True)
|
||||||
ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16")
|
ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16")
|
||||||
ap.add_argument("--text-only", action="store_true", required=False, help="Save a text-only model. It can't be used to encode images")
|
ap.add_argument("--text-only", action="store_true", required=False,
|
||||||
ap.add_argument("--vision-only", action="store_true", required=False, help="Save a vision-only model. It can't be used to encode texts")
|
help="Save a text-only model. It can't be used to encode images")
|
||||||
|
ap.add_argument("--vision-only", action="store_true", required=False,
|
||||||
|
help="Save a vision-only model. It can't be used to encode texts")
|
||||||
ap.add_argument("--llava-projector", help="Path to projector.pt file. If specified, save an image encoder for LLaVA models.")
|
ap.add_argument("--llava-projector", help="Path to projector.pt file. If specified, save an image encoder for LLaVA models.")
|
||||||
ap.add_argument("--image-mean", nargs=3, type=float, required=False, help="Override image mean values")
|
ap.add_argument("--image-mean", nargs=3, type=float, required=False, help="Override image mean values")
|
||||||
ap.add_argument("--image-std", nargs=3, type=float, required=False, help="Override image std values")
|
ap.add_argument("--image-std", nargs=3, type=float, required=False, help="Override image std values")
|
||||||
|
@ -182,8 +188,6 @@ use_gelu = v_hparams["hidden_act"] == "gelu"
|
||||||
fout.add_bool("clip.use_gelu", use_gelu)
|
fout.add_bool("clip.use_gelu", use_gelu)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if has_llava_projector:
|
if has_llava_projector:
|
||||||
model.vision_model.encoder.layers.pop(-1)
|
model.vision_model.encoder.layers.pop(-1)
|
||||||
projector = torch.load(args.llava_projector)
|
projector = torch.load(args.llava_projector)
|
||||||
|
@ -231,7 +235,6 @@ for name, data in list_vars.items():
|
||||||
fout.add_tensor(name, data)
|
fout.add_tensor(name, data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
fout.write_header_to_file()
|
fout.write_header_to_file()
|
||||||
fout.write_kv_data_to_file()
|
fout.write_kv_data_to_file()
|
||||||
fout.write_tensors_to_file()
|
fout.write_tensors_to_file()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue