correcting vision-rope behavior, add the missing last layer back to ViT

This commit is contained in:
HimariO 2024-10-20 21:42:53 +08:00
parent bcd49f5984
commit 023f0076e0
6 changed files with 155 additions and 55 deletions

View file

@ -623,13 +623,13 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
const int patches_w = image_size_width / patch_size; const int patches_w = image_size_width / patch_size;
const int patches_h = image_size_height / patch_size; const int patches_h = image_size_height / patch_size;
const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0);
const int num_position_ids = ctx->has_qwen2vl_merger ? num_positions * 3 : num_positions; const int num_position_ids = ctx->has_qwen2vl_merger ? num_positions * 4 : num_positions;
const int hidden_size = hparams.hidden_size; const int hidden_size = hparams.hidden_size;
const int n_head = hparams.n_head; const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head; const int d_head = hidden_size / n_head;
int n_layer = hparams.n_layer; int n_layer = hparams.n_layer;
const float eps = hparams.eps; const float eps = hparams.eps;
int mrope_sections[3] = {d_head/4, d_head/4, 0}; int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
const int batch_size = imgs->size; const int batch_size = imgs->size;
@ -734,7 +734,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
} }
// loop over layers // loop over layers
if (ctx->has_minicpmv_projector) { if (ctx->has_minicpmv_projector || ctx->has_qwen2vl_merger) {
// TODO: figure out why we doing thing in this way ???
n_layer += 1; n_layer += 1;
} }
for (int il = 0; il < n_layer - 1; il++) { for (int il = 0; il < n_layer - 1; il++) {
@ -829,6 +830,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings = cur; embeddings = cur;
} }
// ggml_build_forward_expand(gf, embeddings); // ggml_build_forward_expand(gf, embeddings);
// ggml_free(ctx0); // ggml_free(ctx0);
@ -2583,16 +2585,18 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
const int ph = image_size_height / patch_size; const int ph = image_size_height / patch_size;
int* positions_data = (int*)malloc(ggml_nbytes(positions)); int* positions_data = (int*)malloc(ggml_nbytes(positions));
int ptr = -1; int ptr = 0;
for (size_t y = 0; y < ph; y+=2) for (size_t y = 0; y < ph; y+=2)
{ {
for (size_t x = 0; x < pw; x+=2) for (size_t x = 0; x < pw; x+=2)
{ {
for (size_t dy = 0; dy < 2; dy++) { for (size_t dy = 0; dy < 2; dy++) {
for (size_t dx = 0; dx < 2; dx++) { for (size_t dx = 0; dx < 2; dx++) {
positions_data[ptr++] = y + dy; positions_data[ptr] = y + dy;
positions_data[num_patches + ptr] = x + dx; positions_data[num_patches + ptr] = x + dx;
positions_data[num_patches * 2 + ptr] = 0; positions_data[num_patches * 2 + ptr] = y + dy;
positions_data[num_patches * 3 + ptr] = x + dx;
ptr++;
} }
} }
} }
@ -2824,4 +2828,8 @@ bool tmp_clip_image_encode (struct clip_ctx * ctx, int n_threads, float * img, i
// ctx->vision_model.hparams.image_size = h; // ctx->vision_model.hparams.image_size = h;
clip_image_encode(ctx, n_threads, &clip_img, vec); clip_image_encode(ctx, n_threads, &clip_img, vec);
return true; return true;
}
void tmp_clip_set_layers (struct clip_ctx * ctx, int layers) {
ctx->vision_model.hparams.n_layer = layers;
} }

View file

@ -92,6 +92,8 @@ CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx); CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
CLIP_API bool tmp_clip_image_encode (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); CLIP_API bool tmp_clip_image_encode (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
CLIP_API void tmp_clip_set_layers (struct clip_ctx * ctx, int layers);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View file

@ -24,7 +24,8 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
const int ph = image_size->height / patch_size + (image_size->height % patch_size > 0); const int ph = image_size->height / patch_size + (image_size->height % patch_size > 0);
const int pw = image_size->width / patch_size + (image_size->width % patch_size > 0); const int pw = image_size->width / patch_size + (image_size->width % patch_size > 0);
auto img_tokens = image_embed->n_image_pos; auto img_tokens = image_embed->n_image_pos;
llama_pos mrope_pos[img_tokens * 3]; llama_pos mrope_pos[img_tokens * 4];
for (size_t y = 0; y < ph; y++) for (size_t y = 0; y < ph; y++)
{ {
for (size_t x = 0; x < pw; x++) for (size_t x = 0; x < pw; x++)
@ -33,6 +34,7 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
mrope_pos[i] = *st_pos_id; mrope_pos[i] = *st_pos_id;
mrope_pos[i + img_tokens] = *st_pos_id + y; mrope_pos[i + img_tokens] = *st_pos_id + y;
mrope_pos[i + img_tokens * 2] = *st_pos_id + x; mrope_pos[i + img_tokens * 2] = *st_pos_id + x;
mrope_pos[i + img_tokens * 3] = 0;
} }
} }
*st_pos_id += std::max(pw, ph); *st_pos_id += std::max(pw, ph);
@ -44,10 +46,11 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
n_eval = n_batch; n_eval = n_batch;
} }
llama_pos batch_mrope_pos[n_eval * 3]; llama_pos batch_mrope_pos[n_eval * 4];
memcpy(batch_mrope_pos, &mrope_pos[processed], n_eval * sizeof(llama_pos)); memcpy(batch_mrope_pos, &mrope_pos[processed], n_eval * sizeof(llama_pos));
memcpy(&batch_mrope_pos[n_eval], &mrope_pos[img_tokens + processed], n_eval * sizeof(llama_pos)); memcpy(&batch_mrope_pos[n_eval * 1], &mrope_pos[img_tokens * 1 + processed], n_eval * sizeof(llama_pos));
memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos)); memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos));
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
llama_batch batch = { llama_batch batch = {
int32_t(n_eval), // n_tokens int32_t(n_eval), // n_tokens
@ -82,7 +85,8 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
} }
auto batch = llama_batch_get_one(&tokens[i], n_eval, *n_past, 0); auto batch = llama_batch_get_one(&tokens[i], n_eval, *n_past, 0);
// TODO: add mrope pos ids somewhere else // TODO: add mrope pos ids somewhere else
pos.resize(batch.n_tokens * 3); pos.resize(batch.n_tokens * 4);
std::fill(pos.begin(), pos.end(), 0);
for (int j = 0; j < batch.n_tokens * 3; j ++) { for (int j = 0; j < batch.n_tokens * 3; j ++) {
pos[j] = *st_pos_id + (j % batch.n_tokens); pos[j] = *st_pos_id + (j % batch.n_tokens);
} }
@ -670,18 +674,22 @@ static void tmp_test_mrope_2d(struct llava_context * ctx_llava, gpt_params * par
std::fill(dummy_q.begin(), dummy_q.end(), 0.1); std::fill(dummy_q.begin(), dummy_q.end(), 0.1);
memcpy(inp_raw->data, dummy_q.data(), 128 * 12 * 30 * ggml_element_size(inp_raw)); memcpy(inp_raw->data, dummy_q.data(), 128 * 12 * 30 * ggml_element_size(inp_raw));
struct ggml_tensor * pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 30 * 3); struct ggml_tensor * pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 30 * 4);
ggml_set_name(pos, "pos"); ggml_set_name(pos, "pos");
ggml_set_input(pos); ggml_set_input(pos);
std::vector<int> pos_id; std::vector<int> pos_id;
pos_id.resize(90); pos_id.resize(30 * 4);
for (int i = 0; i < 30; i ++) pos_id[i] = i; for (int i = 0; i < 30; i ++) {
for (int i = 30; i < 60; i ++) pos_id[i] = i - 30; pos_id[i] = i;
for (int i = 60; i < 90; i ++) pos_id[i] = i - 0; pos_id[i + 30] = i + 10;
memcpy(pos->data, pos_id.data(), 90 * ggml_element_size(pos)); pos_id[i + 60] = i + 10;
pos_id[i + 90] = i + 10;
}
memcpy(pos->data, pos_id.data(), 30 * 4 * ggml_element_size(pos));
int sections[3] = {32, 32, 0}; int sections[4] = {32, 32, 32, 32};
auto encode = ggml_mrope_ext( auto encode = ggml_mrope_ext(
ctx0, inp_raw, pos, nullptr, ctx0, inp_raw, pos, nullptr,
128/2, sections, LLAMA_ROPE_TYPE_NEOX, 32768, 1000000, 1, 128/2, sections, LLAMA_ROPE_TYPE_NEOX, 32768, 1000000, 1,
@ -717,15 +725,16 @@ static void tmp_dump_img_embed(struct llava_context * ctx_llava, gpt_params * pa
int ne = n_embd * 4; int ne = n_embd * 4;
float vals[56 * 56 * 3]; float vals[56 * 56 * 3];
float embd[ne]; float embd[ne];
for (int i = 0; i < 3*56*56; i++) // for (int i = 0; i < 3*56*56; i++)
{
vals[i] = 0.1;
}
// for (int i = 0; i < 56*56; i++)
// { // {
// for (int c = 0; c < 3; c++) // vals[i] = 0.1;
// vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56);
// } // }
for (int i = 0; i < 56*56; i++)
{
for (int c = 0; c < 3; c++)
vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56);
}
// auto param = &ctx_llava->ctx_clip->vision_model.hparams; // auto param = &ctx_llava->ctx_clip->vision_model.hparams;
tmp_clip_image_encode(ctx_llava->ctx_clip, 16, vals, 56, 56, embd); tmp_clip_image_encode(ctx_llava->ctx_clip, 16, vals, 56, 56, embd);
@ -760,25 +769,30 @@ static void tmp_dump_img_embed_from_file(struct llava_context * ctx_llava, gpt_p
} }
static void tmp_dump_img_mid_embed(struct llava_context * ctx_llava, gpt_params * params) { static void tmp_dump_img_mid_embed(struct llava_context * ctx_llava, gpt_params * params) {
int layers = 2;
// auto * image_embed = load_image(ctx_llava, params, "/home/ron/Downloads/gguf/dog.jpeg"); // auto * image_embed = load_image(ctx_llava, params, "/home/ron/Downloads/gguf/dog.jpeg");
int n_embd = llama_n_embd(llama_get_model(ctx_llava->ctx_llama)); int n_embd = llama_n_embd(llama_get_model(ctx_llava->ctx_llama));
// int ne = n_embd * image_embed->n_image_pos; // int ne = n_embd * image_embed->n_image_pos;
int ne = 1280 * 4 * 4; int ne = 1280 * 4 * 4;
float vals[56 * 56 * 3]; float vals[56 * 56 * 3];
float embd[ne]; float embd[ne];
for (int i = 0; i < 3*56*56; i++)
{ // for (int i = 0; i < 3*56*56; i++)
vals[i] = 0.1;
}
// for (int i = 0; i < 56*56; i++)
// { // {
// for (int c = 0; c < 3; c++) // vals[i] = 0.5;
// vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56);
// } // }
for (int i = 0; i < 56*56; i++)
{
for (int c = 0; c < 3; c++)
vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56);
}
// auto param = &ctx_llava->ctx_clip->vision_model.hparams; // auto param = &ctx_llava->ctx_clip->vision_model.hparams;
// tmp_clip_set_layers(ctx_llava->ctx_clip, layers);
tmp_clip_image_encode(ctx_llava->ctx_clip, 16, vals, 56, 56, embd); tmp_clip_image_encode(ctx_llava->ctx_clip, 16, vals, 56, 56, embd);
std::ofstream outFile("img_layer_1_embed.bin", std::ios::binary); std::ofstream outFile("img_layer_" + std::to_string(layers) + "_embed.bin", std::ios::binary);
if (outFile.is_open()) { if (outFile.is_open()) {
outFile.write(reinterpret_cast<const char*>(embd), ne * sizeof(float)); outFile.write(reinterpret_cast<const char*>(embd), ne * sizeof(float));
@ -819,6 +833,33 @@ static void tmp_dump_patch_embed(struct llava_context * ctx_llava, gpt_params *
} }
} }
static llava_image_embed * tmp_load_img_embed() {
std::ifstream inputFile("/home/ron/Projects/llm2vec/hf_img_embed_f.bin", std::ios::binary);
if (!inputFile) {
std::cerr << "Could not open the file!" << std::endl;
return NULL;
}
// Determine the size of the file
inputFile.seekg(0, std::ios::end);
std::streamsize fileSize = inputFile.tellg();
inputFile.seekg(0, std::ios::beg);
static llava_image_embed * result = (llava_image_embed*)malloc(sizeof(llava_image_embed));
result->embed = (float*)malloc(fileSize);
result->n_image_pos = 24 * 36 /4;
// Assuming the binary file contains floating-point numbers (float)
std::size_t numElements = fileSize / sizeof(float);
inputFile.read(reinterpret_cast<char*>(result->embed), fileSize);
inputFile.close();
return result;
}
/* /*
----------------------------------------------------------------------------------------------------------------- -----------------------------------------------------------------------------------------------------------------
*/ */
@ -861,6 +902,16 @@ int main(int argc, char ** argv) {
// This section is for testing LLM parts of the model during development phase! // This section is for testing LLM parts of the model during development phase!
auto ctx_llava = llava_init_context(&params, model); auto ctx_llava = llava_init_context(&params, model);
// {
// auto img_embed = tmp_load_img_embed();
// struct clip_image_size * load_image_size = clip_image_size_init();
// load_image_size->height = 336;
// load_image_size->width = 504;
// clip_add_load_image_size(ctx_llava->ctx_clip, load_image_size);
// process_prompt(ctx_llava, img_embed, &params, params.prompt);
// llava_image_embed_free(img_embed);
// }
// process the prompt // process the prompt
tmp_dump_img_embed(ctx_llava, &params); tmp_dump_img_embed(ctx_llava, &params);
// tmp_dump_img_embed_from_file(ctx_llava, &params); // tmp_dump_img_embed_from_file(ctx_llava, &params);

View file

@ -1451,7 +1451,7 @@ extern "C" {
struct ggml_tensor * b, struct ggml_tensor * b,
struct ggml_tensor * c, struct ggml_tensor * c,
int n_dims, int n_dims,
int sections[3], int sections[4],
int mode, int mode,
int n_ctx_orig, int n_ctx_orig,
float freq_base, float freq_base,

View file

@ -3559,7 +3559,7 @@ struct ggml_tensor * ggml_mrope_ext(
struct ggml_tensor * b, struct ggml_tensor * b,
struct ggml_tensor * c, struct ggml_tensor * c,
int n_dims, int n_dims,
int sections[3], int sections[4],
int mode, int mode,
int n_ctx_orig, int n_ctx_orig,
float freq_base, float freq_base,
@ -3573,7 +3573,7 @@ struct ggml_tensor * ggml_mrope_ext(
GGML_ASSERT(ggml_is_vector(b)); GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32); GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] * 3 == b->ne[0]); GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
if (c) { if (c) {
GGML_ASSERT(c->type == GGML_TYPE_F32); GGML_ASSERT(c->type == GGML_TYPE_F32);
@ -3588,14 +3588,14 @@ struct ggml_tensor * ggml_mrope_ext(
struct ggml_tensor * result = ggml_dup_tensor(ctx, a); struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
int32_t params[11 + 3] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
memcpy(params + 5, &freq_base, sizeof(float)); memcpy(params + 5, &freq_base, sizeof(float));
memcpy(params + 6, &freq_scale, sizeof(float)); memcpy(params + 6, &freq_scale, sizeof(float));
memcpy(params + 7, &ext_factor, sizeof(float)); memcpy(params + 7, &ext_factor, sizeof(float));
memcpy(params + 8, &attn_factor, sizeof(float)); memcpy(params + 8, &attn_factor, sizeof(float));
memcpy(params + 9, &beta_fast, sizeof(float)); memcpy(params + 9, &beta_fast, sizeof(float));
memcpy(params + 10, &beta_slow, sizeof(float)); memcpy(params + 10, &beta_slow, sizeof(float));
memcpy(&params[11], sections, sizeof(int)*3); memcpy(&params[11], sections, sizeof(int)*4);
// memcpy(params + 11, sections, sizeof(int)*3); // memcpy(params + 11, sections, sizeof(int)*3);
ggml_set_op_params(result, params, sizeof(params)); ggml_set_op_params(result, params, sizeof(params));
@ -11238,15 +11238,17 @@ static void ggml_rope_cache_init(
} }
static void ggml_mrope_cache_init( static void ggml_mrope_cache_init(
float theta_base_t, float theta_base_h, float theta_base_w, int sections[3], bool indep_sects, float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[3], bool indep_sects,
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
float * cache, float sin_sign, float theta_scale) { float * cache, float sin_sign, float theta_scale) {
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
float theta_t = theta_base_t; float theta_t = theta_base_t;
float theta_h = theta_base_h; float theta_h = theta_base_h;
float theta_w = theta_base_w; float theta_w = theta_base_w;
int sect_dims = sections[0] + sections[1] + sections[2]; float theta_e = theta_base_e; // extra position id for vision encoder
int prev_sector = -1; int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
int sec_w = sections[1] + sections[0];
GGML_ASSERT(sect_dims <= ne0);
for (int64_t i0 = 0; i0 < ne0; i0 += 2) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f; const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
@ -11262,15 +11264,21 @@ static void ggml_mrope_cache_init(
else if (sector == sections[1]) { else if (sector == sections[1]) {
theta_w = theta_base_w; theta_w = theta_base_w;
} }
else if (sector == sections[2]) {
theta_e = theta_base_e;
}
} }
float theta = theta_t; float theta = theta_t;
if (sector < sections[1] + sections[0] && sector >= sections[0]) { if (sector >= sections[0] && sector < sec_w) {
theta = theta_h; theta = theta_h;
} }
else if (sector >= sections[1] + sections[0]) { else if (sector >= sec_w && sector < sec_w + sections[2]) {
theta = theta_w; theta = theta_w;
} }
else if (sector >= sec_w + sections[2]) {
theta = theta_e;
}
rope_yarn( rope_yarn(
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
@ -11280,7 +11288,7 @@ static void ggml_mrope_cache_init(
theta_t *= theta_scale; theta_t *= theta_scale;
theta_w *= theta_scale; theta_w *= theta_scale;
theta_h *= theta_scale; theta_h *= theta_scale;
prev_sector = sector; theta_e *= theta_scale;
} }
} }
@ -11304,7 +11312,7 @@ static void ggml_compute_forward_rope_f32(
const struct ggml_tensor * src2 = dst->src[2]; const struct ggml_tensor * src2 = dst->src[2];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[3]; int sections[4];
//const int n_past = ((int32_t *) dst->op_params)[0]; //const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
@ -11318,7 +11326,7 @@ static void ggml_compute_forward_rope_f32(
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int) * 3); memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int) * 4);
GGML_TENSOR_UNARY_OP_LOCALS GGML_TENSOR_UNARY_OP_LOCALS
@ -11352,6 +11360,11 @@ static void ggml_compute_forward_rope_f32(
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = sections[0] > 0 || sections[1] > 0 || sections[2] > 0; const bool is_mrope = sections[0] > 0 || sections[1] > 0 || sections[2] > 0;
const bool is_vision = is_mrope && sections[3] > 0;
if (is_vision) {
GGML_ASSERT(n_dims == ne0/2);
}
const float * freq_factors = NULL; const float * freq_factors = NULL;
if (src2 != NULL) { if (src2 != NULL) {
@ -11379,8 +11392,9 @@ static void ggml_compute_forward_rope_f32(
const int64_t p_t = pos[i2]; const int64_t p_t = pos[i2];
const int64_t p_h = pos[i2 + ne2]; const int64_t p_h = pos[i2 + ne2];
const int64_t p_w = pos[i2 + ne2 * 2]; const int64_t p_w = pos[i2 + ne2 * 2];
const int64_t p_e = pos[i2 + ne2 * 3];
ggml_mrope_cache_init( ggml_mrope_cache_init(
p_t, p_h, p_w, sections, sections[2] == 0, p_t, p_h, p_w, p_e, sections, sections[3] != 0,
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
} }
@ -11402,6 +11416,22 @@ static void ggml_compute_forward_rope_f32(
dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta; dst_data[1] = x0*sin_theta + x1*cos_theta;
} }
} else if (is_vision){
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
}
} else { } else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2; const int64_t ic = i0/2;
@ -11420,12 +11450,21 @@ static void ggml_compute_forward_rope_f32(
} }
} }
if (is_mrope) { if (is_vision) {
// fill the remain channels by repeating 0~n_dims channel for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
for (int64_t i0 = n_dims; i0 < ne0; i0 ++) { const int64_t ic = i0/2;
float * dst_data_0 = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); const float cos_theta = cache[i0 + 0];
dst_data[0] = dst_data_0[i0 % n_dims]; const float sin_theta = cache[i0 + 1];
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
} }
} }
else { else {

View file

@ -12525,14 +12525,14 @@ struct llm_build_context {
// inp_pos - contains the positions // inp_pos - contains the positions
// struct ggml_tensor * inp_pos = build_inp_pos(); // struct ggml_tensor * inp_pos = build_inp_pos();
lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens * 3); lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens * 4);
cb(lctx.inp_pos, "inp_pos", -1); cb(lctx.inp_pos, "inp_pos", -1);
ggml_set_input(lctx.inp_pos); ggml_set_input(lctx.inp_pos);
struct ggml_tensor * inp_pos = lctx.inp_pos; struct ggml_tensor * inp_pos = lctx.inp_pos;
// KQ_mask (mask for 1 head, it will be broadcasted to all heads) // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
int sections[3] = {16, 24, 24}; // TODO: move this into gguf model file. int sections[4] = {16, 24, 24, 0}; // TODO: move this into gguf model file.
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL; struct ggml_tensor * inpSA = inpL;
@ -16878,7 +16878,7 @@ static struct ggml_cgraph * llama_build_graph(
} break; } break;
case LLM_ARCH_QWEN2VL: case LLM_ARCH_QWEN2VL:
{ {
lctx.n_pos_per_token = 3; lctx.n_pos_per_token = 4;
result = llm.build_qwen2vl(); result = llm.build_qwen2vl();
} break; } break;
case LLM_ARCH_QWEN2MOE: case LLM_ARCH_QWEN2MOE: