correcting vision-rope behavior, add the missing last layer back to ViT
This commit is contained in:
parent
bcd49f5984
commit
023f0076e0
6 changed files with 155 additions and 55 deletions
|
@ -623,13 +623,13 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
const int patches_w = image_size_width / patch_size;
|
const int patches_w = image_size_width / patch_size;
|
||||||
const int patches_h = image_size_height / patch_size;
|
const int patches_h = image_size_height / patch_size;
|
||||||
const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0);
|
const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0);
|
||||||
const int num_position_ids = ctx->has_qwen2vl_merger ? num_positions * 3 : num_positions;
|
const int num_position_ids = ctx->has_qwen2vl_merger ? num_positions * 4 : num_positions;
|
||||||
const int hidden_size = hparams.hidden_size;
|
const int hidden_size = hparams.hidden_size;
|
||||||
const int n_head = hparams.n_head;
|
const int n_head = hparams.n_head;
|
||||||
const int d_head = hidden_size / n_head;
|
const int d_head = hidden_size / n_head;
|
||||||
int n_layer = hparams.n_layer;
|
int n_layer = hparams.n_layer;
|
||||||
const float eps = hparams.eps;
|
const float eps = hparams.eps;
|
||||||
int mrope_sections[3] = {d_head/4, d_head/4, 0};
|
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
|
||||||
|
|
||||||
const int batch_size = imgs->size;
|
const int batch_size = imgs->size;
|
||||||
|
|
||||||
|
@ -734,7 +734,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
}
|
}
|
||||||
|
|
||||||
// loop over layers
|
// loop over layers
|
||||||
if (ctx->has_minicpmv_projector) {
|
if (ctx->has_minicpmv_projector || ctx->has_qwen2vl_merger) {
|
||||||
|
// TODO: figure out why we doing thing in this way ???
|
||||||
n_layer += 1;
|
n_layer += 1;
|
||||||
}
|
}
|
||||||
for (int il = 0; il < n_layer - 1; il++) {
|
for (int il = 0; il < n_layer - 1; il++) {
|
||||||
|
@ -829,6 +830,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||||
embeddings = cur;
|
embeddings = cur;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_build_forward_expand(gf, embeddings);
|
// ggml_build_forward_expand(gf, embeddings);
|
||||||
// ggml_free(ctx0);
|
// ggml_free(ctx0);
|
||||||
|
|
||||||
|
@ -2583,16 +2585,18 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||||
const int ph = image_size_height / patch_size;
|
const int ph = image_size_height / patch_size;
|
||||||
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
||||||
|
|
||||||
int ptr = -1;
|
int ptr = 0;
|
||||||
for (size_t y = 0; y < ph; y+=2)
|
for (size_t y = 0; y < ph; y+=2)
|
||||||
{
|
{
|
||||||
for (size_t x = 0; x < pw; x+=2)
|
for (size_t x = 0; x < pw; x+=2)
|
||||||
{
|
{
|
||||||
for (size_t dy = 0; dy < 2; dy++) {
|
for (size_t dy = 0; dy < 2; dy++) {
|
||||||
for (size_t dx = 0; dx < 2; dx++) {
|
for (size_t dx = 0; dx < 2; dx++) {
|
||||||
positions_data[ptr++] = y + dy;
|
positions_data[ptr] = y + dy;
|
||||||
positions_data[num_patches + ptr] = x + dx;
|
positions_data[num_patches + ptr] = x + dx;
|
||||||
positions_data[num_patches * 2 + ptr] = 0;
|
positions_data[num_patches * 2 + ptr] = y + dy;
|
||||||
|
positions_data[num_patches * 3 + ptr] = x + dx;
|
||||||
|
ptr++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2824,4 +2828,8 @@ bool tmp_clip_image_encode (struct clip_ctx * ctx, int n_threads, float * img, i
|
||||||
// ctx->vision_model.hparams.image_size = h;
|
// ctx->vision_model.hparams.image_size = h;
|
||||||
clip_image_encode(ctx, n_threads, &clip_img, vec);
|
clip_image_encode(ctx, n_threads, &clip_img, vec);
|
||||||
return true;
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void tmp_clip_set_layers (struct clip_ctx * ctx, int layers) {
|
||||||
|
ctx->vision_model.hparams.n_layer = layers;
|
||||||
}
|
}
|
|
@ -92,6 +92,8 @@ CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
|
||||||
CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
|
CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
|
||||||
|
|
||||||
CLIP_API bool tmp_clip_image_encode (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
|
CLIP_API bool tmp_clip_image_encode (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
|
||||||
|
CLIP_API void tmp_clip_set_layers (struct clip_ctx * ctx, int layers);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -24,7 +24,8 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
|
||||||
const int ph = image_size->height / patch_size + (image_size->height % patch_size > 0);
|
const int ph = image_size->height / patch_size + (image_size->height % patch_size > 0);
|
||||||
const int pw = image_size->width / patch_size + (image_size->width % patch_size > 0);
|
const int pw = image_size->width / patch_size + (image_size->width % patch_size > 0);
|
||||||
auto img_tokens = image_embed->n_image_pos;
|
auto img_tokens = image_embed->n_image_pos;
|
||||||
llama_pos mrope_pos[img_tokens * 3];
|
llama_pos mrope_pos[img_tokens * 4];
|
||||||
|
|
||||||
for (size_t y = 0; y < ph; y++)
|
for (size_t y = 0; y < ph; y++)
|
||||||
{
|
{
|
||||||
for (size_t x = 0; x < pw; x++)
|
for (size_t x = 0; x < pw; x++)
|
||||||
|
@ -33,6 +34,7 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
|
||||||
mrope_pos[i] = *st_pos_id;
|
mrope_pos[i] = *st_pos_id;
|
||||||
mrope_pos[i + img_tokens] = *st_pos_id + y;
|
mrope_pos[i + img_tokens] = *st_pos_id + y;
|
||||||
mrope_pos[i + img_tokens * 2] = *st_pos_id + x;
|
mrope_pos[i + img_tokens * 2] = *st_pos_id + x;
|
||||||
|
mrope_pos[i + img_tokens * 3] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
*st_pos_id += std::max(pw, ph);
|
*st_pos_id += std::max(pw, ph);
|
||||||
|
@ -44,10 +46,11 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
|
||||||
n_eval = n_batch;
|
n_eval = n_batch;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_pos batch_mrope_pos[n_eval * 3];
|
llama_pos batch_mrope_pos[n_eval * 4];
|
||||||
memcpy(batch_mrope_pos, &mrope_pos[processed], n_eval * sizeof(llama_pos));
|
memcpy(batch_mrope_pos, &mrope_pos[processed], n_eval * sizeof(llama_pos));
|
||||||
memcpy(&batch_mrope_pos[n_eval], &mrope_pos[img_tokens + processed], n_eval * sizeof(llama_pos));
|
memcpy(&batch_mrope_pos[n_eval * 1], &mrope_pos[img_tokens * 1 + processed], n_eval * sizeof(llama_pos));
|
||||||
memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos));
|
memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos));
|
||||||
|
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
|
||||||
|
|
||||||
llama_batch batch = {
|
llama_batch batch = {
|
||||||
int32_t(n_eval), // n_tokens
|
int32_t(n_eval), // n_tokens
|
||||||
|
@ -82,7 +85,8 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
|
||||||
}
|
}
|
||||||
auto batch = llama_batch_get_one(&tokens[i], n_eval, *n_past, 0);
|
auto batch = llama_batch_get_one(&tokens[i], n_eval, *n_past, 0);
|
||||||
// TODO: add mrope pos ids somewhere else
|
// TODO: add mrope pos ids somewhere else
|
||||||
pos.resize(batch.n_tokens * 3);
|
pos.resize(batch.n_tokens * 4);
|
||||||
|
std::fill(pos.begin(), pos.end(), 0);
|
||||||
for (int j = 0; j < batch.n_tokens * 3; j ++) {
|
for (int j = 0; j < batch.n_tokens * 3; j ++) {
|
||||||
pos[j] = *st_pos_id + (j % batch.n_tokens);
|
pos[j] = *st_pos_id + (j % batch.n_tokens);
|
||||||
}
|
}
|
||||||
|
@ -670,18 +674,22 @@ static void tmp_test_mrope_2d(struct llava_context * ctx_llava, gpt_params * par
|
||||||
std::fill(dummy_q.begin(), dummy_q.end(), 0.1);
|
std::fill(dummy_q.begin(), dummy_q.end(), 0.1);
|
||||||
memcpy(inp_raw->data, dummy_q.data(), 128 * 12 * 30 * ggml_element_size(inp_raw));
|
memcpy(inp_raw->data, dummy_q.data(), 128 * 12 * 30 * ggml_element_size(inp_raw));
|
||||||
|
|
||||||
struct ggml_tensor * pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 30 * 3);
|
struct ggml_tensor * pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 30 * 4);
|
||||||
ggml_set_name(pos, "pos");
|
ggml_set_name(pos, "pos");
|
||||||
ggml_set_input(pos);
|
ggml_set_input(pos);
|
||||||
|
|
||||||
std::vector<int> pos_id;
|
std::vector<int> pos_id;
|
||||||
pos_id.resize(90);
|
pos_id.resize(30 * 4);
|
||||||
for (int i = 0; i < 30; i ++) pos_id[i] = i;
|
for (int i = 0; i < 30; i ++) {
|
||||||
for (int i = 30; i < 60; i ++) pos_id[i] = i - 30;
|
pos_id[i] = i;
|
||||||
for (int i = 60; i < 90; i ++) pos_id[i] = i - 0;
|
pos_id[i + 30] = i + 10;
|
||||||
memcpy(pos->data, pos_id.data(), 90 * ggml_element_size(pos));
|
pos_id[i + 60] = i + 10;
|
||||||
|
pos_id[i + 90] = i + 10;
|
||||||
|
}
|
||||||
|
|
||||||
|
memcpy(pos->data, pos_id.data(), 30 * 4 * ggml_element_size(pos));
|
||||||
|
|
||||||
int sections[3] = {32, 32, 0};
|
int sections[4] = {32, 32, 32, 32};
|
||||||
auto encode = ggml_mrope_ext(
|
auto encode = ggml_mrope_ext(
|
||||||
ctx0, inp_raw, pos, nullptr,
|
ctx0, inp_raw, pos, nullptr,
|
||||||
128/2, sections, LLAMA_ROPE_TYPE_NEOX, 32768, 1000000, 1,
|
128/2, sections, LLAMA_ROPE_TYPE_NEOX, 32768, 1000000, 1,
|
||||||
|
@ -717,15 +725,16 @@ static void tmp_dump_img_embed(struct llava_context * ctx_llava, gpt_params * pa
|
||||||
int ne = n_embd * 4;
|
int ne = n_embd * 4;
|
||||||
float vals[56 * 56 * 3];
|
float vals[56 * 56 * 3];
|
||||||
float embd[ne];
|
float embd[ne];
|
||||||
for (int i = 0; i < 3*56*56; i++)
|
// for (int i = 0; i < 3*56*56; i++)
|
||||||
{
|
|
||||||
vals[i] = 0.1;
|
|
||||||
}
|
|
||||||
// for (int i = 0; i < 56*56; i++)
|
|
||||||
// {
|
// {
|
||||||
// for (int c = 0; c < 3; c++)
|
// vals[i] = 0.1;
|
||||||
// vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56);
|
|
||||||
// }
|
// }
|
||||||
|
for (int i = 0; i < 56*56; i++)
|
||||||
|
{
|
||||||
|
for (int c = 0; c < 3; c++)
|
||||||
|
vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56);
|
||||||
|
}
|
||||||
|
|
||||||
// auto param = &ctx_llava->ctx_clip->vision_model.hparams;
|
// auto param = &ctx_llava->ctx_clip->vision_model.hparams;
|
||||||
tmp_clip_image_encode(ctx_llava->ctx_clip, 16, vals, 56, 56, embd);
|
tmp_clip_image_encode(ctx_llava->ctx_clip, 16, vals, 56, 56, embd);
|
||||||
|
|
||||||
|
@ -760,25 +769,30 @@ static void tmp_dump_img_embed_from_file(struct llava_context * ctx_llava, gpt_p
|
||||||
}
|
}
|
||||||
|
|
||||||
static void tmp_dump_img_mid_embed(struct llava_context * ctx_llava, gpt_params * params) {
|
static void tmp_dump_img_mid_embed(struct llava_context * ctx_llava, gpt_params * params) {
|
||||||
|
int layers = 2;
|
||||||
// auto * image_embed = load_image(ctx_llava, params, "/home/ron/Downloads/gguf/dog.jpeg");
|
// auto * image_embed = load_image(ctx_llava, params, "/home/ron/Downloads/gguf/dog.jpeg");
|
||||||
int n_embd = llama_n_embd(llama_get_model(ctx_llava->ctx_llama));
|
int n_embd = llama_n_embd(llama_get_model(ctx_llava->ctx_llama));
|
||||||
// int ne = n_embd * image_embed->n_image_pos;
|
// int ne = n_embd * image_embed->n_image_pos;
|
||||||
int ne = 1280 * 4 * 4;
|
int ne = 1280 * 4 * 4;
|
||||||
float vals[56 * 56 * 3];
|
float vals[56 * 56 * 3];
|
||||||
float embd[ne];
|
float embd[ne];
|
||||||
for (int i = 0; i < 3*56*56; i++)
|
|
||||||
{
|
// for (int i = 0; i < 3*56*56; i++)
|
||||||
vals[i] = 0.1;
|
|
||||||
}
|
|
||||||
// for (int i = 0; i < 56*56; i++)
|
|
||||||
// {
|
// {
|
||||||
// for (int c = 0; c < 3; c++)
|
// vals[i] = 0.5;
|
||||||
// vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56);
|
|
||||||
// }
|
// }
|
||||||
|
for (int i = 0; i < 56*56; i++)
|
||||||
|
{
|
||||||
|
for (int c = 0; c < 3; c++)
|
||||||
|
vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56);
|
||||||
|
}
|
||||||
// auto param = &ctx_llava->ctx_clip->vision_model.hparams;
|
// auto param = &ctx_llava->ctx_clip->vision_model.hparams;
|
||||||
|
|
||||||
|
|
||||||
|
// tmp_clip_set_layers(ctx_llava->ctx_clip, layers);
|
||||||
tmp_clip_image_encode(ctx_llava->ctx_clip, 16, vals, 56, 56, embd);
|
tmp_clip_image_encode(ctx_llava->ctx_clip, 16, vals, 56, 56, embd);
|
||||||
|
|
||||||
std::ofstream outFile("img_layer_1_embed.bin", std::ios::binary);
|
std::ofstream outFile("img_layer_" + std::to_string(layers) + "_embed.bin", std::ios::binary);
|
||||||
if (outFile.is_open()) {
|
if (outFile.is_open()) {
|
||||||
outFile.write(reinterpret_cast<const char*>(embd), ne * sizeof(float));
|
outFile.write(reinterpret_cast<const char*>(embd), ne * sizeof(float));
|
||||||
|
|
||||||
|
@ -819,6 +833,33 @@ static void tmp_dump_patch_embed(struct llava_context * ctx_llava, gpt_params *
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static llava_image_embed * tmp_load_img_embed() {
|
||||||
|
std::ifstream inputFile("/home/ron/Projects/llm2vec/hf_img_embed_f.bin", std::ios::binary);
|
||||||
|
|
||||||
|
if (!inputFile) {
|
||||||
|
std::cerr << "Could not open the file!" << std::endl;
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine the size of the file
|
||||||
|
inputFile.seekg(0, std::ios::end);
|
||||||
|
std::streamsize fileSize = inputFile.tellg();
|
||||||
|
inputFile.seekg(0, std::ios::beg);
|
||||||
|
|
||||||
|
static llava_image_embed * result = (llava_image_embed*)malloc(sizeof(llava_image_embed));
|
||||||
|
result->embed = (float*)malloc(fileSize);
|
||||||
|
result->n_image_pos = 24 * 36 /4;
|
||||||
|
|
||||||
|
// Assuming the binary file contains floating-point numbers (float)
|
||||||
|
std::size_t numElements = fileSize / sizeof(float);
|
||||||
|
inputFile.read(reinterpret_cast<char*>(result->embed), fileSize);
|
||||||
|
inputFile.close();
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
-----------------------------------------------------------------------------------------------------------------
|
-----------------------------------------------------------------------------------------------------------------
|
||||||
*/
|
*/
|
||||||
|
@ -861,6 +902,16 @@ int main(int argc, char ** argv) {
|
||||||
// This section is for testing LLM parts of the model during development phase!
|
// This section is for testing LLM parts of the model during development phase!
|
||||||
auto ctx_llava = llava_init_context(¶ms, model);
|
auto ctx_llava = llava_init_context(¶ms, model);
|
||||||
|
|
||||||
|
// {
|
||||||
|
// auto img_embed = tmp_load_img_embed();
|
||||||
|
// struct clip_image_size * load_image_size = clip_image_size_init();
|
||||||
|
// load_image_size->height = 336;
|
||||||
|
// load_image_size->width = 504;
|
||||||
|
// clip_add_load_image_size(ctx_llava->ctx_clip, load_image_size);
|
||||||
|
// process_prompt(ctx_llava, img_embed, ¶ms, params.prompt);
|
||||||
|
// llava_image_embed_free(img_embed);
|
||||||
|
// }
|
||||||
|
|
||||||
// process the prompt
|
// process the prompt
|
||||||
tmp_dump_img_embed(ctx_llava, ¶ms);
|
tmp_dump_img_embed(ctx_llava, ¶ms);
|
||||||
// tmp_dump_img_embed_from_file(ctx_llava, ¶ms);
|
// tmp_dump_img_embed_from_file(ctx_llava, ¶ms);
|
||||||
|
|
|
@ -1451,7 +1451,7 @@ extern "C" {
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
struct ggml_tensor * c,
|
struct ggml_tensor * c,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int sections[3],
|
int sections[4],
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx_orig,
|
int n_ctx_orig,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
|
|
|
@ -3559,7 +3559,7 @@ struct ggml_tensor * ggml_mrope_ext(
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
struct ggml_tensor * c,
|
struct ggml_tensor * c,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int sections[3],
|
int sections[4],
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx_orig,
|
int n_ctx_orig,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
|
@ -3573,7 +3573,7 @@ struct ggml_tensor * ggml_mrope_ext(
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_vector(b));
|
GGML_ASSERT(ggml_is_vector(b));
|
||||||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
||||||
GGML_ASSERT(a->ne[2] * 3 == b->ne[0]);
|
GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
|
||||||
|
|
||||||
if (c) {
|
if (c) {
|
||||||
GGML_ASSERT(c->type == GGML_TYPE_F32);
|
GGML_ASSERT(c->type == GGML_TYPE_F32);
|
||||||
|
@ -3588,14 +3588,14 @@ struct ggml_tensor * ggml_mrope_ext(
|
||||||
|
|
||||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||||
|
|
||||||
int32_t params[11 + 3] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
|
||||||
memcpy(params + 5, &freq_base, sizeof(float));
|
memcpy(params + 5, &freq_base, sizeof(float));
|
||||||
memcpy(params + 6, &freq_scale, sizeof(float));
|
memcpy(params + 6, &freq_scale, sizeof(float));
|
||||||
memcpy(params + 7, &ext_factor, sizeof(float));
|
memcpy(params + 7, &ext_factor, sizeof(float));
|
||||||
memcpy(params + 8, &attn_factor, sizeof(float));
|
memcpy(params + 8, &attn_factor, sizeof(float));
|
||||||
memcpy(params + 9, &beta_fast, sizeof(float));
|
memcpy(params + 9, &beta_fast, sizeof(float));
|
||||||
memcpy(params + 10, &beta_slow, sizeof(float));
|
memcpy(params + 10, &beta_slow, sizeof(float));
|
||||||
memcpy(¶ms[11], sections, sizeof(int)*3);
|
memcpy(¶ms[11], sections, sizeof(int)*4);
|
||||||
// memcpy(params + 11, sections, sizeof(int)*3);
|
// memcpy(params + 11, sections, sizeof(int)*3);
|
||||||
ggml_set_op_params(result, params, sizeof(params));
|
ggml_set_op_params(result, params, sizeof(params));
|
||||||
|
|
||||||
|
@ -11238,15 +11238,17 @@ static void ggml_rope_cache_init(
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_mrope_cache_init(
|
static void ggml_mrope_cache_init(
|
||||||
float theta_base_t, float theta_base_h, float theta_base_w, int sections[3], bool indep_sects,
|
float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[3], bool indep_sects,
|
||||||
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
|
||||||
float * cache, float sin_sign, float theta_scale) {
|
float * cache, float sin_sign, float theta_scale) {
|
||||||
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
||||||
float theta_t = theta_base_t;
|
float theta_t = theta_base_t;
|
||||||
float theta_h = theta_base_h;
|
float theta_h = theta_base_h;
|
||||||
float theta_w = theta_base_w;
|
float theta_w = theta_base_w;
|
||||||
int sect_dims = sections[0] + sections[1] + sections[2];
|
float theta_e = theta_base_e; // extra position id for vision encoder
|
||||||
int prev_sector = -1;
|
int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
|
||||||
|
int sec_w = sections[1] + sections[0];
|
||||||
|
GGML_ASSERT(sect_dims <= ne0);
|
||||||
|
|
||||||
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||||
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
|
const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
|
||||||
|
@ -11262,15 +11264,21 @@ static void ggml_mrope_cache_init(
|
||||||
else if (sector == sections[1]) {
|
else if (sector == sections[1]) {
|
||||||
theta_w = theta_base_w;
|
theta_w = theta_base_w;
|
||||||
}
|
}
|
||||||
|
else if (sector == sections[2]) {
|
||||||
|
theta_e = theta_base_e;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
float theta = theta_t;
|
float theta = theta_t;
|
||||||
if (sector < sections[1] + sections[0] && sector >= sections[0]) {
|
if (sector >= sections[0] && sector < sec_w) {
|
||||||
theta = theta_h;
|
theta = theta_h;
|
||||||
}
|
}
|
||||||
else if (sector >= sections[1] + sections[0]) {
|
else if (sector >= sec_w && sector < sec_w + sections[2]) {
|
||||||
theta = theta_w;
|
theta = theta_w;
|
||||||
}
|
}
|
||||||
|
else if (sector >= sec_w + sections[2]) {
|
||||||
|
theta = theta_e;
|
||||||
|
}
|
||||||
|
|
||||||
rope_yarn(
|
rope_yarn(
|
||||||
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
|
theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
|
||||||
|
@ -11280,7 +11288,7 @@ static void ggml_mrope_cache_init(
|
||||||
theta_t *= theta_scale;
|
theta_t *= theta_scale;
|
||||||
theta_w *= theta_scale;
|
theta_w *= theta_scale;
|
||||||
theta_h *= theta_scale;
|
theta_h *= theta_scale;
|
||||||
prev_sector = sector;
|
theta_e *= theta_scale;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11304,7 +11312,7 @@ static void ggml_compute_forward_rope_f32(
|
||||||
const struct ggml_tensor * src2 = dst->src[2];
|
const struct ggml_tensor * src2 = dst->src[2];
|
||||||
|
|
||||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||||
int sections[3];
|
int sections[4];
|
||||||
|
|
||||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
|
@ -11318,7 +11326,7 @@ static void ggml_compute_forward_rope_f32(
|
||||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||||
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int) * 3);
|
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int) * 4);
|
||||||
|
|
||||||
GGML_TENSOR_UNARY_OP_LOCALS
|
GGML_TENSOR_UNARY_OP_LOCALS
|
||||||
|
|
||||||
|
@ -11352,6 +11360,11 @@ static void ggml_compute_forward_rope_f32(
|
||||||
|
|
||||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||||
const bool is_mrope = sections[0] > 0 || sections[1] > 0 || sections[2] > 0;
|
const bool is_mrope = sections[0] > 0 || sections[1] > 0 || sections[2] > 0;
|
||||||
|
const bool is_vision = is_mrope && sections[3] > 0;
|
||||||
|
|
||||||
|
if (is_vision) {
|
||||||
|
GGML_ASSERT(n_dims == ne0/2);
|
||||||
|
}
|
||||||
|
|
||||||
const float * freq_factors = NULL;
|
const float * freq_factors = NULL;
|
||||||
if (src2 != NULL) {
|
if (src2 != NULL) {
|
||||||
|
@ -11379,8 +11392,9 @@ static void ggml_compute_forward_rope_f32(
|
||||||
const int64_t p_t = pos[i2];
|
const int64_t p_t = pos[i2];
|
||||||
const int64_t p_h = pos[i2 + ne2];
|
const int64_t p_h = pos[i2 + ne2];
|
||||||
const int64_t p_w = pos[i2 + ne2 * 2];
|
const int64_t p_w = pos[i2 + ne2 * 2];
|
||||||
|
const int64_t p_e = pos[i2 + ne2 * 3];
|
||||||
ggml_mrope_cache_init(
|
ggml_mrope_cache_init(
|
||||||
p_t, p_h, p_w, sections, sections[2] == 0,
|
p_t, p_h, p_w, p_e, sections, sections[3] != 0,
|
||||||
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11402,6 +11416,22 @@ static void ggml_compute_forward_rope_f32(
|
||||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||||
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
||||||
}
|
}
|
||||||
|
} else if (is_vision){
|
||||||
|
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||||
|
const int64_t ic = i0/2;
|
||||||
|
|
||||||
|
const float cos_theta = cache[i0 + 0];
|
||||||
|
const float sin_theta = cache[i0 + 1];
|
||||||
|
|
||||||
|
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||||
|
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||||
|
|
||||||
|
const float x0 = src[0];
|
||||||
|
const float x1 = src[n_dims];
|
||||||
|
|
||||||
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||||
|
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
||||||
const int64_t ic = i0/2;
|
const int64_t ic = i0/2;
|
||||||
|
@ -11420,12 +11450,21 @@ static void ggml_compute_forward_rope_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_mrope) {
|
if (is_vision) {
|
||||||
// fill the remain channels by repeating 0~n_dims channel
|
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
||||||
for (int64_t i0 = n_dims; i0 < ne0; i0 ++) {
|
const int64_t ic = i0/2;
|
||||||
float * dst_data_0 = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
|
||||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
const float cos_theta = cache[i0 + 0];
|
||||||
dst_data[0] = dst_data_0[i0 % n_dims];
|
const float sin_theta = cache[i0 + 1];
|
||||||
|
|
||||||
|
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||||
|
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||||
|
|
||||||
|
const float x0 = src[0];
|
||||||
|
const float x1 = src[n_dims];
|
||||||
|
|
||||||
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||||
|
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
|
@ -12525,14 +12525,14 @@ struct llm_build_context {
|
||||||
|
|
||||||
// inp_pos - contains the positions
|
// inp_pos - contains the positions
|
||||||
// struct ggml_tensor * inp_pos = build_inp_pos();
|
// struct ggml_tensor * inp_pos = build_inp_pos();
|
||||||
lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens * 3);
|
lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens * 4);
|
||||||
cb(lctx.inp_pos, "inp_pos", -1);
|
cb(lctx.inp_pos, "inp_pos", -1);
|
||||||
ggml_set_input(lctx.inp_pos);
|
ggml_set_input(lctx.inp_pos);
|
||||||
struct ggml_tensor * inp_pos = lctx.inp_pos;
|
struct ggml_tensor * inp_pos = lctx.inp_pos;
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||||
int sections[3] = {16, 24, 24}; // TODO: move this into gguf model file.
|
int sections[4] = {16, 24, 24, 0}; // TODO: move this into gguf model file.
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
struct ggml_tensor * inpSA = inpL;
|
struct ggml_tensor * inpSA = inpL;
|
||||||
|
@ -16878,7 +16878,7 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_QWEN2VL:
|
case LLM_ARCH_QWEN2VL:
|
||||||
{
|
{
|
||||||
lctx.n_pos_per_token = 3;
|
lctx.n_pos_per_token = 4;
|
||||||
result = llm.build_qwen2vl();
|
result = llm.build_qwen2vl();
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_QWEN2MOE:
|
case LLM_ARCH_QWEN2MOE:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue