diff --git a/Makefile b/Makefile index 6b3416b8b..200aafd3c 100644 --- a/Makefile +++ b/Makefile @@ -1399,7 +1399,7 @@ llama-minicpmv-cli: examples/llava/minicpmv-cli.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) $< $(filter-out %.h $<,$^) -o $@ $(LDFLAGS) -Wno-cast-qual -llama-qwen2vl-cli: examples/llava/minicpmv-cli.cpp \ +llama-qwen2vl-cli: examples/llava/qwen2vl-cli.cpp \ examples/llava/llava.cpp \ examples/llava/llava.h \ examples/llava/clip.cpp \ diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 863d86ea4..33d0c5af3 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -759,7 +759,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size); if (ctx->has_qwen2vl_merger) { - Q = ggml_mrope_ext( + Q = ggml_rope_multi( ctx0, Q, positions, nullptr, d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); } @@ -772,7 +772,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); if (ctx->has_qwen2vl_merger) { - K = ggml_mrope_ext( + K = ggml_rope_multi( ctx0, K, positions, nullptr, d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); } diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index e1c620e15..3525a920f 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1447,7 +1447,7 @@ extern "C" { float beta_fast, float beta_slow); - GGML_API struct ggml_tensor * ggml_mrope_ext( + GGML_API struct ggml_tensor * ggml_rope_multi( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index fb9fcff67..ca03f85b7 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -9317,7 +9317,7 @@ static void ggml_compute_forward_rope_f32( ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; - const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding const bool is_vision = mode == GGML_ROPE_TYPE_VISION; if (is_mrope) { diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index b09ecac19..e1a5361b5 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -114,7 +114,7 @@ static __global__ void rope_neox( } template -static __global__ void rope_mrope( +static __global__ void rope_multi( const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -261,7 +261,7 @@ static void rope_neox_cuda( } template -static void rope_mrope_cuda( +static void rope_multi_cuda( const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); @@ -272,12 +272,12 @@ static void rope_mrope_cuda( const float theta_scale = powf(freq_base, -2.0f/n_dims); if (freq_factors == nullptr) { - rope_mrope<<>>( + rope_multi<<>>( x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections ); } else { - rope_mrope<<>>( + rope_multi<<>>( x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections ); @@ -339,20 +339,20 @@ static void rope_neox_cuda_f32( rope_neox_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } -static void rope_mrope_cuda_f16( +static void rope_multi_cuda_f16( const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream ) { - rope_mrope_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + rope_multi_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); } -static void rope_mrope_cuda_f32( +static void rope_multi_cuda_f32( const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream ) { - rope_mrope_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + rope_multi_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); } static void rope_vision_cuda_f16( @@ -455,12 +455,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { } } else if (is_mrope && !is_vision) { if (src0->type == GGML_TYPE_F32) { - rope_mrope_cuda_f32( + rope_multi_cuda_f32( (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream ); } else if (src0->type == GGML_TYPE_F16) { - rope_mrope_cuda_f16( + rope_multi_cuda_f16( (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream ); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 008022441..26428fc5b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3548,7 +3548,7 @@ struct ggml_tensor * ggml_rope( ); } -struct ggml_tensor * ggml_mrope_ext( +struct ggml_tensor * ggml_rope_multi( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -3563,7 +3563,7 @@ struct ggml_tensor * ggml_mrope_ext( float attn_factor, float beta_fast, float beta_slow) { - + // Multimodal Rotary Position Embedding GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported"); GGML_ASSERT(ggml_is_vector(b)); diff --git a/src/llama.cpp b/src/llama.cpp index d7deaffe0..31eb0a634 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2436,7 +2436,7 @@ struct llama_hparams { float rope_freq_scale_train; uint32_t n_ctx_orig_yarn; float rope_yarn_log_mul; - std::array rope_mrope_sections; + std::array rope_sections; // for State Space Models uint32_t ssm_d_conv = 0; @@ -2493,7 +2493,7 @@ struct llama_hparams { if (this->rope_finetuned != other.rope_finetuned) return true; if (this->n_ctx_orig_yarn != other.n_ctx_orig_yarn) return true; - if (this->rope_mrope_sections != other.rope_mrope_sections) return true; + if (this->rope_sections != other.rope_sections) return true; if (this->ssm_d_conv != other.ssm_d_conv) return true; if (this->ssm_d_inner != other.ssm_d_inner) return true; @@ -5717,8 +5717,8 @@ static void llm_load_hparams( } break; case LLM_ARCH_QWEN2VL: { - std::fill(hparams.rope_mrope_sections.begin(), hparams.rope_mrope_sections.end(), 0); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_mrope_sections, 4, true); + std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); } // fall through case LLM_ARCH_QWEN2: @@ -12543,7 +12543,7 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); int sections[4]; - std::copy(hparams.rope_mrope_sections.begin(), hparams.rope_mrope_sections.end(), sections); + std::copy(hparams.rope_sections.begin(), hparams.rope_sections.end(), sections); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -12572,7 +12572,7 @@ struct llm_build_context { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); - Qcur = ggml_mrope_ext( + Qcur = ggml_rope_multi( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -12580,7 +12580,7 @@ struct llm_build_context { ); cb(Qcur, "Qcur", il); - Kcur = ggml_mrope_ext( + Kcur = ggml_rope_multi( ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow diff --git a/tests/test-rope.cpp b/tests/test-rope.cpp index b138ffb25..b54e3b21e 100644 --- a/tests/test-rope.cpp +++ b/tests/test-rope.cpp @@ -193,20 +193,20 @@ int main(int /*argc*/, const char ** /*argv*/) { // [[100, 101, 102, ..., 172], // [101, 102, 103, ..., 173], // [102, 103, 104, ..., 174]] - r0 = ggml_mrope_ext( + r0 = ggml_rope_multi( ctx0, x, p0, nullptr, n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1); // [[-67, -67, -67, ..., -67] // [-67, -67, -67, ..., -67] // [-67, -67, -67, ..., -67]] - r1 = ggml_mrope_ext( + r1 = ggml_rope_multi( ctx0, r0, p1, nullptr, n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1); // [[33, 34, 35, ..., 105] // [34, 35, 36, ..., 106] // [35, 36, 37, ..., 107]] - r2 = ggml_mrope_ext( + r2 = ggml_rope_multi( ctx0, x, p2, nullptr, n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1); }