diff --git a/examples/llava/clip-test.cpp b/examples/llava/clip-test.cpp index fa224374c..61785a635 100644 --- a/examples/llava/clip-test.cpp +++ b/examples/llava/clip-test.cpp @@ -1,5 +1,6 @@ #include "clip.h" #include +#include int main(int argc, char ** argv) { const char * model_path = argv[1]; @@ -8,14 +9,20 @@ int main(int argc, char ** argv) { auto ctx_clip = clip_model_load(model_path, 1); clip_image_u8 img; - //clip_tokens tokens; - //clip_tokenize(ctx_clip, text, &tokens); - //float vec[512]; - //clip_text_encode(ctx_clip, 4, &tokens, vec, false); + clip_image_f32 img_res; clip_image_load_from_file(img_path, &img); + clip_image_preprocess(ctx_clip, &img, &img_res); + float * vec = (float *)malloc(4096 * 257 * sizeof(float)); + clip_image_encode(ctx_clip, 4, &img_res, vec, false); + + /* float score; clip_compare_text_and_image(ctx_clip, 4, text, &img, &score); printf("score: %f\n", score); + */ + + clip_free(ctx_clip); + free(vec); return 0; diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 25c109d27..6dd20a829 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -43,6 +43,7 @@ static std::string format(const char * fmt, ...) { #define KEY_DESCRIPTION "general.description" #define KEY_HAS_TEXT_ENC "clip.has_text_encoder" #define KEY_HAS_VIS_ENC "clip.has_vision_encoder" +#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector" #define KEY_USE_GELU "clip.use_gelu" #define KEY_N_EMBD "clip.%s.embedding_length" #define KEY_N_FF "clip.%s.feed_forward_length" @@ -77,6 +78,7 @@ static std::string format(const char * fmt, ...) { #define TN_LN_POST "%s.post_ln.%s" #define TN_TEXT_PROJ "text_projection.weight" #define TN_VIS_PROJ "visual_projection.weight" +#define TN_LLAVA_PROJ "llava_projector.%s" // // utilities to get data from a gguf file @@ -221,6 +223,10 @@ struct clip_vision_model { struct ggml_tensor * post_ln_b; struct ggml_tensor * projection; + + // LLaVA projection + struct ggml_tensor * llava_proj_w; + struct ggml_tensor * llava_proj_b; }; // Replacement for std::vector that doesn't require zero-initialization. @@ -240,6 +246,7 @@ struct clip_buffer { struct clip_ctx { bool has_text_encoder = false; bool has_vision_encoder = false; + bool has_llava_projector = false; struct clip_text_model text_model; struct clip_vision_model vision_model; struct clip_vocab vocab; @@ -270,16 +277,17 @@ size_t get_mem_req_by_size(struct clip_ctx * ctx) { if (vision_hparams->patch_size == 32) { // patch size = 32 return 96 * mb; } else { // patch size = 16 - return 256 * mb; + return 128 * mb; } case 197: // base or large, text-only - return 16 * mb; + return 96 * mb; case 589: // large, two-tower case 392: // large, vision-only - if (n_positions == 257) { // input image size = 224 - return 60 * mb; + case 375: // large, LLaVA encoder + if (vision_hparams->image_size == 224) { // input image size = 224 + return 1200 * mb; } else { // input image size = 336 - return 96 * mb; + return 1800 * mb; } case 909: // huge, two-tower case 520: // huge, vision-only @@ -313,6 +321,7 @@ size_t get_scr_buf_req_by_size(struct clip_ctx * ctx) { return 32 * mb; case 589: case 392: + case 377: if (n_positions <= 257) { return 96 * mb; } else { @@ -406,12 +415,18 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { idx = get_key_idx(ctx, KEY_HAS_VIS_ENC); new_clip->has_vision_encoder = gguf_get_val_bool(ctx, idx); + idx = gguf_find_key(ctx, KEY_HAS_LLAVA_PROJ); + if (idx != -1) { + new_clip->has_llava_projector = gguf_get_val_bool(ctx, idx); + } + idx = get_key_idx(ctx, KEY_USE_GELU); new_clip->use_gelu = gguf_get_val_bool(ctx, idx); if (verbosity >= 1) { printf("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder); printf("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder); + printf("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector); printf("%s: model size: %.2f MB\n", __func__, (ctx_size / 1024.0 / 1024.0)); printf("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0); } @@ -556,10 +571,14 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { vision_model.class_embedding = get_tensor(new_clip->ctx, TN_CLASS_EMBD); vision_model.position_embeddings = get_tensor(new_clip->ctx, format(TN_POS_EMBD, "v")); vision_model.pre_ln_w = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "weight")); - vision_model.pre_ln_b = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "bias")); - vision_model.post_ln_w = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "weight")); - vision_model.post_ln_b = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "bias")); - vision_model.projection = get_tensor(new_clip->ctx, TN_VIS_PROJ); + vision_model.pre_ln_b = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "bias"));if (new_clip->has_llava_projector) { + vision_model.llava_proj_w = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, "weight")); + vision_model.llava_proj_b = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, "bias")); + } else { + vision_model.post_ln_w = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "weight")); + vision_model.post_ln_b = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "bias")); + vision_model.projection = get_tensor(new_clip->ctx, TN_VIS_PROJ); + } vision_model.layers.resize(hparams.n_layer); for (int il = 0; il < hparams.n_layer; ++il) { auto & layer = vision_model.layers[il]; @@ -1004,8 +1023,9 @@ bool clip_text_encode(const clip_ctx * ctx, const int n_threads, const clip_toke cplan.work_data = (uint8_t *)malloc(cplan.work_size); } ggml_graph_compute(&gf, &cplan); -*/ -ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); + */ + + ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); // print #ifdef CLIP_DEBUG @@ -1053,11 +1073,12 @@ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); printf("used_mem = %zu\n", ggml_used_mem(ctx0)); #endif memcpy(vec, ggml_get_data_f32(embeddings), sizeof(float) * projection_dim); -/* + + /* if (cplan.work_size != 0) { free(cplan.work_data); } -*/ + */ ggml_free(ctx0); @@ -1254,41 +1275,50 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl embeddings = cur; } - // get the output of cls token, e.g., 0th index - struct ggml_tensor * cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, batch_size); - for (int b = 0; b < batch_size; b++) { - ggml_set_i32_1d(cls, b, b * num_positions); - } - embeddings = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, embeddings, hidden_size, num_positions * batch_size), cls); - - // post-layernorm - { - embeddings = ggml_norm(ctx0, embeddings, eps); - - embeddings = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.post_ln_w, embeddings), embeddings), - ggml_repeat(ctx0, model.post_ln_b, embeddings)); - } //ggml_set_scratch(ctx0, {0, 0, nullptr}); - // final visual projection - embeddings = ggml_mul_mat(ctx0, model.projection, embeddings); - - // normalize output embeddings - struct ggml_tensor * output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, projection_dim, batch_size); - - for (int b = 0; b < batch_size; b++) { - struct ggml_tensor * embedding = ggml_get_rows(ctx0, embeddings, ggml_new_i32(ctx0, b)); - if (normalize) { - ggml_tensor * length = ggml_sqrt(ctx0, ggml_sum(ctx0, ggml_sqr(ctx0, embedding))); - embedding = ggml_scale_inplace(ctx0, embedding, ggml_div(ctx0, ggml_new_f32(ctx0, 1.0f), length)); + struct ggml_tensor * output = NULL; + if (ctx->has_llava_projector) { + output = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); + embeddings = ggml_mul_mat(ctx0, model.llava_proj_w, embeddings); + output = ggml_add(ctx0, ggml_repeat(ctx0, model.llava_proj_b, embeddings), embeddings); + } else { + // get the output of cls token, e.g., 0th index + struct ggml_tensor * cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, batch_size); + for (int b = 0; b < batch_size; b++) { + ggml_set_i32_1d(cls, b, b * num_positions); + } + embeddings = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, embeddings, hidden_size, num_positions * batch_size), cls); + + // post-layernorm + { + embeddings = ggml_norm(ctx0, embeddings, eps); + + embeddings = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.post_ln_w, embeddings), embeddings), + ggml_repeat(ctx0, model.post_ln_b, embeddings)); + } + + // final visual projection + embeddings = ggml_mul_mat(ctx0, model.projection, embeddings); + + // normalize output embeddings + output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, projection_dim, batch_size); + + for (int b = 0; b < batch_size; b++) { + struct ggml_tensor * embedding = ggml_get_rows(ctx0, embeddings, ggml_new_i32(ctx0, b)); + if (normalize) { + ggml_tensor * length = ggml_sqrt(ctx0, ggml_sum(ctx0, ggml_sqr(ctx0, embedding))); + embedding = ggml_scale_inplace(ctx0, embedding, ggml_div(ctx0, ggml_new_f32(ctx0, 1.0f), length)); + } + output = ggml_acc(ctx0, output, embedding, output->nb[1], output->nb[2], output->nb[3], b * ggml_nbytes(embedding)); } - output = ggml_acc(ctx0, output, embedding, output->nb[1], output->nb[2], output->nb[3], b * ggml_nbytes(embedding)); } ggml_set_name(output, "check"); // run the computation ggml_build_forward_expand(&gf, output); + /* ggml_cplan cplan = ggml_graph_plan(&gf, n_threads); cplan.work_size *= batch_size; @@ -1296,8 +1326,9 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl cplan.work_data = (uint8_t *)malloc(cplan.work_size); } ggml_graph_compute(&gf, &cplan); -*/ -ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); + */ + + ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); // print #ifdef CLIP_DEBUG @@ -1347,11 +1378,12 @@ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); #endif memcpy(vec, ggml_get_data_f32(output), sizeof(float) * projection_dim * batch_size); -/* + + /* if (cplan.work_size != 0) { free(cplan.work_data); } -*/ + */ ggml_free(ctx0); diff --git a/examples/llava/convert_hf_to_gguf.py b/examples/llava/convert_hf_to_gguf.py index 2d1a47cd2..f6d3ca406 100644 --- a/examples/llava/convert_hf_to_gguf.py +++ b/examples/llava/convert_hf_to_gguf.py @@ -10,9 +10,11 @@ from transformers import CLIPModel, CLIPProcessor TEXT = "clip.text" VISION = "clip.vision" + def k(raw_key: str, arch: str) -> str: return raw_key.format(arch=arch) + def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: bool) -> bool: if name in ( "logit_scale", @@ -20,22 +22,23 @@ def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: b "vision_model.embeddings.position_ids", ): return True - - if name == "visual_projection.weight" and has_llava: + + if has_llava and name in ["visual_projection.weight", "vision_model.post_layernorm.weight", "vision_model.post_layernorm.bias"]: return True - + if name.startswith("v") and not has_vision: return True - + if name.startswith("t") and not has_text: return True - + return False + def get_tensor_name(name: str) -> str: if "projection" in name: return name - + return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln") @@ -64,11 +67,14 @@ def bytes_to_unicode(): cs = [chr(n) for n in cs] return dict(zip(bs, cs)) + ap = argparse.ArgumentParser(prog="convert_hf_to_gguf.py") ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True) ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16") -ap.add_argument("--text-only", action="store_true", required=False, help="Save a text-only model. It can't be used to encode images") -ap.add_argument("--vision-only", action="store_true", required=False, help="Save a vision-only model. It can't be used to encode texts") +ap.add_argument("--text-only", action="store_true", required=False, + help="Save a text-only model. It can't be used to encode images") +ap.add_argument("--vision-only", action="store_true", required=False, + help="Save a vision-only model. It can't be used to encode texts") ap.add_argument("--llava-projector", help="Path to projector.pt file. If specified, save an image encoder for LLaVA models.") ap.add_argument("--image-mean", nargs=3, type=float, required=False, help="Override image mean values") ap.add_argument("--image-std", nargs=3, type=float, required=False, help="Override image std values") @@ -76,7 +82,7 @@ ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Defaul args = ap.parse_args() - + if args.text_only and args.vision_only: print("--text-only and --image-only arguments cannot be specified at the same time.") exit(1) @@ -91,7 +97,7 @@ dir_model = args.model_dir with open(dir_model + "/vocab.json", "r", encoding="utf-8") as f: vocab = json.load(f) tokens = [key for key in vocab] - + with open(dir_model + "/config.json", "r", encoding="utf-8") as f: config = json.load(f) v_hparams = config["vision_config"] @@ -108,7 +114,7 @@ ftype = 1 if args.use_f32: ftype = 0 - + model = CLIPModel.from_pretrained(dir_model) processor = CLIPProcessor.from_pretrained(dir_model) @@ -182,8 +188,6 @@ use_gelu = v_hparams["hidden_act"] == "gelu" fout.add_bool("clip.use_gelu", use_gelu) - - if has_llava_projector: model.vision_model.encoder.layers.pop(-1) projector = torch.load(args.llava_projector) @@ -203,7 +207,7 @@ for name, data in list_vars.items(): name = get_tensor_name(name) data = data.squeeze().numpy() - + n_dims = len(data.shape) # ftype == 0 -> float32, ftype == 1 -> float16 @@ -229,8 +233,7 @@ for name, data in list_vars.items(): print(f"{name} - {ftype_str[ftype_cur]} - shape = {data.shape}") fout.add_tensor(name, data) - - + fout.write_header_to_file() fout.write_kv_data_to_file()