llama : add Qwen2VL support + multimodal RoPE (#10361)
* Barebone Qwen2VL LLM convertor * Add Qwen2VL cli entrypoint * [WIP] add qwen2vl arch * Verify m-rope output * Add vl-rope/2d-rope support for qwen2vl ViT * update qwen2vl cli tool * update 5D tensor op workaround * [WIP] qwen2vl vision model * make batch and clip utils compatible with qwen2vl * [WIP] create inference workflow, gguf convert script but fix * correcting vision-rope behavior, add the missing last layer back to ViT * add arg parser to qwen2vl_surgery * replace variable size array with vector * cuda-gdb cmake preset * add fp32 mrope, vision rope kernel * add fp16 support for qwen2vl and m-rope * add `GGML_ROPE_TYPE_MROPE`, `GGML_ROPE_TYPE_VISION` * fix rope op mode switching, out dated func args * update `llama_hparams` * update to keep up stream changes * resolve linter, test errors * add makefile entry, update speical image padding token * add mrope unit test, fix few compiler warnings * rename `mrope` related function, params * minor updates on debug util, bug fixs * add `m-rope` testcase to `test-backend-ops` * Apply suggestions from code review Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * fix traililng whitespce * store `llama_hparams.rope_sections` with fixed size array * update position id tensor size check in GGML_OP_ROPE * minor updates * update `ggml_backend_*_supports_op` of unsupported backends * remote old `rope_section` compare operator --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
56eea0781c
commit
ba1cb19cdd
24 changed files with 1873 additions and 114 deletions
|
@ -43,3 +43,10 @@ set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-minicpmv-cli)
|
|||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
set(TARGET llama-qwen2vl-cli)
|
||||
add_executable(${TARGET} qwen2vl-cli.cpp)
|
||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-qwen2vl-cli)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
|
|
@ -102,7 +102,9 @@ static std::string format(const char * fmt, ...) {
|
|||
#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector"
|
||||
#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector"
|
||||
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
|
||||
#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger"
|
||||
#define KEY_USE_GELU "clip.use_gelu"
|
||||
#define KEY_USE_SILU "clip.use_silu"
|
||||
#define KEY_N_EMBD "clip.%s.embedding_length"
|
||||
#define KEY_N_FF "clip.%s.feed_forward_length"
|
||||
#define KEY_N_BLOCK "clip.%s.block_count"
|
||||
|
@ -129,7 +131,8 @@ static std::string format(const char * fmt, ...) {
|
|||
#define TN_TOKEN_EMBD "%s.token_embd.weight"
|
||||
#define TN_POS_EMBD "%s.position_embd.weight"
|
||||
#define TN_CLASS_EMBD "v.class_embd"
|
||||
#define TN_PATCH_EMBD "v.patch_embd.weight"
|
||||
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
|
||||
#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1"
|
||||
#define TN_PATCH_BIAS "v.patch_embd.bias"
|
||||
#define TN_ATTN_K "%s.blk.%d.attn_k.%s"
|
||||
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
|
||||
|
@ -163,6 +166,7 @@ enum projector_type {
|
|||
PROJECTOR_TYPE_LDP,
|
||||
PROJECTOR_TYPE_LDPV2,
|
||||
PROJECTOR_TYPE_RESAMPLER,
|
||||
PROJECTOR_TYPE_MERGER,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
|
@ -171,6 +175,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
|||
{ PROJECTOR_TYPE_LDP, "ldp" },
|
||||
{ PROJECTOR_TYPE_LDPV2, "ldpv2"},
|
||||
{ PROJECTOR_TYPE_RESAMPLER, "resampler"},
|
||||
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
|
||||
};
|
||||
|
||||
|
||||
|
@ -463,7 +468,8 @@ struct clip_vision_model {
|
|||
|
||||
// embeddings
|
||||
struct ggml_tensor * class_embedding;
|
||||
struct ggml_tensor * patch_embeddings;
|
||||
struct ggml_tensor * patch_embeddings_0;
|
||||
struct ggml_tensor * patch_embeddings_1; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL)
|
||||
struct ggml_tensor * patch_bias;
|
||||
struct ggml_tensor * position_embeddings;
|
||||
|
||||
|
@ -553,6 +559,7 @@ struct clip_ctx {
|
|||
bool has_vision_encoder = false;
|
||||
bool has_llava_projector = false;
|
||||
bool has_minicpmv_projector = false;
|
||||
bool has_qwen2vl_merger = false;
|
||||
int minicpmv_version = 2;
|
||||
|
||||
struct clip_vision_model vision_model;
|
||||
|
@ -561,6 +568,7 @@ struct clip_ctx {
|
|||
float image_mean[3];
|
||||
float image_std[3];
|
||||
bool use_gelu = false;
|
||||
bool use_silu = false;
|
||||
int32_t ftype = 1;
|
||||
|
||||
bool has_class_embedding = true;
|
||||
|
@ -606,14 +614,26 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
image_size_height = imgs->data->ny;
|
||||
}
|
||||
}
|
||||
else if (ctx->has_qwen2vl_merger) {
|
||||
// use the image's native resolution when image is avaible
|
||||
if (is_inf) {
|
||||
// if (imgs->data->nx && imgs->data->ny) {
|
||||
image_size_width = imgs->data->nx;
|
||||
image_size_height = imgs->data->ny;
|
||||
}
|
||||
}
|
||||
const int patch_size = hparams.patch_size;
|
||||
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
|
||||
const int patches_w = image_size_width / patch_size;
|
||||
const int patches_h = image_size_height / patch_size;
|
||||
const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0);
|
||||
const int num_position_ids = ctx->has_qwen2vl_merger ? num_positions * 4 : num_positions;
|
||||
const int hidden_size = hparams.hidden_size;
|
||||
const int n_head = hparams.n_head;
|
||||
const int d_head = hidden_size / n_head;
|
||||
int n_layer = hparams.n_layer;
|
||||
const float eps = hparams.eps;
|
||||
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
||||
|
||||
const int batch_size = imgs->size;
|
||||
|
||||
|
@ -634,10 +654,30 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
ggml_set_name(inp_raw, "inp_raw");
|
||||
ggml_set_input(inp_raw);
|
||||
|
||||
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
|
||||
inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
|
||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
|
||||
if (ctx->has_qwen2vl_merger) {
|
||||
GGML_ASSERT(image_size_width % (patch_size * 2) == 0);
|
||||
GGML_ASSERT(image_size_height % (patch_size * 2) == 0);
|
||||
|
||||
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
inp = ggml_add(ctx0, inp, inp_1);
|
||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b]
|
||||
inp = ggml_reshape_4d(
|
||||
ctx0, inp,
|
||||
hidden_size * 2, patches_w / 2, patches_h, batch_size);
|
||||
inp = ggml_reshape_4d(
|
||||
ctx0, inp,
|
||||
hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
|
||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
|
||||
inp = ggml_reshape_3d(
|
||||
ctx0, inp,
|
||||
hidden_size, patches_w * patches_h, batch_size);
|
||||
}
|
||||
else {
|
||||
inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
|
||||
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
|
||||
}
|
||||
|
||||
if (ctx->has_patch_bias) {
|
||||
// inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
|
||||
|
@ -659,12 +699,14 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
|
||||
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
|
||||
ggml_set_name(positions, "positions");
|
||||
ggml_set_input(positions);
|
||||
|
||||
embeddings =
|
||||
ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
|
||||
if (!ctx->has_qwen2vl_merger) { // qwen2vl use rope position embedding
|
||||
embeddings =
|
||||
ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
|
||||
}
|
||||
|
||||
if (ctx->has_minicpmv_projector) {
|
||||
int pos_w = image_size_width/patch_size;
|
||||
|
@ -688,7 +730,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
}
|
||||
|
||||
// loop over layers
|
||||
if (ctx->has_minicpmv_projector) {
|
||||
if (ctx->has_minicpmv_projector || ctx->has_qwen2vl_merger) {
|
||||
// TODO: figure out why we doing thing in this way ???
|
||||
n_layer += 1;
|
||||
}
|
||||
for (int il = 0; il < n_layer - 1; il++) {
|
||||
|
@ -710,8 +753,13 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
struct ggml_tensor * Q =
|
||||
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
|
||||
|
||||
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
|
||||
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
|
||||
if (ctx->has_qwen2vl_merger) {
|
||||
Q = ggml_rope_multi(
|
||||
ctx0, Q, positions, nullptr,
|
||||
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
||||
}
|
||||
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
|
||||
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
|
||||
Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
|
||||
|
||||
|
@ -719,6 +767,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
|
||||
|
||||
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
|
||||
if (ctx->has_qwen2vl_merger) {
|
||||
K = ggml_rope_multi(
|
||||
ctx0, K, positions, nullptr,
|
||||
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
|
||||
}
|
||||
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
|
||||
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
|
||||
|
||||
|
@ -758,6 +811,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
|
||||
if (ctx->use_gelu) {
|
||||
cur = ggml_gelu_inplace(ctx0, cur);
|
||||
} else if (ctx->use_silu) {
|
||||
cur = ggml_silu_inplace(ctx0, cur);
|
||||
} else {
|
||||
cur = ggml_gelu_quick_inplace(ctx0, cur);
|
||||
}
|
||||
|
@ -769,6 +824,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
cur = ggml_add(ctx0, embeddings, cur);
|
||||
|
||||
embeddings = cur;
|
||||
|
||||
}
|
||||
|
||||
// post-layernorm
|
||||
|
@ -1030,6 +1086,19 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
|
||||
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
||||
|
||||
// GELU activation
|
||||
embeddings = ggml_gelu(ctx0, embeddings);
|
||||
|
||||
// Second linear layer
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
|
||||
}
|
||||
|
||||
// build the graph
|
||||
ggml_build_forward_expand(gf, embeddings);
|
||||
|
@ -1206,6 +1275,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx);
|
||||
}
|
||||
|
||||
idx = gguf_find_key(ctx, KEY_HAS_QWEN2VL_MERGER);
|
||||
if (idx != -1) {
|
||||
new_clip->has_qwen2vl_merger = gguf_get_val_bool(ctx, idx);
|
||||
}
|
||||
// GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search
|
||||
|
||||
GGML_ASSERT(new_clip->has_vision_encoder);
|
||||
|
@ -1214,6 +1287,13 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
idx = get_key_idx(ctx, KEY_USE_GELU);
|
||||
new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
|
||||
|
||||
try {
|
||||
idx = get_key_idx(ctx, KEY_USE_SILU);
|
||||
new_clip->use_silu = gguf_get_val_bool(ctx, idx);
|
||||
} catch (std::runtime_error & /*e*/) {
|
||||
new_clip->use_silu = false;
|
||||
}
|
||||
|
||||
if (verbosity >= 1) {
|
||||
LOG_INF("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder);
|
||||
LOG_INF("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
|
||||
|
@ -1389,11 +1469,16 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
}
|
||||
|
||||
try {
|
||||
vision_model.patch_embeddings = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
|
||||
vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
|
||||
vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v"));
|
||||
} catch(const std::exception& /*e*/) {
|
||||
LOG_ERR("%s: failed to load vision model tensors\n", __func__);
|
||||
}
|
||||
try {
|
||||
vision_model.patch_embeddings_1 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD_1);
|
||||
} catch(const std::exception& /*e*/) {
|
||||
new_clip->has_qwen2vl_merger = false;
|
||||
}
|
||||
|
||||
// LLaVA projection
|
||||
if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
||||
|
@ -1481,6 +1566,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
vision_model.mm_model_ln_post_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "weight"));
|
||||
vision_model.mm_model_ln_post_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "bias"));
|
||||
}
|
||||
else if (new_clip->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight"));
|
||||
vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias"));
|
||||
vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight"));
|
||||
vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias"));
|
||||
}
|
||||
else {
|
||||
std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type];
|
||||
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
|
||||
|
@ -1519,6 +1610,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
new_clip->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_clip->backend));
|
||||
clip_image_f32_batch batch;
|
||||
batch.size = 1;
|
||||
batch.data = nullptr;
|
||||
ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch, nullptr, false);
|
||||
ggml_gallocr_reserve(new_clip->compute_alloc, gf);
|
||||
size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0);
|
||||
|
@ -1532,6 +1624,10 @@ void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size
|
|||
ctx_clip->load_image_size = load_image_size;
|
||||
}
|
||||
|
||||
struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip) {
|
||||
return ctx_clip->load_image_size;
|
||||
}
|
||||
|
||||
struct clip_image_size * clip_image_size_init() {
|
||||
struct clip_image_size * load_image_size = new struct clip_image_size();
|
||||
load_image_size->width = 448;
|
||||
|
@ -1984,6 +2080,23 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
|
|||
}
|
||||
return true;
|
||||
}
|
||||
else if (ctx->has_qwen2vl_merger) {
|
||||
clip_image_u8 * resized = clip_image_u8_init();
|
||||
auto patch_size = clip_patch_size(ctx) * 2;
|
||||
int nx = ceil((float)img->nx / patch_size) * patch_size;
|
||||
int ny = ceil((float)img->ny / patch_size) * patch_size;
|
||||
bicubic_resize(*img, *resized, nx, ny);
|
||||
|
||||
res_imgs->data = new clip_image_f32[1];
|
||||
// clip_image_f32 * res = clip_image_f32_init();
|
||||
normalize_image_u8_to_f32(resized, res_imgs->data, ctx->image_mean, ctx->image_std);
|
||||
// res_imgs->data[0] = *res;
|
||||
res_imgs->size = 1;
|
||||
|
||||
// clip_image_f32_free(res);
|
||||
clip_image_u8_free(resized);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool pad_to_square = true;
|
||||
if (!ctx->has_vision_encoder) {
|
||||
|
@ -2173,6 +2286,13 @@ size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
|
|||
return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
|
||||
}
|
||||
|
||||
size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w) {
|
||||
clip_image_f32 img;
|
||||
img.nx = img_w;
|
||||
img.ny = img_h;
|
||||
return clip_n_patches_by_img(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float);
|
||||
}
|
||||
|
||||
int32_t clip_image_size(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams.image_size;
|
||||
}
|
||||
|
@ -2194,6 +2314,13 @@ const int32_t * clip_image_grid(const struct clip_ctx * ctx) {
|
|||
}
|
||||
|
||||
int clip_n_patches(const struct clip_ctx * ctx) {
|
||||
clip_image_f32 img;
|
||||
img.nx = ctx->vision_model.hparams.image_size;
|
||||
img.ny = ctx->vision_model.hparams.image_size;
|
||||
return clip_n_patches_by_img(ctx, &img);
|
||||
}
|
||||
|
||||
int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
|
||||
const auto & params = ctx->vision_model.hparams;
|
||||
|
||||
int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
|
||||
|
@ -2207,6 +2334,11 @@ int clip_n_patches(const struct clip_ctx * ctx) {
|
|||
else if (ctx->minicpmv_version == 3) {
|
||||
n_patches = 64;
|
||||
}
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
int patch_size = params.patch_size * 2;
|
||||
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
|
||||
int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
|
||||
n_patches = x_patch * y_patch;
|
||||
}
|
||||
|
||||
return n_patches;
|
||||
|
@ -2335,7 +2467,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
const int image_size = hparams.image_size;
|
||||
int image_size_width = image_size;
|
||||
int image_size_height = image_size;
|
||||
if (ctx->has_minicpmv_projector) {
|
||||
if (ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger) {
|
||||
image_size_width = imgs->data[0].nx;
|
||||
image_size_height = imgs->data[0].ny;
|
||||
}
|
||||
|
@ -2355,7 +2487,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
for (size_t i = 0; i < imgs->size; i++) {
|
||||
const int nx = imgs->data[i].nx;
|
||||
const int ny = imgs->data[i].ny;
|
||||
if (!ctx->has_minicpmv_projector) {
|
||||
if (!(ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger)) {
|
||||
GGML_ASSERT(nx == image_size && ny == image_size);
|
||||
}
|
||||
|
||||
|
@ -2413,9 +2545,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
|
||||
|
||||
float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
|
||||
for(int i=0;i<pos_w * pos_h;++i){
|
||||
for(int j=0;j<embed_dim;++j){
|
||||
pos_embed_data[i*embed_dim+j]=pos_embed_t[i][j];
|
||||
for(int i=0;i < pos_w * pos_h; ++i){
|
||||
for(int j=0; j < embed_dim; ++j){
|
||||
pos_embed_data[i * embed_dim + j] = pos_embed_t[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2435,7 +2567,34 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
}
|
||||
}
|
||||
|
||||
{
|
||||
if (ctx->has_qwen2vl_merger) {
|
||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||
|
||||
const int pw = image_size_width / patch_size;
|
||||
const int ph = image_size_height / patch_size;
|
||||
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
||||
|
||||
int ptr = 0;
|
||||
for (int y = 0; y < ph; y+=2)
|
||||
{
|
||||
for (int x = 0; x < pw; x+=2)
|
||||
{
|
||||
for (int dy = 0; dy < 2; dy++) {
|
||||
for (int dx = 0; dx < 2; dx++) {
|
||||
positions_data[ptr] = y + dy;
|
||||
positions_data[num_patches + ptr] = x + dx;
|
||||
positions_data[num_patches * 2 + ptr] = y + dy;
|
||||
positions_data[num_patches * 3 + ptr] = x + dx;
|
||||
ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
|
||||
free(positions_data);
|
||||
}
|
||||
else {
|
||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||
|
||||
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
||||
|
@ -2444,16 +2603,16 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
}
|
||||
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
|
||||
free(positions_data);
|
||||
}
|
||||
|
||||
{
|
||||
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
|
||||
int* patches_data = (int*)malloc(ggml_nbytes(patches));
|
||||
for (int i = 0; i < num_patches; i++) {
|
||||
patches_data[i] = i + 1;
|
||||
{
|
||||
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
|
||||
int* patches_data = (int*)malloc(ggml_nbytes(patches));
|
||||
for (int i = 0; i < num_patches; i++) {
|
||||
patches_data[i] = i + 1;
|
||||
}
|
||||
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
|
||||
free(patches_data);
|
||||
}
|
||||
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
|
||||
free(patches_data);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2626,6 +2785,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
|||
return 3584;
|
||||
}
|
||||
}
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
return ctx->vision_model.mm_1_b->ne[0];
|
||||
}
|
||||
|
||||
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
|
||||
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
|
||||
|
@ -2637,3 +2799,21 @@ int clip_is_minicpmv(const struct clip_ctx * ctx) {
|
|||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
|
||||
return ctx->has_qwen2vl_merger;
|
||||
}
|
||||
|
||||
|
||||
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
|
||||
clip_image_f32 clip_img;
|
||||
clip_img.buf.resize(h * w * 3);
|
||||
for (int i = 0; i < h*w*3; i++)
|
||||
{
|
||||
clip_img.buf[i] = img[i];
|
||||
}
|
||||
clip_img.nx = w;
|
||||
clip_img.ny = h;
|
||||
clip_image_encode(ctx, n_threads, &clip_img, vec);
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -45,6 +45,7 @@ CLIP_API struct clip_ctx * clip_model_load_cpu(const char * fname, int verbosity
|
|||
CLIP_API void clip_free(struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx);
|
||||
CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w);
|
||||
|
||||
CLIP_API int32_t clip_image_size (const struct clip_ctx * ctx);
|
||||
CLIP_API int32_t clip_patch_size (const struct clip_ctx * ctx);
|
||||
|
@ -55,11 +56,13 @@ CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
|
|||
|
||||
CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
|
||||
CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx);
|
||||
CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
|
||||
CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img);
|
||||
CLIP_API int clip_n_mmproj_embd (const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
|
||||
CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
|
||||
CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip);
|
||||
|
||||
CLIP_API struct clip_image_size * clip_image_size_init();
|
||||
CLIP_API struct clip_image_u8 * clip_image_u8_init ();
|
||||
|
@ -86,6 +89,9 @@ CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, cons
|
|||
CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype);
|
||||
|
||||
CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
|
||||
CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
|
||||
|
||||
CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -259,25 +259,33 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
|||
|
||||
const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip);
|
||||
|
||||
if (clip_is_minicpmv(ctx_clip)) {
|
||||
if (clip_is_minicpmv(ctx_clip) || clip_is_qwen2vl(ctx_clip)) {
|
||||
std::vector<float *> image_embd_v;
|
||||
image_embd_v.resize(img_res_v.size);
|
||||
struct clip_image_size * load_image_size = clip_image_size_init();
|
||||
|
||||
for (size_t i = 0; i < img_res_v.size; i++) {
|
||||
const int64_t t_img_enc_step_start_us = ggml_time_us();
|
||||
image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip));
|
||||
image_embd_v[i] = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny));
|
||||
int patch_size=14;
|
||||
load_image_size->width = img_res_v.data[i].nx;
|
||||
load_image_size->height = img_res_v.data[i].ny;
|
||||
clip_add_load_image_size(ctx_clip, load_image_size);
|
||||
|
||||
bool encoded = false;
|
||||
int has_minicpmv_projector = clip_is_minicpmv(ctx_clip);
|
||||
if (has_minicpmv_projector == 2) {
|
||||
encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
|
||||
}
|
||||
else if (has_minicpmv_projector == 3) {
|
||||
if (clip_is_qwen2vl(ctx_clip)) {
|
||||
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
|
||||
}
|
||||
else {
|
||||
int has_minicpmv_projector = clip_is_minicpmv(ctx_clip);
|
||||
if (has_minicpmv_projector == 2) {
|
||||
encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
|
||||
}
|
||||
else if (has_minicpmv_projector == 3) {
|
||||
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
|
||||
}
|
||||
}
|
||||
|
||||
if (!encoded) {
|
||||
LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size);
|
||||
return false;
|
||||
|
@ -290,8 +298,11 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
|||
|
||||
int n_img_pos_out = 0;
|
||||
for (size_t i = 0; i < image_embd_v.size(); i++) {
|
||||
std::memcpy(image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip), image_embd_v[i], clip_embd_nbytes(ctx_clip));
|
||||
n_img_pos_out += clip_n_patches(ctx_clip);
|
||||
std::memcpy(
|
||||
image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip),
|
||||
image_embd_v[i],
|
||||
clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny));
|
||||
n_img_pos_out += clip_n_patches_by_img(ctx_clip, &img_res_v.data[i]);
|
||||
}
|
||||
*n_img_pos = n_img_pos_out;
|
||||
for (size_t i = 0; i < image_embd_v.size(); i++) {
|
||||
|
@ -387,7 +398,13 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
|
|||
if (clip_is_minicpmv(ctx_clip)) {
|
||||
num_max_patches = 10;
|
||||
}
|
||||
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*num_max_patches); // TODO: base on gridsize/llava model
|
||||
float * image_embd;
|
||||
if (clip_is_qwen2vl(ctx_clip)) {
|
||||
// qwen2vl don't split image into chunks, so `num_max_patches` is not needed.
|
||||
image_embd = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, img->nx, img->ny));
|
||||
} else {
|
||||
image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*num_max_patches); // TODO: base on gridsize/llava model
|
||||
}
|
||||
if (!image_embd) {
|
||||
LOG_ERR("Unable to allocate memory for image embeddings\n");
|
||||
return false;
|
||||
|
|
158
examples/llava/qwen2_vl_surgery.py
Normal file
158
examples/llava/qwen2_vl_surgery.py
Normal file
|
@ -0,0 +1,158 @@
|
|||
import argparse
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from gguf import *
|
||||
from transformers import (
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Qwen2VLProcessor,
|
||||
AutoProcessor,
|
||||
Qwen2VLConfig
|
||||
)
|
||||
|
||||
|
||||
VISION = "clip.vision"
|
||||
|
||||
|
||||
def k(raw_key: str, arch: str) -> str:
|
||||
return raw_key.format(arch=arch)
|
||||
|
||||
|
||||
def to_gguf_name(name: str) -> str:
|
||||
og = name
|
||||
name = name.replace("text_model", "t").replace("vision_model", "v")
|
||||
name = name.replace("blocks", "blk").replace("embeddings.", "")
|
||||
name = name.replace("attn.", "attn_")
|
||||
name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.")
|
||||
# name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln")
|
||||
name = name.replace("norm1", "ln1").replace("norm2", "ln2")
|
||||
name = name.replace("merger.mlp", 'mm')
|
||||
print(f"[to_gguf_name] {og} --> {name}")
|
||||
return name
|
||||
|
||||
|
||||
def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]:
|
||||
vision_model = qwen2vl.visual
|
||||
tensor_map = {}
|
||||
for name, ten in vision_model.state_dict().items():
|
||||
ten = ten.numpy()
|
||||
if 'qkv' in name:
|
||||
if ten.ndim == 2: # weight
|
||||
c3, _ = ten.shape
|
||||
else: # bias
|
||||
c3 = ten.shape[0]
|
||||
assert c3 % 3 == 0
|
||||
c = c3 // 3
|
||||
wq = ten[:c]
|
||||
wk = ten[c: c * 2]
|
||||
wv = ten[c * 2:]
|
||||
tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq
|
||||
tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk
|
||||
tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv
|
||||
elif 'merger' in name:
|
||||
if name.endswith("ln_q.weight"):
|
||||
tensor_map['v.post_ln.weight'] = ten
|
||||
elif name.endswith("ln_q.bias"):
|
||||
tensor_map['v.post_ln.bias'] = ten
|
||||
else:
|
||||
# "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias"
|
||||
tensor_map[to_gguf_name(name)] = ten
|
||||
elif 'patch_embed.proj.weight' in name:
|
||||
# NOTE: split Conv3D into Conv2Ds
|
||||
c1, c2, kt, kh, kw = ten.shape
|
||||
assert kt == 2, "Current implmentation only support temporal_patch_size of 2"
|
||||
tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...]
|
||||
tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...]
|
||||
else:
|
||||
tensor_map[to_gguf_name(f"vision_model.{name}")] = ten
|
||||
|
||||
for new_name, ten in tensor_map.items():
|
||||
if ten.ndim <= 1 or new_name.endswith("_norm.weight"):
|
||||
tensor_map[new_name] = ten.astype(np.float32)
|
||||
else:
|
||||
tensor_map[new_name] = ten.astype(dtype)
|
||||
tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder
|
||||
return tensor_map
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.data_type == 'fp32':
|
||||
dtype = torch.float32
|
||||
np_dtype = np.float32
|
||||
ftype = 0
|
||||
elif args.data_type == 'fp16':
|
||||
dtype = torch.float32
|
||||
np_dtype = np.float16
|
||||
ftype = 1
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
model_name = args.model_name
|
||||
print("model_name: ", model_name)
|
||||
qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_name, torch_dtype=dtype, device_map="cpu"
|
||||
)
|
||||
cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
|
||||
vcfg = cfg.vision_config
|
||||
|
||||
if os.path.isdir(model_name):
|
||||
if model_name.endswith(os.sep):
|
||||
model_name = model_name[:-1]
|
||||
model_name = os.path.basename(model_name)
|
||||
fname_out = f"{model_name.replace('/', '-').lower()}-vision.gguf"
|
||||
|
||||
fout = GGUFWriter(path=fname_out, arch="clip")
|
||||
fout.add_description("image encoder for Qwen2VL")
|
||||
|
||||
fout.add_file_type(ftype)
|
||||
fout.add_bool("clip.has_text_encoder", False)
|
||||
fout.add_bool("clip.has_vision_encoder", True)
|
||||
fout.add_bool("clip.has_qwen2vl_merger", True)
|
||||
fout.add_string("clip.projector_type", "qwen2vl_merger")
|
||||
|
||||
print(cfg.vision_config)
|
||||
if 'silu' in cfg.vision_config.hidden_act.lower():
|
||||
fout.add_bool("clip.use_silu", True)
|
||||
fout.add_bool("clip.use_gelu", False)
|
||||
elif 'gelu' in cfg.vision_config.hidden_act.lower():
|
||||
fout.add_bool("clip.use_silu", False)
|
||||
fout.add_bool("clip.use_gelu", 'quick' not in cfg.vision_config.hidden_act.lower())
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
tensor_map = find_vision_tensors(qwen2vl, np_dtype)
|
||||
for name, data in tensor_map.items():
|
||||
fout.add_tensor(name, data)
|
||||
|
||||
fout.add_uint32("clip.vision.patch_size", vcfg.patch_size)
|
||||
fout.add_uint32("clip.vision.image_size", 14 * 40) # some reasonable size that is divable by (14*2)
|
||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
|
||||
fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
|
||||
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads)
|
||||
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
|
||||
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth)
|
||||
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 0) # not sure what this does, put 0 here as a placeholder
|
||||
fout.add_name(model_name)
|
||||
"""
|
||||
HACK: Since vision rope related parameter aren't stored in the `Qwen2VLConfig,
|
||||
it will be hardcoded in the `clip_image_build_graph` from `clip.cpp`.
|
||||
"""
|
||||
|
||||
processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_name)
|
||||
fout.add_array("clip.vision.image_mean", processor.image_processor.image_mean) # type: ignore[reportAttributeAccessIssue]
|
||||
fout.add_array("clip.vision.image_std", processor.image_processor.image_std) # type: ignore[reportAttributeAccessIssue]
|
||||
|
||||
fout.write_header_to_file()
|
||||
fout.write_kv_data_to_file()
|
||||
fout.write_tensors_to_file()
|
||||
fout.close()
|
||||
print("save model as: ", fname_out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct")
|
||||
parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
581
examples/llava/qwen2vl-cli.cpp
Normal file
581
examples/llava/qwen2vl-cli.cpp
Normal file
|
@ -0,0 +1,581 @@
|
|||
#include "arg.h"
|
||||
#include "base64.hpp"
|
||||
#include "log.h"
|
||||
#include "common.h"
|
||||
#include "sampling.h"
|
||||
#include "clip.h"
|
||||
#include "llava.h"
|
||||
#include "llama.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#ifdef GGML_USE_CUDA
|
||||
#include "ggml-cuda.h"
|
||||
#endif
|
||||
#ifdef NDEBUG
|
||||
#include "ggml-alloc.h"
|
||||
#include "ggml-backend.h"
|
||||
#endif
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
|
||||
|
||||
static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed,
|
||||
int n_batch, int * n_past, int * st_pos_id, struct clip_image_size * image_size) {
|
||||
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
|
||||
const int patch_size = 14 * 2;
|
||||
const int ph = image_size->height / patch_size + (image_size->height % patch_size > 0);
|
||||
const int pw = image_size->width / patch_size + (image_size->width % patch_size > 0);
|
||||
auto img_tokens = image_embed->n_image_pos;
|
||||
// llama_pos mrope_pos[img_tokens * 4];
|
||||
std::vector<llama_pos> mrope_pos;
|
||||
mrope_pos.resize(img_tokens * 4);
|
||||
|
||||
for (int y = 0; y < ph; y++)
|
||||
{
|
||||
for (int x = 0; x < pw; x++)
|
||||
{
|
||||
int i = y * pw + x;
|
||||
mrope_pos[i] = *st_pos_id;
|
||||
mrope_pos[i + img_tokens] = *st_pos_id + y;
|
||||
mrope_pos[i + img_tokens * 2] = *st_pos_id + x;
|
||||
mrope_pos[i + img_tokens * 3] = 0;
|
||||
}
|
||||
}
|
||||
*st_pos_id += std::max(pw, ph);
|
||||
|
||||
int processed = 0;
|
||||
std::vector<llama_pos> batch_mrope_pos;
|
||||
batch_mrope_pos.resize(img_tokens * 4);
|
||||
|
||||
for (int i = 0; i < img_tokens; i += n_batch) {
|
||||
int n_eval = img_tokens - i;
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
|
||||
// llama_pos batch_mrope_pos[n_eval * 4];
|
||||
std::fill(batch_mrope_pos.begin(), batch_mrope_pos.end(), 0);
|
||||
memcpy(batch_mrope_pos.data(), &mrope_pos[processed], n_eval * sizeof(llama_pos));
|
||||
memcpy(&batch_mrope_pos[n_eval * 1], &mrope_pos[img_tokens * 1 + processed], n_eval * sizeof(llama_pos));
|
||||
memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos));
|
||||
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
|
||||
|
||||
llama_batch batch = {
|
||||
int32_t(n_eval), // n_tokens
|
||||
nullptr, // token
|
||||
(image_embed->embed+i*n_embd), // embed
|
||||
batch_mrope_pos.data(), // pos
|
||||
nullptr, // n_seq_id
|
||||
nullptr, // seq_id
|
||||
nullptr, // logits
|
||||
};
|
||||
|
||||
if (llama_decode(ctx_llama, batch)) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
*n_past += n_eval;
|
||||
processed += n_eval;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past, int * st_pos_id) {
|
||||
int N = (int) tokens.size();
|
||||
std::vector<llama_pos> pos;
|
||||
for (int i = 0; i < N; i += n_batch) {
|
||||
int n_eval = (int) tokens.size() - i;
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
auto batch = llama_batch_get_one(&tokens[i], n_eval);
|
||||
// TODO: add mrope pos ids somewhere else
|
||||
pos.resize(batch.n_tokens * 4);
|
||||
std::fill(pos.begin(), pos.end(), 0);
|
||||
for (int j = 0; j < batch.n_tokens * 3; j ++) {
|
||||
pos[j] = *st_pos_id + (j % batch.n_tokens);
|
||||
}
|
||||
batch.pos = pos.data();
|
||||
|
||||
if (llama_decode(ctx_llama, batch)) {
|
||||
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
||||
return false;
|
||||
}
|
||||
*n_past += n_eval;
|
||||
*st_pos_id += n_eval;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past, int * st_pos_id) {
|
||||
std::vector<llama_token> tokens;
|
||||
tokens.push_back(id);
|
||||
return eval_tokens(ctx_llama, tokens, 1, n_past, st_pos_id);
|
||||
}
|
||||
|
||||
static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, int * st_pos_id, bool add_bos){
|
||||
std::string str2 = str;
|
||||
std::vector<llama_token> embd_inp = common_tokenize(ctx_llama, str2, add_bos, true);
|
||||
eval_tokens(ctx_llama, embd_inp, n_batch, n_past, st_pos_id);
|
||||
return true;
|
||||
}
|
||||
|
||||
static const char * sample(struct common_sampler * smpl,
|
||||
struct llama_context * ctx_llama,
|
||||
int * n_past, int * st_pos_id) {
|
||||
const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
|
||||
common_sampler_accept(smpl, id, true);
|
||||
static std::string ret;
|
||||
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
|
||||
ret = "</s>";
|
||||
} else {
|
||||
ret = common_token_to_piece(ctx_llama, id);
|
||||
}
|
||||
eval_id(ctx_llama, id, n_past, st_pos_id);
|
||||
return ret.c_str();
|
||||
}
|
||||
|
||||
static const char* IMG_BASE64_TAG_BEGIN = "<img src=\"data:image/jpeg;base64,";
|
||||
static const char* IMG_BASE64_TAG_END = "\">";
|
||||
|
||||
static void find_image_tag_in_prompt(const std::string& prompt, size_t& begin_out, size_t& end_out) {
|
||||
begin_out = prompt.find(IMG_BASE64_TAG_BEGIN);
|
||||
end_out = prompt.find(IMG_BASE64_TAG_END, (begin_out == std::string::npos) ? 0UL : begin_out);
|
||||
}
|
||||
|
||||
static bool prompt_contains_image(const std::string& prompt) {
|
||||
size_t begin, end;
|
||||
find_image_tag_in_prompt(prompt, begin, end);
|
||||
return (begin != std::string::npos);
|
||||
}
|
||||
|
||||
// replaces the base64 image tag in the prompt with `replacement`
|
||||
static llava_image_embed * llava_image_embed_make_with_prompt_base64(struct clip_ctx * ctx_clip, int n_threads, const std::string& prompt) {
|
||||
size_t img_base64_str_start, img_base64_str_end;
|
||||
find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end);
|
||||
if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) {
|
||||
LOG_ERR("%s: invalid base64 image tag. must be %s<base64 byte string>%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN);
|
||||
auto base64_bytes_count = img_base64_str_end - base64_bytes_start;
|
||||
auto base64_str = prompt.substr(base64_bytes_start, base64_bytes_count );
|
||||
|
||||
auto required_bytes = base64::required_encode_size(base64_str.size());
|
||||
auto img_bytes = std::vector<unsigned char>(required_bytes);
|
||||
base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin());
|
||||
|
||||
auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, img_bytes.data(), img_bytes.size());
|
||||
if (!embed) {
|
||||
LOG_ERR("%s: could not load image from base64 string.\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return embed;
|
||||
}
|
||||
|
||||
static std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") {
|
||||
size_t begin, end;
|
||||
find_image_tag_in_prompt(prompt, begin, end);
|
||||
if (begin == std::string::npos || end == std::string::npos) {
|
||||
return prompt;
|
||||
}
|
||||
auto pre = prompt.substr(0, begin);
|
||||
auto post = prompt.substr(end + strlen(IMG_BASE64_TAG_END));
|
||||
return pre + replacement + post;
|
||||
}
|
||||
|
||||
struct llava_context {
|
||||
struct clip_ctx * ctx_clip = NULL;
|
||||
struct llama_context * ctx_llama = NULL;
|
||||
struct llama_model * model = NULL;
|
||||
};
|
||||
|
||||
static void print_usage(int, char ** argv) {
|
||||
LOG("\n example usage:\n");
|
||||
LOG("\n %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> --image <path/to/another/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
|
||||
LOG("\n note: a lower temperature value like 0.1 is recommended for better quality.\n");
|
||||
}
|
||||
|
||||
static struct llava_image_embed * load_image(llava_context * ctx_llava, common_params * params, const std::string & fname) {
|
||||
|
||||
// load and preprocess the image
|
||||
llava_image_embed * embed = NULL;
|
||||
auto prompt = params->prompt;
|
||||
if (prompt_contains_image(prompt)) {
|
||||
if (!params->image.empty()) {
|
||||
LOG_INF("using base64 encoded image instead of command line image path\n");
|
||||
}
|
||||
embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->cpuparams.n_threads, prompt);
|
||||
if (!embed) {
|
||||
LOG_ERR("%s: can't load image from prompt\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
params->prompt = remove_image_from_prompt(prompt);
|
||||
} else {
|
||||
embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->cpuparams.n_threads, fname.c_str());
|
||||
if (!embed) {
|
||||
fprintf(stderr, "%s: is %s really an image file?\n", __func__, fname.c_str());
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
return embed;
|
||||
}
|
||||
|
||||
static void process_prompt(struct llava_context * ctx_llava, struct llava_image_embed * image_embed, common_params * params, const std::string & prompt) {
|
||||
int n_past = 0;
|
||||
int cur_pos_id = 0;
|
||||
|
||||
const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
|
||||
|
||||
std::string system_prompt, user_prompt;
|
||||
size_t image_pos = prompt.find("<|vision_start|>");
|
||||
if (image_pos != std::string::npos) {
|
||||
// new templating mode: Provide the full prompt including system message and use <image> as a placeholder for the image
|
||||
system_prompt = prompt.substr(0, image_pos);
|
||||
user_prompt = prompt.substr(image_pos + std::string("<|vision_pad|>").length());
|
||||
LOG_INF("system_prompt: %s\n", system_prompt.c_str());
|
||||
if (params->verbose_prompt) {
|
||||
auto tmp = common_tokenize(ctx_llava->ctx_llama, system_prompt, true, true);
|
||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||
LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
|
||||
}
|
||||
}
|
||||
LOG_INF("user_prompt: %s\n", user_prompt.c_str());
|
||||
if (params->verbose_prompt) {
|
||||
auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
|
||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||
LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// llava-1.5 native mode
|
||||
system_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|>";
|
||||
user_prompt = "<|vision_end|>" + prompt + "<|im_end|>\n<|im_start|>assistant\n";
|
||||
if (params->verbose_prompt) {
|
||||
auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
|
||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||
LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, &cur_pos_id, true);
|
||||
if (image_embed != nullptr) {
|
||||
auto image_size = clip_get_load_image_size(ctx_llava->ctx_clip);
|
||||
qwen2vl_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past, &cur_pos_id, image_size);
|
||||
}
|
||||
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, &cur_pos_id, false);
|
||||
|
||||
// generate the response
|
||||
|
||||
LOG("\n");
|
||||
|
||||
struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling);
|
||||
if (!smpl) {
|
||||
LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::string response = "";
|
||||
for (int i = 0; i < max_tgt_len; i++) {
|
||||
const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past, &cur_pos_id);
|
||||
response += tmp;
|
||||
if (strcmp(tmp, "</s>") == 0) break;
|
||||
if (strstr(tmp, "###")) break; // Yi-VL behavior
|
||||
LOG("%s", tmp);
|
||||
if (strstr(response.c_str(), "<|im_end|>")) break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works)
|
||||
if (strstr(response.c_str(), "<|im_start|>")) break; // Yi-34B llava-1.6
|
||||
if (strstr(response.c_str(), "USER:")) break; // mistral llava-1.6
|
||||
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
common_sampler_free(smpl);
|
||||
LOG("\n");
|
||||
}
|
||||
|
||||
static struct llama_model * llava_init(common_params * params) {
|
||||
llama_backend_init();
|
||||
llama_numa_init(params->numa);
|
||||
|
||||
llama_model_params model_params = common_model_params_to_llama(*params);
|
||||
|
||||
llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params);
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: unable to load model\n" , __func__);
|
||||
return NULL;
|
||||
}
|
||||
return model;
|
||||
}
|
||||
|
||||
static struct llava_context * llava_init_context(common_params * params, llama_model * model) {
|
||||
const char * clip_path = params->mmproj.c_str();
|
||||
|
||||
auto prompt = params->prompt;
|
||||
if (prompt.empty()) {
|
||||
prompt = "describe the image in detail.";
|
||||
}
|
||||
|
||||
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
|
||||
|
||||
|
||||
llama_context_params ctx_params = common_context_params_to_llama(*params);
|
||||
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
|
||||
|
||||
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params);
|
||||
|
||||
if (ctx_llama == NULL) {
|
||||
LOG_ERR("%s: failed to create the llama_context\n" , __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
auto * ctx_llava = (struct llava_context *)malloc(sizeof(llava_context));
|
||||
|
||||
ctx_llava->ctx_llama = ctx_llama;
|
||||
ctx_llava->ctx_clip = ctx_clip;
|
||||
ctx_llava->model = model;
|
||||
return ctx_llava;
|
||||
}
|
||||
|
||||
static void llava_free(struct llava_context * ctx_llava) {
|
||||
if (ctx_llava->ctx_clip) {
|
||||
clip_free(ctx_llava->ctx_clip);
|
||||
ctx_llava->ctx_clip = NULL;
|
||||
}
|
||||
|
||||
llama_free(ctx_llava->ctx_llama);
|
||||
llama_free_model(ctx_llava->model);
|
||||
llama_backend_free();
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
|
||||
static void debug_test_mrope_2d() {
|
||||
// 1. Initialize backend
|
||||
ggml_backend_t backend = NULL;
|
||||
std::string backend_name = "";
|
||||
#ifdef GGML_USE_CUDA
|
||||
fprintf(stderr, "%s: using CUDA backend\n", __func__);
|
||||
backend = ggml_backend_cuda_init(0); // init device 0
|
||||
backend_name = "cuda";
|
||||
if (!backend) {
|
||||
fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
|
||||
}
|
||||
#endif
|
||||
// if there aren't GPU Backends fallback to CPU backend
|
||||
if (!backend) {
|
||||
backend = ggml_backend_cpu_init();
|
||||
backend_name = "cpu";
|
||||
}
|
||||
|
||||
// Calculate the size needed to allocate
|
||||
size_t ctx_size = 0;
|
||||
ctx_size += 2 * ggml_tensor_overhead(); // tensors
|
||||
// no need to allocate anything else!
|
||||
|
||||
// 2. Allocate `ggml_context` to store tensor data
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ ctx_size,
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true, // the tensors will be allocated later by ggml_backend_alloc_ctx_tensors()
|
||||
};
|
||||
struct ggml_context * ctx = ggml_init(params);
|
||||
|
||||
struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 128, 12, 30);
|
||||
ggml_set_name(inp_raw, "inp_raw");
|
||||
ggml_set_input(inp_raw);
|
||||
|
||||
struct ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 30 * 4);
|
||||
ggml_set_name(pos, "pos");
|
||||
ggml_set_input(pos);
|
||||
|
||||
std::vector<float> dummy_q;
|
||||
dummy_q.resize(128 * 12 * 30);
|
||||
std::fill(dummy_q.begin(), dummy_q.end(), 0.1);
|
||||
// memcpy(inp_raw->data, dummy_q.data(), 128 * 12 * 30 * ggml_element_size(inp_raw));
|
||||
|
||||
std::vector<int> pos_id;
|
||||
pos_id.resize(30 * 4);
|
||||
for (int i = 0; i < 30; i ++) {
|
||||
pos_id[i] = i;
|
||||
pos_id[i + 30] = i + 10;
|
||||
pos_id[i + 60] = i + 20;
|
||||
pos_id[i + 90] = i + 30;
|
||||
}
|
||||
int sections[4] = {32, 32, 0, 0};
|
||||
|
||||
// 4. Allocate a `ggml_backend_buffer` to store all tensors
|
||||
ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
|
||||
|
||||
// 5. Copy tensor data from main memory (RAM) to backend buffer
|
||||
ggml_backend_tensor_set(inp_raw, dummy_q.data(), 0, ggml_nbytes(inp_raw));
|
||||
ggml_backend_tensor_set(pos, pos_id.data(), 0, ggml_nbytes(pos));
|
||||
|
||||
// 6. Create a `ggml_cgraph` for mul_mat operation
|
||||
struct ggml_cgraph * gf = NULL;
|
||||
struct ggml_context * ctx_cgraph = NULL;
|
||||
|
||||
// create a temporally context to build the graph
|
||||
struct ggml_init_params params0 = {
|
||||
/*.mem_size =*/ ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
|
||||
};
|
||||
ctx_cgraph = ggml_init(params0);
|
||||
gf = ggml_new_graph(ctx_cgraph);
|
||||
|
||||
struct ggml_tensor * result0 = ggml_rope_multi(
|
||||
ctx_cgraph, inp_raw, pos, nullptr,
|
||||
128/2, sections, LLAMA_ROPE_TYPE_VISION, 32768, 1000000, 1,
|
||||
0, 1, 32, 1);
|
||||
|
||||
// Add "result" tensor and all of its dependencies to the cgraph
|
||||
ggml_build_forward_expand(gf, result0);
|
||||
|
||||
// 7. Create a `ggml_gallocr` for cgraph computation
|
||||
ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
|
||||
ggml_gallocr_alloc_graph(allocr, gf);
|
||||
|
||||
// 9. Run the computation
|
||||
int n_threads = 1; // Optional: number of threads to perform some operations with multi-threading
|
||||
if (ggml_backend_is_cpu(backend)) {
|
||||
ggml_backend_cpu_set_n_threads(backend, n_threads);
|
||||
}
|
||||
ggml_backend_graph_compute(backend, gf);
|
||||
|
||||
// 10. Retrieve results (output tensors)
|
||||
// in this example, output tensor is always the last tensor in the graph
|
||||
struct ggml_tensor * result = result0;
|
||||
// struct ggml_tensor * result = gf->nodes[gf->n_nodes - 1];
|
||||
float * result_data = (float *)malloc(ggml_nbytes(result));
|
||||
// because the tensor data is stored in device buffer, we need to copy it back to RAM
|
||||
ggml_backend_tensor_get(result, result_data, 0, ggml_nbytes(result));
|
||||
const std::string bin_file = "mrope_2d_" + backend_name +".bin";
|
||||
std::ofstream outFile(bin_file, std::ios::binary);
|
||||
|
||||
if (outFile.is_open()) {
|
||||
outFile.write(reinterpret_cast<const char*>(result_data), ggml_nbytes(result));
|
||||
outFile.close();
|
||||
std::cout << "Data successfully written to " + bin_file << std::endl;
|
||||
} else {
|
||||
std::cerr << "Error opening file!" << std::endl;
|
||||
}
|
||||
|
||||
free(result_data);
|
||||
// 11. Free memory and exit
|
||||
ggml_free(ctx_cgraph);
|
||||
ggml_gallocr_free(allocr);
|
||||
ggml_free(ctx);
|
||||
ggml_backend_buffer_free(buffer);
|
||||
ggml_backend_free(backend);
|
||||
}
|
||||
|
||||
static void debug_dump_img_embed(struct llava_context * ctx_llava) {
|
||||
int n_embd = llama_n_embd(llama_get_model(ctx_llava->ctx_llama));
|
||||
int ne = n_embd * 4;
|
||||
float vals[56 * 56 * 3];
|
||||
// float embd[ne];
|
||||
std::vector<float> embd;
|
||||
embd.resize(ne);
|
||||
|
||||
for (int i = 0; i < 56*56; i++)
|
||||
{
|
||||
for (int c = 0; c < 3; c++)
|
||||
vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56);
|
||||
}
|
||||
|
||||
clip_encode_float_image(ctx_llava->ctx_clip, 16, vals, 56, 56, embd.data());
|
||||
|
||||
std::ofstream outFile("img_embed.bin", std::ios::binary);
|
||||
if (outFile.is_open()) {
|
||||
outFile.write(reinterpret_cast<const char*>(embd.data()), ne * sizeof(float));
|
||||
|
||||
outFile.close();
|
||||
std::cout << "Data successfully written to mrope.bin" << std::endl;
|
||||
} else {
|
||||
std::cerr << "Error opening file!" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_time_init();
|
||||
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
if (params.mmproj.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) {
|
||||
print_usage(argc, argv);
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto * model = llava_init(¶ms);
|
||||
if (model == NULL) {
|
||||
fprintf(stderr, "%s: error: failed to init llava model\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (prompt_contains_image(params.prompt)) {
|
||||
auto * ctx_llava = llava_init_context(¶ms, model);
|
||||
|
||||
auto * image_embed = load_image(ctx_llava, ¶ms, "");
|
||||
|
||||
// process the prompt
|
||||
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
||||
|
||||
llama_perf_context_print(ctx_llava->ctx_llama);
|
||||
llava_image_embed_free(image_embed);
|
||||
ctx_llava->model = NULL;
|
||||
llava_free(ctx_llava);
|
||||
#ifndef NDEBUG
|
||||
} else if (params.image[0].empty()) {
|
||||
auto ctx_llava = llava_init_context(¶ms, model);
|
||||
|
||||
debug_test_mrope_2d();
|
||||
debug_dump_img_embed(ctx_llava);
|
||||
|
||||
llama_perf_context_print(ctx_llava->ctx_llama);
|
||||
ctx_llava->model = NULL;
|
||||
llava_free(ctx_llava);
|
||||
#endif
|
||||
} else {
|
||||
for (auto & image : params.image) {
|
||||
auto * ctx_llava = llava_init_context(¶ms, model);
|
||||
|
||||
auto * image_embed = load_image(ctx_llava, ¶ms, image);
|
||||
if (!image_embed) {
|
||||
LOG_ERR("%s: failed to load image %s. Terminating\n\n", __func__, image.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
// process the prompt
|
||||
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
||||
|
||||
llama_perf_context_print(ctx_llava->ctx_llama);
|
||||
llava_image_embed_free(image_embed);
|
||||
ctx_llava->model = NULL;
|
||||
llava_free(ctx_llava);
|
||||
}
|
||||
}
|
||||
|
||||
llama_free_model(model);
|
||||
|
||||
return 0;
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue