Merge 262000fa4d
into d7b31a9d84
This commit is contained in:
commit
72ce3e5833
6 changed files with 209 additions and 37 deletions
|
@ -101,8 +101,27 @@ python ./examples/convert_legacy_llama.py ../llava-v1.6-vicuna-7b/ --skip-unknow
|
|||
```
|
||||
|
||||
**note** llava-1.6 needs more context than llava-1.5, at least 3000 is needed (just run it at -c 4096)
|
||||
|
||||
**note** llava-1.6 greatly benefits from batched prompt processing (defaults work)
|
||||
|
||||
**note** if the language model in step `6)` is incompatible with the legacy conversion script, the easiest way handle the LLM model conversion is to load the model in transformers, and export only the LLM from the llava next model.
|
||||
|
||||
```python
|
||||
import os
|
||||
import transformers
|
||||
|
||||
model_path = ...
|
||||
llm_export_path = ...
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||
model = transformers.AutoModelForImageTextToText.from_pretrained(model_path)
|
||||
|
||||
tokenizer.save_pretrained(llm_export_path)
|
||||
model.language_model.save_pretrained(llm_export_path)
|
||||
```
|
||||
|
||||
Then, you can convert the LLM using the `convert_hf_to_gguf.py` script, which handles more LLM architectures.
|
||||
|
||||
## llava-cli templating and llava-1.6 prompting
|
||||
|
||||
llava-1.5 models all use the same vicuna prompt, here you can just add your image question like `-p "Provide a full description."`
|
||||
|
|
|
@ -120,6 +120,7 @@ static std::string format(const char * fmt, ...) {
|
|||
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
|
||||
#define KEY_IMAGE_STD "clip.vision.image_std"
|
||||
#define KEY_PROJ_TYPE "clip.projector_type"
|
||||
#define KEY_VISION_FEATURE_LAYER "clip.vision.feature_layer"
|
||||
|
||||
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
||||
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
|
||||
|
@ -444,8 +445,9 @@ struct clip_hparams {
|
|||
|
||||
char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default)
|
||||
|
||||
int32_t image_grid_pinpoints[32];
|
||||
int32_t image_grid_pinpoints[64];
|
||||
int32_t image_crop_resolution;
|
||||
int32_t vision_feature_layer[4];
|
||||
};
|
||||
|
||||
struct clip_layer {
|
||||
|
@ -752,13 +754,36 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
|
||||
}
|
||||
|
||||
// loop over layers
|
||||
if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) {
|
||||
n_layer += 1;
|
||||
// Check to see we have 1+ set vision feature layers set; otherwise it's the last layer
|
||||
std::vector<struct ggml_tensor *> embedding_stack;
|
||||
bool has_feature_layers = ctx->vision_model.hparams.vision_feature_layer[0] > 0;
|
||||
// Determine how many encoder layers we need to process; if we have explicit vision feature
|
||||
// layers, only process what we need, otherwise process all of the visual encoder layers.
|
||||
int max_feature_layer = -1;
|
||||
if(has_feature_layers) {
|
||||
for(int vf_layer_idx = 0; vf_layer_idx < 4; vf_layer_idx++) {
|
||||
if(ctx->vision_model.hparams.vision_feature_layer[vf_layer_idx] > max_feature_layer) {
|
||||
max_feature_layer = ctx->vision_model.hparams.vision_feature_layer[vf_layer_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int il = 0; il < n_layer - 1; il++) {
|
||||
if(max_feature_layer < 0) {
|
||||
max_feature_layer = n_layer;
|
||||
}
|
||||
|
||||
// loop over layers
|
||||
for (int il = 0; il < max_feature_layer; il++) {
|
||||
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
|
||||
|
||||
// If this is an embedding feature layer, save the output.
|
||||
// NOTE: 0 index here refers to the input to the encoder.
|
||||
for(int vf_layer_idx = 0; vf_layer_idx < 4; vf_layer_idx++) {
|
||||
if (il == ctx->vision_model.hparams.vision_feature_layer[vf_layer_idx]) {
|
||||
embedding_stack.push_back(embeddings);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
//const size_t nb_q_w = model.layers[il].q_w->nb[0];
|
||||
|
||||
// layernorm1
|
||||
|
@ -846,17 +871,32 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
cur = ggml_add(ctx0, embeddings, cur);
|
||||
|
||||
embeddings = cur;
|
||||
|
||||
}
|
||||
|
||||
// post-layernorm
|
||||
if (ctx->has_post_norm) {
|
||||
if (ctx->has_post_norm && max_feature_layer == n_layer) {
|
||||
embeddings = ggml_norm(ctx0, embeddings, eps);
|
||||
ggml_set_name(embeddings, "post_ln");
|
||||
|
||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
|
||||
}
|
||||
|
||||
// final layer is a vision feature layer
|
||||
for(int vf_layer_idx = 0; vf_layer_idx < 4; vf_layer_idx++) {
|
||||
if (n_layer == ctx->vision_model.hparams.vision_feature_layer[vf_layer_idx]) {
|
||||
embedding_stack.push_back(embeddings);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If feature layers are explicitly set, stack them (if we have multiple)
|
||||
if(has_feature_layers && embedding_stack.size() > 0) {
|
||||
embeddings = embedding_stack.at(0);
|
||||
for(unsigned long i=1; i < embedding_stack.size(); i++) {
|
||||
embeddings = ggml_concat(ctx0, embeddings, embedding_stack.at(i), 0);
|
||||
}
|
||||
}
|
||||
|
||||
// llava projector
|
||||
if (ctx->has_llava_projector) {
|
||||
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
|
||||
|
@ -1443,15 +1483,35 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS);
|
||||
int n = gguf_get_arr_n(ctx, idx);
|
||||
const int32_t * pinpoints = (const int32_t *)gguf_get_arr_data(ctx, idx);
|
||||
for (int i = 0; i < 32 && i < n && pinpoints[i] != 0; ++i) {
|
||||
for (int i = 0; i < 64 && i < n && pinpoints[i] != 0; ++i) {
|
||||
hparams.image_grid_pinpoints[i] = pinpoints[i];
|
||||
}
|
||||
if (n < 32)
|
||||
if (n < 64)
|
||||
hparams.image_grid_pinpoints[n] = 0;
|
||||
} catch (std::runtime_error & /*e*/) {
|
||||
hparams.image_grid_pinpoints[0]=0;
|
||||
}
|
||||
|
||||
// Load the vision feature layer indices if they are explicitly provided;
|
||||
// if multiple vision feature layers are present, the values will be concatenated
|
||||
// to form the final visual features.
|
||||
// NOTE: gguf conversions should standardize the values of the vision feature layer to uints,
|
||||
// since we use -1 as an unset value here.
|
||||
try {
|
||||
int idx = get_key_idx(ctx, KEY_VISION_FEATURE_LAYER);
|
||||
int n = gguf_get_arr_n(ctx, idx);
|
||||
|
||||
const int32_t * vision_feature_layer = (const int32_t *)gguf_get_arr_data(ctx, idx);
|
||||
|
||||
for (int i = 0; i < 4 && i < n && vision_feature_layer[i] != 0; ++i) {
|
||||
hparams.vision_feature_layer[i] = vision_feature_layer[i];
|
||||
}
|
||||
if (n < 4)
|
||||
hparams.vision_feature_layer[n] = -1;
|
||||
} catch (std::runtime_error & /*e*/) {
|
||||
hparams.vision_feature_layer[0] = -1;
|
||||
}
|
||||
|
||||
try {
|
||||
int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE);
|
||||
strcpy(hparams.mm_patch_merge_type, gguf_get_val_str(ctx, idx));
|
||||
|
@ -1489,10 +1549,15 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
LOG_INF("v_image_mean %f %f %f\n", new_clip->image_mean[0], new_clip->image_mean[1], new_clip->image_mean[2]);
|
||||
LOG_INF("v_image_std %f %f %f\n", new_clip->image_std[0], new_clip->image_std[1], new_clip->image_std[2]);
|
||||
LOG_INF("v_image_grid_pinpoints: ");
|
||||
for (int i = 0; i < 32 && (hparams.image_grid_pinpoints[i] != 0); ++i) {
|
||||
for (int i = 0; i < 64 && (hparams.image_grid_pinpoints[i] != 0); ++i) {
|
||||
LOG_INF("%d ", hparams.image_grid_pinpoints[i]);
|
||||
}
|
||||
LOG_INF("\n");
|
||||
LOG_INF("v_vision_feature_layer: ");
|
||||
for(int i = 0; i < 4 && (hparams.vision_feature_layer[i] > 0); i++) {
|
||||
LOG_INF("%d ", hparams.vision_feature_layer[i]);
|
||||
}
|
||||
LOG_INF("\n");
|
||||
LOG_INF("v_mm_patch_merge_type: %s\n", hparams.mm_patch_merge_type);
|
||||
|
||||
}
|
||||
|
@ -2238,7 +2303,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
|
|||
if (params.image_grid_pinpoints[0] != 0) {
|
||||
// "spatial_unpad" with "anyres" processing for llava-1.6
|
||||
std::vector<std::pair<int, int>> possible_resolutions;
|
||||
for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) {
|
||||
for (int i = 0; i < 64 && params.image_grid_pinpoints[i] != 0; i+=2) {
|
||||
possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]});
|
||||
}
|
||||
std::pair<int, int> best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions);
|
||||
|
|
|
@ -6,7 +6,7 @@ import re
|
|||
import torch
|
||||
import numpy as np
|
||||
from gguf import *
|
||||
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
||||
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel, SiglipVisionModel
|
||||
|
||||
TEXT = "clip.text"
|
||||
VISION = "clip.vision"
|
||||
|
@ -37,6 +37,18 @@ def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: b
|
|||
|
||||
|
||||
def get_tensor_name(name: str) -> str:
|
||||
# Standardize the transformers llava next keys for
|
||||
# image newline / mm projector with the classes in haotian-liu LLaVA
|
||||
if name == "image_newline":
|
||||
return "model.image_newline"
|
||||
if name.startswith("multi_modal_projector"):
|
||||
name = name.replace("multi_modal_projector", "mm")
|
||||
if "linear_1" in name:
|
||||
name = name.replace("linear_1", "0")
|
||||
if "linear_2" in name:
|
||||
name = name.replace("linear_2", "2")
|
||||
return name
|
||||
|
||||
if "projection" in name:
|
||||
return name
|
||||
if "mm_projector" in name:
|
||||
|
@ -83,8 +95,14 @@ ap.add_argument("--vision-only", action="store_true", required=False,
|
|||
help="Save a vision-only model. It can't be used to encode texts")
|
||||
ap.add_argument("--clip-model-is-vision", action="store_true", required=False,
|
||||
help="The clip model is a pure vision model (ShareGPT4V vision extract for example)")
|
||||
ap.add_argument("--clip-model-is-openclip", action="store_true", required=False,
|
||||
|
||||
# Selectable visual encoders that are compatible with this script
|
||||
encoder_group = ap.add_mutually_exclusive_group()
|
||||
encoder_group.add_argument("--clip-model-is-openclip", action="store_true", required=False,
|
||||
help="The clip model is from openclip (for ViT-SO400M type))")
|
||||
encoder_group.add_argument("--clip-model-is-siglip", action="store_true", required=False,
|
||||
help="the visual encoder is Siglip.")
|
||||
|
||||
ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
|
||||
ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp")
|
||||
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
|
||||
|
@ -109,7 +127,12 @@ if args.use_f32:
|
|||
# output in the same directory as the model if output_dir is None
|
||||
dir_model = args.model_dir
|
||||
|
||||
if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip:
|
||||
if (
|
||||
args.clip_model_is_vision or
|
||||
not os.path.exists(dir_model + "/vocab.json") or
|
||||
args.clip_model_is_openclip or
|
||||
args.clip_model_is_siglip
|
||||
):
|
||||
vocab = None
|
||||
tokens = None
|
||||
else:
|
||||
|
@ -137,7 +160,10 @@ ftype = 1
|
|||
if args.use_f32:
|
||||
ftype = 0
|
||||
|
||||
if args.clip_model_is_vision or args.clip_model_is_openclip:
|
||||
if args.clip_model_is_siglip:
|
||||
model = SiglipVisionModel.from_pretrained(dir_model)
|
||||
processor = None
|
||||
elif args.clip_model_is_vision or args.clip_model_is_openclip:
|
||||
model = CLIPVisionModel.from_pretrained(dir_model)
|
||||
processor = None
|
||||
else:
|
||||
|
@ -187,26 +213,67 @@ else:
|
|||
if has_text_encoder:
|
||||
assert t_hparams is not None
|
||||
assert tokens is not None
|
||||
if args.clip_model_is_siglip:
|
||||
text_projection_dim = 0
|
||||
else:
|
||||
text_projection_dim = t_hparams.get("projection_dim", config["projection_dim"])
|
||||
# text_model hparams
|
||||
fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"])
|
||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"])
|
||||
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, TEXT), t_hparams["intermediate_size"])
|
||||
fout.add_uint32("clip.text.projection_dim", t_hparams.get("projection_dim", config["projection_dim"]))
|
||||
fout.add_uint32("clip.text.projection_dim", text_projection_dim)
|
||||
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, TEXT), t_hparams["num_attention_heads"])
|
||||
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, TEXT), t_hparams["layer_norm_eps"])
|
||||
fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"])
|
||||
fout.add_token_list(tokens)
|
||||
|
||||
|
||||
|
||||
def get_unsigned_vision_feature_layers(v_hparams):
|
||||
"""
|
||||
Determine the vision feature layer(s) for the llava model, which are indices into the
|
||||
hidden states of the visual encoder. Note that the hidden states array generally takes the
|
||||
form:
|
||||
|
||||
[<emb input>, <output of enc block 0>, ... <output of enc block num_hidden_layers>]
|
||||
|
||||
so positive feature indices should be offset as n+1 to get the output of encoder block n.
|
||||
We convert all vision feature layers to unsigned ints so that -1 can be used in the model
|
||||
as an unset value. If no vision feature layer is found, we leave it unset.
|
||||
"""
|
||||
num_hidden_layers = v_hparams["num_hidden_layers"]
|
||||
to_uint = lambda layer_idx: layer_idx if layer_idx >= 0 else num_hidden_layers + layer_idx + 1
|
||||
feature_layers_key = None
|
||||
# Key used for llava models in transformers
|
||||
if "vision_feature_layer" in config:
|
||||
feature_layers_key = "vision_feature_layer"
|
||||
# Key used for llava models in the original format
|
||||
elif "mm_vision_select_layer" in config:
|
||||
feature_layers_key = "mm_vision_select_layer"
|
||||
if feature_layers_key is not None:
|
||||
feature_layers = config[feature_layers_key]
|
||||
if isinstance(feature_layers, int):
|
||||
feature_layers = [feature_layers]
|
||||
return [to_uint(feature_layer) for feature_layer in feature_layers]
|
||||
|
||||
if has_vision_encoder:
|
||||
# vision_model hparams
|
||||
feature_layers = get_unsigned_vision_feature_layers(v_hparams)
|
||||
|
||||
# Siglip does not have a visual projector; set projection dim to 0
|
||||
if args.clip_model_is_siglip:
|
||||
visual_projection_dim = 0
|
||||
else:
|
||||
visual_projection_dim = v_hparams.get("projection_dim", config["projection_dim"])
|
||||
|
||||
# set vision_model hparams
|
||||
fout.add_uint32("clip.vision.image_size", v_hparams["image_size"])
|
||||
fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"])
|
||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"])
|
||||
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), v_hparams["intermediate_size"])
|
||||
fout.add_uint32("clip.vision.projection_dim", v_hparams.get("projection_dim", config["projection_dim"]))
|
||||
fout.add_uint32("clip.vision.projection_dim", visual_projection_dim)
|
||||
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"])
|
||||
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"])
|
||||
block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"]
|
||||
block_count = v_hparams["num_hidden_layers"]
|
||||
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
|
||||
# /**
|
||||
# "image_grid_pinpoints": [
|
||||
|
@ -258,7 +325,8 @@ if has_vision_encoder:
|
|||
fout.add_string("clip.vision.mm_patch_merge_type", v_hparams["mm_patch_merge_type"])
|
||||
if "mm_projector_type" in v_hparams:
|
||||
fout.add_string("clip.vision.mm_projector_type", v_hparams["mm_projector_type"])
|
||||
|
||||
if feature_layers:
|
||||
fout.add_array("clip.vision.feature_layer", feature_layers)
|
||||
|
||||
if processor is not None:
|
||||
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
@ -274,7 +342,6 @@ fout.add_bool("clip.use_gelu", use_gelu)
|
|||
|
||||
|
||||
if has_llava_projector:
|
||||
model.vision_model.encoder.layers.pop(-1)
|
||||
projector = torch.load(args.llava_projector)
|
||||
for name, data in projector.items():
|
||||
name = get_tensor_name(name)
|
||||
|
|
|
@ -355,7 +355,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
|||
const int32_t * image_grid = clip_image_grid(ctx_clip);
|
||||
|
||||
std::vector<std::pair<int, int>> grid_pinpoints;
|
||||
for (int i = 0; i < 32 && image_grid[i] != 0; i += 2) {
|
||||
for (int i = 0; i < 64 && image_grid[i] != 0; i += 2) {
|
||||
grid_pinpoints.push_back({image_grid[i], image_grid[i+1]});
|
||||
}
|
||||
|
||||
|
@ -405,10 +405,8 @@ bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx *
|
|||
}
|
||||
|
||||
bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) {
|
||||
int num_max_patches = 6;
|
||||
if (clip_is_minicpmv(ctx_clip)) {
|
||||
num_max_patches = 10;
|
||||
}
|
||||
// Minicpmv / granite vision use 10 patches
|
||||
int num_max_patches = 10;
|
||||
if (clip_is_glm(ctx_clip)) {
|
||||
num_max_patches = 1;
|
||||
}
|
||||
|
|
|
@ -33,6 +33,33 @@ def save_model(model, file_path, file_type):
|
|||
else:
|
||||
torch.save(model, file_path)
|
||||
|
||||
# Helpers to match weight names from specific components or
|
||||
# determine if a saved shard contains that component
|
||||
def is_vision_tower(weight_name):
|
||||
return (
|
||||
weight_name.startswith("model.vision_tower") or
|
||||
weight_name.startswith("vit.") or
|
||||
weight_name.startswith("vision_tower")
|
||||
)
|
||||
|
||||
def is_newline(weight_name):
|
||||
return (
|
||||
weight_name.startswith("model.image_newline") or
|
||||
weight_name.startswith("image_newline")
|
||||
)
|
||||
|
||||
def is_mm_projector(weight_name):
|
||||
return (
|
||||
weight_name.startswith("model.mm_projector") or
|
||||
weight_name.startswith("vision_proj.") or
|
||||
weight_name.startswith("multi_modal_projector")
|
||||
)
|
||||
|
||||
def newline_criteria(checkpoint):
|
||||
return any(is_newline(k) for k in checkpoint.keys())
|
||||
|
||||
def proj_criteria(checkpoint):
|
||||
return any(is_mm_projector(k) for k in checkpoint.keys())
|
||||
|
||||
# Adapted function to clean vision tower from checkpoint
|
||||
def clean_vision_tower_from_checkpoint(checkpoint_path):
|
||||
|
@ -40,7 +67,7 @@ def clean_vision_tower_from_checkpoint(checkpoint_path):
|
|||
# file_type = 'pytorch'
|
||||
model_path = os.path.dirname(checkpoint_path)
|
||||
print(f"Searching for vision tower tensors in {checkpoint_path}")
|
||||
clip_tensors = [k for k, v in checkpoint.items() if (k.startswith("model.vision_tower") or k.startswith("vit."))]
|
||||
clip_tensors = [k for k, v in checkpoint.items() if is_vision_tower(k)]
|
||||
|
||||
if len(clip_tensors) > 0:
|
||||
print(f"Found {len(clip_tensors)} tensors to extract from {checkpoint_path}")
|
||||
|
@ -84,12 +111,6 @@ def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector):
|
|||
|
||||
return newline_checkpoint_path, projector_checkpoint_path
|
||||
|
||||
def newline_criteria(checkpoint):
|
||||
return any(k.startswith("model.image_newline") for k in checkpoint.keys())
|
||||
|
||||
def proj_criteria(checkpoint):
|
||||
return any(k.startswith("model.mm_projector") or k.startswith("vision_proj.") for k in checkpoint.keys())
|
||||
|
||||
|
||||
# Command-line interface setup
|
||||
ap = argparse.ArgumentParser()
|
||||
|
@ -123,14 +144,14 @@ first_checkpoint = None
|
|||
if newline_checkpoint_path is not None:
|
||||
print(f"Taking newline from {newline_checkpoint_path}")
|
||||
first_checkpoint, file_type = load_model(newline_checkpoint_path)
|
||||
first_mm_tensors = [k for k, v in first_checkpoint.items() if k.startswith("model.image_newline")]
|
||||
first_mm_tensors = [k for k, v in first_checkpoint.items() if is_newline(k)]
|
||||
|
||||
# Load the checkpoint
|
||||
mm_tensors = []
|
||||
last_checkpoint = None
|
||||
if projector_checkpoint_path is not None:
|
||||
last_checkpoint, file_type = load_model(projector_checkpoint_path)
|
||||
mm_tensors = [k for k, v in last_checkpoint.items() if k.startswith("model.mm_projector") or k.startswith("vision_proj.")]
|
||||
mm_tensors = [k for k, v in last_checkpoint.items() if is_mm_projector(k)]
|
||||
|
||||
if len(mm_tensors) == 0:
|
||||
if last_checkpoint is not None:
|
||||
|
@ -155,5 +176,5 @@ if len(projector) > 0:
|
|||
save_model(projector, f"{args.model}/llava.projector", 'pytorch')
|
||||
|
||||
print("Done!")
|
||||
print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.")
|
||||
print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")
|
||||
print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")
|
||||
|
|
|
@ -8509,7 +8509,9 @@ static void ggml_compute_forward_get_rows_f32(
|
|||
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
|
||||
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
||||
|
||||
GGML_ASSERT(i01 >= 0 && i01 < ne01);
|
||||
// Copying this out for a bit while investigating due to issues like:
|
||||
// https://github.com/ggerganov/llama.cpp/issues/10157
|
||||
// GGML_ASSERT(i01 >= 0 && i01 < ne01);
|
||||
|
||||
ggml_vec_cpy_f32(nc,
|
||||
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue