LLaVA image encoder is working. will combine with llama
This commit is contained in:
parent
0f0e7c6480
commit
7e9120f7b1
3 changed files with 105 additions and 63 deletions
|
@ -1,5 +1,6 @@
|
|||
#include "clip.h"
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
const char * model_path = argv[1];
|
||||
|
@ -8,14 +9,20 @@ int main(int argc, char ** argv) {
|
|||
|
||||
auto ctx_clip = clip_model_load(model_path, 1);
|
||||
clip_image_u8 img;
|
||||
//clip_tokens tokens;
|
||||
//clip_tokenize(ctx_clip, text, &tokens);
|
||||
//float vec[512];
|
||||
//clip_text_encode(ctx_clip, 4, &tokens, vec, false);
|
||||
clip_image_f32 img_res;
|
||||
clip_image_load_from_file(img_path, &img);
|
||||
clip_image_preprocess(ctx_clip, &img, &img_res);
|
||||
float * vec = (float *)malloc(4096 * 257 * sizeof(float));
|
||||
clip_image_encode(ctx_clip, 4, &img_res, vec, false);
|
||||
|
||||
/*
|
||||
float score;
|
||||
clip_compare_text_and_image(ctx_clip, 4, text, &img, &score);
|
||||
printf("score: %f\n", score);
|
||||
*/
|
||||
|
||||
clip_free(ctx_clip);
|
||||
free(vec);
|
||||
|
||||
|
||||
return 0;
|
||||
|
|
|
@ -43,6 +43,7 @@ static std::string format(const char * fmt, ...) {
|
|||
#define KEY_DESCRIPTION "general.description"
|
||||
#define KEY_HAS_TEXT_ENC "clip.has_text_encoder"
|
||||
#define KEY_HAS_VIS_ENC "clip.has_vision_encoder"
|
||||
#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector"
|
||||
#define KEY_USE_GELU "clip.use_gelu"
|
||||
#define KEY_N_EMBD "clip.%s.embedding_length"
|
||||
#define KEY_N_FF "clip.%s.feed_forward_length"
|
||||
|
@ -77,6 +78,7 @@ static std::string format(const char * fmt, ...) {
|
|||
#define TN_LN_POST "%s.post_ln.%s"
|
||||
#define TN_TEXT_PROJ "text_projection.weight"
|
||||
#define TN_VIS_PROJ "visual_projection.weight"
|
||||
#define TN_LLAVA_PROJ "llava_projector.%s"
|
||||
|
||||
//
|
||||
// utilities to get data from a gguf file
|
||||
|
@ -221,6 +223,10 @@ struct clip_vision_model {
|
|||
struct ggml_tensor * post_ln_b;
|
||||
|
||||
struct ggml_tensor * projection;
|
||||
|
||||
// LLaVA projection
|
||||
struct ggml_tensor * llava_proj_w;
|
||||
struct ggml_tensor * llava_proj_b;
|
||||
};
|
||||
|
||||
// Replacement for std::vector<uint8_t> that doesn't require zero-initialization.
|
||||
|
@ -240,6 +246,7 @@ struct clip_buffer {
|
|||
struct clip_ctx {
|
||||
bool has_text_encoder = false;
|
||||
bool has_vision_encoder = false;
|
||||
bool has_llava_projector = false;
|
||||
struct clip_text_model text_model;
|
||||
struct clip_vision_model vision_model;
|
||||
struct clip_vocab vocab;
|
||||
|
@ -270,16 +277,17 @@ size_t get_mem_req_by_size(struct clip_ctx * ctx) {
|
|||
if (vision_hparams->patch_size == 32) { // patch size = 32
|
||||
return 96 * mb;
|
||||
} else { // patch size = 16
|
||||
return 256 * mb;
|
||||
return 128 * mb;
|
||||
}
|
||||
case 197: // base or large, text-only
|
||||
return 16 * mb;
|
||||
return 96 * mb;
|
||||
case 589: // large, two-tower
|
||||
case 392: // large, vision-only
|
||||
if (n_positions == 257) { // input image size = 224
|
||||
return 60 * mb;
|
||||
case 375: // large, LLaVA encoder
|
||||
if (vision_hparams->image_size == 224) { // input image size = 224
|
||||
return 1200 * mb;
|
||||
} else { // input image size = 336
|
||||
return 96 * mb;
|
||||
return 1800 * mb;
|
||||
}
|
||||
case 909: // huge, two-tower
|
||||
case 520: // huge, vision-only
|
||||
|
@ -313,6 +321,7 @@ size_t get_scr_buf_req_by_size(struct clip_ctx * ctx) {
|
|||
return 32 * mb;
|
||||
case 589:
|
||||
case 392:
|
||||
case 377:
|
||||
if (n_positions <= 257) {
|
||||
return 96 * mb;
|
||||
} else {
|
||||
|
@ -406,12 +415,18 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
idx = get_key_idx(ctx, KEY_HAS_VIS_ENC);
|
||||
new_clip->has_vision_encoder = gguf_get_val_bool(ctx, idx);
|
||||
|
||||
idx = gguf_find_key(ctx, KEY_HAS_LLAVA_PROJ);
|
||||
if (idx != -1) {
|
||||
new_clip->has_llava_projector = gguf_get_val_bool(ctx, idx);
|
||||
}
|
||||
|
||||
idx = get_key_idx(ctx, KEY_USE_GELU);
|
||||
new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
|
||||
|
||||
if (verbosity >= 1) {
|
||||
printf("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder);
|
||||
printf("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
|
||||
printf("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector);
|
||||
printf("%s: model size: %.2f MB\n", __func__, (ctx_size / 1024.0 / 1024.0));
|
||||
printf("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0);
|
||||
}
|
||||
|
@ -556,10 +571,14 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
vision_model.class_embedding = get_tensor(new_clip->ctx, TN_CLASS_EMBD);
|
||||
vision_model.position_embeddings = get_tensor(new_clip->ctx, format(TN_POS_EMBD, "v"));
|
||||
vision_model.pre_ln_w = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "weight"));
|
||||
vision_model.pre_ln_b = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "bias"));
|
||||
vision_model.post_ln_w = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "weight"));
|
||||
vision_model.post_ln_b = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "bias"));
|
||||
vision_model.projection = get_tensor(new_clip->ctx, TN_VIS_PROJ);
|
||||
vision_model.pre_ln_b = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "bias"));if (new_clip->has_llava_projector) {
|
||||
vision_model.llava_proj_w = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, "weight"));
|
||||
vision_model.llava_proj_b = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, "bias"));
|
||||
} else {
|
||||
vision_model.post_ln_w = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "weight"));
|
||||
vision_model.post_ln_b = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "bias"));
|
||||
vision_model.projection = get_tensor(new_clip->ctx, TN_VIS_PROJ);
|
||||
}
|
||||
vision_model.layers.resize(hparams.n_layer);
|
||||
for (int il = 0; il < hparams.n_layer; ++il) {
|
||||
auto & layer = vision_model.layers[il];
|
||||
|
@ -1004,8 +1023,9 @@ bool clip_text_encode(const clip_ctx * ctx, const int n_threads, const clip_toke
|
|||
cplan.work_data = (uint8_t *)malloc(cplan.work_size);
|
||||
}
|
||||
ggml_graph_compute(&gf, &cplan);
|
||||
*/
|
||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||
*/
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||
|
||||
// print
|
||||
#ifdef CLIP_DEBUG
|
||||
|
@ -1053,11 +1073,12 @@ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
|||
printf("used_mem = %zu\n", ggml_used_mem(ctx0));
|
||||
#endif
|
||||
memcpy(vec, ggml_get_data_f32(embeddings), sizeof(float) * projection_dim);
|
||||
/*
|
||||
|
||||
/*
|
||||
if (cplan.work_size != 0) {
|
||||
free(cplan.work_data);
|
||||
}
|
||||
*/
|
||||
*/
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
|
@ -1254,41 +1275,50 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl
|
|||
embeddings = cur;
|
||||
}
|
||||
|
||||
// get the output of cls token, e.g., 0th index
|
||||
struct ggml_tensor * cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, batch_size);
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
ggml_set_i32_1d(cls, b, b * num_positions);
|
||||
}
|
||||
embeddings = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, embeddings, hidden_size, num_positions * batch_size), cls);
|
||||
|
||||
// post-layernorm
|
||||
{
|
||||
embeddings = ggml_norm(ctx0, embeddings, eps);
|
||||
|
||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.post_ln_w, embeddings), embeddings),
|
||||
ggml_repeat(ctx0, model.post_ln_b, embeddings));
|
||||
}
|
||||
|
||||
//ggml_set_scratch(ctx0, {0, 0, nullptr});
|
||||
|
||||
// final visual projection
|
||||
embeddings = ggml_mul_mat(ctx0, model.projection, embeddings);
|
||||
|
||||
// normalize output embeddings
|
||||
struct ggml_tensor * output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, projection_dim, batch_size);
|
||||
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
struct ggml_tensor * embedding = ggml_get_rows(ctx0, embeddings, ggml_new_i32(ctx0, b));
|
||||
if (normalize) {
|
||||
ggml_tensor * length = ggml_sqrt(ctx0, ggml_sum(ctx0, ggml_sqr(ctx0, embedding)));
|
||||
embedding = ggml_scale_inplace(ctx0, embedding, ggml_div(ctx0, ggml_new_f32(ctx0, 1.0f), length));
|
||||
struct ggml_tensor * output = NULL;
|
||||
if (ctx->has_llava_projector) {
|
||||
output = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
|
||||
embeddings = ggml_mul_mat(ctx0, model.llava_proj_w, embeddings);
|
||||
output = ggml_add(ctx0, ggml_repeat(ctx0, model.llava_proj_b, embeddings), embeddings);
|
||||
} else {
|
||||
// get the output of cls token, e.g., 0th index
|
||||
struct ggml_tensor * cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, batch_size);
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
ggml_set_i32_1d(cls, b, b * num_positions);
|
||||
}
|
||||
embeddings = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, embeddings, hidden_size, num_positions * batch_size), cls);
|
||||
|
||||
// post-layernorm
|
||||
{
|
||||
embeddings = ggml_norm(ctx0, embeddings, eps);
|
||||
|
||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.post_ln_w, embeddings), embeddings),
|
||||
ggml_repeat(ctx0, model.post_ln_b, embeddings));
|
||||
}
|
||||
|
||||
// final visual projection
|
||||
embeddings = ggml_mul_mat(ctx0, model.projection, embeddings);
|
||||
|
||||
// normalize output embeddings
|
||||
output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, projection_dim, batch_size);
|
||||
|
||||
for (int b = 0; b < batch_size; b++) {
|
||||
struct ggml_tensor * embedding = ggml_get_rows(ctx0, embeddings, ggml_new_i32(ctx0, b));
|
||||
if (normalize) {
|
||||
ggml_tensor * length = ggml_sqrt(ctx0, ggml_sum(ctx0, ggml_sqr(ctx0, embedding)));
|
||||
embedding = ggml_scale_inplace(ctx0, embedding, ggml_div(ctx0, ggml_new_f32(ctx0, 1.0f), length));
|
||||
}
|
||||
output = ggml_acc(ctx0, output, embedding, output->nb[1], output->nb[2], output->nb[3], b * ggml_nbytes(embedding));
|
||||
}
|
||||
output = ggml_acc(ctx0, output, embedding, output->nb[1], output->nb[2], output->nb[3], b * ggml_nbytes(embedding));
|
||||
}
|
||||
ggml_set_name(output, "check");
|
||||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(&gf, output);
|
||||
|
||||
/*
|
||||
ggml_cplan cplan = ggml_graph_plan(&gf, n_threads);
|
||||
cplan.work_size *= batch_size;
|
||||
|
@ -1296,8 +1326,9 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl
|
|||
cplan.work_data = (uint8_t *)malloc(cplan.work_size);
|
||||
}
|
||||
ggml_graph_compute(&gf, &cplan);
|
||||
*/
|
||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||
*/
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||
|
||||
// print
|
||||
#ifdef CLIP_DEBUG
|
||||
|
@ -1347,11 +1378,12 @@ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
|||
#endif
|
||||
|
||||
memcpy(vec, ggml_get_data_f32(output), sizeof(float) * projection_dim * batch_size);
|
||||
/*
|
||||
|
||||
/*
|
||||
if (cplan.work_size != 0) {
|
||||
free(cplan.work_data);
|
||||
}
|
||||
*/
|
||||
*/
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
|
|
|
@ -10,9 +10,11 @@ from transformers import CLIPModel, CLIPProcessor
|
|||
TEXT = "clip.text"
|
||||
VISION = "clip.vision"
|
||||
|
||||
|
||||
def k(raw_key: str, arch: str) -> str:
|
||||
return raw_key.format(arch=arch)
|
||||
|
||||
|
||||
def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: bool) -> bool:
|
||||
if name in (
|
||||
"logit_scale",
|
||||
|
@ -20,22 +22,23 @@ def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: b
|
|||
"vision_model.embeddings.position_ids",
|
||||
):
|
||||
return True
|
||||
|
||||
if name == "visual_projection.weight" and has_llava:
|
||||
|
||||
if has_llava and name in ["visual_projection.weight", "vision_model.post_layernorm.weight", "vision_model.post_layernorm.bias"]:
|
||||
return True
|
||||
|
||||
|
||||
if name.startswith("v") and not has_vision:
|
||||
return True
|
||||
|
||||
|
||||
if name.startswith("t") and not has_text:
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_tensor_name(name: str) -> str:
|
||||
if "projection" in name:
|
||||
return name
|
||||
|
||||
|
||||
return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln")
|
||||
|
||||
|
||||
|
@ -64,11 +67,14 @@ def bytes_to_unicode():
|
|||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
ap = argparse.ArgumentParser(prog="convert_hf_to_gguf.py")
|
||||
ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True)
|
||||
ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16")
|
||||
ap.add_argument("--text-only", action="store_true", required=False, help="Save a text-only model. It can't be used to encode images")
|
||||
ap.add_argument("--vision-only", action="store_true", required=False, help="Save a vision-only model. It can't be used to encode texts")
|
||||
ap.add_argument("--text-only", action="store_true", required=False,
|
||||
help="Save a text-only model. It can't be used to encode images")
|
||||
ap.add_argument("--vision-only", action="store_true", required=False,
|
||||
help="Save a vision-only model. It can't be used to encode texts")
|
||||
ap.add_argument("--llava-projector", help="Path to projector.pt file. If specified, save an image encoder for LLaVA models.")
|
||||
ap.add_argument("--image-mean", nargs=3, type=float, required=False, help="Override image mean values")
|
||||
ap.add_argument("--image-std", nargs=3, type=float, required=False, help="Override image std values")
|
||||
|
@ -76,7 +82,7 @@ ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Defaul
|
|||
|
||||
args = ap.parse_args()
|
||||
|
||||
|
||||
|
||||
if args.text_only and args.vision_only:
|
||||
print("--text-only and --image-only arguments cannot be specified at the same time.")
|
||||
exit(1)
|
||||
|
@ -91,7 +97,7 @@ dir_model = args.model_dir
|
|||
with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f:
|
||||
vocab = json.load(f)
|
||||
tokens = [key for key in vocab]
|
||||
|
||||
|
||||
with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
v_hparams = config["vision_config"]
|
||||
|
@ -108,7 +114,7 @@ ftype = 1
|
|||
if args.use_f32:
|
||||
ftype = 0
|
||||
|
||||
|
||||
|
||||
model = CLIPModel.from_pretrained(dir_model)
|
||||
processor = CLIPProcessor.from_pretrained(dir_model)
|
||||
|
||||
|
@ -182,8 +188,6 @@ use_gelu = v_hparams["hidden_act"] == "gelu"
|
|||
fout.add_bool("clip.use_gelu", use_gelu)
|
||||
|
||||
|
||||
|
||||
|
||||
if has_llava_projector:
|
||||
model.vision_model.encoder.layers.pop(-1)
|
||||
projector = torch.load(args.llava_projector)
|
||||
|
@ -203,7 +207,7 @@ for name, data in list_vars.items():
|
|||
|
||||
name = get_tensor_name(name)
|
||||
data = data.squeeze().numpy()
|
||||
|
||||
|
||||
n_dims = len(data.shape)
|
||||
|
||||
# ftype == 0 -> float32, ftype == 1 -> float16
|
||||
|
@ -229,8 +233,7 @@ for name, data in list_vars.items():
|
|||
|
||||
print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}")
|
||||
fout.add_tensor(name, data)
|
||||
|
||||
|
||||
|
||||
|
||||
fout.write_header_to_file()
|
||||
fout.write_kv_data_to_file()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue