From 023f0076e07961d655c8d114e2e54426fc363101 Mon Sep 17 00:00:00 2001 From: HimariO Date: Sun, 20 Oct 2024 21:42:53 +0800 Subject: [PATCH] correcting vision-rope behavior, add the missing last layer back to ViT --- examples/llava/clip.cpp | 20 +++++-- examples/llava/clip.h | 2 + examples/llava/qwen2vl-cli.cpp | 103 ++++++++++++++++++++++++--------- ggml/include/ggml.h | 2 +- ggml/src/ggml.c | 77 ++++++++++++++++++------ src/llama.cpp | 6 +- 6 files changed, 155 insertions(+), 55 deletions(-) diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 0b648c5b7..2f302d935 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -623,13 +623,13 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 const int patches_w = image_size_width / patch_size; const int patches_h = image_size_height / patch_size; const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); - const int num_position_ids = ctx->has_qwen2vl_merger ? num_positions * 3 : num_positions; + const int num_position_ids = ctx->has_qwen2vl_merger ? num_positions * 4 : num_positions; const int hidden_size = hparams.hidden_size; const int n_head = hparams.n_head; const int d_head = hidden_size / n_head; int n_layer = hparams.n_layer; const float eps = hparams.eps; - int mrope_sections[3] = {d_head/4, d_head/4, 0}; + int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; const int batch_size = imgs->size; @@ -734,7 +734,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 } // loop over layers - if (ctx->has_minicpmv_projector) { + if (ctx->has_minicpmv_projector || ctx->has_qwen2vl_merger) { + // TODO: figure out why we doing thing in this way ??? n_layer += 1; } for (int il = 0; il < n_layer - 1; il++) { @@ -829,6 +830,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = cur; } + // ggml_build_forward_expand(gf, embeddings); // ggml_free(ctx0); @@ -2583,16 +2585,18 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima const int ph = image_size_height / patch_size; int* positions_data = (int*)malloc(ggml_nbytes(positions)); - int ptr = -1; + int ptr = 0; for (size_t y = 0; y < ph; y+=2) { for (size_t x = 0; x < pw; x+=2) { for (size_t dy = 0; dy < 2; dy++) { for (size_t dx = 0; dx < 2; dx++) { - positions_data[ptr++] = y + dy; + positions_data[ptr] = y + dy; positions_data[num_patches + ptr] = x + dx; - positions_data[num_patches * 2 + ptr] = 0; + positions_data[num_patches * 2 + ptr] = y + dy; + positions_data[num_patches * 3 + ptr] = x + dx; + ptr++; } } } @@ -2824,4 +2828,8 @@ bool tmp_clip_image_encode (struct clip_ctx * ctx, int n_threads, float * img, i // ctx->vision_model.hparams.image_size = h; clip_image_encode(ctx, n_threads, &clip_img, vec); return true; +} + +void tmp_clip_set_layers (struct clip_ctx * ctx, int layers) { + ctx->vision_model.hparams.n_layer = layers; } \ No newline at end of file diff --git a/examples/llava/clip.h b/examples/llava/clip.h index 9f75c67db..750a0438e 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -92,6 +92,8 @@ CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx); CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx); CLIP_API bool tmp_clip_image_encode (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); +CLIP_API void tmp_clip_set_layers (struct clip_ctx * ctx, int layers); + #ifdef __cplusplus } #endif diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index 1d35bf858..cfd6673b7 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -24,7 +24,8 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla const int ph = image_size->height / patch_size + (image_size->height % patch_size > 0); const int pw = image_size->width / patch_size + (image_size->width % patch_size > 0); auto img_tokens = image_embed->n_image_pos; - llama_pos mrope_pos[img_tokens * 3]; + llama_pos mrope_pos[img_tokens * 4]; + for (size_t y = 0; y < ph; y++) { for (size_t x = 0; x < pw; x++) @@ -33,6 +34,7 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla mrope_pos[i] = *st_pos_id; mrope_pos[i + img_tokens] = *st_pos_id + y; mrope_pos[i + img_tokens * 2] = *st_pos_id + x; + mrope_pos[i + img_tokens * 3] = 0; } } *st_pos_id += std::max(pw, ph); @@ -44,10 +46,11 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla n_eval = n_batch; } - llama_pos batch_mrope_pos[n_eval * 3]; + llama_pos batch_mrope_pos[n_eval * 4]; memcpy(batch_mrope_pos, &mrope_pos[processed], n_eval * sizeof(llama_pos)); - memcpy(&batch_mrope_pos[n_eval], &mrope_pos[img_tokens + processed], n_eval * sizeof(llama_pos)); + memcpy(&batch_mrope_pos[n_eval * 1], &mrope_pos[img_tokens * 1 + processed], n_eval * sizeof(llama_pos)); memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos)); + memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos)); llama_batch batch = { int32_t(n_eval), // n_tokens @@ -82,7 +85,8 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vectordata, dummy_q.data(), 128 * 12 * 30 * ggml_element_size(inp_raw)); - struct ggml_tensor * pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 30 * 3); + struct ggml_tensor * pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 30 * 4); ggml_set_name(pos, "pos"); ggml_set_input(pos); std::vector pos_id; - pos_id.resize(90); - for (int i = 0; i < 30; i ++) pos_id[i] = i; - for (int i = 30; i < 60; i ++) pos_id[i] = i - 30; - for (int i = 60; i < 90; i ++) pos_id[i] = i - 0; - memcpy(pos->data, pos_id.data(), 90 * ggml_element_size(pos)); + pos_id.resize(30 * 4); + for (int i = 0; i < 30; i ++) { + pos_id[i] = i; + pos_id[i + 30] = i + 10; + pos_id[i + 60] = i + 10; + pos_id[i + 90] = i + 10; + } + + memcpy(pos->data, pos_id.data(), 30 * 4 * ggml_element_size(pos)); - int sections[3] = {32, 32, 0}; + int sections[4] = {32, 32, 32, 32}; auto encode = ggml_mrope_ext( ctx0, inp_raw, pos, nullptr, 128/2, sections, LLAMA_ROPE_TYPE_NEOX, 32768, 1000000, 1, @@ -717,15 +725,16 @@ static void tmp_dump_img_embed(struct llava_context * ctx_llava, gpt_params * pa int ne = n_embd * 4; float vals[56 * 56 * 3]; float embd[ne]; - for (int i = 0; i < 3*56*56; i++) - { - vals[i] = 0.1; - } - // for (int i = 0; i < 56*56; i++) + // for (int i = 0; i < 3*56*56; i++) // { - // for (int c = 0; c < 3; c++) - // vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56); + // vals[i] = 0.1; // } + for (int i = 0; i < 56*56; i++) + { + for (int c = 0; c < 3; c++) + vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56); + } + // auto param = &ctx_llava->ctx_clip->vision_model.hparams; tmp_clip_image_encode(ctx_llava->ctx_clip, 16, vals, 56, 56, embd); @@ -760,25 +769,30 @@ static void tmp_dump_img_embed_from_file(struct llava_context * ctx_llava, gpt_p } static void tmp_dump_img_mid_embed(struct llava_context * ctx_llava, gpt_params * params) { + int layers = 2; // auto * image_embed = load_image(ctx_llava, params, "/home/ron/Downloads/gguf/dog.jpeg"); int n_embd = llama_n_embd(llama_get_model(ctx_llava->ctx_llama)); // int ne = n_embd * image_embed->n_image_pos; int ne = 1280 * 4 * 4; float vals[56 * 56 * 3]; float embd[ne]; - for (int i = 0; i < 3*56*56; i++) - { - vals[i] = 0.1; - } - // for (int i = 0; i < 56*56; i++) + + // for (int i = 0; i < 3*56*56; i++) // { - // for (int c = 0; c < 3; c++) - // vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56); + // vals[i] = 0.5; // } + for (int i = 0; i < 56*56; i++) + { + for (int c = 0; c < 3; c++) + vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56); + } // auto param = &ctx_llava->ctx_clip->vision_model.hparams; + + + // tmp_clip_set_layers(ctx_llava->ctx_clip, layers); tmp_clip_image_encode(ctx_llava->ctx_clip, 16, vals, 56, 56, embd); - std::ofstream outFile("img_layer_1_embed.bin", std::ios::binary); + std::ofstream outFile("img_layer_" + std::to_string(layers) + "_embed.bin", std::ios::binary); if (outFile.is_open()) { outFile.write(reinterpret_cast(embd), ne * sizeof(float)); @@ -819,6 +833,33 @@ static void tmp_dump_patch_embed(struct llava_context * ctx_llava, gpt_params * } } + +static llava_image_embed * tmp_load_img_embed() { + std::ifstream inputFile("/home/ron/Projects/llm2vec/hf_img_embed_f.bin", std::ios::binary); + + if (!inputFile) { + std::cerr << "Could not open the file!" << std::endl; + return NULL; + } + + // Determine the size of the file + inputFile.seekg(0, std::ios::end); + std::streamsize fileSize = inputFile.tellg(); + inputFile.seekg(0, std::ios::beg); + + static llava_image_embed * result = (llava_image_embed*)malloc(sizeof(llava_image_embed)); + result->embed = (float*)malloc(fileSize); + result->n_image_pos = 24 * 36 /4; + + // Assuming the binary file contains floating-point numbers (float) + std::size_t numElements = fileSize / sizeof(float); + inputFile.read(reinterpret_cast(result->embed), fileSize); + inputFile.close(); + + return result; +} + + /* ----------------------------------------------------------------------------------------------------------------- */ @@ -861,6 +902,16 @@ int main(int argc, char ** argv) { // This section is for testing LLM parts of the model during development phase! auto ctx_llava = llava_init_context(¶ms, model); + // { + // auto img_embed = tmp_load_img_embed(); + // struct clip_image_size * load_image_size = clip_image_size_init(); + // load_image_size->height = 336; + // load_image_size->width = 504; + // clip_add_load_image_size(ctx_llava->ctx_clip, load_image_size); + // process_prompt(ctx_llava, img_embed, ¶ms, params.prompt); + // llava_image_embed_free(img_embed); + // } + // process the prompt tmp_dump_img_embed(ctx_llava, ¶ms); // tmp_dump_img_embed_from_file(ctx_llava, ¶ms); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index ff833a4fc..5d01181dc 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1451,7 +1451,7 @@ extern "C" { struct ggml_tensor * b, struct ggml_tensor * c, int n_dims, - int sections[3], + int sections[4], int mode, int n_ctx_orig, float freq_base, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 98a1110d7..e589d6552 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3559,7 +3559,7 @@ struct ggml_tensor * ggml_mrope_ext( struct ggml_tensor * b, struct ggml_tensor * c, int n_dims, - int sections[3], + int sections[4], int mode, int n_ctx_orig, float freq_base, @@ -3573,7 +3573,7 @@ struct ggml_tensor * ggml_mrope_ext( GGML_ASSERT(ggml_is_vector(b)); GGML_ASSERT(b->type == GGML_TYPE_I32); - GGML_ASSERT(a->ne[2] * 3 == b->ne[0]); + GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token if (c) { GGML_ASSERT(c->type == GGML_TYPE_F32); @@ -3588,14 +3588,14 @@ struct ggml_tensor * ggml_mrope_ext( struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - int32_t params[11 + 3] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; + int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; memcpy(params + 5, &freq_base, sizeof(float)); memcpy(params + 6, &freq_scale, sizeof(float)); memcpy(params + 7, &ext_factor, sizeof(float)); memcpy(params + 8, &attn_factor, sizeof(float)); memcpy(params + 9, &beta_fast, sizeof(float)); memcpy(params + 10, &beta_slow, sizeof(float)); - memcpy(¶ms[11], sections, sizeof(int)*3); + memcpy(¶ms[11], sections, sizeof(int)*4); // memcpy(params + 11, sections, sizeof(int)*3); ggml_set_op_params(result, params, sizeof(params)); @@ -11238,15 +11238,17 @@ static void ggml_rope_cache_init( } static void ggml_mrope_cache_init( - float theta_base_t, float theta_base_h, float theta_base_w, int sections[3], bool indep_sects, + float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[3], bool indep_sects, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, float * cache, float sin_sign, float theta_scale) { // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py float theta_t = theta_base_t; float theta_h = theta_base_h; float theta_w = theta_base_w; - int sect_dims = sections[0] + sections[1] + sections[2]; - int prev_sector = -1; + float theta_e = theta_base_e; // extra position id for vision encoder + int sect_dims = sections[0] + sections[1] + sections[2] + sections[3]; + int sec_w = sections[1] + sections[0]; + GGML_ASSERT(sect_dims <= ne0); for (int64_t i0 = 0; i0 < ne0; i0 += 2) { const float ff = freq_factors ? freq_factors[i0/2] : 1.0f; @@ -11262,15 +11264,21 @@ static void ggml_mrope_cache_init( else if (sector == sections[1]) { theta_w = theta_base_w; } + else if (sector == sections[2]) { + theta_e = theta_base_e; + } } float theta = theta_t; - if (sector < sections[1] + sections[0] && sector >= sections[0]) { + if (sector >= sections[0] && sector < sec_w) { theta = theta_h; } - else if (sector >= sections[1] + sections[0]) { + else if (sector >= sec_w && sector < sec_w + sections[2]) { theta = theta_w; } + else if (sector >= sec_w + sections[2]) { + theta = theta_e; + } rope_yarn( theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] @@ -11280,7 +11288,7 @@ static void ggml_mrope_cache_init( theta_t *= theta_scale; theta_w *= theta_scale; theta_h *= theta_scale; - prev_sector = sector; + theta_e *= theta_scale; } } @@ -11304,7 +11312,7 @@ static void ggml_compute_forward_rope_f32( const struct ggml_tensor * src2 = dst->src[2]; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - int sections[3]; + int sections[4]; //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; @@ -11318,7 +11326,7 @@ static void ggml_compute_forward_rope_f32( memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int) * 3); + memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int) * 4); GGML_TENSOR_UNARY_OP_LOCALS @@ -11352,6 +11360,11 @@ static void ggml_compute_forward_rope_f32( const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_mrope = sections[0] > 0 || sections[1] > 0 || sections[2] > 0; + const bool is_vision = is_mrope && sections[3] > 0; + + if (is_vision) { + GGML_ASSERT(n_dims == ne0/2); + } const float * freq_factors = NULL; if (src2 != NULL) { @@ -11379,8 +11392,9 @@ static void ggml_compute_forward_rope_f32( const int64_t p_t = pos[i2]; const int64_t p_h = pos[i2 + ne2]; const int64_t p_w = pos[i2 + ne2 * 2]; + const int64_t p_e = pos[i2 + ne2 * 3]; ggml_mrope_cache_init( - p_t, p_h, p_w, sections, sections[2] == 0, + p_t, p_h, p_w, p_e, sections, sections[3] != 0, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); } @@ -11402,6 +11416,22 @@ static void ggml_compute_forward_rope_f32( dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[1] = x0*sin_theta + x1*cos_theta; } + } else if (is_vision){ + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; + + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims] = x0*sin_theta + x1*cos_theta; + } } else { for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { const int64_t ic = i0/2; @@ -11420,12 +11450,21 @@ static void ggml_compute_forward_rope_f32( } } - if (is_mrope) { - // fill the remain channels by repeating 0~n_dims channel - for (int64_t i0 = n_dims; i0 < ne0; i0 ++) { - float * dst_data_0 = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - dst_data[0] = dst_data_0[i0 % n_dims]; + if (is_vision) { + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { + const int64_t ic = i0/2; + + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims] = x0*sin_theta + x1*cos_theta; } } else { diff --git a/src/llama.cpp b/src/llama.cpp index aa09dc98c..fdc93dcd6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -12525,14 +12525,14 @@ struct llm_build_context { // inp_pos - contains the positions // struct ggml_tensor * inp_pos = build_inp_pos(); - lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens * 3); + lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens * 4); cb(lctx.inp_pos, "inp_pos", -1); ggml_set_input(lctx.inp_pos); struct ggml_tensor * inp_pos = lctx.inp_pos; // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - int sections[3] = {16, 24, 24}; // TODO: move this into gguf model file. + int sections[4] = {16, 24, 24, 0}; // TODO: move this into gguf model file. for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -16878,7 +16878,7 @@ static struct ggml_cgraph * llama_build_graph( } break; case LLM_ARCH_QWEN2VL: { - lctx.n_pos_per_token = 3; + lctx.n_pos_per_token = 4; result = llm.build_qwen2vl(); } break; case LLM_ARCH_QWEN2MOE: