LLaVA is working e2e, needs to optimize memory allocation + cleanup

This commit is contained in:
M. Yusuf Sarıgöz 2023-10-08 01:15:13 +03:00
parent d37ed4750f
commit 8690f425ec
4 changed files with 78 additions and 92 deletions

View file

@ -78,7 +78,7 @@ static std::string format(const char * fmt, ...) {
#define TN_LN_POST "%s.post_ln.%s" #define TN_LN_POST "%s.post_ln.%s"
#define TN_TEXT_PROJ "text_projection.weight" #define TN_TEXT_PROJ "text_projection.weight"
#define TN_VIS_PROJ "visual_projection.weight" #define TN_VIS_PROJ "visual_projection.weight"
#define TN_LLAVA_PROJ "llava_projector.%s" #define TN_LLAVA_PROJ "mm.%d.%s"
// //
// utilities to get data from a gguf file // utilities to get data from a gguf file
@ -225,8 +225,10 @@ struct clip_vision_model {
struct ggml_tensor * projection; struct ggml_tensor * projection;
// LLaVA projection // LLaVA projection
struct ggml_tensor * llava_proj_w; struct ggml_tensor * mm_0_w;
struct ggml_tensor * llava_proj_b; struct ggml_tensor * mm_0_b;
struct ggml_tensor * mm_2_w;
struct ggml_tensor * mm_2_b;
}; };
// Replacement for std::vector<uint8_t> that doesn't require zero-initialization. // Replacement for std::vector<uint8_t> that doesn't require zero-initialization.
@ -283,11 +285,11 @@ size_t get_mem_req_by_size(struct clip_ctx * ctx) {
return 96 * mb; return 96 * mb;
case 589: // large, two-tower case 589: // large, two-tower
case 392: // large, vision-only case 392: // large, vision-only
case 375: // large, LLaVA encoder case 377: // large, LLaVA encoder
if (vision_hparams->image_size == 224) { // input image size = 224 if (vision_hparams->image_size == 224) { // input image size = 224
return 1200 * mb; return 1200 * mb;
} else { // input image size = 336 } else { // input image size = 336
return 1800 * mb; return 2900 * mb;
} }
case 909: // huge, two-tower case 909: // huge, two-tower
case 520: // huge, vision-only case 520: // huge, vision-only
@ -572,8 +574,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
vision_model.position_embeddings = get_tensor(new_clip->ctx, format(TN_POS_EMBD, "v")); vision_model.position_embeddings = get_tensor(new_clip->ctx, format(TN_POS_EMBD, "v"));
vision_model.pre_ln_w = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "weight")); vision_model.pre_ln_w = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "weight"));
vision_model.pre_ln_b = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "bias"));if (new_clip->has_llava_projector) { vision_model.pre_ln_b = get_tensor(new_clip->ctx, format(TN_LN_PRE, "v", "bias"));if (new_clip->has_llava_projector) {
vision_model.llava_proj_w = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, "weight")); vision_model.mm_0_w = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, 0, "weight"));
vision_model.llava_proj_b = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, "bias")); vision_model.mm_0_b = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, 0, "bias"));
vision_model.mm_2_w = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, 2, "weight"));
vision_model.mm_2_b = get_tensor(new_clip->ctx, format(TN_LLAVA_PROJ, 2, "bias"));
} else { } else {
vision_model.post_ln_w = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "weight")); vision_model.post_ln_w = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "weight"));
vision_model.post_ln_b = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "bias")); vision_model.post_ln_b = get_tensor(new_clip->ctx, format(TN_LN_POST, "v", "bias"));
@ -1278,20 +1282,26 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl
embeddings = cur; embeddings = cur;
} }
//ggml_set_scratch(ctx0, {0, 0, nullptr}); //ggml_set_scratch(ctx0, {0, 0, nullptr});
struct ggml_tensor * output = NULL;
if (ctx->has_llava_projector) { if (ctx->has_llava_projector) {
output = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
embeddings = ggml_mul_mat(ctx0, model.llava_proj_w, embeddings);
output = ggml_add(ctx0, ggml_repeat(ctx0, model.llava_proj_b, embeddings), embeddings);
output = ggml_reshape_2d(ctx0, output, output->ne[0], output->ne[1]);
struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches); struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
for (int i = 0; i < num_patches; ++i) { for (int i = 0; i < num_patches; ++i) {
ggml_set_i32_1d(patches, i, i+1); ggml_set_i32_1d(patches, i, i+1);
} }
output = ggml_get_rows(ctx0, output, patches); embeddings = ggml_get_rows(ctx0, embeddings, patches);
// mm projection 0
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
embeddings = ggml_add(ctx0, ggml_repeat(ctx0, model.mm_0_b, embeddings), embeddings);
embeddings = ggml_gelu(ctx0, embeddings);
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
embeddings = ggml_add(ctx0, ggml_repeat(ctx0, model.mm_2_b, embeddings), embeddings);
ggml_set_name(embeddings, "check");
} else { } else {
// get the output of cls token, e.g., 0th index // get the output of cls token, e.g., 0th index
struct ggml_tensor * cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, batch_size); struct ggml_tensor * cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, batch_size);
@ -1312,7 +1322,7 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl
embeddings = ggml_mul_mat(ctx0, model.projection, embeddings); embeddings = ggml_mul_mat(ctx0, model.projection, embeddings);
// normalize output embeddings // normalize output embeddings
output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, projection_dim, batch_size); struct ggml_tensor * output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, projection_dim, batch_size);
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
struct ggml_tensor * embedding = ggml_get_rows(ctx0, embeddings, ggml_new_i32(ctx0, b)); struct ggml_tensor * embedding = ggml_get_rows(ctx0, embeddings, ggml_new_i32(ctx0, b));
@ -1322,11 +1332,13 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl
} }
output = ggml_acc(ctx0, output, embedding, output->nb[1], output->nb[2], output->nb[3], b * ggml_nbytes(embedding)); output = ggml_acc(ctx0, output, embedding, output->nb[1], output->nb[2], output->nb[3], b * ggml_nbytes(embedding));
} }
embeddings = output;
} }
ggml_set_name(output, "check"); //ggml_set_name(embeddings, "check");
// run the computation // run the computation
ggml_build_forward_expand(&gf, output); ggml_build_forward_expand(&gf, embeddings);
/* /*
ggml_cplan cplan = ggml_graph_plan(&gf, n_threads); ggml_cplan cplan = ggml_graph_plan(&gf, n_threads);
@ -1386,7 +1398,7 @@ bool clip_image_batch_encode(const clip_ctx * ctx, const int n_threads, const cl
printf("used_mem = %zu\n", ggml_used_mem(ctx0)); printf("used_mem = %zu\n", ggml_used_mem(ctx0));
#endif #endif
memcpy(vec, ggml_get_data_f32(output), sizeof(float) * projection_dim * batch_size); memcpy(vec, ggml_get_data_f32(embeddings), ggml_nbytes(embeddings));
/* /*
if (cplan.work_size != 0) { if (cplan.work_size != 0) {

View file

@ -39,6 +39,9 @@ def get_tensor_name(name: str) -> str:
if "projection" in name: if "projection" in name:
return name return name
if "mm_projector" in name:
return name.replace("model.mm_projector", "mm")
return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln") return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln")
@ -75,7 +78,7 @@ ap.add_argument("--text-only", action="store_true", required=False,
help="Save a text-only model. It can't be used to encode images") help="Save a text-only model. It can't be used to encode images")
ap.add_argument("--vision-only", action="store_true", required=False, ap.add_argument("--vision-only", action="store_true", required=False,
help="Save a vision-only model. It can't be used to encode texts") help="Save a vision-only model. It can't be used to encode texts")
ap.add_argument("--llava-projector", help="Path to projector.pt file. If specified, save an image encoder for LLaVA models.") ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
ap.add_argument("--image-mean", nargs=3, type=float, required=False, help="Override image mean values") ap.add_argument("--image-mean", nargs=3, type=float, required=False, help="Override image mean values")
ap.add_argument("--image-std", nargs=3, type=float, required=False, help="Override image std values") ap.add_argument("--image-std", nargs=3, type=float, required=False, help="Override image std values")
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
@ -138,7 +141,7 @@ else:
output_dir = args.output_dir if args.output_dir is not None else dir_model output_dir = args.output_dir if args.output_dir is not None else dir_model
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
output_prefix = os.path.basename(output_dir).replace("ggml_", "") output_prefix = os.path.basename(output_dir).replace("ggml_", "")
fname_out = os.path.join(output_dir, f"{output_prefix}_ggml-{fname_middle}model-{ftype_str[ftype]}.gguf") fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf")
fout = GGUFWriter(path=fname_out, arch="clip") fout = GGUFWriter(path=fname_out, arch="clip")
fout.add_bool("clip.has_text_encoder", has_text_encoder) fout.add_bool("clip.has_text_encoder", has_text_encoder)
@ -191,15 +194,19 @@ fout.add_bool("clip.use_gelu", use_gelu)
if has_llava_projector: if has_llava_projector:
model.vision_model.encoder.layers.pop(-1) model.vision_model.encoder.layers.pop(-1)
projector = torch.load(args.llava_projector) projector = torch.load(args.llava_projector)
weight = projector["model.mm_projector.weight"].cpu().squeeze().float().numpy().astype(np.float16) for name, data in projector.items():
bias = projector['model.mm_projector.bias'].cpu().squeeze().float().numpy().astype(np.float32) name = get_tensor_name(name)
fout.add_tensor("llava_projector.weight", weight) if data.ndim == 2:
fout.add_tensor("llava_projector.bias", bias) data = data.squeeze().numpy().astype(np.float16)
else:
data = data.squeeze().numpy().astype(np.float32)
fout.add_tensor(name, data)
print("Projector tensors added\n") print("Projector tensors added\n")
state_dict = model.state_dict()
list_vars = model.state_dict() for name, data in state_dict.items():
for name, data in list_vars.items():
if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_llava_projector): if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_llava_projector):
# we don't need this # we don't need this
print(f"skipping parameter: {name}") print(f"skipping parameter: {name}")

View file

@ -7,12 +7,11 @@
#include "llama.h" #include "llama.h"
static bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int * n_past) { static bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int n_batch, int * n_past) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama)); int n_embd = llama_n_embd(llama_get_model(ctx_llama));
int n_batch = N; // params.n_batch;
for (int i = 0; i < (int) N; i += n_batch) { for (int i = 0; i < N; i += n_batch) {
int n_eval = (int) N - i; int n_eval = N - i;
if (n_eval > n_batch) { if (n_eval > n_batch) {
n_eval = n_batch; n_eval = n_batch;
} }
@ -161,16 +160,16 @@ int main(int argc, char ** argv) {
} }
if (params.prompt.empty()) { if (params.prompt.empty()) {
params.prompt = "user: describe the image in detail.\nassistant:"; params.prompt = "describe the image in detail.";
} }
auto ctx_clip = clip_model_load(clip_path, 1); auto ctx_clip = clip_model_load(clip_path, 3);
clip_image_u8 img; clip_image_u8 img;
clip_image_f32 img_res; clip_image_f32 img_res;
clip_image_load_from_file(img_path, &img); clip_image_load_from_file(img_path, &img);
clip_image_preprocess(ctx_clip, &img, &img_res); clip_image_preprocess(ctx_clip, &img, &img_res);
float * vec = (float *)malloc(4096 * 256 * sizeof(float)); float * vec = (float *)malloc(4096 * 576 * sizeof(float));
clip_image_encode(ctx_clip, params.n_threads, &img_res, vec, false); clip_image_encode(ctx_clip, params.n_threads, &img_res, vec, false);
clip_free(ctx_clip); clip_free(ctx_clip);
@ -198,9 +197,10 @@ clip_free(ctx_clip);
int n_past = 0; int n_past = 0;
int max_tgt_len = 256; int max_tgt_len = 256;
//eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past); eval_string(ctx_llama, "user: ", params.n_batch, &n_past);
eval_image_embd(ctx_llama, vec, 256, &n_past); eval_image_embd(ctx_llama, vec, 576, params.n_batch, &n_past);
//eval_string(ctx_llama, "assistant:", params.n_batch, &n_past); eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past);
eval_string(ctx_llama, "\nassistant:", params.n_batch, &n_past);
printf("n_past = %d\n", n_past); printf("n_past = %d\n", n_past);
const char* tmp; const char* tmp;

View file

@ -1,63 +1,30 @@
import argparse import argparse
from llava.model import LlavaLlamaForCausalLM import glob
from transformers import AutoTokenizer import os
from peft import PeftModel
import torch import torch
dtype = torch.bfloat16
ap = argparse.ArgumentParser() ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", help="Path to LLaVA RLHF model") ap.add_argument("-m", "--model", help="Path to LLaVA v1.5 model")
ap.add_argument("-o", "--output", help="Output directory to save the merged file")
args = ap.parse_args() args = ap.parse_args()
model_path = f"{args.model}/sft_model" # find the model part that includes the the multimodal projector weights
lora_path = f"{args.model}/rlhf_lora_adapter_model" path = sorted(glob.glob(f"{args.model}/pytorch_model*.bin"))[-1]
save_path = args.output checkpoint = torch.load(path)
model = LlavaLlamaForCausalLM.from_pretrained( # get a list of mm tensor names
model_path, mm_tensors = [k for k, v in checkpoint.items() if k.startswith("model.mm_projector")]
device_map={"": "cuda:0"},
torch_dtype=dtype,
)
model = PeftModel.from_pretrained(
model,
lora_path,
)
# store these tensors in a new dictionary and torch.save them
projector = {name: checkpoint[name] for name in mm_tensors}
torch.save(projector, f"{args.model}/llava.projector")
model = model.merge_and_unload() # remove these tensors from the checkpoint and save it again
for name in mm_tensors:
del checkpoint[name]
model.save_pretrained(save_path) torch.save(checkpoint, path)
tokenizer = AutoTokenizer.from_pretrained(model_path) print("Done!")
tokenizer.save_pretrained(save_path) print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.")
print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.")
del model
del tokenizer
# Load the checkpoint
checkpoint = torch.load(f"{save_path}/pytorch_model-00002-of-00002.bin")
# Extract the tensors we want
mm_projector_weight = checkpoint['model.mm_projector.weight']
mm_projector_bias = checkpoint['model.mm_projector.bias']
# Remove the tensors from the checkpoint
del checkpoint['model.mm_projector.weight']
del checkpoint['model.mm_projector.bias']
# Create a dictionary with the original names as keys
mm_projector = {
'model.mm_projector.weight': mm_projector_weight,
'model.mm_projector.bias': mm_projector_bias
}
# Save the combined dictionary using torch.save
torch.save(mm_projector, "projector.pt")
# Save the rest of the model with the same original name
torch.save(checkpoint, "./llava-7b-rlhf-merged/pytorch_model-00002-of-00002.bin")
Print("Operation complete!")