diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 5c9a6a6f1..87cb1a28a 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -78,7 +78,7 @@ static std::string format(const char * fmt, ...) { #define TN_LN_POST "%s.post_ln.%s" #define TN_TEXT_PROJ "text_projection.weight" #define TN_VIS_PROJ "visual_projection.weight" -#define TN_LLAVA_PROJ "llava_projector.%s" +#define TN_LLAVA_PROJ "mm.%d.%s" // // utilities to get data from a gguf file @@ -225,8 +225,10 @@ struct clip_vision_model { struct ggml_tensor * projection; // LLaVA projection - struct ggml_tensor * llava_proj_w; - struct ggml_tensor * llava_proj_b; + struct ggml_tensor * mm_0_w; + struct ggml_tensor * mm_0_b; + struct ggml_tensor * mm_2_w; + struct ggml_tensor * mm_2_b; }; // Replacement for std::vector that doesn't require zero-initialization. @@ -283,11 +285,11 @@ size_t get_mem_req_by_size(struct clip_ctx * ctx) { return 96 * mb; case 589: // large, two-tower case 392: // large, vision-only - case 375: // large, LLaVA encoder + case 377: // large, LLaVA encoder if (vision_hparams->image_size == 224) { // input image size = 224 return 1200 * mb; } else { // input image size = 336 - return 1800 * mb; + return 2900 * mb; } case 909: // huge, two-tower case 520: // huge, vision-only @@ -572,8 +574,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { vision_model.position_embeddings = get_tensor(new_clip->ctx, format(TN_POS_EMBD, "v")); vision_model.pre_ln_w = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "weight")); vision_model.pre_ln_b = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "bias"));if (new_clip->has_llava_projector) { - vision_model.llava_proj_w = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, "weight")); - vision_model.llava_proj_b = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, "bias")); + vision_model.mm_0_w = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, 0, "weight")); + vision_model.mm_0_b = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, 0, "bias")); + vision_model.mm_2_w = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, 2, "weight")); + vision_model.mm_2_b = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, 2, "bias")); } else { vision_model.post_ln_w = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "weight")); vision_model.post_ln_b = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "bias")); @@ -1278,20 +1282,26 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl embeddings = cur; } - //ggml_set_scratch(ctx0, {0, 0, nullptr}); - struct ggml_tensor * output = NULL; if (ctx->has_llava_projector) { - output = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); - embeddings = ggml_mul_mat(ctx0, model.llava_proj_w, embeddings); - output = ggml_add(ctx0, ggml_repeat(ctx0, model.llava_proj_b, embeddings), embeddings); - output = ggml_reshape_2d(ctx0, output, output->ne[0], output->ne[1]); + embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches); for (int i = 0; i < num_patches; ++i) { ggml_set_i32_1d(patches, i, i+1); } - output = ggml_get_rows(ctx0, output, patches); + embeddings = ggml_get_rows(ctx0, embeddings, patches); + + // mm projection 0 + embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); + embeddings = ggml_add(ctx0, ggml_repeat(ctx0, model.mm_0_b, embeddings), embeddings); + + embeddings = ggml_gelu(ctx0, embeddings); + + embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); + embeddings = ggml_add(ctx0, ggml_repeat(ctx0, model.mm_2_b, embeddings), embeddings); + + ggml_set_name(embeddings, "check"); } else { // get the output of cls token, e.g., 0th index struct ggml_tensor * cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, batch_size); @@ -1312,7 +1322,7 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl embeddings = ggml_mul_mat(ctx0, model.projection, embeddings); // normalize output embeddings - output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, projection_dim, batch_size); + struct ggml_tensor * output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, projection_dim, batch_size); for (int b = 0; b < batch_size; b++) { struct ggml_tensor * embedding = ggml_get_rows(ctx0, embeddings, ggml_new_i32(ctx0, b)); @@ -1322,11 +1332,13 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl } output = ggml_acc(ctx0, output, embedding, output->nb[1], output->nb[2], output->nb[3], b * ggml_nbytes(embedding)); } + + embeddings = output; } - ggml_set_name(output, "check"); + //ggml_set_name(embeddings, "check"); // run the computation - ggml_build_forward_expand(&gf, output); + ggml_build_forward_expand(&gf, embeddings); /* ggml_cplan cplan = ggml_graph_plan(&gf, n_threads); @@ -1386,7 +1398,7 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl printf("used_mem = %zu\n", ggml_used_mem(ctx0)); #endif - memcpy(vec, ggml_get_data_f32(output), sizeof(float) * projection_dim * batch_size); + memcpy(vec, ggml_get_data_f32(embeddings), ggml_nbytes(embeddings)); /* if (cplan.work_size != 0) { diff --git a/examples/llava/convert_hf_to_gguf.py b/examples/llava/convert_hf_to_gguf.py index f6d3ca406..2f5eef199 100644 --- a/examples/llava/convert_hf_to_gguf.py +++ b/examples/llava/convert_hf_to_gguf.py @@ -39,6 +39,9 @@ def get_tensor_name(name: str) -> str: if "projection" in name: return name + if "mm_projector" in name: + return name.replace("model.mm_projector", "mm") + return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln") @@ -75,7 +78,7 @@ ap.add_argument("--text-only", action="store_true", required=False, help="Save a text-only model. It can't be used to encode images") ap.add_argument("--vision-only", action="store_true", required=False, help="Save a vision-only model. It can't be used to encode texts") -ap.add_argument("--llava-projector", help="Path to projector.pt file. If specified, save an image encoder for LLaVA models.") +ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.") ap.add_argument("--image-mean", nargs=3, type=float, required=False, help="Override image mean values") ap.add_argument("--image-std", nargs=3, type=float, required=False, help="Override image std values") ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) @@ -138,7 +141,7 @@ else: output_dir = args.output_dir if args.output_dir is not None else dir_model os.makedirs(output_dir, exist_ok=True) output_prefix = os.path.basename(output_dir).replace("ggml_", "") -fname_out = os.path.join(output_dir, f"{output_prefix}_ggml-{fname_middle}model-{ftype_str[ftype]}.gguf") +fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf") fout = GGUFWriter(path=fname_out, arch="clip") fout.add_bool("clip.has_text_encoder", has_text_encoder) @@ -191,15 +194,19 @@ fout.add_bool("clip.use_gelu", use_gelu) if has_llava_projector: model.vision_model.encoder.layers.pop(-1) projector = torch.load(args.llava_projector) - weight = projector["model.mm_projector.weight"].cpu().squeeze().float().numpy().astype(np.float16) - bias = projector['model.mm_projector.bias'].cpu().squeeze().float().numpy().astype(np.float32) - fout.add_tensor("llava_projector.weight", weight) - fout.add_tensor("llava_projector.bias", bias) + for name, data in projector.items(): + name = get_tensor_name(name) + if data.ndim == 2: + data = data.squeeze().numpy().astype(np.float16) + else: + data = data.squeeze().numpy().astype(np.float32) + + fout.add_tensor(name, data) + print("Projector tensors added\n") - -list_vars = model.state_dict() -for name, data in list_vars.items(): +state_dict = model.state_dict() +for name, data in state_dict.items(): if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_llava_projector): # we don't need this print(f"skipping parameter: {name}") diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 4d40f9e27..104b05cad 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -7,12 +7,11 @@ #include "llama.h" -static bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int * n_past) { +static bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int n_batch, int * n_past) { int n_embd = llama_n_embd(llama_get_model(ctx_llama)); - int n_batch = N; // params.n_batch; - - for (int i = 0; i < (int) N; i += n_batch) { - int n_eval = (int) N - i; + + for (int i = 0; i < N; i += n_batch) { + int n_eval = N - i; if (n_eval > n_batch) { n_eval = n_batch; } @@ -161,18 +160,18 @@ int main(int argc, char ** argv) { } if (params.prompt.empty()) { - params.prompt = "user: describe the image in detail.\nassistant:"; + params.prompt = "describe the image in detail."; } - - auto ctx_clip = clip_model_load(clip_path, 1); + + auto ctx_clip = clip_model_load(clip_path, 3); clip_image_u8 img; clip_image_f32 img_res; clip_image_load_from_file(img_path, &img); clip_image_preprocess(ctx_clip, &img, &img_res); - float * vec = (float *)malloc(4096 * 256 * sizeof(float)); + float * vec = (float *)malloc(4096 * 576 * sizeof(float)); clip_image_encode(ctx_clip, params.n_threads, &img_res, vec, false); -clip_free(ctx_clip); + clip_free(ctx_clip); llama_backend_init(params.numa); @@ -198,9 +197,10 @@ clip_free(ctx_clip); int n_past = 0; int max_tgt_len = 256; - //eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past); - eval_image_embd(ctx_llama, vec, 256, &n_past); -//eval_string(ctx_llama, "assistant:", params.n_batch, &n_past); + eval_string(ctx_llama, "user: ", params.n_batch, &n_past); + eval_image_embd(ctx_llama, vec, 576, params.n_batch, &n_past); + eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past); +eval_string(ctx_llama, "\nassistant:", params.n_batch, &n_past); printf("n_past = %d\n", n_past); const char* tmp; @@ -220,4 +220,4 @@ printf("n_past = %d\n", n_past); free(vec); return 0; -} \ No newline at end of file +} diff --git a/examples/llava/llava_surgery.py b/examples/llava/llava_surgery.py index a97cc06ec..26294d9bd 100644 --- a/examples/llava/llava_surgery.py +++ b/examples/llava/llava_surgery.py @@ -1,63 +1,30 @@ import argparse -from llava.model import LlavaLlamaForCausalLM -from transformers import AutoTokenizer -from peft import PeftModel +import glob +import os import torch -dtype = torch.bfloat16 ap = argparse.ArgumentParser() -ap.add_argument("-m", "--model", help="Path to LLaVA RLHF model") -ap.add_argument("-o", "--output", help="Output directory to save the merged file") +ap.add_argument("-m", "--model", help="Path to LLaVA v1.5 model") args = ap.parse_args() -model_path = f"{args.model}/sft_model" -lora_path = f"{args.model}/rlhf_lora_adapter_model" -save_path = args.output +# find the model part that includes the the multimodal projector weights +path = sorted(glob.glob(f"{args.model}/pytorch_model*.bin"))[-1] +checkpoint = torch.load(path) -model = LlavaLlamaForCausalLM.from_pretrained( - model_path, - device_map={"": "cuda:0"}, - torch_dtype=dtype, -) -model = PeftModel.from_pretrained( - model, - lora_path, -) +# get a list of mm tensor names +mm_tensors = [k for k, v in checkpoint.items() if k.startswith("model.mm_projector")] +# store these tensors in a new dictionary and torch.save them +projector = {name: checkpoint[name] for name in mm_tensors} +torch.save(projector, f"{args.model}/llava.projector") -model = model.merge_and_unload() +# remove these tensors from the checkpoint and save it again +for name in mm_tensors: + del checkpoint[name] -model.save_pretrained(save_path) +torch.save(checkpoint, path) -tokenizer = AutoTokenizer.from_pretrained(model_path) -tokenizer.save_pretrained(save_path) - -del model -del tokenizer - - -# Load the checkpoint -checkpoint = torch.load(f"{save_path}/pytorch_model-00002-of-00002.bin") - -# Extract the tensors we want -mm_projector_weight = checkpoint['model.mm_projector.weight'] -mm_projector_bias = checkpoint['model.mm_projector.bias'] - -# Remove the tensors from the checkpoint -del checkpoint['model.mm_projector.weight'] -del checkpoint['model.mm_projector.bias'] - -# Create a dictionary with the original names as keys -mm_projector = { - 'model.mm_projector.weight': mm_projector_weight, - 'model.mm_projector.bias': mm_projector_bias -} - -# Save the combined dictionary using torch.save -torch.save(mm_projector, "projector.pt") - -# Save the rest of the model with the same original name -torch.save(checkpoint, "./llava-7b-rlhf-merged/pytorch_model-00002-of-00002.bin") - -Print("Operation complete!") +print("Done!") +print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.") +print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")