Merge 5175117a09
into f09b7cb609
This commit is contained in:
commit
7c7b9c0bae
5 changed files with 226 additions and 46 deletions
|
@ -103,6 +103,59 @@ python ./examples/convert-legacy-llama.py ../llava-v1.6-vicuna-7b/ --skip-unknow
|
||||||
**note** llava-1.6 needs more context than llava-1.5, at least 3000 is needed (just run it at -c 4096)
|
**note** llava-1.6 needs more context than llava-1.5, at least 3000 is needed (just run it at -c 4096)
|
||||||
**note** llava-1.6 greatly benefits from batched prompt processing (defaults work)
|
**note** llava-1.6 greatly benefits from batched prompt processing (defaults work)
|
||||||
|
|
||||||
|
## Phi-3-Vision-128K-Instruct gguf conversion
|
||||||
|
1) Set a working directory for PHI3V and PHI3 instruct. Clone both into this dir. (It's easiest to cd into your local hf cache and copy the models from there to here)
|
||||||
|
|
||||||
|
```console
|
||||||
|
mkdir phi3-fun
|
||||||
|
cd phi3-fun
|
||||||
|
|
||||||
|
mkdir phi3-base
|
||||||
|
git clone https://huggingface.co/microsoft/Phi-3-mini-128k-instruct
|
||||||
|
|
||||||
|
mkdir phi3-vision
|
||||||
|
git clone https://huggingface.co/microsoft/Phi-3-vision-128k-instruct
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
2) Use `llava-surgery-v2.py` to extract clip from PHI3V:
|
||||||
|
```console
|
||||||
|
python examples/llava/llava-surgery-v2.py -C -m phi3-fun/phi3-vision/
|
||||||
|
```
|
||||||
|
- you will find a llava.projector and a llava.clip file in your model directory
|
||||||
|
|
||||||
|
4) Copy the llava.clip file into a subdirectory (like vit), rename it to pytorch_model.bin and add a fitting vit configuration to the directory:
|
||||||
|
```console
|
||||||
|
// under phi3-fun/phi-vision dir
|
||||||
|
mkdir vit
|
||||||
|
cp llava.clip vit/pytorch_model.bin
|
||||||
|
cp llava.projector vit/
|
||||||
|
curl -s -q https://huggingface.co/cmp-nct/llava-1.6-gguf/raw/main/config_vit.json -o vit/config.json
|
||||||
|
```
|
||||||
|
set `mm_projector_type` -> `mlp_phi` in `config.json`
|
||||||
|
|
||||||
|
5) Create the visual gguf model:
|
||||||
|
```console
|
||||||
|
python examples/llava/convert-image-encoder-to-gguf.py -m phi3-fun/phi3-vision/vit --llava-projector phi3-fun/phi3-vision/vit/llava.projector --output-dir phi3-fun/phi3-vision/vit --clip-model-is-vision
|
||||||
|
```
|
||||||
|
|
||||||
|
6) Extract the language-modelling (everything except CLIP) part of PHI3V and assign the weights to a normal PHI3 model
|
||||||
|
|
||||||
|
```console
|
||||||
|
python examples/llava/phi3-weight-transfer.py --phi3-instruct-base-path phi3-fun/phi3-base --phi3v-base-path phi3-fun/phi3-vision
|
||||||
|
```
|
||||||
|
|
||||||
|
7) Convert this to a normal gguf
|
||||||
|
(First delete the old safetensors from this directory)
|
||||||
|
```console
|
||||||
|
python convert-hf-to-gguf.py phi3-fun/phi3-base
|
||||||
|
```
|
||||||
|
|
||||||
|
8) Invoke
|
||||||
|
```console
|
||||||
|
./llava-cli -m phi3-fun/phi3-base/ggml-model-f16.gguf --mmproj phi3-fun/phi3-vision/vit/mmproj-model-f16.gguf --image IMAGE -c 4096 --temp .1 -p "PROMPT"
|
||||||
|
```
|
||||||
|
|
||||||
## llava-cli templating and llava-1.6 prompting
|
## llava-cli templating and llava-1.6 prompting
|
||||||
|
|
||||||
llava-1.5 models all use the same vicuna prompt, here you can just add your image question like `-p "Provide a full description."`
|
llava-1.5 models all use the same vicuna prompt, here you can just add your image question like `-p "Provide a full description."`
|
||||||
|
@ -137,3 +190,4 @@ Alternatively just pay notice to how many "tokens" have been used for your promp
|
||||||
- [x] Support non-CPU backend for the image encoding part.
|
- [x] Support non-CPU backend for the image encoding part.
|
||||||
- [ ] Support different sampling methods.
|
- [ ] Support different sampling methods.
|
||||||
- [ ] Support more model variants.
|
- [ ] Support more model variants.
|
||||||
|
|
||||||
|
|
|
@ -130,12 +130,14 @@ enum projector_type {
|
||||||
PROJECTOR_TYPE_LDP,
|
PROJECTOR_TYPE_LDP,
|
||||||
PROJECTOR_TYPE_LDPV2,
|
PROJECTOR_TYPE_LDPV2,
|
||||||
PROJECTOR_TYPE_UNKNOWN,
|
PROJECTOR_TYPE_UNKNOWN,
|
||||||
|
PROJECTOR_TYPE_MLP_PHI
|
||||||
};
|
};
|
||||||
|
|
||||||
static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||||
{ PROJECTOR_TYPE_MLP, "mlp" },
|
{ PROJECTOR_TYPE_MLP, "mlp" },
|
||||||
{ PROJECTOR_TYPE_LDP, "ldp" },
|
{ PROJECTOR_TYPE_LDP, "ldp" },
|
||||||
{ PROJECTOR_TYPE_LDPV2, "ldpv2"},
|
{ PROJECTOR_TYPE_LDPV2, "ldpv2"},
|
||||||
|
{ PROJECTOR_TYPE_MLP_PHI, "mlp_phi" }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -698,8 +700,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
// ne is whcn, ne = [1024, 576, 1, 1]
|
// ne is whcn, ne = [1024, 576, 1, 1]
|
||||||
embeddings = ggml_get_rows(ctx0, embeddings, patches);
|
embeddings = ggml_get_rows(ctx0, embeddings, patches);
|
||||||
|
|
||||||
// print_tensor_info(embeddings, "embeddings");
|
|
||||||
|
|
||||||
// llava projector
|
// llava projector
|
||||||
if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
|
if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
|
||||||
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||||
|
@ -709,7 +709,24 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
|
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
|
||||||
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
|
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
|
||||||
|
|
||||||
} else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
} else if (ctx->proj_type == PROJECTOR_TYPE_MLP_PHI) {
|
||||||
|
// needs to be reworked, see https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
|
||||||
|
// line 204 onwards
|
||||||
|
struct ggml_tensor * embeddings_ = embeddings;
|
||||||
|
// [1024, 576, 1, 1] -> [4096, 576, 1, 1]
|
||||||
|
embeddings = ggml_concat(ctx0, embeddings, embeddings_, 0);
|
||||||
|
embeddings = ggml_concat(ctx0, embeddings, embeddings_, 0);
|
||||||
|
embeddings = ggml_concat(ctx0, embeddings, embeddings_, 0);
|
||||||
|
|
||||||
|
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||||
|
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
||||||
|
|
||||||
|
embeddings = ggml_gelu(ctx0, embeddings);
|
||||||
|
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
|
||||||
|
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
|
||||||
|
|
||||||
|
}
|
||||||
|
else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
||||||
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||||
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
||||||
// ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
|
// ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
|
||||||
|
@ -1208,7 +1225,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// LLaVA projection
|
// LLaVA projection
|
||||||
if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM || new_clip->proj_type == PROJECTOR_TYPE_MLP_PHI) {
|
||||||
vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight"));
|
vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight"));
|
||||||
vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias"));
|
vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias"));
|
||||||
try {
|
try {
|
||||||
|
@ -2069,6 +2086,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||||
if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
|
if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
|
||||||
return ctx->vision_model.mm_2_b->ne[0];
|
return ctx->vision_model.mm_2_b->ne[0];
|
||||||
}
|
}
|
||||||
|
if (ctx->proj_type == PROJECTOR_TYPE_MLP_PHI) {
|
||||||
|
return ctx->vision_model.mm_2_b->ne[0];
|
||||||
|
}
|
||||||
if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
||||||
return ctx->vision_model.mm_3_b->ne[0];
|
return ctx->vision_model.mm_3_b->ne[0];
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,7 +86,7 @@ ap.add_argument("--clip-model-is-vision", action="store_true", required=False,
|
||||||
ap.add_argument("--clip-model-is-openclip", action="store_true", required=False,
|
ap.add_argument("--clip-model-is-openclip", action="store_true", required=False,
|
||||||
help="The clip model is from openclip (for ViT-SO400M type))")
|
help="The clip model is from openclip (for ViT-SO400M type))")
|
||||||
ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
|
ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
|
||||||
ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp")
|
ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2", "mlp_phi"], default="mlp_phi")
|
||||||
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
|
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
|
||||||
# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711
|
# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711
|
||||||
# Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5
|
# Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5
|
||||||
|
@ -206,39 +206,39 @@ if has_vision_encoder:
|
||||||
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"])
|
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"])
|
||||||
block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"]
|
block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"]
|
||||||
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
|
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
|
||||||
# /**
|
# /**
|
||||||
# "image_grid_pinpoints": [
|
# "image_grid_pinpoints": [
|
||||||
# [
|
# [
|
||||||
# 336,
|
# 336,
|
||||||
# 672
|
# 672
|
||||||
# ],
|
# ],
|
||||||
# [
|
# [
|
||||||
# 672,
|
# 672,
|
||||||
# 336
|
# 336
|
||||||
# ],
|
# ],
|
||||||
# [
|
# [
|
||||||
# 672,
|
# 672,
|
||||||
# 672
|
# 672
|
||||||
# ],
|
# ],
|
||||||
# [
|
# [
|
||||||
# 1008,
|
# 1008,
|
||||||
# 336
|
# 336
|
||||||
# ],
|
# ],
|
||||||
# [
|
# [
|
||||||
# 336,
|
# 336,
|
||||||
# 1008
|
# 1008
|
||||||
# ]
|
# ]
|
||||||
# ],
|
# ],
|
||||||
# Flattened:
|
# Flattened:
|
||||||
# [
|
# [
|
||||||
# 336, 672,
|
# 336, 672,
|
||||||
# 672, 336,
|
# 672, 336,
|
||||||
# 672, 672,
|
# 672, 672,
|
||||||
# 1008, 336,
|
# 1008, 336,
|
||||||
# 336, 1008
|
# 336, 1008
|
||||||
# ]
|
# ]
|
||||||
# *
|
# *
|
||||||
# */
|
# */
|
||||||
if "image_grid_pinpoints" in v_hparams:
|
if "image_grid_pinpoints" in v_hparams:
|
||||||
# flatten it
|
# flatten it
|
||||||
image_grid_pinpoints = []
|
image_grid_pinpoints = []
|
||||||
|
@ -257,7 +257,6 @@ if has_vision_encoder:
|
||||||
if "mm_projector_type" in v_hparams:
|
if "mm_projector_type" in v_hparams:
|
||||||
fout.add_string("clip.vision.mm_projector_type", v_hparams["mm_projector_type"])
|
fout.add_string("clip.vision.mm_projector_type", v_hparams["mm_projector_type"])
|
||||||
|
|
||||||
|
|
||||||
if processor is not None:
|
if processor is not None:
|
||||||
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean
|
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean
|
||||||
image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std
|
image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std
|
||||||
|
|
|
@ -38,7 +38,9 @@ def clean_vision_tower_from_checkpoint(checkpoint_path):
|
||||||
# file_type = 'pytorch'
|
# file_type = 'pytorch'
|
||||||
model_path = os.path.dirname(checkpoint_path)
|
model_path = os.path.dirname(checkpoint_path)
|
||||||
print(f"Searching for vision tower tensors in {checkpoint_path}")
|
print(f"Searching for vision tower tensors in {checkpoint_path}")
|
||||||
clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))]
|
clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_embed_tokens.img_processor.vision_model") or \
|
||||||
|
(k.startswith("model.vision_tower")) or \
|
||||||
|
(k.startswith("vit.")))]
|
||||||
|
|
||||||
if len(clip_tensors) > 0:
|
if len(clip_tensors) > 0:
|
||||||
print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
|
print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
|
||||||
|
@ -83,10 +85,13 @@ def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector):
|
||||||
return newline_checkpoint_path, projector_checkpoint_path
|
return newline_checkpoint_path, projector_checkpoint_path
|
||||||
|
|
||||||
def newline_criteria(checkpoint):
|
def newline_criteria(checkpoint):
|
||||||
return any(k.startswith("model.image_newline") for k in checkpoint.keys())
|
return any(k.startswith("model.vision_embed_tokens.sub_GN") or \
|
||||||
|
k.startswith("model.image_newline") for k in checkpoint.keys())
|
||||||
|
|
||||||
def proj_criteria(checkpoint):
|
def proj_criteria(checkpoint):
|
||||||
return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys())
|
return any(k.startswith("model.vision_embed_tokens.img_projection") or \
|
||||||
|
k.startswith("vision_proj.") or \
|
||||||
|
k.startswith("model.mm_projector") for k in checkpoint.keys())
|
||||||
|
|
||||||
|
|
||||||
# Command-line interface setup
|
# Command-line interface setup
|
||||||
|
@ -121,14 +126,16 @@ first_checkpoint = None
|
||||||
if newline_checkpoint_path is not None:
|
if newline_checkpoint_path is not None:
|
||||||
print(f"Taking newline from {newline_checkpoint_path}")
|
print(f"Taking newline from {newline_checkpoint_path}")
|
||||||
first_checkpoint, file_type = load_model(newline_checkpoint_path)
|
first_checkpoint, file_type = load_model(newline_checkpoint_path)
|
||||||
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
|
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.vision_embed_tokens.sub_GN") or k.startswith("model.image_newline")]
|
||||||
|
|
||||||
# Load the checkpoint
|
# Load the checkpoint
|
||||||
mm_tensors = []
|
mm_tensors = []
|
||||||
last_checkpoint = None
|
last_checkpoint = None
|
||||||
if projector_checkpoint_path is not None:
|
if projector_checkpoint_path is not None:
|
||||||
last_checkpoint, file_type = load_model(projector_checkpoint_path)
|
last_checkpoint, file_type = load_model(projector_checkpoint_path)
|
||||||
mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
|
mm_tensors = [k for k, v in last_checkpoint.items() if (k.startswith("model.vision_embed_tokens.img_projection")) or \
|
||||||
|
(k.startswith("vision_proj.")) or \
|
||||||
|
(k.startswith("model.mm_projector"))]
|
||||||
|
|
||||||
if len(mm_tensors) == 0:
|
if len(mm_tensors) == 0:
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
|
@ -144,8 +151,28 @@ print(f"Found additional {len(first_mm_tensors)} tensors to extract.")
|
||||||
projector = {}
|
projector = {}
|
||||||
for name in mm_tensors:
|
for name in mm_tensors:
|
||||||
projector[name] = last_checkpoint[name].float()
|
projector[name] = last_checkpoint[name].float()
|
||||||
for name in first_mm_tensors:
|
|
||||||
projector[name] = first_checkpoint[name].float()
|
def rename_keys(d, prefix):
|
||||||
|
new_dict = {}
|
||||||
|
for key, value in d.items():
|
||||||
|
parts = key.split('.')
|
||||||
|
new_key = f"{prefix}.{parts[-2]}.{parts[-1]}"
|
||||||
|
new_dict[new_key] = value
|
||||||
|
return new_dict
|
||||||
|
|
||||||
|
if list(projector.keys())[0].startswith("mm") is False:
|
||||||
|
|
||||||
|
print("-------------------------------")
|
||||||
|
print("PHI3V clip implicit conversion")
|
||||||
|
print("-------------------------------")
|
||||||
|
|
||||||
|
projector = rename_keys(projector, "mm")
|
||||||
|
|
||||||
|
for name in first_mm_tensors:
|
||||||
|
projector["model.image_newline"] = first_checkpoint[name].float()[0, 0, 0, :]
|
||||||
|
|
||||||
|
print("Updated projector keys to match LLAVA clip schema")
|
||||||
|
print(projector)
|
||||||
|
|
||||||
if len(projector) > 0:
|
if len(projector) > 0:
|
||||||
save_model(projector, f"{args.model}/llava.projector", 'pytorch')
|
save_model(projector, f"{args.model}/llava.projector", 'pytorch')
|
||||||
|
|
80
examples/llava/phi3-weight-transfer.py
Normal file
80
examples/llava/phi3-weight-transfer.py
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
|
||||||
|
# https://stackoverflow.com/questions/67689219/copy-one-layers-weights-from-one-huggingface-bert-model-to-another
|
||||||
|
|
||||||
|
phi3_vision = AutoModelForCausalLM.from_pretrained(args.phi3v_base_path,
|
||||||
|
device_map="auto",
|
||||||
|
trust_remote_code=True,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
_attn_implementation='eager')
|
||||||
|
|
||||||
|
print("PHI3 VISION LOADED IN MEMORY")
|
||||||
|
|
||||||
|
phi3_base = AutoModelForCausalLM.from_pretrained(args.phi3_instruct_base_path,
|
||||||
|
device_map="auto",
|
||||||
|
trust_remote_code=True,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
_attn_implementation='eager')
|
||||||
|
|
||||||
|
print("PHI3 BASE LOADED IN MEMORY")
|
||||||
|
|
||||||
|
phi3_vision_layers = dict(phi3_vision.named_parameters())
|
||||||
|
phi3_base_layers = dict(phi3_base.named_parameters())
|
||||||
|
|
||||||
|
parts = list(set(phi3_vision_layers.keys()) & set(phi3_base_layers.keys()))
|
||||||
|
|
||||||
|
print("----------------------------------------------------")
|
||||||
|
print("before transfer")
|
||||||
|
print(dict(phi3_vision.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"]
|
||||||
|
== dict(phi3_base.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"])
|
||||||
|
print("----------------------------------------------------")
|
||||||
|
|
||||||
|
for part in parts:
|
||||||
|
phi3_base_layers[part].data.copy_(phi3_vision_layers[part].data)
|
||||||
|
# target # source
|
||||||
|
|
||||||
|
print("----------------------------------------------------")
|
||||||
|
print("after transfer")
|
||||||
|
print(dict(phi3_vision.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"]
|
||||||
|
== dict(phi3_base.named_parameters())["model.layers.19.mlp.gate_up_proj.weight"])
|
||||||
|
print("----------------------------------------------------")
|
||||||
|
|
||||||
|
# save updated model weights
|
||||||
|
outfile = "phi3-instruct-vision-weight-transfer.safetensors"
|
||||||
|
outpath = os.path.join(args.phi3_instruct_base_path, outfile)
|
||||||
|
save_file(phi3_base_layers, outpath)
|
||||||
|
print(f"updates .safetensors saved to {outpath}")
|
||||||
|
|
||||||
|
# update safetensors index config
|
||||||
|
weight_index_path = os.path.join(args.phi3_instruct_base_path, "model.safetensors.index.json")
|
||||||
|
|
||||||
|
with open(weight_index_path, "r") as f:
|
||||||
|
index_data = json.load(f)
|
||||||
|
|
||||||
|
for k,v in index_data["weight_map"].items():
|
||||||
|
if v != "phi3-instruct-vision-weight-transfer.safetensors":
|
||||||
|
index_data["weight_map"][k] = outfile
|
||||||
|
|
||||||
|
with open(weight_index_path, "w") as f:
|
||||||
|
json.dump(index_data, f)
|
||||||
|
|
||||||
|
print(f"hf saftensor mapping updated!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="script to copy weights from PHI3V language model to PHI3-instruct")
|
||||||
|
|
||||||
|
parser.add_argument("--phi3-instruct-base-path", type=str, default="microsoft/Phi-3-mini-128k-instruct", help="model path or model card for PHI3-instruct")
|
||||||
|
parser.add_argument("--phi3v-base-path", type=str, default="microsoft/Phi-3-vision-128k-instruct", help="model path or model card for PHI3V")
|
||||||
|
|
||||||
|
main(parser.parse_args())
|
Loading…
Add table
Add a link
Reference in a new issue