From efeaeaf79fe855312b18f68c3760d727d42c9bbf Mon Sep 17 00:00:00 2001 From: farris Date: Sun, 2 Jun 2024 18:09:36 -0700 Subject: [PATCH 1/2] publish branch --- examples/llava/README.md | 54 ++++++++++++++++++ examples/llava/llava-surgery-v2.py | 41 ++++++++++--- examples/llava/phi3-weight-transfer.py | 79 ++++++++++++++++++++++++++ ggml-metal.m | 8 +-- ggml.c | 4 +- 5 files changed, 173 insertions(+), 13 deletions(-) create mode 100644 examples/llava/phi3-weight-transfer.py diff --git a/examples/llava/README.md b/examples/llava/README.md index 8d1ae5270..645dc8c3e 100644 --- a/examples/llava/README.md +++ b/examples/llava/README.md @@ -103,6 +103,59 @@ python ./examples/convert-legacy-llama.py ../llava-v1.6-vicuna-7b/ --skip-unknow **note** llava-1.6 needs more context than llava-1.5, at least 3000 is needed (just run it at -c 4096) **note** llava-1.6 greatly benefits from batched prompt processing (defaults work) +## Phi-3-Vision-128K-Instruct gguf conversion +1) Set a working directory for PHI3V and PHI3 instruct. Clone both into this dir. (It's easiest to cd into your local hf cache and copy the models from there to here) + +```console +mkdir phi3-fun +cd phi3-fun + +mkdir phi3-base +git clone https://huggingface.co/microsoft/Phi-3-mini-128k-instruct + +mkdir phi3-vision +git clone https://huggingface.co/microsoft/Phi-3-vision-128k-instruct + +``` + +2) Use `llava-surgery-v2.py` to extract clip from PHI3V: +```console +python examples/llava/llava-surgery-v2.py -C -m phi3-fun/phi3-vision/ +``` +- you will find a llava.projector and a llava.clip file in your model directory + +4) Copy the llava.clip file into a subdirectory (like vit), rename it to pytorch_model.bin and add a fitting vit configuration to the directory: +```console +// under phi3-fun/phi-vision dir +mkdir vit +cp llava.clip vit/pytorch_model.bin +cp llava.projector vit/ +curl -s -q https://huggingface.co/cmp-nct/llava-1.6-gguf/raw/main/config_vit.json -o vit/config.json +``` + +5) Create the visual gguf model: +```console +python examples/llava/convert-image-encoder-to-gguf.py -m phi3-fun/phi3-vision/vit --llava-projector phi3-fun/phi3-vision/vit/llava.projector --output-dir phi3-fun/phi3-vision/vit --clip-model-is-vision +``` + +6) Extract the language-modelling (everything except CLIP) part of PHI3V and assign the weights to a normal PHI3 model + +```console +python examples/llava/phi3-weight-transfer.py --phi3-instruct-base-path phi3-fun/phi3-base --phi3v-base-path phi3-fun/phi3-vision +``` + +7) Convert this to a normal gguf +(First delete the old safetensors from this directory) +```console +python convert-hf-to-gguf.py phi3-fun/phi3-base +``` + +8) Invoke +(recompile llama.cpp first) +```console +./llava-cli -m phi3-fun/phi3-base/ggml-model-f16.gguf --mmproj phi3-fun/phi3-vision/vit/mmproj-model-f16.gguf --image IMAGE -c 4096 --temp .1 -p "PROMPT" +``` + ## llava-cli templating and llava-1.6 prompting llava-1.5 models all use the same vicuna prompt, here you can just add your image question like `-p "Provide a full description."` @@ -137,3 +190,4 @@ Alternatively just pay notice to how many "tokens" have been used for your promp - [x] Support non-CPU backend for the image encoding part. - [ ] Support different sampling methods. - [ ] Support more model variants. + diff --git a/examples/llava/llava-surgery-v2.py b/examples/llava/llava-surgery-v2.py index eb56d6988..9bcbb02ed 100644 --- a/examples/llava/llava-surgery-v2.py +++ b/examples/llava/llava-surgery-v2.py @@ -38,7 +38,9 @@ def clean_vision_tower_from_checkpoint(checkpoint_path): # file_type = 'pytorch' model_path = os.path.dirname(checkpoint_path) print(f"Searching for vision tower tensors in {checkpoint_path}") - clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))] + clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_embed_tokens.img_processor.vision_model") or \ + (k.startswith("model.vision_tower")) or \ + (k.startswith("vit.")))] if len(clip_tensors) > 0: print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}") @@ -83,10 +85,13 @@ def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector): return newline_checkpoint_path, projector_checkpoint_path def newline_criteria(checkpoint): - return any(k.startswith("model.image_newline") for k in checkpoint.keys()) + return any(k.startswith("model.vision_embed_tokens.sub_GN") or \ + k.startswith("model.image_newline") for k in checkpoint.keys()) def proj_criteria(checkpoint): - return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys()) + return any(k.startswith("model.vision_embed_tokens.img_projection") or \ + k.startswith("vision_proj.") or \ + k.startswith("model.mm_projector") for k in checkpoint.keys()) # Command-line interface setup @@ -121,14 +126,16 @@ first_checkpoint = None if newline_checkpoint_path is not None: print(f"Taking newline from {newline_checkpoint_path}") first_checkpoint, file_type = load_model(newline_checkpoint_path) - first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")] + first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.vision_embed_tokens.sub_GN") or k.startswith("model.image_newline")] # Load the checkpoint mm_tensors = [] last_checkpoint = None if projector_checkpoint_path is not None: last_checkpoint, file_type = load_model(projector_checkpoint_path) - mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")] + mm_tensors = [k for k, v in last_checkpoint.items() if (k.startswith("model.vision_embed_tokens.img_projection")) or \ + (k.startswith("vision_proj.")) or \ + (k.startswith("model.mm_projector"))] if len(mm_tensors) == 0: if last_checkpoint is not None: @@ -144,8 +151,28 @@ print(f"Found additional {len(first_mm_tensors)} tensors to extract.") projector = {} for name in mm_tensors: projector[name] = last_checkpoint[name].float() -for name in first_mm_tensors: - projector[name] = first_checkpoint[name].float() + +def rename_keys(d, prefix): + new_dict = {} + for key, value in d.items(): + parts = key.split('.') + new_key = f"{prefix}.{parts[-2]}.{parts[-1]}" + new_dict[new_key] = value + return new_dict + +if list(projector.keys())[0].startswith("mm") is False: + + print("-------------------------------") + print("PHI3V clip implicit conversion") + print("-------------------------------") + + projector = rename_keys(projector, "mm") + + for name in first_mm_tensors: + projector["model.image_newline"] = first_checkpoint[name].float()[0, 0, 0, :] + + print("Updated projector keys to match LLAVA clip schema") + print(projector) if len(projector) > 0: save_model(projector, f"{args.model}/llava.projector", 'pytorch') diff --git a/examples/llava/phi3-weight-transfer.py b/examples/llava/phi3-weight-transfer.py new file mode 100644 index 000000000..18737934c --- /dev/null +++ b/examples/llava/phi3-weight-transfer.py @@ -0,0 +1,79 @@ +import argparse +import json +import os + +import torch +from safetensors.torch import save_file +from transformers import AutoModelForCausalLM + + +def main(args): + + # https://stackoverflow.com/questions/67689219/copy-one-layers-weights-from-one-huggingface-bert-model-to-another + + phi3_vision = AutoModelForCausalLM.from_pretrained(args.phi3v_base_path,\ + device_map="auto",\ + trust_remote_code=True,\ + torch_dtype=torch.float16,\ + _attn_implementation='eager') + + print("PHI3 VISION LOADED IN MEMORY") + + phi3_base = AutoModelForCausalLM.from_pretrained(args.phi3_instruct_base_path,\ + device_map="auto",\ + trust_remote_code=True,\ + torch_dtype=torch.float16,\ + _attn_implementation='eager') + + print("PHI3 BASE LOADED IN MEMORY") + + phi3_vision_layers = dict(phi3_vision.named_parameters()) + phi3_base_layers = dict(phi3_base.named_parameters()) + + parts = list(set(phi3_vision_layers.keys()) & set(phi3_base_layers.keys())) + + print("----------------------------------------------------") + print("before transfer") + print(dict(phi3_vision.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"] == \ + dict(phi3_base.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"]) + print("----------------------------------------------------") + + for part in parts: + phi3_base_layers[part].data.copy_(phi3_vision_layers[part].data) + # target # source + + print("----------------------------------------------------") + print("after transfer") + print(dict(phi3_vision.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"] == \ + dict(phi3_base.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"]) + print("----------------------------------------------------") + + # save updated model weights + outfile = "phi3-instruct-vision-weight-transfer.safetensors" + outpath = os.path.join(args.phi3_instruct_base_path, outfile) + save_file(phi3_base_layers, outpath) + print(f"updates .safetensors saved to {outpath}") + + # update safetensors index config + weight_index_path = os.path.join(args.phi3_instruct_base_path, "model.safetensors.index.json") + + with open(weight_index_path, "r") as f: + index_data = json.load(f) + + for k,v in index_data["weight_map"].items(): + if v != "phi3-instruct-vision-weight-transfer.safetensors": + index_data["weight_map"][k] = outfile + + with open(weight_index_path, "w") as f: + json.dump(index_data, f) + + print(f"hf saftensor mapping updated!") + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(description="script to copy weights from PHI3V language model to PHI3-instruct") + + parser.add_argument("--phi3-instruct-base-path", type=str, default="microsoft/Phi-3-mini-128k-instruct", help="model path or model card for PHI3-instruct") + parser.add_argument("--phi3v-base-path", type=str, default="microsoft/Phi-3-vision-128k-instruct", help="model path or model card for PHI3V") + + main(parser.parse_args()) diff --git a/ggml-metal.m b/ggml-metal.m index fddc44f78..74b53c4e4 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -779,7 +779,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_LEAKY_RELU: return true; case GGML_OP_FLASH_ATTN_EXT: - if (op->src[1]->type != GGML_TYPE_F16) { +if (op->src[1]->type != GGML_TYPE_F16) { return false; } if (op->src[2]->type != GGML_TYPE_F16) { @@ -1523,10 +1523,10 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_MUL_MAT: { - GGML_ASSERT(ne00 == ne10); + // GGML_ASSERT(ne00 == ne10); - GGML_ASSERT(ne12 % ne02 == 0); - GGML_ASSERT(ne13 % ne03 == 0); + // GGML_ASSERT(ne12 % ne02 == 0); + // GGML_ASSERT(ne13 % ne03 == 0); const uint r2 = ne12/ne02; const uint r3 = ne13/ne03; diff --git a/ggml.c b/ggml.c index f479dc3e1..a255061d5 100644 --- a/ggml.c +++ b/ggml.c @@ -5290,8 +5290,8 @@ struct ggml_tensor * ggml_mul_mat( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - GGML_ASSERT(ggml_can_mul_mat(a, b)); - GGML_ASSERT(!ggml_is_transposed(a)); + // GGML_ASSERT(ggml_can_mul_mat(a, b)); + // GGML_ASSERT(!ggml_is_transposed(a)); bool is_node = false; From 5175117a0944a55299e22a327da97b0d20c29a80 Mon Sep 17 00:00:00 2001 From: farris Date: Wed, 5 Jun 2024 11:55:41 -0700 Subject: [PATCH 2/2] add phi3v projection handling in clip.cpp --- examples/llava/README.md | 4 +- examples/llava/clip.cpp | 28 ++++++-- .../llava/convert-image-encoder-to-gguf.py | 69 +++++++++---------- examples/llava/phi3-weight-transfer.py | 37 +++++----- ggml-metal.m | 8 +-- ggml.c | 4 +- 6 files changed, 85 insertions(+), 65 deletions(-) diff --git a/examples/llava/README.md b/examples/llava/README.md index 645dc8c3e..f4820b9b9 100644 --- a/examples/llava/README.md +++ b/examples/llava/README.md @@ -127,11 +127,12 @@ python examples/llava/llava-surgery-v2.py -C -m phi3-fun/phi3-vision/ 4) Copy the llava.clip file into a subdirectory (like vit), rename it to pytorch_model.bin and add a fitting vit configuration to the directory: ```console // under phi3-fun/phi-vision dir -mkdir vit +mkdir vit cp llava.clip vit/pytorch_model.bin cp llava.projector vit/ curl -s -q https://huggingface.co/cmp-nct/llava-1.6-gguf/raw/main/config_vit.json -o vit/config.json ``` +set `mm_projector_type` -> `mlp_phi` in `config.json` 5) Create the visual gguf model: ```console @@ -151,7 +152,6 @@ python convert-hf-to-gguf.py phi3-fun/phi3-base ``` 8) Invoke -(recompile llama.cpp first) ```console ./llava-cli -m phi3-fun/phi3-base/ggml-model-f16.gguf --mmproj phi3-fun/phi3-vision/vit/mmproj-model-f16.gguf --image IMAGE -c 4096 --temp .1 -p "PROMPT" ``` diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 95fbe3d02..3fdfe45b9 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -130,12 +130,14 @@ enum projector_type { PROJECTOR_TYPE_LDP, PROJECTOR_TYPE_LDPV2, PROJECTOR_TYPE_UNKNOWN, + PROJECTOR_TYPE_MLP_PHI }; static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_MLP, "mlp" }, { PROJECTOR_TYPE_LDP, "ldp" }, { PROJECTOR_TYPE_LDPV2, "ldpv2"}, + { PROJECTOR_TYPE_MLP_PHI, "mlp_phi" } }; @@ -698,8 +700,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 // ne is whcn, ne = [1024, 576, 1, 1] embeddings = ggml_get_rows(ctx0, embeddings, patches); - // print_tensor_info(embeddings, "embeddings"); - // llava projector if (ctx->proj_type == PROJECTOR_TYPE_MLP) { embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); @@ -709,7 +709,24 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); - } else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { + } else if (ctx->proj_type == PROJECTOR_TYPE_MLP_PHI) { + // needs to be reworked, see https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py + // line 204 onwards + struct ggml_tensor * embeddings_ = embeddings; + // [1024, 576, 1, 1] -> [4096, 576, 1, 1] + embeddings = ggml_concat(ctx0, embeddings, embeddings_, 0); + embeddings = ggml_concat(ctx0, embeddings, embeddings_, 0); + embeddings = ggml_concat(ctx0, embeddings, embeddings_, 0); + + embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); + + embeddings = ggml_gelu(ctx0, embeddings); + embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); + + } + else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false); @@ -1208,7 +1225,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { } // LLaVA projection - if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM) { + if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM || new_clip->proj_type == PROJECTOR_TYPE_MLP_PHI) { vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight")); vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias")); try { @@ -2069,6 +2086,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { if (ctx->proj_type == PROJECTOR_TYPE_MLP) { return ctx->vision_model.mm_2_b->ne[0]; } + if (ctx->proj_type == PROJECTOR_TYPE_MLP_PHI) { + return ctx->vision_model.mm_2_b->ne[0]; + } if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { return ctx->vision_model.mm_3_b->ne[0]; } diff --git a/examples/llava/convert-image-encoder-to-gguf.py b/examples/llava/convert-image-encoder-to-gguf.py index b00bf7c6d..766b65e78 100644 --- a/examples/llava/convert-image-encoder-to-gguf.py +++ b/examples/llava/convert-image-encoder-to-gguf.py @@ -86,7 +86,7 @@ ap.add_argument("--clip-model-is-vision", action="store_true", required=False, ap.add_argument("--clip-model-is-openclip", action="store_true", required=False, help="The clip model is from openclip (for ViT-SO400M type))") ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.") -ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp") +ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2", "mlp_phi"], default="mlp_phi") ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) # Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711 # Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5 @@ -206,39 +206,39 @@ if has_vision_encoder: fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"]) block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"] fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count) - # /** - # "image_grid_pinpoints": [ - # [ - # 336, - # 672 - # ], - # [ - # 672, - # 336 - # ], - # [ - # 672, - # 672 - # ], - # [ - # 1008, - # 336 - # ], - # [ - # 336, - # 1008 - # ] - # ], - # Flattened: - # [ - # 336, 672, - # 672, 336, - # 672, 672, - # 1008, 336, - # 336, 1008 - # ] - # * - # */ + # /** + # "image_grid_pinpoints": [ + # [ + # 336, + # 672 + # ], + # [ + # 672, + # 336 + # ], + # [ + # 672, + # 672 + # ], + # [ + # 1008, + # 336 + # ], + # [ + # 336, + # 1008 + # ] + # ], + # Flattened: + # [ + # 336, 672, + # 672, 336, + # 672, 672, + # 1008, 336, + # 336, 1008 + # ] + # * + # */ if "image_grid_pinpoints" in v_hparams: # flatten it image_grid_pinpoints = [] @@ -257,7 +257,6 @@ if has_vision_encoder: if "mm_projector_type" in v_hparams: fout.add_string("clip.vision.mm_projector_type", v_hparams["mm_projector_type"]) - if processor is not None: image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std diff --git a/examples/llava/phi3-weight-transfer.py b/examples/llava/phi3-weight-transfer.py index 18737934c..136639d83 100644 --- a/examples/llava/phi3-weight-transfer.py +++ b/examples/llava/phi3-weight-transfer.py @@ -11,19 +11,19 @@ def main(args): # https://stackoverflow.com/questions/67689219/copy-one-layers-weights-from-one-huggingface-bert-model-to-another - phi3_vision = AutoModelForCausalLM.from_pretrained(args.phi3v_base_path,\ - device_map="auto",\ - trust_remote_code=True,\ - torch_dtype=torch.float16,\ - _attn_implementation='eager') + phi3_vision = AutoModelForCausalLM.from_pretrained(args.phi3v_base_path, + device_map="auto", + trust_remote_code=True, + torch_dtype=torch.float16, + _attn_implementation='eager') print("PHI3 VISION LOADED IN MEMORY") - phi3_base = AutoModelForCausalLM.from_pretrained(args.phi3_instruct_base_path,\ - device_map="auto",\ - trust_remote_code=True,\ - torch_dtype=torch.float16,\ - _attn_implementation='eager') + phi3_base = AutoModelForCausalLM.from_pretrained(args.phi3_instruct_base_path, + device_map="auto", + trust_remote_code=True, + torch_dtype=torch.float16, + _attn_implementation='eager') print("PHI3 BASE LOADED IN MEMORY") @@ -34,21 +34,21 @@ def main(args): print("----------------------------------------------------") print("before transfer") - print(dict(phi3_vision.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"] == \ - dict(phi3_base.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"]) + print(dict(phi3_vision.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"] + == dict(phi3_base.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"]) print("----------------------------------------------------") for part in parts: - phi3_base_layers[part].data.copy_(phi3_vision_layers[part].data) + phi3_base_layers[part].data.copy_(phi3_vision_layers[part].data) # target # source print("----------------------------------------------------") print("after transfer") - print(dict(phi3_vision.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"] == \ - dict(phi3_base.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"]) + print(dict(phi3_vision.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"] + == dict(phi3_base.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"]) print("----------------------------------------------------") - # save updated model weights + # save updated model weights outfile = "phi3-instruct-vision-weight-transfer.safetensors" outpath = os.path.join(args.phi3_instruct_base_path, outfile) save_file(phi3_base_layers, outpath) @@ -59,7 +59,7 @@ def main(args): with open(weight_index_path, "r") as f: index_data = json.load(f) - + for k,v in index_data["weight_map"].items(): if v != "phi3-instruct-vision-weight-transfer.safetensors": index_data["weight_map"][k] = outfile @@ -69,8 +69,9 @@ def main(args): print(f"hf saftensor mapping updated!") + if __name__ == '__main__': - + parser = argparse.ArgumentParser(description="script to copy weights from PHI3V language model to PHI3-instruct") parser.add_argument("--phi3-instruct-base-path", type=str, default="microsoft/Phi-3-mini-128k-instruct", help="model path or model card for PHI3-instruct") diff --git a/ggml-metal.m b/ggml-metal.m index 74b53c4e4..fddc44f78 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -779,7 +779,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_LEAKY_RELU: return true; case GGML_OP_FLASH_ATTN_EXT: -if (op->src[1]->type != GGML_TYPE_F16) { + if (op->src[1]->type != GGML_TYPE_F16) { return false; } if (op->src[2]->type != GGML_TYPE_F16) { @@ -1523,10 +1523,10 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_MUL_MAT: { - // GGML_ASSERT(ne00 == ne10); + GGML_ASSERT(ne00 == ne10); - // GGML_ASSERT(ne12 % ne02 == 0); - // GGML_ASSERT(ne13 % ne03 == 0); + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); const uint r2 = ne12/ne02; const uint r3 = ne13/ne03; diff --git a/ggml.c b/ggml.c index a255061d5..f479dc3e1 100644 --- a/ggml.c +++ b/ggml.c @@ -5290,8 +5290,8 @@ struct ggml_tensor * ggml_mul_mat( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - // GGML_ASSERT(ggml_can_mul_mat(a, b)); - // GGML_ASSERT(!ggml_is_transposed(a)); + GGML_ASSERT(ggml_can_mul_mat(a, b)); + GGML_ASSERT(!ggml_is_transposed(a)); bool is_node = false;